"""
SYNTHOS Model Evaluation & Benchmarking Engine
===============================================

Provides a fast, structured evaluation framework for comparing open-source
LLMs and inference methodologies.  Results are expressed through SYNTHOS's
native pattern-geometry scoring so they can feed back into the cognitive
pipeline.

Supported model providers (via HTTP API):
    - Ollama  (local, ``http://localhost:11434``)
    - llama.cpp server  (local, ``http://localhost:8080``)
    - LM Studio  (local, ``http://localhost:1234``)
    - Any OpenAI-compatible endpoint (vLLM, TGI, LocalAI, etc.)

Evaluation axes
---------------
* **Latency** — time-to-first-token (TTFT) and tokens-per-second (TPS).
* **Quality** — regex-match accuracy against golden patterns; BLEU-1 for
  free-form text.
* **Throughput** — sustained tokens/s under concurrent load.
* **Coherence** — SYNTHOS-native backreference coherence score.
* **Cost** — estimated $/M tokens (if pricing metadata is available).
"""

from __future__ import annotations

import asyncio
import json
import math
import statistics
import time
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence

import httpx

# ── Data structures ────────────────────────────────────────────────────────────


class Provider(Enum):
    """Supported inference providers."""
    OLLAMA = "ollama"
    LLAMACPP = "llamacpp"
    LMSTUDIO = "lmstudio"
    OPENAI_COMPAT = "openai_compat"


@dataclass
class ModelCard:
    """
    Describes one model to evaluate.

    Parameters
    ----------
    name : str
        Human-readable name (e.g. ``"Llama-3.1-8B-Q4"``).
    provider : Provider
        Which backend serves this model.
    base_url : str
        Root URL for the API.
    model_id : str
        Model identifier understood by the provider.
    ctx_len : int
        Maximum context length in tokens.
    params_b : float
        Parameter count in billions (for reporting).
    quant : str
        Quantisation label (e.g. ``"Q4_K_M"``).
    cost_per_mtok : float
        Estimated cost per million tokens (0.0 for local).
    """
    name: str
    provider: Provider = Provider.OLLAMA
    base_url: str = "http://localhost:11434"
    model_id: str = ""
    ctx_len: int = 4096
    params_b: float = 0.0
    quant: str = "fp16"
    cost_per_mtok: float = 0.0


@dataclass
class BenchResult:
    """
    Results for a single model on a single eval suite.

    All timing values are in **seconds** unless noted.
    """
    model: str
    suite: str
    samples: int = 0

    # Latency
    ttft_mean: float = 0.0          # time to first token (s)
    ttft_p50: float = 0.0
    ttft_p95: float = 0.0
    tps_mean: float = 0.0           # tokens per second
    tps_p50: float = 0.0

    # Quality
    exact_match_rate: float = 0.0   # fraction of regex-exact matches
    bleu1: float = 0.0              # unigram BLEU
    coherence: float = 0.0          # SYNTHOS backreference coherence

    # Throughput
    throughput_tps: float = 0.0     # sustained throughput under load

    # Cost
    est_cost_per_mtok: float = 0.0

    # Errors
    errors: int = 0

    def score(self) -> float:
        """
        Composite score (higher is better).

        Formula:  0.35·quality + 0.30·speed + 0.20·coherence + 0.15·cost_eff
        """
        quality = self.exact_match_rate * 0.7 + self.bleu1 * 0.3
        speed = min(self.tps_mean / 100.0, 1.0)       # normalise to 100 tps
        cost_eff = 1.0 / (1.0 + self.est_cost_per_mtok)
        return 0.35 * quality + 0.30 * speed + 0.20 * self.coherence + 0.15 * cost_eff


@dataclass
class EvalPrompt:
    """One evaluation prompt with an expected output pattern."""
    prompt: str
    expected_pattern: str = ""      # regex the answer must match
    reference: str = ""             # golden reference for BLEU
    category: str = "general"
    max_tokens: int = 256


