"""
A/B Test Evaluator — Fogbreak AI.

Compares fine-tuned model outputs against base model outputs.
Tracks agent approval rates to measure fine-tuning quality.

Usage:
    python -m ai.finetune.evaluate \
        --base-url http://vllm-mistral:8001 \
        --finetuned-url http://vllm-finetuned:8002 \
        --data-dir training_data \
        --output results/ab_test.json
"""

import argparse
import asyncio
import json
import logging
import random
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional

import httpx

logger = logging.getLogger("fogbreak.finetune.evaluate")


@dataclass
class ABTestResult:
    """Result of a single A/B test comparison."""
    prompt_id: int
    user_input: str
    base_output: str
    finetuned_output: str
    base_latency_ms: float
    finetuned_latency_ms: float
    base_tokens: int
    finetuned_tokens: int
    preferred: Optional[str] = None  # "base", "finetuned", or None (pending)
    notes: str = ""


async def generate_output(
    client: httpx.AsyncClient,
    base_url: str,
    model_id: str,
    messages: list[dict],
    max_tokens: int = 2048,
) -> tuple[str, float, int]:
    """Generate output from a vLLM endpoint. Returns (text, latency_ms, tokens)."""
    start = time.monotonic()

    response = await client.post(
        f"{base_url}/v1/chat/completions",
        json={
            "model": model_id,
            "messages": messages,
            "max_tokens": max_tokens,
            "temperature": 0.7,
        },
        timeout=120.0,
    )
    response.raise_for_status()

    latency_ms = (time.monotonic() - start) * 1000
    data = response.json()
    content = data["choices"][0]["message"]["content"]
    tokens = data.get("usage", {}).get("completion_tokens", 0)

    return content, latency_ms, tokens


async def run_ab_test(
    base_url: str,
    base_model: str,
    finetuned_url: str,
    finetuned_model: str,
    data_dir: str,
    output_path: str,
    sample_size: int = 50,
):
    """
    Run A/B test comparing base and fine-tuned model outputs.
    Generates side-by-side outputs for human evaluation.
    """
    # Load test prompts from training data (eval split)
    examples = []
    data_path = Path(data_dir)
    for jsonl_file in data_path.glob("*.jsonl"):
        with open(jsonl_file) as f:
            for line in f:
                line = line.strip()
                if line:
                    examples.append(json.loads(line))

    if not examples:
        logger.error("No test data found")
        return

    # Sample for testing
    if len(examples) > sample_size:
        examples = random.sample(examples, sample_size)

    logger.info(f"Running A/B test with {len(examples)} prompts")

    results = []
    async with httpx.AsyncClient() as client:
        for i, ex in enumerate(examples):
            messages = ex["messages"][:2]  # system + user only

            try:
                # Run both models in parallel
                base_task = generate_output(client, base_url, base_model, messages)
                ft_task = generate_output(client, finetuned_url, finetuned_model, messages)

                (base_text, base_lat, base_tok), (ft_text, ft_lat, ft_tok) = (
                    await asyncio.gather(base_task, ft_task)
                )

                result = ABTestResult(
                    prompt_id=i,
                    user_input=messages[1]["content"],
                    base_output=base_text,
                    finetuned_output=ft_text,
                    base_latency_ms=round(base_lat, 2),
                    finetuned_latency_ms=round(ft_lat, 2),
                    base_tokens=base_tok,
                    finetuned_tokens=ft_tok,
                )
                results.append(result)

                if (i + 1) % 10 == 0:
                    logger.info(f"Completed {i + 1}/{len(examples)}")

            except Exception as e:
                logger.warning(f"Prompt {i} failed: {e}")

    # Save results
    output = Path(output_path)
    output.parent.mkdir(parents=True, exist_ok=True)

    with open(output, "w") as f:
        json.dump(
            {
                "test_date": time.strftime("%Y-%m-%d %H:%M:%S"),
                "base_model": base_model,
                "finetuned_model": finetuned_model,
                "total_prompts": len(results),
                "avg_base_latency_ms": round(
                    sum(r.base_latency_ms for r in results) / len(results), 2
                ) if results else 0,
                "avg_finetuned_latency_ms": round(
                    sum(r.finetuned_latency_ms for r in results) / len(results), 2
                ) if results else 0,
                "results": [asdict(r) for r in results],
            },
            f,
            indent=2,
        )

    logger.info(f"A/B test results saved to {output}")
    logger.info(f"Total comparisons: {len(results)}")

    if results:
        avg_base = sum(r.base_latency_ms for r in results) / len(results)
        avg_ft = sum(r.finetuned_latency_ms for r in results) / len(results)
        logger.info(f"Avg latency — Base: {avg_base:.0f}ms | Fine-tuned: {avg_ft:.0f}ms")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser(description="Fogbreak A/B Test Evaluator")
    parser.add_argument("--base-url", default="http://vllm-mistral:8001")
    parser.add_argument("--base-model", default="mistralai/Mistral-7B-Instruct-v0.3")
    parser.add_argument("--finetuned-url", default="http://vllm-finetuned:8002")
    parser.add_argument("--finetuned-model", default="fogbreak/mistral-7b-finetuned")
    parser.add_argument("--data-dir", default="training_data")
    parser.add_argument("--output", default="results/ab_test.json")
    parser.add_argument("--sample-size", type=int, default=50)
    args = parser.parse_args()

    asyncio.run(run_ab_test(
        base_url=args.base_url,
        base_model=args.base_model,
        finetuned_url=args.finetuned_url,
        finetuned_model=args.finetuned_model,
        data_dir=args.data_dir,
        output_path=args.output,
        sample_size=args.sample_size,
    ))
