"""
LM Distillation Techniques for SynthoLM
=========================================

Implements three distillation strategies that allow the syntax-based
pipeline to produce richer responses:

1. **Chain-of-Thought (CoT)** — decompose a complex request into
   explicit reasoning steps before answering.
2. **Knowledge Distillation** — compress domain knowledge into
   compact templates keyed by topic signature.
3. **Self-Consistency** — generate multiple candidate answers via
   different reasoning paths, then select the most consistent one.

These techniques work in pure-syntax mode (no external LLM needed)
and also enrich the system prompt when delegating to a real LLM
in hybrid mode.
"""

from __future__ import annotations

import re
import hashlib
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple


# ═══════════════════════════════════════════════════════════════════════════════
# 1. Chain-of-Thought
# ═══════════════════════════════════════════════════════════════════════════════

@dataclass
class ThoughtStep:
    """One reasoning step in a chain-of-thought."""
    index: int
    thought: str
    conclusion: str = ""
    confidence: float = 1.0


@dataclass
class ChainOfThought:
    """Complete CoT trace for a single query."""
    query: str
    steps: List[ThoughtStep] = field(default_factory=list)
    final_answer: str = ""

    def to_text(self, include_steps: bool = True) -> str:
        parts: List[str] = []
        if include_steps:
            for s in self.steps:
                parts.append(f"Step {s.index}: {s.thought}")
                if s.conclusion:
                    parts.append(f"  → {s.conclusion}")
            parts.append("")
        parts.append(self.final_answer)
        return "\n".join(parts)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "query": self.query,
            "steps": [{"index": s.index, "thought": s.thought,
                        "conclusion": s.conclusion, "confidence": s.confidence}
                       for s in self.steps],
            "final_answer": self.final_answer,
        }


# Decomposition patterns — map query shapes to reasoning skeletons
_COT_SKELETONS: List[Tuple[str, List[str]]] = [
    (r"(?i)(?:how|why|explain)\s+(?:does|do|is|are|can|would)\s+(.+)",
     ["Identify the subject: {subject}",
      "Recall relevant knowledge about {subject}",
      "Reason about the mechanism or cause",
      "Synthesize into a clear explanation"]),

    (r"(?i)(?:compare|difference|versus)\s+(.+)",
     ["Identify the items being compared",
      "List key properties of each item",
      "Find similarities and differences",
      "Summarize the comparison"]),

    (r"(?i)(?:create|make|build|write|generate)\s+(.+)",
     ["Understand what needs to be created: {subject}",
      "Determine required components and structure",
      "Plan the implementation approach",
      "Execute the creation steps"]),

    (r"(?i)(?:what|define)\s+(?:is|are)\s+(.+)",
     ["Identify the concept: {subject}",
      "Recall the definition or description",
      "Provide relevant context or examples",
      "Deliver a clear answer"]),

    (r"(?i)(?:list|enumerate|show)\s+(.+)",
     ["Identify what to list: {subject}",
      "Gather all relevant items",
      "Organize in a logical order",
      "Present the list"]),
]

_DEFAULT_SKELETON = [
    "Parse the request: {subject}",
    "Identify key concepts and requirements",
    "Reason about the best approach",
    "Formulate the response",
]


def build_chain_of_thought(query: str) -> ChainOfThought:
    """
    Decompose *query* into a chain-of-thought using pattern-matched
    reasoning skeletons.
    """
    cot = ChainOfThought(query=query)
    skeleton = _DEFAULT_SKELETON
    subject = query

    for pattern, skel in _COT_SKELETONS:
        m = re.search(pattern, query)
        if m:
            skeleton = skel
            subject = m.group(1).strip() if m.lastindex else query
            break

    for i, template in enumerate(skeleton, 1):
        thought = template.format(subject=subject)
        cot.steps.append(ThoughtStep(index=i, thought=thought))

    return cot


# ═══════════════════════════════════════════════════════════════════════════════
# 2. Knowledge Distillation
# ═══════════════════════════════════════════════════════════════════════════════

@dataclass
class KnowledgeEntry:
    """A compressed knowledge fragment."""
    topic: str
    signature: str          # hash-based lookup key
    summary: str
    details: List[str] = field(default_factory=list)
    related: List[str] = field(default_factory=list)