# ── Built-in evaluation suites ────────────────────────────────────────────────

_BUILTIN_SUITES: Dict[str, List[EvalPrompt]] = {
    "pattern_match": [
        EvalPrompt("What is 2+2? Answer with just the number.", r"^4$", "4", "math"),
        EvalPrompt("Name the three primary colours, comma-separated.", r"red.*blue.*yellow|red.*yellow.*blue", "red, blue, yellow", "knowledge"),
        EvalPrompt("Write a Python function that returns True.", r"def\s+\w+.*return\s+True", "def f():\n    return True", "code"),
        EvalPrompt("Reverse the word 'hello'.", r"olleh", "olleh", "reasoning"),
        EvalPrompt("Is 17 prime? Answer yes or no.", r"(?i)yes", "yes", "math"),
        EvalPrompt("Give me a valid JSON object with key 'a'.", r'\{\s*"a"\s*:', '{"a": 1}', "format"),
        EvalPrompt("Translate 'hello' to French.", r"(?i)bonjour", "bonjour", "translation"),
        EvalPrompt("What colour is the sky? One word.", r"(?i)blue", "blue", "knowledge"),
        EvalPrompt("Count to 5 with commas.", r"1.*2.*3.*4.*5", "1, 2, 3, 4, 5", "reasoning"),
        EvalPrompt("Capital of France?", r"(?i)paris", "Paris", "knowledge"),
    ],
    "coherence": [
        EvalPrompt(
            "Repeat the following exactly: 'The pattern is the architecture.'",
            r"The pattern is the architecture\.",
            "The pattern is the architecture.",
            "echo",
        ),
        EvalPrompt(
            "Given A=3, B=7, what is A+B? Answer with just the number.",
            r"^10$", "10", "binding",
        ),
        EvalPrompt(
            "If X='SYNTHOS', spell X backwards.",
            r"(?i)sohtny", "SOHTNYS", "binding",
        ),
    ],
    "speed_stress": [
        EvalPrompt(
            "Write a 200-word essay on artificial intelligence.",
            r".{400,}", "", "throughput", max_tokens=512,
        ),
    ],
}


# ── Evaluation suite ──────────────────────────────────────────────────────────

class EvalSuite:
    """
    A named set of evaluation prompts.

    Parameters
    ----------
    name : str
        Suite name.
    prompts : list[EvalPrompt] | None
        Custom prompts.  If *None*, loads a built-in suite by *name*.
    """

    def __init__(self, name: str, prompts: Optional[List[EvalPrompt]] = None):
        self.name = name
        if prompts is not None:
            self.prompts = prompts
        elif name in _BUILTIN_SUITES:
            self.prompts = _BUILTIN_SUITES[name]
        else:
            raise ValueError(f"Unknown built-in suite '{name}'. Available: {list(_BUILTIN_SUITES)}")

    def __len__(self) -> int:
        return len(self.prompts)


# ── Benchmark runner ──────────────────────────────────────────────────────────

