# ═══════════════════════════════════════════════════════════════════════════════════
# SYNTHOS STATE CRYSTALLIZATION FIELD (SCF) v1.0
# Multi-dimensional state tensor with crystallized match geometry
# ═══════════════════════════════════════════════════════════════════════════════════

import re
import numpy as np
from typing import Dict, List, Tuple, Optional, Any, Union
from dataclasses import dataclass, field
from enum import Enum
from collections import defaultdict, deque
import hashlib
import json

class LayerType(Enum):
    LEXICAL = "lexical"
    SEMANTIC = "semantic"
    TOPOLOGICAL = "topological"

class MemoryType(Enum):
    SHORT_TERM = "short_term"
    LONG_TERM = "long_term"
    EPISODIC = "episodic"

@dataclass
class StateTensor:
    """Rank-3 ASCII tensor: T[LAYER][ROW][COL]"""
    layers: Dict[LayerType, np.ndarray]
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def __post_init__(self):
        # Ensure all layers have the same shape
        if self.layers:
            shape = None
            for layer_array in self.layers.values():
                if shape is None:
                    shape = layer_array.shape
                elif layer_array.shape != shape:
                    raise ValueError("All layers must have the same shape")
    
    def get_layer(self, layer_type: LayerType) -> np.ndarray:
        """Get a specific layer"""
        return self.layers.get(layer_type, np.array([]))
    
    def set_layer(self, layer_type: LayerType, data: np.ndarray):
        """Set a specific layer"""
        self.layers[layer_type] = data
    
    def get_cell(self, layer_type: LayerType, row: int, col: int) -> Any:
        """Get a specific cell value"""
        if layer_type in self.layers:
            return self.layers[layer_type][row, col]
        return None
    
    def set_cell(self, layer_type: LayerType, row: int, col: int, value: Any):
        """Set a specific cell value"""
        if layer_type not in self.layers:
            rows, cols = 4, 4  # Default size
            self.layers[layer_type] = np.empty((rows, cols), dtype=object)
        
        self.layers[layer_type][row, col] = value
    
    def crystallize_pattern(self, pattern: str, match_result: re.Match, 
                          layer_type: LayerType, position: Tuple[int, int]):
        """Crystallize a pattern match into the tensor"""
        row, col = position
        
        # Store pattern information
        pattern_info = {
            "pattern": pattern,
            "match": match_result.group(0),
            "groups": match_result.groupdict(),
            "span": match_result.span(),
            "timestamp": len(self.metadata.get("history", []))
        }
        
        self.set_cell(layer_type, row, col, pattern_info)
        
        # Update metadata
        if "history" not in self.metadata:
            self.metadata["history"] = []
        self.metadata["history"].append({
            "layer": layer_type.value,
            "position": position,
            "pattern": pattern,
            "match": match_result.group(0)
        })

@dataclass
class MemoryLattice:
    """Symbolic working memory lattice"""
    short_term_buffer: deque = field(default_factory=lambda: deque(maxlen=8))
    long_term_registers: Dict[str, Any] = field(default_factory=dict)
    episodic_stack: List[Dict[str, Any]] = field(default_factory=list)
    
    def add_short_term(self, item: Any):
        """Add item to short-term memory"""
        self.short_term_buffer.append(item)
    
    def get_short_term_pattern(self) -> str:
        """Get sliding window pattern"""
        return "(?:.*?){0,7}(?P<FOCUS>.+)$"
    
    def set_long_term(self, register_name: str, value: Any):
        """Set long-term register value"""
        self.long_term_registers[register_name] = value
    
    def get_long_term(self, register_name: str) -> Any:
        """Get long-term register value"""
        return self.long_term_registers.get(register_name)
    
    def add_episode(self, episode_data: Dict[str, Any]):
        """Add episodic memory"""
        episode_id = f"EPISODE_{len(self.episodic_stack)}"
        episode = {
            "id": episode_id,
            "timestamp": len(self.episodic_stack),
            "data": episode_data
        }
        self.episodic_stack.append(episode)
        return episode_id
    
    def get_episode_pattern(self, episode_n: int) -> str:
        """Get episodic backreference pattern"""
        if episode_n > 0 and episode_n <= len(self.episodic_stack):
            prev_episode = self.episodic_stack[episode_n - 1]
            return f"\\k<{prev_episode['id']}>"
        return ""

