"""
Fogbreak AI Inference Server — FastAPI proxy for vLLM.

Routes requests to the correct model based on task type.
Provides health checks, metrics, and tenant-aware logging.

Usage:
    uvicorn ai.server:app --host 0.0.0.0 --port 8080 --workers 4
"""

import logging
import os
import time
from contextlib import asynccontextmanager
from typing import Optional

from fastapi import FastAPI, HTTPException, Header, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field

from ai.models.vllm_client import (
    InferenceRequest,
    ModelType,
    VLLMClient,
    create_client,
)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("fogbreak.server")

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------

API_KEY = os.getenv("FOGBREAK_AI_API_KEY", "MTH_PORTAL_2026")
ALLOWED_ORIGINS = os.getenv("FOGBREAK_CORS_ORIGINS", "http://localhost:3000,https://fogbreak.io").split(",")

ENDPOINTS_CONFIG = [
    {
        "name": "llama-reasoning",
        "model_id": os.getenv("VLLM_REASONING_MODEL", "meta-llama/Llama-3.3-70B"),
        "base_url": os.getenv("VLLM_REASONING_URL", "http://vllm-llama:8000"),
        "model_type": "reasoning",
        "max_tokens": 8192,
        "temperature": 0.7,
    },
    {
        "name": "mistral-fast",
        "model_id": os.getenv("VLLM_FAST_MODEL", "mistralai/Mistral-7B-Instruct-v0.3"),
        "base_url": os.getenv("VLLM_FAST_URL", "http://vllm-mistral:8001"),
        "model_type": "fast",
        "max_tokens": 4096,
        "temperature": 0.7,
    },
]

# Add fine-tuned endpoint if configured
FINETUNED_URL = os.getenv("VLLM_FINETUNED_URL")
FINETUNED_MODEL = os.getenv("VLLM_FINETUNED_MODEL")
if FINETUNED_URL and FINETUNED_MODEL:
    ENDPOINTS_CONFIG.append({
        "name": "finetuned",
        "model_id": FINETUNED_MODEL,
        "base_url": FINETUNED_URL,
        "model_type": "finetuned",
        "max_tokens": 4096,
        "temperature": 0.7,
    })

# ---------------------------------------------------------------------------
# App lifecycle
# ---------------------------------------------------------------------------

client: Optional[VLLMClient] = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    global client
    client = create_client(ENDPOINTS_CONFIG)
    await client.start()
    logger.info("Fogbreak AI server started with %d endpoints", len(ENDPOINTS_CONFIG))
    yield
    await client.stop()
    logger.info("Fogbreak AI server stopped")