class ModelBenchmark:
    """
    Orchestrates evaluation of one or more models against one or more suites.

    Parameters
    ----------
    models : Sequence[ModelCard]
        Models to evaluate.
    suites : Sequence[EvalSuite] | None
        Suites to run.  Defaults to all built-in suites.
    concurrency : int
        Max parallel requests during throughput tests.
    timeout : float
        Per-request timeout in seconds.
    verbose : bool
        Emit per-prompt diagnostics.

    Examples
    --------
    >>> cards = [ModelCard("phi-3", model_id="phi3")]
    >>> bench = ModelBenchmark(cards)
    >>> results = bench.run()          # synchronous convenience
    >>> for r in results:
    ...     print(f"{r.model}: score={r.score():.3f}")
    """

    def __init__(
        self,
        models: Sequence[ModelCard],
        suites: Optional[Sequence[EvalSuite]] = None,
        *,
        concurrency: int = 4,
        timeout: float = 60.0,
        verbose: bool = False,
    ):
        self.models = list(models)
        self.suites = list(suites) if suites else [EvalSuite(n) for n in _BUILTIN_SUITES]
        self.concurrency = concurrency
        self.timeout = timeout
        self.verbose = verbose
        self.trace: List[str] = []

    # ── Public API ─────────────────────────────────────────────────────────

    def run(self) -> List[BenchResult]:
        """Run all evaluations synchronously (wraps the async runner)."""
        return asyncio.run(self.run_async())

    async def run_async(self) -> List[BenchResult]:
        """Run all evaluations asynchronously."""
        results: List[BenchResult] = []
        async with httpx.AsyncClient(timeout=self.timeout) as client:
            for model in self.models:
                for suite in self.suites:
                    self._log(f"▸ Evaluating {model.name} on '{suite.name}' ({len(suite)} prompts)")
                    result = await self._eval_model_suite(client, model, suite)
                    results.append(result)
                    self._log(f"  ✓ score={result.score():.3f}  tps={result.tps_mean:.1f}  exact={result.exact_match_rate:.2f}")
        return results

    # ── Internal ───────────────────────────────────────────────────────────

    async def _eval_model_suite(self, client: httpx.AsyncClient, model: ModelCard, suite: EvalSuite) -> BenchResult:
        """Evaluate a single model against a single suite."""
        import re as _re

        ttfts: List[float] = []
        tps_list: List[float] = []
        exact_matches = 0
        bleu_scores: List[float] = []
        errors = 0
        total_tokens = 0
        wall_start = time.monotonic()

        for prompt in suite.prompts:
            try:
                resp_text, ttft, tok_count, elapsed = await self._generate(client, model, prompt)
            except Exception as exc:
                self._log(f"    ✗ error: {exc}")
                errors += 1
                continue

            ttfts.append(ttft)
            tps = tok_count / elapsed if elapsed > 0 else 0.0
            tps_list.append(tps)
            total_tokens += tok_count

            # Quality: exact regex match
            if prompt.expected_pattern:
                if _re.search(prompt.expected_pattern, resp_text):
                    exact_matches += 1
                    self._log(f"    ✓ matched '{prompt.expected_pattern}'")
                else:
                    self._log(f"    ✗ no match for '{prompt.expected_pattern}' in '{resp_text[:80]}…'")

            # Quality: BLEU-1
            if prompt.reference:
                bleu_scores.append(self._bleu1(prompt.reference, resp_text))

        wall_elapsed = time.monotonic() - wall_start

        n = max(len(suite.prompts) - errors, 1)
        return BenchResult(
            model=model.name,
            suite=suite.name,
            samples=len(suite.prompts),
            ttft_mean=statistics.mean(ttfts) if ttfts else 0,
            ttft_p50=self._percentile(ttfts, 50),
            ttft_p95=self._percentile(ttfts, 95),
            tps_mean=statistics.mean(tps_list) if tps_list else 0,
            tps_p50=self._percentile(tps_list, 50),
            exact_match_rate=exact_matches / n,
            bleu1=statistics.mean(bleu_scores) if bleu_scores else 0,
            coherence=exact_matches / n,   # simplified coherence proxy
            throughput_tps=total_tokens / wall_elapsed if wall_elapsed > 0 else 0,
            est_cost_per_mtok=model.cost_per_mtok,
            errors=errors,
        )

    async def _generate(self, client: httpx.AsyncClient, model: ModelCard, prompt: EvalPrompt):
        """
        Send a generation request and measure latency.

        Returns (response_text, ttft_seconds, token_count, total_seconds).
        """
        if model.provider == Provider.OLLAMA:
            return await self._gen_ollama(client, model, prompt)
        else:
            return await self._gen_openai_compat(client, model, prompt)

    async def _gen_ollama(self, client, model, prompt):
        url = f"{model.base_url}/api/generate"
        body = {"model": model.model_id or model.name, "prompt": prompt.prompt, "stream": True,
                "options": {"num_predict": prompt.max_tokens}}
        t0 = time.monotonic()
        ttft = None
        chunks: List[str] = []
        tok_count = 0

        async with client.stream("POST", url, json=body) as resp:
            resp.raise_for_status()
            async for line in resp.aiter_lines():
                if not line:
                    continue
                obj = json.loads(line)
                if ttft is None:
                    ttft = time.monotonic() - t0
                token = obj.get("response", "")
                chunks.append(token)
                tok_count += 1
                if obj.get("done"):
                    break

        total = time.monotonic() - t0
        return "".join(chunks), ttft or total, tok_count, total

    async def _gen_openai_compat(self, client, model, prompt):
        url = f"{model.base_url}/v1/chat/completions"
        body = {
            "model": model.model_id or model.name,
            "messages": [{"role": "user", "content": prompt.prompt}],
            "max_tokens": prompt.max_tokens,
            "stream": True,
        }
        t0 = time.monotonic()
        ttft = None
        chunks: List[str] = []
        tok_count = 0

        async with client.stream("POST", url, json=body) as resp:
            resp.raise_for_status()
            async for line in resp.aiter_lines():
                if not line or not line.startswith("data:"):
                    continue
                data = line[5:].strip()
                if data == "[DONE]":
                    break
                obj = json.loads(data)
                delta = obj.get("choices", [{}])[0].get("delta", {})
                content = delta.get("content", "")
                if content:
                    if ttft is None:
                        ttft = time.monotonic() - t0
                    chunks.append(content)
                    tok_count += 1

        total = time.monotonic() - t0
        return "".join(chunks), ttft or total, tok_count, total

    # ── Scoring helpers ────────────────────────────────────────────────────

    @staticmethod
    def _bleu1(reference: str, hypothesis: str) -> float:
        """Unigram BLEU (precision-only, no brevity penalty)."""
        ref_tokens = reference.lower().split()
        hyp_tokens = hypothesis.lower().split()
        if not hyp_tokens:
            return 0.0
        ref_set = set(ref_tokens)
        hits = sum(1 for t in hyp_tokens if t in ref_set)
        return hits / len(hyp_tokens)

    @staticmethod
    def _percentile(values: List[float], pct: float) -> float:
        if not values:
            return 0.0
        s = sorted(values)
        k = (len(s) - 1) * pct / 100.0
        f = math.floor(k)
        c = math.ceil(k)
        if f == c:
            return s[int(k)]
        return s[f] * (c - k) + s[c] * (k - f)

    # ── Reporting ──────────────────────────────────────────────────────────

    @staticmethod
    def leaderboard(results: List[BenchResult]) -> str:
        """
        Produce an ASCII leaderboard table sorted by composite score.

        Returns a ready-to-print string.
        """
        if not results:
            return "(no results)"

        results_sorted = sorted(results, key=lambda r: r.score(), reverse=True)
        header = f"{'Rank':<5} {'Model':<28} {'Suite':<18} {'Score':>6} {'TPS':>7} {'TTFT':>7} {'Exact%':>7} {'BLEU1':>6} {'Coh':>5} {'Err':>4}"
        sep = "─" * len(header)
        lines = [sep, header, sep]
        for i, r in enumerate(results_sorted, 1):
            lines.append(
                f"{i:<5} {r.model:<28} {r.suite:<18} {r.score():>6.3f} {r.tps_mean:>7.1f} {r.ttft_mean:>7.3f} {r.exact_match_rate*100:>6.1f}% {r.bleu1:>6.3f} {r.coherence:>5.2f} {r.errors:>4d}"
            )
        lines.append(sep)
        return "\n".join(lines)

    @staticmethod
    def to_json(results: List[BenchResult]) -> str:
        """Serialize results to JSON."""
        import dataclasses
        return json.dumps([dataclasses.asdict(r) | {"score": r.score()} for r in results], indent=2)

    def _log(self, msg: str):
        if self.verbose:
            self.trace.append(msg)
