"""Tests for LM distillation techniques — CoT, knowledge store, self-consistency."""

import pytest
from synthos.lm.distill import (
    build_chain_of_thought, ChainOfThought, ThoughtStep,
    KnowledgeStore, KnowledgeEntry,
    generate_candidates, self_consistency_check, Candidate,
)


class TestChainOfThought:
    def test_basic_cot(self):
        cot = build_chain_of_thought("how does attention work in SYNTHOS")
        assert isinstance(cot, ChainOfThought)
        assert len(cot.steps) >= 3
        assert cot.query == "how does attention work in SYNTHOS"

    def test_cot_explain_pattern(self):
        cot = build_chain_of_thought("explain how the pipeline processes input")
        assert any("Identify" in s.thought or "subject" in s.thought for s in cot.steps)

    def test_cot_compare_pattern(self):
        cot = build_chain_of_thought("compare SYNTHOS versus transformers")
        assert any("compar" in s.thought.lower() or "similar" in s.thought.lower() for s in cot.steps)

    def test_cot_create_pattern(self):
        cot = build_chain_of_thought("create a data processing pipeline")
        assert any("creat" in s.thought.lower() or "implement" in s.thought.lower() for s in cot.steps)

    def test_cot_define_pattern(self):
        cot = build_chain_of_thought("what is regex topology")
        assert any("concept" in s.thought.lower() or "definition" in s.thought.lower() for s in cot.steps)

    def test_cot_list_pattern(self):
        cot = build_chain_of_thought("list the seven SYNTHOS layers")
        assert any("list" in s.thought.lower() or "gather" in s.thought.lower() for s in cot.steps)

    def test_cot_fallback(self):
        cot = build_chain_of_thought("hello world")
        assert len(cot.steps) >= 3

    def test_cot_to_text(self):
        cot = build_chain_of_thought("what is SYNTHOS")
        text = cot.to_text(include_steps=True)
        assert "Step 1:" in text

    def test_cot_to_dict(self):
        cot = build_chain_of_thought("test query")
        d = cot.to_dict()
        assert "query" in d
        assert "steps" in d
        assert len(d["steps"]) > 0


class TestKnowledgeStore:
    def test_default_entries(self):
        ks = KnowledgeStore()
        entries = ks.query("SYNTHOS pipeline")
        assert len(entries) > 0

    def test_query_attention(self):
        ks = KnowledgeStore()
        entries = ks.query("how does attention work")
        assert any("attention" in e.topic.lower() for e in entries)

    def test_query_encryption(self):
        ks = KnowledgeStore()
        entries = ks.query("STC cipher encryption")
        assert any("encrypt" in e.topic.lower() or "STC" in e.summary for e in entries)

    def test_query_no_match(self):
        ks = KnowledgeStore()
        entries = ks.query("xyzzy foobar baz")
        assert len(entries) == 0

    def test_add_custom_entry(self):
        ks = KnowledgeStore()
        ks.add("custom topic", "This is custom knowledge.", ["Detail 1"], ["related1"])
        entries = ks.query("custom topic")
        assert len(entries) >= 1
        assert "custom" in entries[0].summary.lower()

    def test_distill(self):
        ks = KnowledgeStore()
        entries = ks.query("SYNTHOS", top_k=2)
        text = ks.distill(entries)
        assert len(text) > 0
        assert "SYNTHOS" in text

    def test_distill_empty(self):
        ks = KnowledgeStore()
        assert ks.distill([]) == ""


class TestSelfConsistency:
    def test_single_candidate(self):
        c = Candidate(path_id="p0", answer="SYNTHOS is great", key_claims=["claim1"])
        winner = self_consistency_check([c])
        assert winner.score == 1.0

    def test_multiple_candidates(self):
        candidates = [
            Candidate("p0", "A", key_claims=["regex is key", "7 layers"]),
            Candidate("p1", "B", key_claims=["regex is key", "pattern geometry"]),
            Candidate("p2", "C", key_claims=["regex is key", "7 layers"]),
        ]
        winner = self_consistency_check(candidates)
        # "regex is key" appears 3x, "7 layers" 2x — p0 and p2 should score highest
        assert winner.path_id in ("p0", "p2")

    def test_no_candidates_raises(self):
        with pytest.raises(ValueError):
            self_consistency_check([])

    def test_generate_candidates(self):
        ks = KnowledgeStore()
        candidates = generate_candidates("what is SYNTHOS", ks, n=3)
        assert len(candidates) == 3
        assert all(len(c.answer) > 0 for c in candidates)

    def test_generate_and_check(self):
        ks = KnowledgeStore()
        candidates = generate_candidates("explain attention in SYNTHOS", ks, n=3)
        winner = self_consistency_check(candidates)
        assert winner.score > 0


class TestDistillIntegration:
    """Test distillation integrated into SynthoLM."""

    def test_lm_uses_knowledge_store(self):
        from synthos.lm.engine import SynthoLM
        from synthos.utils.config import SynthosConfig
        from synthos.utils.log import VerboseLevel

        lm = SynthoLM(config=SynthosConfig(verbose=0), verbose=VerboseLevel.QUIET)
        assert lm.knowledge is not None
        entries = lm.knowledge.query("SYNTHOS pipeline")
        assert len(entries) > 0

    def test_lm_explain_uses_distilled_knowledge(self):
        from synthos.lm.engine import SynthoLM
        from synthos.utils.config import SynthosConfig
        from synthos.utils.log import VerboseLevel

        lm = SynthoLM(config=SynthosConfig(verbose=0), verbose=VerboseLevel.QUIET)
        result = lm.generate("what is SYNTHOS")
        # Should contain distilled knowledge about SYNTHOS
        assert "SYNTHOS" in result.text
        assert "syntax" in result.text.lower() or "regex" in result.text.lower() or "layer" in result.text.lower()
