# ═══════════════════════════════════════════════════════════════════════════════════
# SYNTHOS TOPOLOGICAL ATTENTION MESH (TAM) v1.0
# Pattern intersection geometry for multi-head attention computation
# ═══════════════════════════════════════════════════════════════════════════════════

import re
import numpy as np
from typing import Dict, List, Tuple, Optional, Set, Any
from dataclasses import dataclass, field
from enum import Enum
from collections import defaultdict
import math

class AttentionType(Enum):
    FULL_OVERLAP = "full_overlap"      # Self-attention: exact backreference
    PARTIAL_OVERLAP = "partial_overlap" # Cross-attention: lookahead overlap
    CONTAINMENT = "containment"        # Nested attention: inner ⊂ outer
    DISJOINT = "disjoint"             # Zero attention: mutual exclusion
    SEQUENTIAL = "sequential"         # Temporal attention: concatenation

class ScopeType(Enum):
    LOCAL = "local"
    GLOBAL = "global"
    WINDOW = "window"

@dataclass
class AttentionHead:
    """Multi-pattern attention head"""
    head_id: int
    query_pattern: str
    key_pattern: str
    value_pattern: str
    scope: ScopeType
    attention_type: AttentionType
    weight: float = 1.0
    active: bool = True
    
    def __post_init__(self):
        # Compile regex patterns
        try:
            self.query_regex = re.compile(self.query_pattern)
            self.key_regex = re.compile(self.key_pattern)
            self.value_regex = re.compile(self.value_pattern)
        except re.error as e:
            raise ValueError(f"Invalid regex pattern in attention head {self.head_id}: {e}")

@dataclass
class PatternIntersection:
    """Represents intersection between two patterns"""
    pattern1: str
    pattern2: str
    intersection_type: AttentionType
    intersection_area: float
    overlap_ratio: float
    geometric_form: str
    
    def __str__(self):
        return f"{self.pattern1} ∩ {self.pattern2} = {self.intersection_type.value} ({self.intersection_area:.2f})"

@dataclass
class AttentionResult:
    """Result of attention computation"""
    head_id: int
    query_matches: List[re.Match]
    key_matches: List[re.Match]
    intersections: List[PatternIntersection]
    attention_weights: np.ndarray
    output_values: List[str]
    geometric_visualization: str

