"""
vLLM Client — Production inference client for Fogbreak AI.

Replaces Ollama client with vLLM OpenAI-compatible API.
Supports multi-model routing, health checks, retries, and tenant isolation.
"""

import asyncio
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, AsyncIterator, Optional

import httpx

logger = logging.getLogger("fogbreak.vllm")


class ModelType(str, Enum):
    """Available model types for routing."""
    REASONING = "reasoning"      # Llama 3.3 70B — complex tasks
    FAST = "fast"                # Mistral 7B — chatbot, routine tasks
    FINETUNED = "finetuned"      # Fine-tuned Mistral — tenant-specific


@dataclass
class ModelEndpoint:
    """Configuration for a single vLLM model endpoint."""
    name: str
    model_id: str
    base_url: str
    model_type: ModelType
    max_tokens: int = 4096
    temperature: float = 0.7
    healthy: bool = True
    last_health_check: float = 0.0
    request_count: int = 0
    error_count: int = 0
    total_latency: float = 0.0

    @property
    def avg_latency(self) -> float:
        if self.request_count == 0:
            return 0.0
        return self.total_latency / self.request_count


@dataclass
class InferenceRequest:
    """Standardized inference request across all models."""
    tenant_id: Optional[int]
    model_type: ModelType
    messages: list[dict[str, str]]
    max_tokens: int = 4096
    temperature: float = 0.7
    top_p: float = 0.95
    stream: bool = False
    stop: Optional[list[str]] = None
    metadata: dict[str, Any] = field(default_factory=dict)


@dataclass
class InferenceResponse:
    """Standardized inference response."""
    content: str
    model: str
    usage: dict[str, int]
    latency_ms: float
    tenant_id: Optional[int]
    finish_reason: str = "stop"


