"""
Migration Script — Ollama to vLLM.

Handles zero-downtime migration from Ollama local inference
to production vLLM on Kubernetes.

Phases:
1. SHADOW   — Route to Ollama, shadow-test vLLM (log only)
2. CANARY   — Route 10% traffic to vLLM, 90% Ollama
3. SPLIT    — Route 50/50
4. PRIMARY  — Route 90% vLLM, 10% Ollama (fallback)
5. COMPLETE — Route 100% vLLM, decommission Ollama

Usage:
    python -m ai.migrate_ollama_to_vllm --phase shadow
    python -m ai.migrate_ollama_to_vllm --phase canary
    python -m ai.migrate_ollama_to_vllm --phase complete
    python -m ai.migrate_ollama_to_vllm --status
"""

import argparse
import asyncio
import json
import logging
import os
import random
import time
from enum import Enum
from pathlib import Path
from typing import Optional

import httpx

logger = logging.getLogger("fogbreak.migrate")


class MigrationPhase(str, Enum):
    SHADOW = "shadow"
    CANARY = "canary"
    SPLIT = "split"
    PRIMARY = "primary"
    COMPLETE = "complete"


# Traffic split percentages (vLLM share)
PHASE_SPLITS = {
    MigrationPhase.SHADOW: 0.0,     # 0% vLLM, 100% Ollama (shadow only)
    MigrationPhase.CANARY: 0.10,    # 10% vLLM
    MigrationPhase.SPLIT: 0.50,     # 50/50
    MigrationPhase.PRIMARY: 0.90,   # 90% vLLM
    MigrationPhase.COMPLETE: 1.0,   # 100% vLLM
}

STATE_FILE = os.getenv(
    "FOGBREAK_MIGRATION_STATE",
    "/var/lib/fogbreak/migration_state.json",
)

OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434")
VLLM_URL = os.getenv("VLLM_URL", "http://vllm-llama:8000")


class MigrationState:
    """Persisted migration state."""

    def __init__(self):
        self.phase: MigrationPhase = MigrationPhase.SHADOW
        self.started_at: Optional[str] = None
        self.phase_started_at: Optional[str] = None
        self.ollama_requests: int = 0
        self.vllm_requests: int = 0
        self.ollama_errors: int = 0
        self.vllm_errors: int = 0
        self.shadow_mismatches: int = 0
        self.shadow_total: int = 0
        self._load()

    def _load(self):
        path = Path(STATE_FILE)
        if path.exists():
            data = json.loads(path.read_text())
            self.phase = MigrationPhase(data.get("phase", "shadow"))
            self.started_at = data.get("started_at")
            self.phase_started_at = data.get("phase_started_at")
            self.ollama_requests = data.get("ollama_requests", 0)
            self.vllm_requests = data.get("vllm_requests", 0)
            self.ollama_errors = data.get("ollama_errors", 0)
            self.vllm_errors = data.get("vllm_errors", 0)
            self.shadow_mismatches = data.get("shadow_mismatches", 0)
            self.shadow_total = data.get("shadow_total", 0)

    def save(self):
        path = Path(STATE_FILE)
        path.parent.mkdir(parents=True, exist_ok=True)
        path.write_text(json.dumps({
            "phase": self.phase.value,
            "started_at": self.started_at,
            "phase_started_at": self.phase_started_at,
            "ollama_requests": self.ollama_requests,
            "vllm_requests": self.vllm_requests,
            "ollama_errors": self.ollama_errors,
            "vllm_errors": self.vllm_errors,
            "shadow_mismatches": self.shadow_mismatches,
            "shadow_total": self.shadow_total,
        }, indent=2))

    def set_phase(self, phase: MigrationPhase):
        self.phase = phase
        self.phase_started_at = time.strftime("%Y-%m-%dT%H:%M:%SZ")
        if not self.started_at:
            self.started_at = self.phase_started_at
        self.save()
        logger.info(f"Migration phase set to: {phase.value}")


