"""
SYNTHOS Symbolic Topology Cipher (STC)
======================================

A novel encryption algorithm rooted in SYNTHOS's core philosophy: intelligence
(and now confidentiality) is encoded in **pattern geometry**, not numeric
weights.

Design
------
The cipher operates in three phases that mirror the SYNTHOS pipeline:

1. **Lattice Permutation** — plaintext bytes are scattered across an N×N
   lattice whose traversal order is derived from the key.  This provides
   positional confusion (analogous to GPL cell routing).

2. **Regex-Topology Substitution** — each byte is transformed through a
   key-dependent S-box built by hashing named-capture groups extracted from
   the key material.  This provides value diffusion (analogous to SCM
   semantic binding).

3. **Möbius Fold** — the two halves of the state are XOR-twisted against
   each other in alternating directions (forward / reversed), inspired by
   the Möbius self-referential construct.  Multiple rounds amplify the
   avalanche effect.

Modes
-----
* ``STC`` — Electronic-codebook style (each block independent).
* ``STC-CBC`` — Cipher-block-chaining; each block XORed with the previous
  ciphertext before encryption.
* ``STC-CTR`` — Counter mode; a nonce + counter are encrypted to produce a
  keystream that is XORed with plaintext (parallelisable, seekable).

Security notes
--------------
This is a *research / educational* cipher designed to showcase SYNTHOS
concepts.  It has **not** been formally audited.  Do not rely on it for
production secrets without independent cryptanalysis.
"""

from __future__ import annotations

import hashlib
import hmac
import os
import re
import struct
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, Tuple

# ── Constants ──────────────────────────────────────────────────────────────────

BLOCK_SIZE = 16          # 128-bit block
KEY_SIZE = 32            # 256-bit key
ROUNDS = 12              # Feistel-style rounds
LATTICE_DIM = 4          # 4×4 lattice (mirrors GPL default)
MAGIC = b"STC1"          # File / wire-format magic bytes
VERSION = 1


class STCMode(Enum):
    """Cipher operating mode."""
    ECB = "ecb"
    CBC = "cbc"
    CTR = "ctr"


# ── Key material ───────────────────────────────────────────────────────────────

@dataclass(frozen=True)
class STCKey:
    """
    Encapsulates derived key material for the Symbolic Topology Cipher.

    Parameters
    ----------
    raw : bytes
        The raw key bytes (must be exactly ``KEY_SIZE`` bytes).

    Attributes
    ----------
    sbox : tuple[int, ...]
        256-entry substitution box derived from the key.
    inv_sbox : tuple[int, ...]
        Inverse substitution box for decryption.
    lattice_order : tuple[int, ...]
        Permutation of 0..BLOCK_SIZE-1 controlling lattice scatter.
    inv_lattice_order : tuple[int, ...]
        Inverse permutation for un-scattering.
    round_keys : tuple[bytes, ...]
        Per-round subkeys (ROUNDS entries, each BLOCK_SIZE bytes).
    """

    raw: bytes
    sbox: Tuple[int, ...] = field(init=False, repr=False)
    inv_sbox: Tuple[int, ...] = field(init=False, repr=False)
    lattice_order: Tuple[int, ...] = field(init=False, repr=False)
    inv_lattice_order: Tuple[int, ...] = field(init=False, repr=False)
    round_keys: Tuple[bytes, ...] = field(init=False, repr=False)

    def __post_init__(self):
        if len(self.raw) != KEY_SIZE:
            raise ValueError(f"Key must be exactly {KEY_SIZE} bytes, got {len(self.raw)}")

        # ── Derive S-box via named-capture hashing ─────────────────────────
        # We treat the key as a "pattern" and hash successive named groups
        # to build a full 256-byte permutation (Fisher–Yates seeded by key).
        sbox = list(range(256))
        seed = hashlib.sha512(b"STC-SBOX:" + self.raw).digest()
        j = 0
        for i in range(255, 0, -1):
            j = (j + sbox[i] + seed[i % len(seed)]) % (i + 1)
            sbox[i], sbox[j] = sbox[j], sbox[i]
        object.__setattr__(self, "sbox", tuple(sbox))

        inv = [0] * 256
        for i, v in enumerate(sbox):
            inv[v] = i
        object.__setattr__(self, "inv_sbox", tuple(inv))

        # ── Derive lattice permutation ─────────────────────────────────────
        perm = list(range(BLOCK_SIZE))
        lseed = hashlib.sha256(b"STC-LATTICE:" + self.raw).digest()
        k = 0
        for i in range(BLOCK_SIZE - 1, 0, -1):
            k = (k + lseed[i % len(lseed)]) % (i + 1)
            perm[i], perm[k] = perm[k], perm[i]
        object.__setattr__(self, "lattice_order", tuple(perm))

        inv_perm = [0] * BLOCK_SIZE
        for i, v in enumerate(perm):
            inv_perm[v] = i
        object.__setattr__(self, "inv_lattice_order", tuple(inv_perm))

        # ── Derive round subkeys ───────────────────────────────────────────
        rkeys: list[bytes] = []
        for r in range(ROUNDS):
            rk = hashlib.sha256(b"STC-RK:" + self.raw + struct.pack(">I", r)).digest()[:BLOCK_SIZE]
            rkeys.append(rk)
        object.__setattr__(self, "round_keys", tuple(rkeys))

    # ── Convenience constructors ───────────────────────────────────────────

    @classmethod
    def generate(cls) -> "STCKey":
        """Generate a cryptographically random key."""
        return cls(raw=os.urandom(KEY_SIZE))

    @classmethod
    def from_passphrase(cls, passphrase: str, salt: bytes | None = None) -> "STCKey":
        """Derive a key from a passphrase using PBKDF2-HMAC-SHA256."""
        if salt is None:
            salt = os.urandom(16)
        raw = hashlib.pbkdf2_hmac("sha256", passphrase.encode(), salt, iterations=600_000, dklen=KEY_SIZE)
        return cls(raw=raw)

    @classmethod
    def from_hex(cls, hex_str: str) -> "STCKey":
        """Load key from hex string."""
        return cls(raw=bytes.fromhex(hex_str))

    def hex(self) -> str:
        """Return key as hex string."""
        return self.raw.hex()