class VLLMClient:
    """
    Production vLLM client with:
    - Multi-model routing (reasoning vs fast vs fine-tuned)
    - Health checking with automatic failover
    - Request retries with exponential backoff
    - Per-tenant request tracking
    - Streaming support
    - Prometheus-compatible metrics
    """

    HEALTH_CHECK_INTERVAL = 30.0  # seconds
    REQUEST_TIMEOUT = 120.0       # seconds
    MAX_RETRIES = 3
    RETRY_BASE_DELAY = 1.0       # seconds

    def __init__(self, endpoints: Optional[list[dict]] = None):
        self.endpoints: dict[str, ModelEndpoint] = {}
        self._client: Optional[httpx.AsyncClient] = None
        self._health_task: Optional[asyncio.Task] = None

        # Metrics counters
        self.metrics = {
            "total_requests": 0,
            "total_errors": 0,
            "total_tokens_in": 0,
            "total_tokens_out": 0,
        }

        if endpoints:
            for ep in endpoints:
                self.add_endpoint(**ep)

    def add_endpoint(
        self,
        name: str,
        model_id: str,
        base_url: str,
        model_type: str,
        max_tokens: int = 4096,
        temperature: float = 0.7,
    ) -> None:
        """Register a vLLM model endpoint."""
        self.endpoints[name] = ModelEndpoint(
            name=name,
            model_id=model_id,
            base_url=base_url.rstrip("/"),
            model_type=ModelType(model_type),
            max_tokens=max_tokens,
            temperature=temperature,
        )
        logger.info(f"Registered endpoint: {name} ({model_id}) at {base_url}")

    async def start(self) -> None:
        """Initialize HTTP client and start health checking."""
        self._client = httpx.AsyncClient(
            timeout=httpx.Timeout(self.REQUEST_TIMEOUT, connect=10.0),
            limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
        )
        self._health_task = asyncio.create_task(self._health_check_loop())
        logger.info("vLLM client started")

    async def stop(self) -> None:
        """Graceful shutdown."""
        if self._health_task:
            self._health_task.cancel()
            try:
                await self._health_task
            except asyncio.CancelledError:
                pass
        if self._client:
            await self._client.aclose()
        logger.info("vLLM client stopped")

    def _select_endpoint(self, model_type: ModelType) -> ModelEndpoint:
        """Select the best healthy endpoint for the given model type."""
        candidates = [
            ep for ep in self.endpoints.values()
            if ep.model_type == model_type and ep.healthy
        ]

        if not candidates:
            # Fallback: try any healthy endpoint
            candidates = [ep for ep in self.endpoints.values() if ep.healthy]

        if not candidates:
            raise RuntimeError(
                f"No healthy endpoints available for model type: {model_type.value}"
            )

        # Pick endpoint with lowest average latency
        return min(candidates, key=lambda ep: ep.avg_latency)

    async def infer(self, request: InferenceRequest) -> InferenceResponse:
        """
        Run inference against the appropriate vLLM endpoint.
        Includes retry logic with exponential backoff.
        """
        if not self._client:
            raise RuntimeError("Client not started. Call start() first.")

        self.metrics["total_requests"] += 1
        last_error: Optional[Exception] = None

        for attempt in range(self.MAX_RETRIES):
            endpoint = self._select_endpoint(request.model_type)

            try:
                return await self._send_request(endpoint, request)
            except (httpx.HTTPError, httpx.TimeoutException, RuntimeError) as e:
                last_error = e
                endpoint.error_count += 1
                self.metrics["total_errors"] += 1
                logger.warning(
                    f"Inference attempt {attempt + 1}/{self.MAX_RETRIES} failed "
                    f"on {endpoint.name}: {e}"
                )

                if attempt < self.MAX_RETRIES - 1:
                    delay = self.RETRY_BASE_DELAY * (2 ** attempt)
                    await asyncio.sleep(delay)

        raise RuntimeError(
            f"All {self.MAX_RETRIES} inference attempts failed. Last error: {last_error}"
        )

    async def _send_request(
        self, endpoint: ModelEndpoint, request: InferenceRequest
    ) -> InferenceResponse:
        """Send a single inference request to a vLLM endpoint."""
        url = f"{endpoint.base_url}/v1/chat/completions"

        payload = {
            "model": endpoint.model_id,
            "messages": request.messages,
            "max_tokens": request.max_tokens or endpoint.max_tokens,
            "temperature": request.temperature,
            "top_p": request.top_p,
            "stream": False,
        }
        if request.stop:
            payload["stop"] = request.stop

        start_time = time.monotonic()

        response = await self._client.post(url, json=payload)
        response.raise_for_status()

        latency_ms = (time.monotonic() - start_time) * 1000
        data = response.json()

        # Update endpoint stats
        endpoint.request_count += 1
        endpoint.total_latency += latency_ms

        # Update global metrics
        usage = data.get("usage", {})
        self.metrics["total_tokens_in"] += usage.get("prompt_tokens", 0)
        self.metrics["total_tokens_out"] += usage.get("completion_tokens", 0)

        choice = data["choices"][0]

        return InferenceResponse(
            content=choice["message"]["content"],
            model=data.get("model", endpoint.model_id),
            usage=usage,
            latency_ms=latency_ms,
            tenant_id=request.tenant_id,
            finish_reason=choice.get("finish_reason", "stop"),
        )

    async def stream(
        self, request: InferenceRequest
    ) -> AsyncIterator[str]:
        """Stream inference results token by token."""
        if not self._client:
            raise RuntimeError("Client not started. Call start() first.")

        endpoint = self._select_endpoint(request.model_type)
        url = f"{endpoint.base_url}/v1/chat/completions"

        payload = {
            "model": endpoint.model_id,
            "messages": request.messages,
            "max_tokens": request.max_tokens or endpoint.max_tokens,
            "temperature": request.temperature,
            "top_p": request.top_p,
            "stream": True,
        }
        if request.stop:
            payload["stop"] = request.stop

        async with self._client.stream("POST", url, json=payload) as response:
            response.raise_for_status()
            async for line in response.aiter_lines():
                if not line.startswith("data: "):
                    continue
                data_str = line[6:]
                if data_str.strip() == "[DONE]":
                    break
                import json
                chunk = json.loads(data_str)
                delta = chunk["choices"][0].get("delta", {})
                content = delta.get("content", "")
                if content:
                    yield content

    async def check_health(self, endpoint_name: str) -> bool:
        """Check if a specific endpoint is healthy."""
        if endpoint_name not in self.endpoints:
            return False

        endpoint = self.endpoints[endpoint_name]

        if not self._client:
            return False

        try:
            # vLLM exposes /health endpoint
            response = await self._client.get(
                f"{endpoint.base_url}/health",
                timeout=5.0,
            )
            healthy = response.status_code == 200
            endpoint.healthy = healthy
            endpoint.last_health_check = time.monotonic()
            return healthy
        except (httpx.HTTPError, httpx.TimeoutException):
            endpoint.healthy = False
            endpoint.last_health_check = time.monotonic()
            return False

    async def _health_check_loop(self) -> None:
        """Continuously check endpoint health."""
        while True:
            try:
                await asyncio.sleep(self.HEALTH_CHECK_INTERVAL)
                for name in list(self.endpoints.keys()):
                    healthy = await self.check_health(name)
                    if not healthy:
                        logger.warning(f"Endpoint {name} is unhealthy")
            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f"Health check error: {e}")

    def get_metrics(self) -> dict[str, Any]:
        """Return Prometheus-compatible metrics."""
        endpoint_metrics = {}
        for name, ep in self.endpoints.items():
            endpoint_metrics[name] = {
                "model_id": ep.model_id,
                "model_type": ep.model_type.value,
                "healthy": ep.healthy,
                "request_count": ep.request_count,
                "error_count": ep.error_count,
                "avg_latency_ms": round(ep.avg_latency, 2),
            }

        return {
            **self.metrics,
            "endpoints": endpoint_metrics,
        }


# Default configuration for Fogbreak production
DEFAULT_ENDPOINTS = [
    {
        "name": "llama-reasoning",
        "model_id": "meta-llama/Llama-3.3-70B",
        "base_url": "http://vllm-llama:8000",
        "model_type": "reasoning",
        "max_tokens": 8192,
        "temperature": 0.7,
    },
    {
        "name": "mistral-fast",
        "model_id": "mistralai/Mistral-7B-Instruct-v0.3",
        "base_url": "http://vllm-mistral:8001",
        "model_type": "fast",
        "max_tokens": 4096,
        "temperature": 0.7,
    },
]


def create_client(endpoints: Optional[list[dict]] = None) -> VLLMClient:
    """Factory function to create a configured vLLM client."""
    return VLLMClient(endpoints=endpoints or DEFAULT_ENDPOINTS)
