"""
Fine-Tuning Runner — Fogbreak AI.

Fine-tunes Mistral 7B on Fogbreak-specific training data using
HuggingFace Transformers + PEFT (LoRA). Produces a fine-tuned
model ready for deployment alongside base models in vLLM.

Usage:
    python -m ai.finetune.train \
        --data-dir training_data \
        --output-dir models/fogbreak-mistral-ft \
        --epochs 3 \
        --batch-size 4 \
        --learning-rate 2e-5
"""

import argparse
import json
import logging
import os
from pathlib import Path

logger = logging.getLogger("fogbreak.finetune.train")

BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"


def load_training_data(data_dir: str) -> list[dict]:
    """Load all JSONL training files from a directory."""
    examples = []
    data_path = Path(data_dir)

    for jsonl_file in data_path.glob("*.jsonl"):
        logger.info(f"Loading {jsonl_file}")
        with open(jsonl_file) as f:
            for line in f:
                line = line.strip()
                if line:
                    examples.append(json.loads(line))

    logger.info(f"Loaded {len(examples)} total training examples")
    return examples


def prepare_dataset(examples: list[dict]):
    """Convert training examples to HuggingFace Dataset format."""
    from datasets import Dataset

    texts = []
    for ex in examples:
        messages = ex["messages"]
        # Format as Mistral instruction template
        text = ""
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            if role == "system":
                text += f"[INST] <<SYS>>\n{content}\n<</SYS>>\n\n"
            elif role == "user":
                if not text:
                    text += "[INST] "
                text += f"{content} [/INST] "
            elif role == "assistant":
                text += f"{content}</s>"

        texts.append({"text": text})

    dataset = Dataset.from_list(texts)

    # 90/10 train/eval split
    split = dataset.train_test_split(test_size=0.1, seed=42)
    return split["train"], split["test"]


def run_finetuning(
    data_dir: str,
    output_dir: str,
    epochs: int = 3,
    batch_size: int = 4,
    learning_rate: float = 2e-5,
    lora_r: int = 16,
    lora_alpha: int = 32,
    lora_dropout: float = 0.05,
    max_seq_length: int = 2048,
    gradient_accumulation_steps: int = 4,
):
    """
    Fine-tune Mistral 7B with LoRA on Fogbreak training data.
    """
    import torch
    from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
    from transformers import (
        AutoModelForCausalLM,
        AutoTokenizer,
        BitsAndBytesConfig,
        TrainingArguments,
    )
    from trl import SFTTrainer

    logger.info(f"Fine-tuning {BASE_MODEL}")
    logger.info(f"Data: {data_dir} | Output: {output_dir}")
    logger.info(f"Epochs: {epochs} | Batch: {batch_size} | LR: {learning_rate}")

    # Load training data
    examples = load_training_data(data_dir)
    if not examples:
        logger.error("No training data found. Run export_training_data.py first.")
        return

    train_dataset, eval_dataset = prepare_dataset(examples)
    logger.info(f"Train: {len(train_dataset)} | Eval: {len(eval_dataset)}")

    # Quantization config for 4-bit training
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )

    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
        token=os.getenv("HF_TOKEN"),
    )
    model = prepare_model_for_kbit_training(model)

    # LoRA configuration
    lora_config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
        bias="none",
        task_type="CAUSAL_LM",
    )

    model = get_peft_model(model, lora_config)
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    logger.info(
        f"Trainable: {trainable_params:,} / {total_params:,} "
        f"({100 * trainable_params / total_params:.2f}%)"
    )

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        BASE_MODEL,
        trust_remote_code=True,
        token=os.getenv("HF_TOKEN"),
    )
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        weight_decay=0.01,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        logging_steps=10,
        eval_strategy="steps",
        eval_steps=50,
        save_strategy="steps",
        save_steps=100,
        save_total_limit=3,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
        fp16=not (torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False),
        report_to="none",
        dataloader_num_workers=4,
        group_by_length=True,
    )

    # Trainer
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        args=training_args,
        max_seq_length=max_seq_length,
    )

    # Train
    logger.info("Starting fine-tuning...")
    train_result = trainer.train()

    # Save
    logger.info(f"Saving model to {output_dir}")
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

    # Save training metrics
    metrics = train_result.metrics
    metrics_path = Path(output_dir) / "training_metrics.json"
    with open(metrics_path, "w") as f:
        json.dump(metrics, f, indent=2)

    logger.info(f"Fine-tuning complete. Loss: {metrics.get('train_loss', 'N/A')}")
    return metrics


def merge_and_export(
    adapter_dir: str,
    output_dir: str,
):
    """
    Merge LoRA adapter with base model for vLLM deployment.
    vLLM can serve merged models directly.
    """
    from peft import PeftModel
    from transformers import AutoModelForCausalLM, AutoTokenizer

    logger.info(f"Merging adapter from {adapter_dir}")

    # Load base model (full precision for merging)
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        device_map="auto",
        token=os.getenv("HF_TOKEN"),
    )

    # Load and merge LoRA
    model = PeftModel.from_pretrained(model, adapter_dir)
    model = model.merge_and_unload()

    # Save merged model
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    model.save_pretrained(output_dir)

    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=os.getenv("HF_TOKEN"))
    tokenizer.save_pretrained(output_dir)

    logger.info(f"Merged model saved to {output_dir}")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser(description="Fogbreak AI Fine-Tuning")
    subparsers = parser.add_subparsers(dest="command")

    # Train command
    train_parser = subparsers.add_parser("train", help="Fine-tune model")
    train_parser.add_argument("--data-dir", default="training_data")
    train_parser.add_argument("--output-dir", default="models/fogbreak-mistral-ft")
    train_parser.add_argument("--epochs", type=int, default=3)
    train_parser.add_argument("--batch-size", type=int, default=4)
    train_parser.add_argument("--learning-rate", type=float, default=2e-5)
    train_parser.add_argument("--lora-r", type=int, default=16)
    train_parser.add_argument("--lora-alpha", type=int, default=32)
    train_parser.add_argument("--max-seq-length", type=int, default=2048)

    # Merge command
    merge_parser = subparsers.add_parser("merge", help="Merge LoRA adapter")
    merge_parser.add_argument("--adapter-dir", required=True)
    merge_parser.add_argument("--output-dir", required=True)

    args = parser.parse_args()

    if args.command == "train":
        run_finetuning(
            data_dir=args.data_dir,
            output_dir=args.output_dir,
            epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=args.learning_rate,
            lora_r=args.lora_r,
            lora_alpha=args.lora_alpha,
            max_seq_length=args.max_seq_length,
        )
    elif args.command == "merge":
        merge_and_export(
            adapter_dir=args.adapter_dir,
            output_dir=args.output_dir,
        )
    else:
        parser.print_help()
