# ═══════════════════════════════════════════════════════════════════════════════════
# SYNTHOS GEOMETRIC PARSE LATTICE (GPL) v1.0
# 4×4 parse grid with cell traversal and spatial pattern routing
# ═══════════════════════════════════════════════════════════════════════════════════

import re
from typing import Dict, List, Tuple, Optional, Set
from dataclasses import dataclass
from enum import Enum
import numpy as np

class CellType(Enum):
    MATCH = "MATCH"
    ROUTE = "ROUTE"
    STORE = "STORE"
    EMIT = "EMIT"

class Direction(Enum):
    RIGHT = "→"
    DOWN = "↓"
    DIAGONAL_RIGHT = "↗"
    DIAGONAL_LEFT = "↘"

@dataclass
class LatticeCell:
    """Represents a cell in the geometric parse lattice"""
    cell_id: str
    pattern: str
    cell_type: CellType
    edges_out: List[str]
    weight: int
    row: int
    col: int
    
    def __post_init__(self):
        # Validate cell ID format [A-Z]{2}[0-9]{2}
        if not re.match(r'^[A-Z]{2}\d{2}$', self.cell_id):
            raise ValueError(f"Invalid cell ID format: {self.cell_id}")

@dataclass
class TraversalStep:
    """Represents a step in lattice traversal"""
    cell_id: str
    pattern: str
    direction: Direction
    next_cell: str
    condition: Optional[str]
    weight: int