class StateCrystallizationField:
    """Multi-dimensional state tensor with crystallized match geometry"""
    
    def __init__(self, tensor_shape: Tuple[int, int] = (4, 4)):
        self.tensor_shape = tensor_shape
        self.state_tensor = StateTensor(
            layers={
                LayerType.LEXICAL: np.empty(tensor_shape, dtype=object),
                LayerType.SEMANTIC: np.empty(tensor_shape, dtype=object),
                LayerType.TOPOLOGICAL: np.empty(tensor_shape, dtype=object)
            }
        )
        self.memory_lattice = MemoryLattice()
        self.coherence_rules = []
        self.gate_conditions = []
        
        # Initialize default tensor state
        self._initialize_tensor()
    
    def _initialize_tensor(self):
        """Initialize tensor with default empty state"""
        for layer_type in LayerType:
            for row in range(self.tensor_shape[0]):
                for col in range(self.tensor_shape[1]):
                    self.state_tensor.set_cell(layer_type, row, col, None)
    
    def crystallize_match(self, pattern: str, match_result: re.Match, 
                         layer_type: LayerType, position: Optional[Tuple[int, int]] = None):
        """Crystallize a regex match into the state tensor"""
        if position is None:
            position = self._find_empty_position(layer_type)
        
        self.state_tensor.crystallize_pattern(pattern, match_result, layer_type, position)
        
        # Also add to short-term memory
        self.memory_lattice.add_short_term({
            "pattern": pattern,
            "match": match_result.group(0),
            "layer": layer_type.value,
            "position": position
        })
        
        return position
    
    def _find_empty_position(self, layer_type: LayerType) -> Tuple[int, int]:
        """Find an empty position in the tensor layer"""
        layer = self.state_tensor.get_layer(layer_type)
        
        for row in range(layer.shape[0]):
            for col in range(layer.shape[1]):
                if layer[row, col] is None:
                    return (row, col)
        
        # If no empty position, use least recently used (simplified)
        return (0, 0)
    
    def extract_state_pattern(self, layer_type: LayerType) -> str:
        """Extract regex pattern representing current state"""
        layer = self.state_tensor.get_layer(layer_type)
        patterns = []
        
        for row in range(layer.shape[0]):
            for col in range(layer.shape[1]):
                cell_value = layer[row, col]
                if cell_value and isinstance(cell_value, dict):
                    patterns.append(cell_value.get("pattern", ""))
        
        if patterns:
            return f"(?:{'|'.join(patterns)})+"
        return ""
    
    def compute_coherence(self) -> float:
        """Compute state coherence based on backreferences and consistency"""
        coherence_score = 0.0
        total_checks = 0
        
        # Check for backreferences within state
        for layer_type in LayerType:
            layer = self.state_tensor.get_layer(layer_type)
            
            for row in range(layer.shape[0]):
                for col in range(layer.shape[1]):
                    cell_value = layer[row, col]
                    if cell_value and isinstance(cell_value, dict):
                        # Check if pattern has backreferences
                        pattern = cell_value.get("pattern", "")
                        if re.search(r'\\[1-9]|\\k<', pattern):
                            coherence_score += 0.5
                        
                        total_checks += 1
        
        # Normalize coherence score
        if total_checks > 0:
            coherence_score = min(1.0, coherence_score / total_checks)
        
        return coherence_score
    
    def apply_gate_conditions(self) -> bool:
        """Check if output gate conditions are met"""
        for condition in self.gate_conditions:
            if not self._evaluate_gate_condition(condition):
                return False
        return True
    
    def _evaluate_gate_condition(self, condition: str) -> bool:
        """Evaluate a gate condition"""
        # Parse condition like "STATE_TENSOR.TOPOLOGICAL.EMIT == TRUE"
        if "TOPOLOGICAL.EMIT" in condition:
            # Check if topological layer has emit operation
            topo_layer = self.state_tensor.get_layer(LayerType.TOPOLOGICAL)
            for row in range(topo_layer.shape[0]):
                for col in range(topo_layer.shape[1]):
                    cell_value = topo_layer[row, col]
                    if cell_value and isinstance(cell_value, dict):
                        if "EMIT" in str(cell_value.get("pattern", "")):
                            return True
        return False
    
    def add_coherence_rule(self, rule_pattern: str):
        """Add a coherence assertion rule"""
        self.coherence_rules.append(rule_pattern)
    
    def add_gate_condition(self, condition: str):
        """Add an output gate condition"""
        self.gate_conditions.append(condition)
    
    def get_state_vector(self) -> np.ndarray:
        """Get flattened state vector for processing"""
        vectors = []
        
        for layer_type in LayerType:
            layer = self.state_tensor.get_layer(layer_type)
            # Convert layer to numeric representation
            layer_vector = np.zeros(layer.shape[0] * layer.shape[1])
            
            for i in range(layer.shape[0]):
                for j in range(layer.shape[1]):
                    cell_value = layer[i, j]
                    if cell_value and isinstance(cell_value, dict):
                        # Simple encoding: use hash of pattern
                        pattern = cell_value.get("pattern", "")
                        layer_vector[i * layer.shape[1] + j] = hash(pattern) % 1000 / 1000.0
            
            vectors.append(layer_vector)
        
        # Concatenate all layers
        return np.concatenate(vectors)
    
    def update_from_attention(self, attention_results: List[Dict[str, Any]]):
        """Update state tensor from attention mesh results"""
        for result in attention_results:
            # Crystallize attention patterns
            if "intersections" in result:
                for intersection in result.get("intersections", []):
                    # Create a mock match for crystallization
                    pattern_text = f"{intersection.pattern1}∩{intersection.pattern2}"
                    mock_match = re.match(pattern_text, pattern_text)
                    if mock_match:
                        self.crystallize_match(
                            pattern_text,
                            mock_match,
                            LayerType.TOPOLOGICAL
                        )
    
    def get_memory_state(self) -> Dict[str, Any]:
        """Get current memory lattice state"""
        return {
            "short_term": list(self.memory_lattice.short_term_buffer),
            "long_term": self.memory_lattice.long_term_registers.copy(),
            "episodic_count": len(self.memory_lattice.episodic_stack),
            "pattern": self.memory_lattice.get_short_term_pattern()
        }
    
    def create_state_hash(self) -> str:
        """Create hash of current state for comparison"""
        state_vector = self.get_state_vector()
        state_str = json.dumps(state_vector.tolist(), sort_keys=True)
        return hashlib.md5(state_str.encode()).hexdigest()
    
    def diff_states(self, other_state: 'StateCrystallizationField') -> Dict[str, Any]:
        """Compute difference between two states"""
        current_hash = self.create_state_hash()
        other_hash = other_state.create_state_hash()
        
        if current_hash == other_hash:
            return {"changed": False, "differences": []}
        
        differences = []
        
        # Compare tensor layers
        for layer_type in LayerType:
            current_layer = self.state_tensor.get_layer(layer_type)
            other_layer = other_state.state_tensor.get_layer(layer_type)
            
            for row in range(current_layer.shape[0]):
                for col in range(current_layer.shape[1]):
                    current_cell = current_layer[row, col]
                    other_cell = other_layer[row, col]
                    
                    if current_cell != other_cell:
                        differences.append({
                            "layer": layer_type.value,
                            "position": (row, col),
                            "from": other_cell,
                            "to": current_cell
                        })
        
        return {
            "changed": True,
            "differences": differences,
            "coherence_delta": self.compute_coherence() - other_state.compute_coherence()
        }
    
    def visualize_tensor(self) -> str:
        """Create ASCII visualization of the state tensor"""
        visualization = []
        visualization.append("STATE TENSOR VISUALIZATION")
        visualization.append("=" * 50)
        
        for layer_type in LayerType:
            visualization.append(f"\n{layer_type.value.upper()} LAYER:")
            layer = self.state_tensor.get_layer(layer_type)
            
            for row in range(layer.shape[0]):
                row_str = "│ "
                for col in range(layer.shape[1]):
                    cell_value = layer[row, col]
                    if cell_value and isinstance(cell_value, dict):
                        pattern = cell_value.get("pattern", "")[:8]
                        row_str += f"{pattern:<10} │ "
                    else:
                        row_str += f"{'EMPTY':<10} │ "
                visualization.append(row_str)
                
                # Add separator
                if row < layer.shape[0] - 1:
                    visualization.append("├" + "─" * (12 * layer.shape[1] - 1) + "┤")
            
            # Add bottom border
            visualization.append("└" + "─" * (12 * layer.shape[1] - 1) + "┘")
        
        return "\n".join(visualization)
    
    def export_state(self, filename: str):
        """Export state tensor to file"""
        export_data = {
            "tensor_shape": self.tensor_shape,
            "layers": {},
            "memory_lattice": {
                "short_term": list(self.memory_lattice.short_term_buffer),
                "long_term": self.memory_lattice.long_term_registers,
                "episodic": self.memory_lattice.episodic_stack
            },
            "coherence": self.compute_coherence(),
            "metadata": self.state_tensor.metadata
        }
        
        # Export layers as serializable data
        for layer_type in LayerType:
            layer = self.state_tensor.get_layer(layer_type)
            layer_data = []
            for row in range(layer.shape[0]):
                row_data = []
                for col in range(layer.shape[1]):
                    cell_value = layer[row, col]
                    if cell_value and isinstance(cell_value, dict):
                        # Convert match objects to strings
                        serializable_cell = cell_value.copy()
                        if "match" in serializable_cell:
                            serializable_cell["match"] = str(serializable_cell["match"])
                        row_data.append(serializable_cell)
                    else:
                        row_data.append(None)
                layer_data.append(row_data)
            export_data["layers"][layer_type.value] = layer_data
        
        with open(filename, 'w') as f:
            json.dump(export_data, f, indent=2)
        
        print(f"State exported to {filename}")
    
    def import_state(self, filename: str):
        """Import state tensor from file"""
        with open(filename, 'r') as f:
            import_data = json.load(f)
        
        # Restore tensor shape
        self.tensor_shape = tuple(import_data["tensor_shape"])
        
        # Restore layers
        for layer_name, layer_data in import_data["layers"].items():
            layer_type = LayerType(layer_name)
            layer_array = np.empty(self.tensor_shape, dtype=object)
            
            for row, row_data in enumerate(layer_data):
                for col, cell_data in enumerate(row_data):
                    layer_array[row, col] = cell_data
            
            self.state_tensor.set_layer(layer_type, layer_array)
        
        # Restore memory lattice
        memory_data = import_data["memory_lattice"]
        self.memory_lattice.short_term_buffer = deque(
            memory_data["short_term"], maxlen=8
        )
        self.memory_lattice.long_term_registers = memory_data["long_term"]
        self.memory_lattice.episodic_stack = memory_data["episodic"]
        
        # Restore metadata
        self.state_tensor.metadata = import_data.get("metadata", {})
        
        print(f"State imported from {filename}")