class KnowledgeStore:
    """
    A compact knowledge base built by distilling verbose descriptions
    into structured entries keyed by topic signature.
    """

    def __init__(self):
        self._entries: Dict[str, KnowledgeEntry] = {}
        self._load_defaults()

    def _sig(self, topic: str) -> str:
        return hashlib.md5(topic.lower().strip().encode()).hexdigest()[:12]

    def add(self, topic: str, summary: str, details: Optional[List[str]] = None,
            related: Optional[List[str]] = None):
        sig = self._sig(topic)
        self._entries[sig] = KnowledgeEntry(
            topic=topic, signature=sig, summary=summary,
            details=details or [], related=related or [],
        )

    def query(self, text: str, top_k: int = 3) -> List[KnowledgeEntry]:
        """Find entries whose topic appears in *text*."""
        text_lower = text.lower()
        scored: List[Tuple[int, KnowledgeEntry]] = []
        for entry in self._entries.values():
            score = 0
            for word in entry.topic.lower().split():
                if word in text_lower:
                    score += 1
            if score > 0:
                scored.append((score, entry))
        scored.sort(key=lambda x: -x[0])
        return [e for _, e in scored[:top_k]]

    def distill(self, entries: List[KnowledgeEntry]) -> str:
        """Compress multiple entries into a single context paragraph."""
        if not entries:
            return ""
        parts = []
        for e in entries:
            parts.append(e.summary)
            if e.details:
                parts.append(" ".join(e.details[:2]))
        return " ".join(parts)

    def _load_defaults(self):
        self.add("SYNTHOS", "SYNTHOS is a syntax-driven AI architecture where intelligence is encoded in regex pattern geometry across 7 layers.",
                 ["The 7 layers are: LPE, GPL, SCM, TAM, RGE, SCF, OPS.",
                  "Each layer processes input through regex patterns, not neural weights.",
                  "The system is deterministic — same input always produces same output."],
                 ["regex", "pipeline", "architecture"])
        self.add("attention", "In SYNTHOS, attention is pattern intersection area across 8 heads in the Topological Attention Mesh.",
                 ["Each head has a query pattern, key pattern, and value pattern.",
                  "Overlap area between query and key matches determines relevance.",
                  "Heads can have LOCAL, WINDOW, or GLOBAL scope."],
                 ["TAM", "pattern intersection", "heads"])
        self.add("encryption", "The Symbolic Topology Cipher (STC) is a 256-bit block cipher with 12 rounds of lattice permutation, S-box substitution, and Möbius fold.",
                 ["Supports ECB, CBC (default), and CTR modes.",
                  "Uses HMAC-SHA256 for authentication.",
                  "Key derivation via PBKDF2 with 600,000 iterations."],
                 ["STC", "cipher", "crypto"])
        self.add("pipeline", "The SYNTHOS pipeline processes input through 7 sequential layers: LPE→GPL→SCM→TAM→RGE→SCF→OPS.",
                 ["L0 LPE: 48+ regex primitives with ASCII geometric forms.",
                  "L1 GPL: 4×4 lattice grid with directional transitions.",
                  "L2 SCM: Named captures become concept nodes in a semantic graph.",
                  "L3 TAM: 8-head attention via pattern intersection geometry.",
                  "L4 RGE: EBNF grammar encoded as recursive regex.",
                  "L5 SCF: Rank-3 tensor with short/long/episodic memory.",
                  "L6 OPS: Substitution chains produce structured ASCII output."],
                 ["layers", "architecture", "cognitive"])
        self.add("regex", "Regular expressions are the atomic cognitive operations of SYNTHOS. Each primitive has an ASCII geometric form.",
                 ["48+ primitives across 6 types: ATOMIC, QUANTIFIER, GROUP, LOOKAROUND, REFERENCE, CHARACTER_CLASS.",
                  "Geometric forms: · for dot, ─── for sequence, ○ for groups, ∞ for Kleene star."],
                 ["primitives", "LPE", "patterns"])
        self.add("memory", "The State Crystallization Field maintains three memory types: short-term (sliding window), long-term (named registers), and episodic (match history stack).",
                 ["Short-term memory is a sliding window of 8 recent matches.",
                  "Long-term memory stores named values that persist across turns.",
                  "Episodic memory records a stack of interaction summaries."],
                 ["SCF", "state", "crystallization"])
        self.add("grammar", "The Recursive Grammar Engine encodes EBNF productions as recursive regex patterns with FSM parsing.",
                 ["Built-in rules: SYNTHOS_ROOT, STATEMENT, EXPRESSION, GEOMETRY, LITERAL, IDENTIFIER, NUMBER.",
                  "Parse trees are constructed from named capture groups."],
                 ["RGE", "EBNF", "parsing"])
        self.add("distillation", "LM distillation compresses the knowledge and reasoning patterns of a larger model into a smaller, faster one.",
                 ["Chain-of-thought (CoT) decomposes complex queries into reasoning steps.",
                  "Knowledge distillation stores topic summaries for fast retrieval.",
                  "Self-consistency generates multiple answers and picks the most consistent."],
                 ["CoT", "reasoning", "compression"])
        self.add("transformer", "Transformers use self-attention with Q·K^T/√d followed by softmax. SYNTHOS replaces this with regex pattern intersection.",
                 ["Neural attention learns weights via backpropagation.",
                  "SYNTHOS attention uses structural overlap — no training needed."],
                 ["neural network", "attention", "comparison"])
        self.add("tools", "SynthoLM includes a tool system for executing system tasks: creating files, directories, running shell commands, and scaffolding projects.",
                 ["Tools are invoked via natural language: 'create a python file called X'.",
                  "Multi-step tasks are decomposed into chains of tool calls.",
                  "File generators produce well-structured .py, .sh, and .md files."],
                 ["filesystem", "batch", "generator"])