class MigrationRouter:
    """
    Routes inference requests during the Ollama→vLLM migration.
    Supports shadow testing, canary rollout, and full cutover.
    """

    def __init__(self):
        self.state = MigrationState()
        self._client: Optional[httpx.AsyncClient] = None

    async def start(self):
        self._client = httpx.AsyncClient(timeout=120.0)
        logger.info(f"Migration router started in phase: {self.state.phase.value}")

    async def stop(self):
        if self._client:
            await self._client.aclose()

    async def infer(
        self,
        messages: list[dict],
        model: str = "llama3.3:70b",
        max_tokens: int = 4096,
        temperature: float = 0.7,
    ) -> dict:
        """
        Route inference based on current migration phase.
        Returns the response from the selected backend.
        """
        phase = self.state.phase
        vllm_share = PHASE_SPLITS[phase]

        if phase == MigrationPhase.SHADOW:
            # Always use Ollama, shadow-test vLLM in background
            result = await self._ollama_infer(messages, model, max_tokens, temperature)
            asyncio.create_task(
                self._shadow_test(messages, model, max_tokens, temperature, result)
            )
            return result

        if phase == MigrationPhase.COMPLETE:
            return await self._vllm_infer(messages, max_tokens, temperature)

        # Canary / Split / Primary — probabilistic routing
        use_vllm = random.random() < vllm_share

        if use_vllm:
            try:
                return await self._vllm_infer(messages, max_tokens, temperature)
            except Exception as e:
                logger.warning(f"vLLM failed, falling back to Ollama: {e}")
                self.state.vllm_errors += 1
                self.state.save()
                return await self._ollama_infer(messages, model, max_tokens, temperature)
        else:
            return await self._ollama_infer(messages, model, max_tokens, temperature)

    async def _ollama_infer(
        self,
        messages: list[dict],
        model: str,
        max_tokens: int,
        temperature: float,
    ) -> dict:
        """Send request to Ollama."""
        try:
            response = await self._client.post(
                f"{OLLAMA_URL}/api/chat",
                json={
                    "model": model,
                    "messages": messages,
                    "options": {
                        "num_predict": max_tokens,
                        "temperature": temperature,
                    },
                    "stream": False,
                },
            )
            response.raise_for_status()
            data = response.json()
            self.state.ollama_requests += 1
            self.state.save()
            return {
                "content": data["message"]["content"],
                "model": model,
                "backend": "ollama",
            }
        except Exception as e:
            self.state.ollama_errors += 1
            self.state.save()
            raise

    async def _vllm_infer(
        self,
        messages: list[dict],
        max_tokens: int,
        temperature: float,
    ) -> dict:
        """Send request to vLLM (OpenAI-compatible)."""
        try:
            response = await self._client.post(
                f"{VLLM_URL}/v1/chat/completions",
                json={
                    "model": "meta-llama/Llama-3.3-70B",
                    "messages": messages,
                    "max_tokens": max_tokens,
                    "temperature": temperature,
                },
            )
            response.raise_for_status()
            data = response.json()
            self.state.vllm_requests += 1
            self.state.save()
            return {
                "content": data["choices"][0]["message"]["content"],
                "model": data.get("model", "vllm"),
                "backend": "vllm",
            }
        except Exception as e:
            self.state.vllm_errors += 1
            self.state.save()
            raise

    async def _shadow_test(
        self,
        messages: list[dict],
        model: str,
        max_tokens: int,
        temperature: float,
        ollama_result: dict,
    ):
        """Shadow-test vLLM and log differences (no user impact)."""
        self.state.shadow_total += 1
        try:
            vllm_result = await self._vllm_infer(messages, max_tokens, temperature)

            # Compare output lengths as a basic quality check
            ollama_len = len(ollama_result.get("content", ""))
            vllm_len = len(vllm_result.get("content", ""))

            # Flag significant length differences (>50%)
            if ollama_len > 0:
                ratio = abs(vllm_len - ollama_len) / ollama_len
                if ratio > 0.5:
                    self.state.shadow_mismatches += 1
                    logger.info(
                        f"Shadow mismatch: Ollama={ollama_len} chars, "
                        f"vLLM={vllm_len} chars (ratio={ratio:.2f})"
                    )

        except Exception as e:
            self.state.vllm_errors += 1
            logger.warning(f"Shadow test failed: {e}")

        self.state.save()


def print_status(state: MigrationState):
    """Print current migration status."""
    total_requests = state.ollama_requests + state.vllm_requests
    total_errors = state.ollama_errors + state.vllm_errors

    print(f"\n{'='*50}")
    print(f"  Fogbreak Migration: Ollama → vLLM")
    print(f"{'='*50}")
    print(f"  Phase:          {state.phase.value.upper()}")
    print(f"  vLLM share:     {PHASE_SPLITS[state.phase]*100:.0f}%")
    print(f"  Started:        {state.started_at or 'Not started'}")
    print(f"  Phase started:  {state.phase_started_at or 'N/A'}")
    print(f"{'='*50}")
    print(f"  Ollama requests: {state.ollama_requests}")
    print(f"  vLLM requests:   {state.vllm_requests}")
    print(f"  Total requests:  {total_requests}")
    print(f"  Ollama errors:   {state.ollama_errors}")
    print(f"  vLLM errors:     {state.vllm_errors}")
    print(f"  Total errors:    {total_errors}")
    if total_requests > 0:
        print(f"  Error rate:      {total_errors/total_requests*100:.2f}%")
    if state.shadow_total > 0:
        print(f"  Shadow tests:    {state.shadow_total}")
        print(f"  Shadow mismatches: {state.shadow_mismatches}")
        print(f"  Mismatch rate:   {state.shadow_mismatches/state.shadow_total*100:.2f}%")
    print(f"{'='*50}\n")


async def run_health_check():
    """Check connectivity to both backends."""
    async with httpx.AsyncClient(timeout=10.0) as client:
        print("\nHealth Check:")

        # Ollama
        try:
            resp = await client.get(f"{OLLAMA_URL}/api/tags")
            print(f"  Ollama ({OLLAMA_URL}): UP ({resp.status_code})")
        except Exception as e:
            print(f"  Ollama ({OLLAMA_URL}): DOWN ({e})")

        # vLLM
        try:
            resp = await client.get(f"{VLLM_URL}/health")
            print(f"  vLLM ({VLLM_URL}): UP ({resp.status_code})")
        except Exception as e:
            print(f"  vLLM ({VLLM_URL}): DOWN ({e})")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser(description="Fogbreak Ollama→vLLM Migration")
    parser.add_argument(
        "--phase",
        choices=["shadow", "canary", "split", "primary", "complete"],
        help="Set migration phase",
    )
    parser.add_argument("--status", action="store_true", help="Show migration status")
    parser.add_argument("--health", action="store_true", help="Check backend health")
    args = parser.parse_args()

    state = MigrationState()

    if args.status:
        print_status(state)
    elif args.health:
        asyncio.run(run_health_check())
    elif args.phase:
        phase = MigrationPhase(args.phase)
        state.set_phase(phase)
        print_status(state)
    else:
        parser.print_help()