class TopologicalAttentionMesh:
    """Multi-head attention mesh with pattern intersection geometry"""
    
    def __init__(self, num_heads: int = 8):
        self.num_heads = num_heads
        self.attention_heads: Dict[int, AttentionHead] = {}
        self.attention_history: List[AttentionResult] = []
        self.intersection_cache: Dict[str, PatternIntersection] = {}
        
        # Regex patterns for attention head definition
        self.head_def_pattern = re.compile(
            r'\[\[ATTN_HEAD_(?P<HEAD_ID>\d+)\]\]\s*'
            r'\[QUERY\:\s*(?P<QUERY>[^\]]+)\]\s*'
            r'\[KEY\:\s*(?P<KEY>[^\]]+)\]\s*'
            r'\[VALUE\:\s*(?P<VALUE>[^\]]+)\]\s*'
            r'\[SCOPE\:\s*(?P<SCOPE>LOCAL|GLOBAL|WINDOW)\]'
        )
        
        # Initialize default attention heads
        self._initialize_default_heads()
    
    def _initialize_default_heads(self):
        """Initialize default SYNTHOS attention heads"""
        default_heads = [
            AttentionHead(
                head_id=0,
                query_pattern=r'\w+',
                key_pattern=r'\w+',
                value_pattern=r'.+?',
                scope=ScopeType.LOCAL,
                attention_type=AttentionType.FULL_OVERLAP
            ),
            AttentionHead(
                head_id=1,
                query_pattern=r'[A-Z]\w+',
                key_pattern=r'(?P<C>\w+)',
                value_pattern=r'\w+',
                scope=ScopeType.GLOBAL,
                attention_type=AttentionType.PARTIAL_OVERLAP
            ),
            AttentionHead(
                head_id=2,
                query_pattern=r'\d+',
                key_pattern=r'-?\d+(?:\.\d*)?',
                value_pattern=r'.+',
                scope=ScopeType.LOCAL,
                attention_type=AttentionType.FULL_OVERLAP
            ),
            AttentionHead(
                head_id=3,
                query_pattern=r'[.!?]',
                key_pattern=r'[^.!?]+[.!?]',
                value_pattern=r'.{0,200}',
                scope=ScopeType.WINDOW,
                attention_type=AttentionType.CONTAINMENT
            ),
            AttentionHead(
                head_id=4,
                query_pattern=r'\b\w+\b',
                key_pattern=r'\w{3,}',
                value_pattern=r'.+?',
                scope=ScopeType.GLOBAL,
                attention_type=AttentionType.PARTIAL_OVERLAP
            ),
            AttentionHead(
                head_id=5,
                query_pattern=r'[A-Z_]+',
                key_pattern=r'[A-Z_][A-Z0-9_]*',
                value_pattern=r'.+?',
                scope=ScopeType.LOCAL,
                attention_type=AttentionType.FULL_OVERLAP
            ),
            AttentionHead(
                head_id=6,
                query_pattern=r'\s+',
                key_pattern=r'\s+',
                value_pattern=r'\s*',
                scope=ScopeType.WINDOW,
                attention_type=AttentionType.DISJOINT
            ),
            AttentionHead(
                head_id=7,
                query_pattern=r'.+?',
                key_pattern=r'.+?',
                value_pattern=r'.+?',
                scope=ScopeType.GLOBAL,
                attention_type=AttentionType.SEQUENTIAL
            )
        ]
        
        for head in default_heads:
            self.attention_heads[head.head_id] = head
    
    def add_attention_head(self, head: AttentionHead):
        """Add a new attention head"""
        if head.head_id >= self.num_heads:
            raise ValueError(f"Head ID {head.head_id} exceeds maximum {self.num_heads}")
        self.attention_heads[head.head_id] = head
    
    def parse_attention_head_definition(self, definition: str) -> Optional[AttentionHead]:
        """Parse attention head definition from text"""
        match = self.head_def_pattern.match(definition)
        if not match:
            return None
        
        head_id = int(match.group('HEAD_ID'))
        query_pattern = match.group('QUERY')
        key_pattern = match.group('KEY')
        value_pattern = match.group('VALUE')
        scope_str = match.group('SCOPE')
        
        scope = ScopeType(scope_str.lower())
        
        # Determine attention type based on patterns
        attention_type = self._determine_attention_type(query_pattern, key_pattern)
        
        return AttentionHead(
            head_id=head_id,
            query_pattern=query_pattern,
            key_pattern=key_pattern,
            value_pattern=value_pattern,
            scope=scope,
            attention_type=attention_type
        )
    
    def _determine_attention_type(self, query_pattern: str, key_pattern: str) -> AttentionType:
        """Determine attention type from pattern analysis"""
        # Full overlap: identical patterns or backreferences
        if query_pattern == key_pattern or '\\k<' in query_pattern or '\\1' in query_pattern:
            return AttentionType.FULL_OVERLAP
        
        # Partial overlap: lookahead patterns
        if '(?=' in query_pattern or '(?=' in key_pattern:
            return AttentionType.PARTIAL_OVERLAP
        
        # Containment: nested patterns
        if '(?P<' in query_pattern and '(?P<' in key_pattern:
            return AttentionType.CONTAINMENT
        
        # Disjoint: negative lookahead
        if '(?!)' in query_pattern or '(?!)' in key_pattern:
            return AttentionType.DISJOINT
        
        # Default to sequential
        return AttentionType.SEQUENTIAL
    
    def compute_attention(self, input_text: str, head_ids: Optional[List[int]] = None) -> List[AttentionResult]:
        """Compute attention for specified heads"""
        if head_ids is None:
            head_ids = list(self.attention_heads.keys())
        
        results = []
        
        for head_id in head_ids:
            if head_id not in self.attention_heads:
                continue
            
            head = self.attention_heads[head_id]
            if not head.active:
                continue
            
            result = self._compute_head_attention(head, input_text)
            results.append(result)
            self.attention_history.append(result)
        
        return results
    
    def _compute_head_attention(self, head: AttentionHead, input_text: str) -> AttentionResult:
        """Compute attention for a single head"""
        # Find query matches
        query_matches = list(head.query_regex.finditer(input_text))
        
        # Find key matches
        key_matches = list(head.key_regex.finditer(input_text))
        
        # Compute pattern intersections
        intersections = []
        attention_weights = np.zeros((len(query_matches), len(key_matches)))
        
        for i, q_match in enumerate(query_matches):
            for j, k_match in enumerate(key_matches):
                intersection = self._compute_pattern_intersection(
                    q_match, k_match, head.attention_type
                )
                intersections.append(intersection)
                attention_weights[i, j] = intersection.intersection_area
        
        # Normalize attention weights
        if attention_weights.sum() > 0:
            attention_weights = attention_weights / attention_weights.sum()
        
        # Extract output values
        output_values = []
        for match in key_matches:
            value_match = head.value_regex.search(match.group(0))
            if value_match:
                output_values.append(value_match.group(0))
            else:
                output_values.append(match.group(0))
        
        # Generate geometric visualization
        geometric_viz = self._generate_attention_visualization(
            query_matches, key_matches, intersections, head.attention_type
        )
        
        return AttentionResult(
            head_id=head.head_id,
            query_matches=query_matches,
            key_matches=key_matches,
            intersections=intersections,
            attention_weights=attention_weights,
            output_values=output_values,
            geometric_visualization=geometric_viz
        )
    
    def _compute_pattern_intersection(self, query_match: re.Match, key_match: re.Match, 
                                    attention_type: AttentionType) -> PatternIntersection:
        """Compute intersection between two pattern matches"""
        q_start, q_end = query_match.span()
        k_start, k_end = key_match.span()
        
        # Calculate overlap
        overlap_start = max(q_start, k_start)
        overlap_end = min(q_end, k_end)
        overlap_length = max(0, overlap_end - overlap_start)
        
        # Calculate intersection area based on attention type
        if attention_type == AttentionType.FULL_OVERLAP:
            # Exact match or backreference
            if query_match.group(0) == key_match.group(0):
                intersection_area = 1.0
                geometric_form = "≡"
            else:
                intersection_area = overlap_length / max(q_end - q_start, k_end - k_start)
                geometric_form = "∩"
        
        elif attention_type == AttentionType.PARTIAL_OVERLAP:
            # Lookahead overlap
            intersection_area = overlap_length / max(q_end - q_start, k_end - k_start)
            geometric_form = "┼"
        
        elif attention_type == AttentionType.CONTAINMENT:
            # Nested patterns
            if q_start <= k_start and k_end <= q_end:
                intersection_area = 0.8
                geometric_form = "⊂"
            elif k_start <= q_start and q_end <= k_end:
                intersection_area = 0.8
                geometric_form = "⊃"
            else:
                intersection_area = overlap_length / max(q_end - q_start, k_end - k_start)
                geometric_form = "∩"
        
        elif attention_type == AttentionType.DISJOINT:
            # No overlap
            intersection_area = 0.0
            geometric_form = "∅"
        
        else:  # SEQUENTIAL
            # Sequential binding
            distance = abs(q_start - k_start)
            max_distance = len(query_match.string)
            intersection_area = 1.0 - (distance / max_distance)
            geometric_form = "→"
        
        overlap_ratio = overlap_length / max(q_end - q_start, k_end - k_start)
        
        return PatternIntersection(
            pattern1=query_match.group(0),
            pattern2=key_match.group(0),
            intersection_type=attention_type,
            intersection_area=intersection_area,
            overlap_ratio=overlap_ratio,
            geometric_form=geometric_form
        )
    
    def _generate_attention_visualization(self, query_matches: List[re.Match], 
                                       key_matches: List[re.Match],
                                       intersections: List[PatternIntersection],
                                       attention_type: AttentionType) -> str:
        """Generate ASCII visualization of attention pattern"""
        viz = []
        
        if attention_type == AttentionType.FULL_OVERLAP:
            viz.append("┌─────────────────┐")
            viz.append("│  P1 ≡ P2        │")
            viz.append("└─────────────────┘")
        
        elif attention_type == AttentionType.PARTIAL_OVERLAP:
            viz.append("┌──────────┐")
            viz.append("│  P1   ┌──┼──┐")
            viz.append("│       │  │P2│")
            viz.append("└───────┼──┘  │")
            viz.append("        └─────┘")
        
        elif attention_type == AttentionType.CONTAINMENT:
            viz.append("┌──────────────────┐")
            viz.append("│  P1              │")
            viz.append("│    ┌──────────┐  │")
            viz.append("│    │  P2      │  │")
            viz.append("│    └──────────┘  │")
            viz.append("└──────────────────┘")
        
        elif attention_type == AttentionType.DISJOINT:
            viz.append("┌────────┐    ┌────────┐")
            viz.append("│  P1    │    │  P2    │")
            viz.append("└────────┘    └────────┘")
        
        elif attention_type == AttentionType.SEQUENTIAL:
            viz.append("┌────────┐──►┌────────┐──►┌────────┐")
            viz.append("│  P1    │   │  P2    │   │  P3    │")
            viz.append("└────────┘   └────────┘   └────────┘")
        
        return "\n".join(viz)
    
    def multi_head_attention(self, input_text: str) -> Dict[str, Any]:
        """
        Compute multi-head attention with load-adaptive head activation.

        The number of heads that fire scales with input complexity:
          - Very short inputs  (< 3 words)  → min 2 heads
          - Short inputs       (3-8 words)  → ~4 heads
          - Medium inputs      (9-20 words) → ~6 heads
          - Long/complex inputs (20+ words) → all heads

        Heads are selected by matching their query pattern against the
        input — only heads with actual query hits are activated.
        """
        # ── Determine load-adaptive head count ──
        words = input_text.split()
        word_count = len(words)
        unique_chars = len(set(input_text.lower()))

        # Base head count from word count
        if word_count <= 2:
            target_heads = 2
        elif word_count <= 8:
            target_heads = max(2, min(4, word_count // 2 + 1))
        elif word_count <= 20:
            target_heads = max(4, min(6, word_count // 3 + 1))
        else:
            target_heads = self.num_heads

        # Boost for character diversity (code, mixed case, symbols)
        if unique_chars > 30:
            target_heads = min(self.num_heads, target_heads + 1)

        # ── Select which heads to fire ──
        # Score each head by how many query matches it gets
        head_scores: List[Tuple[int, int]] = []
        for hid, head in self.attention_heads.items():
            if not head.active:
                continue
            try:
                hits = len(list(head.query_regex.finditer(input_text)))
                head_scores.append((hid, hits))
            except Exception:
                head_scores.append((hid, 0))

        # Sort by relevance (most matches first), then take top N
        head_scores.sort(key=lambda x: -x[1])
        active_ids = [hid for hid, _ in head_scores[:target_heads]]

        # Always include head 0 (general word attention) if available
        if 0 in self.attention_heads and 0 not in active_ids:
            active_ids.insert(0, 0)
            if len(active_ids) > target_heads:
                active_ids = active_ids[:target_heads]

        head_results = self.compute_attention(input_text, head_ids=sorted(active_ids))
        
        if not head_results:
            return {"concatenated_output": "", "attention_weights": [], "heads_used": [],
                    "target_heads": target_heads, "load_words": word_count}
        
        # Concatenate outputs from all heads
        concatenated_output = []
        all_weights = []
        heads_used = []
        
        for result in head_results:
            concatenated_output.extend(result.output_values)
            all_weights.append(result.attention_weights)
            heads_used.append(result.head_id)
        
        # Apply final projection
        final_output = " ".join(concatenated_output)
        
        return {
            "concatenated_output": final_output,
            "attention_weights": all_weights,
            "heads_used": heads_used,
            "individual_results": head_results,
            "target_heads": target_heads,
            "load_words": word_count,
        }
    
    def get_attention_statistics(self) -> Dict[str, Any]:
        """Get statistics about attention computation"""
        if not self.attention_history:
            return {"total_computations": 0}
        
        total_computations = len(self.attention_history)
        head_usage = defaultdict(int)
        intersection_types = defaultdict(int)
        average_attention_weights = []
        
        for result in self.attention_history:
            head_usage[result.head_id] += 1
            average_attention_weights.append(result.attention_weights.mean())
            
            for intersection in result.intersections:
                intersection_types[intersection.intersection_type.value] += 1
        
        return {
            "total_computations": total_computations,
            "head_usage": dict(head_usage),
            "intersection_types": dict(intersection_types),
            "average_attention_weight": np.mean(average_attention_weights) if average_attention_weights else 0,
            "active_heads": [hid for hid, head in self.attention_heads.items() if head.active]
        }
    
    def visualize_attention_mesh(self, input_text: str, output_file: str = "attention_mesh.png"):
        """Visualize the attention mesh computation"""
        try:
            import matplotlib.pyplot as plt
            import seaborn as sns
            
            # Compute attention for all heads
            results = self.compute_attention(input_text)
            
            if not results:
                print("No attention results to visualize")
                return
            
            # Create subplot grid
            fig, axes = plt.subplots(2, 4, figsize=(20, 10))
            axes = axes.flatten()
            
            for i, result in enumerate(results[:8]):  # Limit to 8 heads
                ax = axes[i]
                
                # Create heatmap of attention weights
                if result.attention_weights.size > 0:
                    sns.heatmap(result.attention_weights, ax=ax, cmap='viridis', 
                               cbar=True, square=True)
                    ax.set_title(f"Head {result.head_id} ({result.intersections[0].intersection_type.value if result.intersections else 'None'})")
                    ax.set_xlabel("Key Matches")
                    ax.set_ylabel("Query Matches")
                else:
                    ax.text(0.5, 0.5, f"Head {result.head_id}\nNo matches", 
                           ha='center', va='center', transform=ax.transAxes)
                    ax.set_title(f"Head {result.head_id}")
            
            plt.suptitle(f"SYNTHOS Attention Mesh\nInput: '{input_text[:50]}...'", fontsize=16)
            plt.tight_layout()
            plt.savefig(output_file, dpi=300, bbox_inches='tight')
            plt.close()
            
            print(f"Attention mesh visualization saved to {output_file}")
            
        except ImportError:
            print("Matplotlib/seaborn not available for visualization")
    
    def export_attention_history(self, filename: str):
        """Export attention computation history"""
        import json
        
        export_data = {
            "attention_heads": {
                str(hid): {
                    "head_id": head.head_id,
                    "query_pattern": head.query_pattern,
                    "key_pattern": head.key_pattern,
                    "value_pattern": head.value_pattern,
                    "scope": head.scope.value,
                    "attention_type": head.attention_type.value,
                    "weight": head.weight,
                    "active": head.active
                }
                for hid, head in self.attention_heads.items()
            },
            "statistics": self.get_attention_statistics(),
            "computations": len(self.attention_history)
        }
        
        with open(filename, 'w') as f:
            json.dump(export_data, f, indent=2)
        
        print(f"Attention history exported to {filename}")

# Example usage and demonstration
if __name__ == "__main__":
    print("=== SYNTHOS TOPOLOGICAL ATTENTION MESH DEMO ===")
    
    # Create attention mesh
    tam = TopologicalAttentionMesh()
    
    # Sample input text
    sample_text = "Neural networks process data through layers of interconnected nodes."
    
    print(f"\nComputing attention for: '{sample_text}'")
    
    # Compute multi-head attention
    mha_result = tam.multi_head_attention(sample_text)
    print(f"\nMulti-Head Attention Results:")
    print(f"Heads used: {mha_result['heads_used']}")
    print(f"Concatenated output: '{mha_result['concatenated_output']}'")
    
    # Get individual head results
    results = tam.compute_attention(sample_text)
    print(f"\nIndividual Head Results:")
    for result in results:
        print(f"\nHead {result.head_id}:")
        print(f"  Query matches: {len(result.query_matches)}")
        print(f"  Key matches: {len(result.key_matches)}")
        print(f"  Intersections: {len(result.intersections)}")
        print(f"  Output values: {result.output_values[:3]}...")  # Show first 3
        
        # Show geometric visualization
        if result.geometric_visualization:
            print(f"  Geometry:\n{result.geometric_visualization}")
    
    # Get attention statistics
    stats = tam.get_attention_statistics()
    print(f"\nAttention Statistics:")
    for key, value in stats.items():
        print(f"  {key}: {value}")
    
    # Test custom attention head
    custom_head_def = "[[ATTN_HEAD_8]] [QUERY:\\b[A-Z]+\\b] [KEY:\\w+] [VALUE:.+?] [SCOPE:GLOBAL]"
    custom_head = tam.parse_attention_head_definition(custom_head_def)
    if custom_head:
        print(f"\nParsed custom attention head: {custom_head.head_id}")
        tam.add_attention_head(custom_head)
        print(f"Added custom attention head {custom_head.head_id}")
    
    # Visualize attention mesh (if matplotlib available)
    try:
        tam.visualize_attention_mesh(sample_text, "synthos_attention.png")
    except Exception as e:
        print(f"Visualization failed: {e}")
    
    # Export attention history
    tam.export_attention_history("synthos_attention.json")
    
    print("\n=== TAM DEMO COMPLETE ===")