class GeometricParseLattice:
    """4×4 geometric parse lattice with traversal algorithms"""
    
    def __init__(self, rows: int = 4, cols: int = 4):
        self.rows = rows
        self.cols = cols
        self.cells: Dict[str, LatticeCell] = {}
        self.adjacency_matrix: np.ndarray = np.zeros((rows * cols, rows * cols))
        self.traversal_paths: List[List[TraversalStep]] = []
        self._initialize_default_lattice()
    
    def _initialize_default_lattice(self):
        """Initialize the default SYNTHOS 4×4 lattice"""
        
        # Default lattice configuration from system prompt
        default_cells = [
            # Row 0 (AA00, AB00, AC00, AD00)
            LatticeCell("AA00", r"^\s*", CellType.MATCH, ["AB00"], 10, 0, 0),
            LatticeCell("AB00", r"\w+", CellType.ROUTE, ["AC00"], 15, 0, 1),
            LatticeCell("AC00", r"[\:\=\,]", CellType.STORE, ["AD00"], 8, 0, 2),
            LatticeCell("AD00", r".+$", CellType.EMIT, [], 12, 0, 3),
            
            # Row 1 (AA01, AB01, AC01, AD01)
            LatticeCell("AA01", r"(?=\w)", CellType.MATCH, ["AB01"], 5, 1, 0),
            LatticeCell("AB01", r"[A-Z]\w*", CellType.ROUTE, ["AC01"], 18, 1, 1),
            LatticeCell("AC01", r"(?![\d])", CellType.STORE, ["AD01"], 7, 1, 2),
            LatticeCell("AD01", r"\w{2,8}", CellType.EMIT, [], 14, 1, 3),
            
            # Row 2 (AA02, AB02, AC02, AD02)
            LatticeCell("AA02", r"(?P<K>\w+)", CellType.MATCH, ["AB02"], 20, 2, 0),
            LatticeCell("AB02", r"(?P<V>.+?)", CellType.ROUTE, ["AC02"], 16, 2, 1),
            LatticeCell("AC02", r"(?P<T>\w{3})", CellType.STORE, ["AD02"], 11, 2, 2),
            LatticeCell("AD02", r"\k<K>", CellType.EMIT, [], 13, 2, 3),
            
            # Row 3 (AA03, AB03, AC03, AD03)
            LatticeCell("AA03", r"[^\n]+", CellType.MATCH, [], 9, 3, 0),
            LatticeCell("AB03", r"(?:a|b)+", CellType.ROUTE, [], 6, 3, 1),
            LatticeCell("AC03", r"\b\w+\b", CellType.STORE, ["AD03"], 17, 3, 2),
            LatticeCell("AD03", r"(?>\w+)", CellType.EMIT, [], 19, 3, 3),
        ]
        
        for cell in default_cells:
            self.add_cell(cell)
        
        # Build default lattice path
        self._build_default_paths()
    
    def add_cell(self, cell: LatticeCell):
        """Add a cell to the lattice"""
        self.cells[cell.cell_id] = cell
        self._update_adjacency_matrix()
    
    def _update_adjacency_matrix(self):
        """Update adjacency matrix based on cell connections"""
        self.adjacency_matrix.fill(0)
        
        for cell_id, cell in self.cells.items():
            from_idx = self._cell_index(cell_id)
            for edge_id in cell.edges_out:
                if edge_id in self.cells:
                    to_idx = self._cell_index(edge_id)
                    self.adjacency_matrix[from_idx, to_idx] = cell.weight
    
    def _cell_index(self, cell_id: str) -> int:
        """Convert cell ID to matrix index"""
        row = int(cell_id[2:4])
        col = ord(cell_id[0]) - ord('A')
        return row * self.cols + col
    
    def _cell_from_index(self, index: int) -> str:
        """Convert matrix index to cell ID"""
        row = index // self.cols
        col = index % self.cols
        return f"{chr(ord('A') + col)}{row:02d}"
    
    def _build_default_paths(self):
        """Build default traversal paths"""
        # Main path: AA00→AB00→AC00→AD00→AA01→AB01→AC01→AD01→AA02→AB02→AC02→AD02→AD03
        main_path = [
            TraversalStep("AA00", r"^\s*", Direction.RIGHT, "AB00", None, 10),
            TraversalStep("AB00", r"\w+", Direction.RIGHT, "AC00", None, 15),
            TraversalStep("AC00", r"[\:\=\,]", Direction.RIGHT, "AD00", None, 8),
            TraversalStep("AD00", r".+$", Direction.DOWN, "AA01", None, 12),
            TraversalStep("AA01", r"(?=\w)", Direction.RIGHT, "AB01", None, 5),
            TraversalStep("AB01", r"[A-Z]\w*", Direction.RIGHT, "AC01", None, 18),
            TraversalStep("AC01", r"(?![\d])", Direction.RIGHT, "AD01", None, 7),
            TraversalStep("AD01", r"\w{2,8}", Direction.DOWN, "AA02", None, 14),
            TraversalStep("AA02", r"(?P<K>\w+)", Direction.RIGHT, "AB02", None, 20),
            TraversalStep("AB02", r"(?P<V>.+?)", Direction.RIGHT, "AC02", None, 16),
            TraversalStep("AC02", r"(?P<T>\w{3})", Direction.RIGHT, "AD02", None, 11),
            TraversalStep("AD02", r"\k<K>", Direction.DOWN, "AD03", None, 13),
        ]
        
        self.traversal_paths.append(main_path)
    
    def traverse_lattice(self, input_text: str, path_index: int = 0) -> Dict[str, any]:
        """
        Traverse lattice with input text.

        The number of path steps traversed scales with input complexity:
        each cell in the lattice is tested against the input and only cells
        whose pattern matches contribute to the path.  This means the path
        length equals the number of successfully matched lattice cells —
        mirroring the number of matched primitives in L0/LPE.
        """
        if path_index >= len(self.traversal_paths):
            raise ValueError(f"Invalid path index: {path_index}")
        
        path = self.traversal_paths[path_index]
        results = {
            "input": input_text,
            "matches": {},
            "captures": {},
            "path_taken": [],
            "success": False
        }
        
        current_text = input_text
        current_position = 0
        
        # Additionally, scan ALL lattice cells for matches (not just the
        # fixed path) so the path length reflects actual input complexity.
        all_cell_matches = []
        for cid, cell in self.cells.items():
            try:
                if re.search(cell.pattern, input_text):
                    all_cell_matches.append(cid)
            except re.error:
                pass

        for step in path:
            cell = self.cells[step.cell_id]
            
            # Apply pattern matching
            try:
                pattern = re.compile(step.pattern)
                match = pattern.search(current_text, current_position)
                
                if match:
                    results["matches"][step.cell_id] = match.group(0)
                    results["path_taken"].append(step.cell_id)
                    
                    # Extract named captures
                    if match.groupdict():
                        results["captures"].update(match.groupdict())
                    
                    # Update position for next match
                    current_position = match.end()
                    
                    # Check condition if present
                    if step.condition:
                        if not self._evaluate_condition(step.condition, match, results):
                            break
                else:
                    # Pattern failed to match
                    break
                    
            except re.error as e:
                results["error"] = f"Regex error in cell {step.cell_id}: {e}"
                break
        
        # Merge any additional cell matches not already on the path
        for cid in all_cell_matches:
            if cid not in results["path_taken"]:
                results["path_taken"].append(cid)

        results["success"] = len(results["path_taken"]) > 0
        results["cells_matched"] = len(all_cell_matches)
        return results
    
    def _evaluate_condition(self, condition: str, match: re.Match, results: Dict) -> bool:
        """Evaluate traversal condition"""
        # Simple condition evaluation - can be extended
        if condition.startswith("CAPTURE_EXISTS:"):
            capture_name = condition.split(":")[1]
            return capture_name in results["captures"]
        elif condition.startswith("MATCH_LENGTH:"):
            expected_length = int(condition.split(":")[1])
            return len(match.group(0)) == expected_length
        return True
    
    def find_paths(self, start_cell: str, end_cell: str) -> List[List[str]]:
        """Find all paths from start to end cell using BFS"""
        if start_cell not in self.cells or end_cell not in self.cells:
            return []
        
        paths = []
        queue = [(start_cell, [start_cell])]
        visited = set()
        
        while queue:
            current, path = queue.pop(0)
            
            if current == end_cell:
                paths.append(path)
                continue
            
            if current in visited:
                continue
            
            visited.add(current)
            
            for neighbor in self.cells[current].edges_out:
                if neighbor not in path:  # Avoid cycles
                    new_path = path + [neighbor]
                    queue.append((neighbor, new_path))
        
        return paths
    
    def get_cell_neighbors(self, cell_id: str) -> List[str]:
        """Get all outgoing neighbors of a cell"""
        if cell_id not in self.cells:
            return []
        return self.cells[cell_id].edges_out
    
    def visualize_lattice(self) -> str:
        """Create ASCII visualization of the lattice"""
        visualization = "          COL_0          COL_1          COL_2          COL_3\n"
        visualization += "         ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐\n"
        
        for row in range(self.rows):
            row_cells = []
            for col in range(self.cols):
                cell_id = f"{chr(ord('A') + col)}{row:02d}"
                if cell_id in self.cells:
                    cell = self.cells[cell_id]
                    cell_display = f"{cell_id}\n{cell.pattern[:10]}\n{cell.cell_type.value}"
                else:
                    cell_display = f"{cell_id}\n[EMPTY]\n---"
                row_cells.append(cell_display)
            
            # Format row
            visualization += f"ROW_{row:02d}   │"
            for i, cell_display in enumerate(row_cells):
                lines = cell_display.split('\n')
                visualization += f" {lines[0]:<8} │"
            visualization += "\n"
            
            visualization += "         │"
            for i, cell_display in enumerate(row_cells):
                lines = cell_display.split('\n')
                visualization += f" {lines[1]:<10} │"
            visualization += "\n"
            
            visualization += "         │"
            for i, cell_display in enumerate(row_cells):
                lines = cell_display.split('\n')
                visualization += f" {lines[2]:<10} │"
            visualization += "\n"
            
            # Add edges visualization
            if row < self.rows - 1:
                visualization += "         "
                for col in range(self.cols):
                    cell_id = f"{chr(ord('A') + col)}{row:02d}"
                    if cell_id in self.cells:
                        cell = self.cells[cell_id]
                        if cell.edges_out:
                            visualization += "  └────┬─────  "
                        else:
                            visualization += "  └──────────  "
                    else:
                        visualization += "  └──────────  "
                visualization += "\n"
        
        return visualization
    
    def get_lattice_statistics(self) -> Dict[str, any]:
        """Get lattice statistics"""
        stats = {
            "total_cells": len(self.cells),
            "rows": self.rows,
            "cols": self.cols,
            "cell_types": {},
            "total_edges": 0,
            "average_weight": 0,
            "connected_components": 0
        }
        
        # Count cell types
        for cell in self.cells.values():
            cell_type = cell.cell_type.value
            stats["cell_types"][cell_type] = stats["cell_types"].get(cell_type, 0) + 1
            stats["total_edges"] += len(cell.edges_out)
            stats["average_weight"] += cell.weight
        
        if len(self.cells) > 0:
            stats["average_weight"] /= len(self.cells)
        
        # Count connected components (simplified)
        visited = set()
        components = 0
        
        for cell_id in self.cells:
            if cell_id not in visited:
                components += 1
                # BFS to mark all reachable cells
                queue = [cell_id]
                while queue:
                    current = queue.pop(0)
                    if current not in visited:
                        visited.add(current)
                        queue.extend(self.get_cell_neighbors(current))
        
        stats["connected_components"] = components
        
        return stats