# Example usage and demonstration
if __name__ == "__main__":
    print("=== SYNTHOS STATE CRYSTALLIZATION FIELD DEMO ===")
    
    # Create state crystallization field
    scf = StateCrystallizationField()
    
    # Add some coherence rules and gate conditions
    scf.add_coherence_rule("(?P<A>\\w+).*\\k<A>")  # Backreference coherence
    scf.add_gate_condition("STATE_TENSOR.TOPOLOGICAL.EMIT == TRUE")
    
    print(f"\nInitial state coherence: {scf.compute_coherence():.2f}")
    
    # Simulate some pattern matches and crystallize them
    test_patterns = [
        (r"\\w+", re.match(r"\\w+", "word"), LayerType.LEXICAL),
        (r"(?P<concept>\\w+)", re.match(r"(?P<concept>\\w+)", "concept"), LayerType.SEMANTIC),
        (r"\\b\\w+\\b", re.match(r"\\b\\w+\\b", "boundary"), LayerType.TOPOLOGICAL),
    ]
    
    print(f"\nCrystallizing patterns...")
    for pattern, match, layer_type in test_patterns:
        if match:
            position = scf.crystallize_match(pattern, match, layer_type)
            print(f"  Crystallized '{match.group(0)}' at {position} in {layer_type.value}")
    
    # Show state visualization
    print(f"\n{scf.visualize_tensor()}")
    
    # Check coherence after crystallization
    print(f"\nUpdated state coherence: {scf.compute_coherence():.2f}")
    
    # Test memory lattice
    print(f"\nMemory lattice state:")
    memory_state = scf.get_memory_state()
    print(f"  Short-term items: {len(memory_state['short_term'])}")
    print(f"  Long-term registers: {list(memory_state['long_term'].keys())}")
    print(f"  Episodic count: {memory_state['episodic_count']}")
    
    # Add some long-term registers
    scf.memory_lattice.set_long_term("GOAL", "ACHIEVE understanding")
    scf.memory_lattice.set_long_term("CONTEXT", "SYNTHOS demonstration")
    
    # Add episodic memory
    episode_id = scf.memory_lattice.add_episode({
        "action": "pattern_crystallization",
        "patterns": [p[0] for p in test_patterns],
        "coherence": scf.compute_coherence()
    })
    print(f"  Added episode: {episode_id}")
    
    # Test state vector
    state_vector = scf.get_state_vector()
    print(f"\nState vector shape: {state_vector.shape}")
    print(f"State vector sample: {state_vector[:5]}...")
    
    # Test state hashing
    state_hash = scf.create_state_hash()
    print(f"State hash: {state_hash}")
    
    # Test state diff
    new_scf = StateCrystallizationField()
    diff = scf.diff_states(new_scf)
    print(f"\nState diff:")
    print(f"  Changed: {diff['changed']}")
    print(f"  Differences: {len(diff['differences'])}")
    if diff['changed']:
        print(f"  Coherence delta: {diff['coherence_delta']:.3f}")
    
    # Export and import test
    scf.export_state("synthos_state.json")
    
    # Create new SCF and import state
    test_scf = StateCrystallizationField()
    test_scf.import_state("synthos_state.json")
    
    # Verify import
    imported_coherence = test_scf.compute_coherence()
    print(f"\nImported state coherence: {imported_coherence:.2f}")
    
    print("\n=== SCF DEMO COMPLETE ===")