app = FastAPI(
    title="Fogbreak AI Inference API",
    description="Production inference proxy for Fogbreak real estate platform",
    version="2.0.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=ALLOWED_ORIGINS,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ---------------------------------------------------------------------------
# Auth
# ---------------------------------------------------------------------------


def verify_api_key(authorization: Optional[str] = Header(None)) -> str:
    """Validate API key from Authorization header."""
    if not authorization:
        raise HTTPException(status_code=401, detail="Missing Authorization header")
    token = authorization.removeprefix("Bearer ").strip()
    if token != API_KEY:
        raise HTTPException(status_code=403, detail="Invalid API key")
    return token


# ---------------------------------------------------------------------------
# Request / Response schemas
# ---------------------------------------------------------------------------


class Message(BaseModel):
    role: str = Field(..., description="Role: system, user, or assistant")
    content: str = Field(..., description="Message content")


class ChatRequest(BaseModel):
    tenant_id: Optional[int] = Field(None, description="Tenant ID for isolation")
    model_type: str = Field("fast", description="Model type: reasoning, fast, or finetuned")
    messages: list[Message] = Field(..., description="Chat messages")
    max_tokens: int = Field(4096, ge=1, le=16384)
    temperature: float = Field(0.7, ge=0.0, le=2.0)
    top_p: float = Field(0.95, ge=0.0, le=1.0)
    stream: bool = Field(False, description="Stream response tokens")
    stop: Optional[list[str]] = Field(None, description="Stop sequences")


class ChatResponse(BaseModel):
    content: str
    model: str
    usage: dict
    latency_ms: float
    tenant_id: Optional[int]
    finish_reason: str


class HealthResponse(BaseModel):
    status: str
    endpoints: dict
    uptime_seconds: float


# ---------------------------------------------------------------------------
# Task-to-model routing
# ---------------------------------------------------------------------------

TASK_MODEL_MAP = {
    # Reasoning tasks → Llama 3.3 70B
    "listing_description": ModelType.REASONING,
    "market_analysis": ModelType.REASONING,
    "contract_review": ModelType.REASONING,
    "compliance_check": ModelType.REASONING,
    "coaching_insight": ModelType.REASONING,
    "lead_scoring": ModelType.REASONING,
    "property_matching": ModelType.REASONING,
    # Fast tasks → Mistral 7B
    "chatbot": ModelType.FAST,
    "email_draft": ModelType.FAST,
    "social_post": ModelType.FAST,
    "showing_summary": ModelType.FAST,
    "quick_reply": ModelType.FAST,
    "fair_housing_check": ModelType.FAST,
    "notification_text": ModelType.FAST,
}

_start_time = time.monotonic()

# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------


@app.post("/v1/chat/completions", response_model=ChatResponse)
async def chat_completions(
    req: ChatRequest,
    authorization: Optional[str] = Header(None),
):
    """
    OpenAI-compatible chat completions endpoint.
    Routes to the appropriate vLLM model based on model_type.
    """
    verify_api_key(authorization)

    try:
        model_type = ModelType(req.model_type)
    except ValueError:
        raise HTTPException(
            status_code=400,
            detail=f"Invalid model_type: {req.model_type}. Use: reasoning, fast, finetuned",
        )

    inference_req = InferenceRequest(
        tenant_id=req.tenant_id,
        model_type=model_type,
        messages=[{"role": m.role, "content": m.content} for m in req.messages],
        max_tokens=req.max_tokens,
        temperature=req.temperature,
        top_p=req.top_p,
        stream=req.stream,
        stop=req.stop,
    )

    if req.stream:
        return StreamingResponse(
            _stream_response(inference_req),
            media_type="text/event-stream",
        )

    result = await client.infer(inference_req)

    return ChatResponse(
        content=result.content,
        model=result.model,
        usage=result.usage,
        latency_ms=round(result.latency_ms, 2),
        tenant_id=result.tenant_id,
        finish_reason=result.finish_reason,
    )


async def _stream_response(req: InferenceRequest):
    """Generator for SSE streaming responses."""
    import json
    async for token in client.stream(req):
        yield f"data: {json.dumps({'content': token})}\n\n"
    yield "data: [DONE]\n\n"


@app.post("/v1/task", response_model=ChatResponse)
async def task_inference(
    req: ChatRequest,
    task: str = "chatbot",
    authorization: Optional[str] = Header(None),
):
    """
    Task-based routing endpoint.
    Automatically selects the best model for the given task type.
    """
    verify_api_key(authorization)

    model_type = TASK_MODEL_MAP.get(task, ModelType.FAST)

    inference_req = InferenceRequest(
        tenant_id=req.tenant_id,
        model_type=model_type,
        messages=[{"role": m.role, "content": m.content} for m in req.messages],
        max_tokens=req.max_tokens,
        temperature=req.temperature,
        top_p=req.top_p,
        metadata={"task": task},
    )

    result = await client.infer(inference_req)

    return ChatResponse(
        content=result.content,
        model=result.model,
        usage=result.usage,
        latency_ms=round(result.latency_ms, 2),
        tenant_id=result.tenant_id,
        finish_reason=result.finish_reason,
    )


@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check for load balancer and Kubernetes probes."""
    endpoints_status = {}
    all_healthy = False

    if client:
        for name, ep in client.endpoints.items():
            endpoints_status[name] = {
                "healthy": ep.healthy,
                "model": ep.model_id,
                "type": ep.model_type.value,
                "requests": ep.request_count,
                "errors": ep.error_count,
                "avg_latency_ms": round(ep.avg_latency, 2),
            }
        all_healthy = any(ep.healthy for ep in client.endpoints.values())

    status = "healthy" if all_healthy else "degraded"
    uptime = time.monotonic() - _start_time

    return HealthResponse(
        status=status,
        endpoints=endpoints_status,
        uptime_seconds=round(uptime, 2),
    )


@app.get("/metrics")
async def prometheus_metrics():
    """Prometheus-compatible metrics endpoint."""
    if not client:
        return {"error": "Client not initialized"}

    metrics = client.get_metrics()
    lines = []

    # Global counters
    lines.append(f"fogbreak_ai_requests_total {metrics['total_requests']}")
    lines.append(f"fogbreak_ai_errors_total {metrics['total_errors']}")
    lines.append(f"fogbreak_ai_tokens_in_total {metrics['total_tokens_in']}")
    lines.append(f"fogbreak_ai_tokens_out_total {metrics['total_tokens_out']}")

    # Per-endpoint gauges
    for name, ep in metrics.get("endpoints", {}).items():
        labels = f'endpoint="{name}",model="{ep["model_id"]}",type="{ep["model_type"]}"'
        lines.append(f'fogbreak_ai_endpoint_healthy{{{labels}}} {1 if ep["healthy"] else 0}')
        lines.append(f'fogbreak_ai_endpoint_requests{{{labels}}} {ep["request_count"]}')
        lines.append(f'fogbreak_ai_endpoint_errors{{{labels}}} {ep["error_count"]}')
        lines.append(f'fogbreak_ai_endpoint_latency_ms{{{labels}}} {ep["avg_latency_ms"]}')

    return "\n".join(lines) + "\n"


@app.get("/models")
async def list_models(authorization: Optional[str] = Header(None)):
    """List available models and their status."""
    verify_api_key(authorization)

    if not client:
        return {"models": []}

    models = []
    for name, ep in client.endpoints.items():
        models.append({
            "name": name,
            "model_id": ep.model_id,
            "type": ep.model_type.value,
            "healthy": ep.healthy,
            "max_tokens": ep.max_tokens,
        })

    return {"models": models}