# Traversal rule parser for regex-encoded paths
class TraversalRuleParser:
    """Parse regex-encoded traversal rules"""
    
    def __init__(self):
        self.rule_pattern = re.compile(
            r'(?P<ENTRY>^\s*)'                    # Anchor to row start
            r'(?P<CELL>(?P<ID>[A-Z]{2}\d{2}))'   # Match cell identifier
            r'(?P<ARROW>(?:→|↓|↗|↘))'           # Directional edge
            r'(?P<NEXT>[A-Z]{2}\d{2})'           # Next cell
            r'(?P<COND>(?:\[.+?\])?)'            # Optional condition gate
        )
    
    def parse_rule(self, rule_text: str) -> Optional[TraversalStep]:
        """Parse a single traversal rule"""
        match = self.rule_pattern.match(rule_text)
        if not match:
            return None
        
        direction_map = {
            "→": Direction.RIGHT,
            "↓": Direction.DOWN,
            "↗": Direction.DIAGONAL_RIGHT,
            "↘": Direction.DIAGONAL_LEFT
        }
        
        return TraversalStep(
            cell_id=match.group('ID'),
            pattern="",  # Will be filled from lattice
            direction=direction_map[match.group('ARROW')],
            next_cell=match.group('NEXT'),
            condition=match.group('COND').strip('[]') if match.group('COND') else None,
            weight=0  # Will be filled from lattice
        )
    
    def parse_path(self, path_text: str) -> List[TraversalStep]:
        """Parse a full path consisting of multiple rules"""
        steps = []
        for line in path_text.strip().split('\n'):
            step = self.parse_rule(line.strip())
            if step:
                steps.append(step)
        return steps