# ── Core cipher ────────────────────────────────────────────────────────────────

class SymbolicTopologyCipher:
    """
    The Symbolic Topology Cipher (STC).

    Implements encrypt / decrypt for arbitrary byte strings with PKCS#7
    padding and optional HMAC-SHA256 authentication.

    Parameters
    ----------
    key : STCKey
        Pre-derived key material.
    mode : STCMode
        Operating mode (ECB, CBC, or CTR).
    verbose : bool
        If True, emit per-round diagnostics to a list accessible via
        ``self.trace``.

    Examples
    --------
    >>> key = STCKey.generate()
    >>> cipher = SymbolicTopologyCipher(key)
    >>> ct = cipher.encrypt(b"hello SYNTHOS")
    >>> cipher.decrypt(ct)
    b'hello SYNTHOS'
    """

    def __init__(self, key: STCKey, mode: STCMode = STCMode.CBC, *, verbose: bool = False):
        self.key = key
        self.mode = mode
        self.verbose = verbose
        self.trace: List[str] = []

    # ── Public API ─────────────────────────────────────────────────────────

    def encrypt(self, plaintext: bytes, *, authenticate: bool = True) -> bytes:
        """
        Encrypt *plaintext* and return the ciphertext envelope.

        The envelope layout is::

            MAGIC (4 B) | VERSION (1 B) | MODE (1 B) | FLAGS (1 B)
            | IV/NONCE (16 B) | CT_LEN (4 B, big-endian) | CIPHERTEXT
            [| HMAC-SHA256 (32 B)]

        Parameters
        ----------
        plaintext : bytes
        authenticate : bool
            Append HMAC-SHA256 tag over the envelope (default True).
        """
        padded = self._pkcs7_pad(plaintext)
        blocks = [padded[i:i + BLOCK_SIZE] for i in range(0, len(padded), BLOCK_SIZE)]
        iv = os.urandom(BLOCK_SIZE)
        flags = (0x01 if authenticate else 0x00)

        self._log(f"[ENC] mode={self.mode.value}  blocks={len(blocks)}  auth={authenticate}")

        if self.mode == STCMode.ECB:
            ct_blocks = [self._encrypt_block(b) for b in blocks]
        elif self.mode == STCMode.CBC:
            ct_blocks = []
            prev = iv
            for blk in blocks:
                xored = bytes(a ^ b for a, b in zip(blk, prev))
                enc = self._encrypt_block(xored)
                ct_blocks.append(enc)
                prev = enc
        elif self.mode == STCMode.CTR:
            ct_blocks = []
            for idx, blk in enumerate(blocks):
                counter_blk = iv[:12] + struct.pack(">I", idx)
                ks = self._encrypt_block(counter_blk)
                ct_blocks.append(bytes(a ^ b for a, b in zip(blk, ks)))
        else:
            raise ValueError(f"Unsupported mode: {self.mode}")

        ct = b"".join(ct_blocks)
        envelope = MAGIC + struct.pack("BBB", VERSION, list(STCMode).index(self.mode), flags)
        envelope += iv + struct.pack(">I", len(ct)) + ct

        if authenticate:
            tag = hmac.new(self.key.raw, envelope, hashlib.sha256).digest()
            envelope += tag

        self._log(f"[ENC] envelope={len(envelope)} B  ct={len(ct)} B")
        return envelope

    def decrypt(self, envelope: bytes) -> bytes:
        """
        Decrypt an STC envelope produced by :meth:`encrypt`.

        Raises
        ------
        ValueError
            On bad magic, version mismatch, or HMAC failure.
        """
        if envelope[:4] != MAGIC:
            raise ValueError("Bad STC envelope magic")
        ver, mode_idx, flags = struct.unpack("BBB", envelope[4:7])
        if ver != VERSION:
            raise ValueError(f"Unsupported STC version {ver}")

        mode = list(STCMode)[mode_idx]
        authenticate = bool(flags & 0x01)

        if authenticate:
            tag = envelope[-32:]
            body = envelope[:-32]
            expected = hmac.new(self.key.raw, body, hashlib.sha256).digest()
            if not hmac.compare_digest(tag, expected):
                raise ValueError("HMAC verification failed — data may be tampered")
        else:
            body = envelope

        iv = body[7:7 + BLOCK_SIZE]
        ct_len = struct.unpack(">I", body[7 + BLOCK_SIZE:11 + BLOCK_SIZE])[0]
        ct = body[11 + BLOCK_SIZE:11 + BLOCK_SIZE + ct_len]
        blocks = [ct[i:i + BLOCK_SIZE] for i in range(0, len(ct), BLOCK_SIZE)]

        self._log(f"[DEC] mode={mode.value}  blocks={len(blocks)}  auth={authenticate}")

        if mode == STCMode.ECB:
            pt_blocks = [self._decrypt_block(b) for b in blocks]
        elif mode == STCMode.CBC:
            pt_blocks = []
            prev = iv
            for blk in blocks:
                dec = self._decrypt_block(blk)
                pt_blocks.append(bytes(a ^ b for a, b in zip(dec, prev)))
                prev = blk
        elif mode == STCMode.CTR:
            pt_blocks = []
            for idx, blk in enumerate(blocks):
                counter_blk = iv[:12] + struct.pack(">I", idx)
                ks = self._encrypt_block(counter_blk)
                pt_blocks.append(bytes(a ^ b for a, b in zip(blk, ks)))
        else:
            raise ValueError(f"Unsupported mode: {mode}")

        padded = b"".join(pt_blocks)
        return self._pkcs7_unpad(padded)

    # ── Block-level operations ─────────────────────────────────────────────

    def _encrypt_block(self, block: bytes) -> bytes:
        """Encrypt a single BLOCK_SIZE block through all rounds."""
        state = bytearray(block)
        for r in range(ROUNDS):
            state = self._round_encrypt(state, r)
        return bytes(state)

    def _decrypt_block(self, block: bytes) -> bytes:
        """Decrypt a single BLOCK_SIZE block (rounds in reverse)."""
        state = bytearray(block)
        for r in range(ROUNDS - 1, -1, -1):
            state = self._round_decrypt(state, r)
        return bytes(state)

    def _round_encrypt(self, state: bytearray, r: int) -> bytearray:
        """
        Single encryption round:
          1. Add round key (XOR)
          2. S-box substitution
          3. Lattice permutation (scatter)
          4. Möbius fold (XOR-twist halves)
        """
        rk = self.key.round_keys[r]

        # 1. AddRoundKey
        state = bytearray(a ^ b for a, b in zip(state, rk))

        # 2. SubBytes (S-box)
        state = bytearray(self.key.sbox[b] for b in state)

        # 3. Lattice permutation
        tmp = bytearray(BLOCK_SIZE)
        for i in range(BLOCK_SIZE):
            tmp[self.key.lattice_order[i]] = state[i]
        state = tmp

        # 4. Möbius fold
        half = BLOCK_SIZE // 2
        left = state[:half]
        right = state[half:]
        new_right = bytearray(((l + r_) % 256) ^ rk[i % len(rk)] for i, (l, r_) in enumerate(zip(left, right)))
        state = bytearray(left) + new_right

        self._log(f"  round {r:2d} enc  state={state.hex()}")
        return state

    def _round_decrypt(self, state: bytearray, r: int) -> bytearray:
        """Inverse of ``_round_encrypt``."""
        rk = self.key.round_keys[r]

        # 4'. Inverse Möbius fold
        half = BLOCK_SIZE // 2
        left = state[:half]
        new_right_enc = state[half:]
        right = bytearray(((nr ^ rk[i % len(rk)]) - l) % 256 for i, (l, nr) in enumerate(zip(left, new_right_enc)))
        state = bytearray(left) + right

        # 3'. Inverse lattice permutation
        tmp = bytearray(BLOCK_SIZE)
        for i in range(BLOCK_SIZE):
            tmp[i] = state[self.key.lattice_order[i]]
        state = tmp

        # 2'. Inverse S-box
        state = bytearray(self.key.inv_sbox[b] for b in state)

        # 1'. Remove round key
        state = bytearray(a ^ b for a, b in zip(state, rk))

        self._log(f"  round {r:2d} dec  state={state.hex()}")
        return state

    # ── Padding ────────────────────────────────────────────────────────────

    @staticmethod
    def _pkcs7_pad(data: bytes) -> bytes:
        pad_len = BLOCK_SIZE - (len(data) % BLOCK_SIZE)
        return data + bytes([pad_len] * pad_len)

    @staticmethod
    def _pkcs7_unpad(data: bytes) -> bytes:
        pad_len = data[-1]
        if pad_len < 1 or pad_len > BLOCK_SIZE:
            raise ValueError("Invalid PKCS#7 padding")
        if data[-pad_len:] != bytes([pad_len] * pad_len):
            raise ValueError("Corrupt PKCS#7 padding")
        return data[:-pad_len]

    # ── Diagnostics ────────────────────────────────────────────────────────

    def _log(self, msg: str):
        if self.verbose:
            self.trace.append(msg)