# ═══════════════════════════════════════════════════════════════════════════════
# 3. Self-Consistency
# ═══════════════════════════════════════════════════════════════════════════════

@dataclass
class Candidate:
    """One candidate answer from a reasoning path."""
    path_id: str
    answer: str
    key_claims: List[str] = field(default_factory=list)
    score: float = 0.0


def self_consistency_check(candidates: List[Candidate]) -> Candidate:
    """
    Pick the most self-consistent candidate by counting how many
    key claims are shared across answers.
    """
    if not candidates:
        raise ValueError("No candidates to check")
    if len(candidates) == 1:
        candidates[0].score = 1.0
        return candidates[0]

    # Build a claim frequency map
    all_claims: Dict[str, int] = {}
    for c in candidates:
        for claim in c.key_claims:
            norm = claim.lower().strip()
            all_claims[norm] = all_claims.get(norm, 0) + 1

    # Score each candidate by its claims' frequency
    for c in candidates:
        total = 0
        for claim in c.key_claims:
            norm = claim.lower().strip()
            total += all_claims.get(norm, 0)
        c.score = total / max(len(c.key_claims), 1)

    candidates.sort(key=lambda c: -c.score)
    return candidates[0]


def generate_candidates(query: str, knowledge: KnowledgeStore, n: int = 3) -> List[Candidate]:
    """
    Generate *n* candidate answers using different reasoning paths.

    Each path uses a slightly different subset of knowledge entries
    and CoT skeleton to produce variation.
    """
    entries = knowledge.query(query, top_k=5)
    cot = build_chain_of_thought(query)
    candidates: List[Candidate] = []

    for i in range(n):
        # Vary the knowledge subset
        subset = entries[i:] + entries[:i] if entries else []
        context = knowledge.distill(subset[:3])

        # Build answer from CoT + context
        answer_parts = []
        claims = []

        if context:
            answer_parts.append(context)
            # Extract key claims (sentences) from context
            for sentence in re.split(r'[.!]', context):
                s = sentence.strip()
                if len(s) > 10:
                    claims.append(s)

        for step in cot.steps:
            answer_parts.append(f"{step.thought}.")

        answer = " ".join(answer_parts)
        candidates.append(Candidate(
            path_id=f"path_{i}",
            answer=answer,
            key_claims=claims[:5],
        ))

    return candidates