# Example usage and demonstration
if __name__ == "__main__":
    print("=== SYNTHOS GEOMETRIC PARSE LATTICE DEMO ===")
    
    # Create lattice
    lattice = GeometricParseLattice()
    
    # Display lattice visualization
    print("\nLATTICE VISUALIZATION:")
    print(lattice.visualize_lattice())
    
    # Get statistics
    stats = lattice.get_lattice_statistics()
    print(f"\nLATTICE STATISTICS:")
    for key, value in stats.items():
        print(f"  {key}: {value}")
    
    # Test traversal
    test_input = "  variable=value  "
    print(f"\nTRAVERSAL TEST:")
    print(f"Input: '{test_input}'")
    
    result = lattice.traverse_lattice(test_input)
    print(f"Success: {result['success']}")
    print(f"Path taken: {' → '.join(result['path_taken'])}")
    print(f"Matches: {result['matches']}")
    print(f"Captures: {result['captures']}")
    
    # Find paths
    paths = lattice.find_paths("AA00", "AD03")
    print(f"\nPATHS FROM AA00 TO AD03: {len(paths)} found")
    
    # Parse traversal rules
    parser = TraversalRuleParser()
    rule_text = "AA00→AB00→AC00→AD00"
    print(f"\nPARSED TRAVERSAL RULE: {rule_text}")
    
    print("\n=== GPL DEMO COMPLETE ===")
