"""Tests for the Symbolic Topology Cipher (STC)."""

import os
import pytest
from synthos.crypto.stc import SymbolicTopologyCipher, STCKey, STCMode, BLOCK_SIZE, KEY_SIZE


class TestSTCKey:
    """Key generation and derivation tests."""

    def test_generate_random_key(self):
        key = STCKey.generate()
        assert len(key.raw) == KEY_SIZE
        assert len(key.sbox) == 256
        assert len(key.inv_sbox) == 256
        assert len(key.lattice_order) == BLOCK_SIZE
        assert len(key.round_keys) == 12

    def test_sbox_is_permutation(self):
        key = STCKey.generate()
        assert sorted(key.sbox) == list(range(256))

    def test_inv_sbox_inverts(self):
        key = STCKey.generate()
        for i in range(256):
            assert key.inv_sbox[key.sbox[i]] == i

    def test_lattice_order_is_permutation(self):
        key = STCKey.generate()
        assert sorted(key.lattice_order) == list(range(BLOCK_SIZE))

    def test_inv_lattice_inverts(self):
        key = STCKey.generate()
        for i in range(BLOCK_SIZE):
            assert key.inv_lattice_order[key.lattice_order[i]] == i

    def test_from_hex_roundtrip(self):
        key = STCKey.generate()
        restored = STCKey.from_hex(key.hex())
        assert restored.raw == key.raw
        assert restored.sbox == key.sbox

    def test_from_passphrase(self):
        salt = os.urandom(16)
        k1 = STCKey.from_passphrase("test-passphrase", salt=salt)
        k2 = STCKey.from_passphrase("test-passphrase", salt=salt)
        assert k1.raw == k2.raw

    def test_bad_key_length_raises(self):
        with pytest.raises(ValueError):
            STCKey(raw=b"short")

    def test_different_keys_produce_different_sboxes(self):
        k1 = STCKey.generate()
        k2 = STCKey.generate()
        assert k1.sbox != k2.sbox


class TestSymbolicTopologyCipher:
    """Encryption / decryption round-trip and edge-case tests."""

    @pytest.fixture
    def key(self):
        return STCKey.generate()

    @pytest.mark.parametrize("mode", [STCMode.ECB, STCMode.CBC, STCMode.CTR])
    def test_roundtrip_all_modes(self, key, mode):
        cipher = SymbolicTopologyCipher(key, mode)
        pt = b"hello SYNTHOS world"
        ct = cipher.encrypt(pt)
        assert cipher.decrypt(ct) == pt

    def test_roundtrip_empty(self, key):
        cipher = SymbolicTopologyCipher(key)
        ct = cipher.encrypt(b"")
        assert cipher.decrypt(ct) == b""

    def test_roundtrip_exact_block(self, key):
        cipher = SymbolicTopologyCipher(key)
        pt = b"A" * BLOCK_SIZE
        ct = cipher.encrypt(pt)
        assert cipher.decrypt(ct) == pt

    def test_roundtrip_multi_block(self, key):
        cipher = SymbolicTopologyCipher(key)
        pt = os.urandom(BLOCK_SIZE * 5 + 7)
        ct = cipher.encrypt(pt)
        assert cipher.decrypt(ct) == pt

    def test_roundtrip_large(self, key):
        cipher = SymbolicTopologyCipher(key)
        pt = os.urandom(10_000)
        ct = cipher.encrypt(pt)
        assert cipher.decrypt(ct) == pt

    def test_authentication_detects_tamper(self, key):
        cipher = SymbolicTopologyCipher(key)
        ct = bytearray(cipher.encrypt(b"secret"))
        ct[-1] ^= 0xFF  # flip last byte of HMAC
        with pytest.raises(ValueError, match="HMAC"):
            cipher.decrypt(bytes(ct))

    def test_no_auth_mode(self, key):
        cipher = SymbolicTopologyCipher(key)
        ct = cipher.encrypt(b"no-auth", authenticate=False)
        assert cipher.decrypt(ct) == b"no-auth"

    def test_bad_magic_raises(self, key):
        cipher = SymbolicTopologyCipher(key)
        with pytest.raises(ValueError, match="magic"):
            cipher.decrypt(b"BAD!" + b"\x00" * 100)

    def test_verbose_trace(self, key):
        cipher = SymbolicTopologyCipher(key, verbose=True)
        cipher.encrypt(b"trace me")
        assert len(cipher.trace) > 0

    def test_different_keys_produce_different_ciphertext(self):
        k1 = STCKey.generate()
        k2 = STCKey.generate()
        pt = b"same plaintext"
        ct1 = SymbolicTopologyCipher(k1).encrypt(pt)
        ct2 = SymbolicTopologyCipher(k2).encrypt(pt)
        # The ciphertext bodies differ (skip magic+header+iv which are random anyway)
        assert ct1 != ct2

    def test_cbc_iv_makes_ciphertext_unique(self, key):
        cipher = SymbolicTopologyCipher(key, STCMode.CBC)
        pt = b"deterministic?"
        ct1 = cipher.encrypt(pt)
        ct2 = cipher.encrypt(pt)
        assert ct1 != ct2  # different IVs
