"""
Tool Executor — the action dispatch layer for SYNTHOS.

Provides a registry of callable tools that SynthoLM can invoke when it
determines (via reasoning) that a system action is needed. Tools cover
file I/O, directory management, shell execution, and code generation.
"""

from __future__ import annotations

import os
import json
import subprocess
import shlex
import time
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple


class ToolStatus(Enum):
    SUCCESS = "success"
    ERROR = "error"
    SKIPPED = "skipped"


@dataclass
class ToolResult:
    """Outcome of a single tool invocation."""
    tool: str
    status: ToolStatus
    message: str = ""
    output: str = ""
    path: Optional[str] = None
    elapsed_ms: float = 0.0

    def to_dict(self) -> Dict[str, Any]:
        return {
            "tool": self.tool,
            "status": self.status.value,
            "message": self.message,
            "output": self.output[:500] if self.output else "",
            "path": self.path,
            "elapsed_ms": round(self.elapsed_ms, 1),
        }


@dataclass
class ToolCall:
    """A parsed tool invocation request."""
    name: str
    args: Dict[str, Any] = field(default_factory=dict)

    def __repr__(self) -> str:
        return f"ToolCall({self.name}, {self.args})"


# ── Allowed base directories (safety) ─────────────────────────────────────────

_DEFAULT_SANDBOX = Path.home() / "synthos_workspace"
_ALLOWED_ROOTS: List[Path] = [
    _DEFAULT_SANDBOX,
    Path.home() / "CascadeProjects",
    Path.home() / "Desktop",
    Path.home() / "Documents",
    Path.home() / "projects",
]

_BLOCKED_COMMANDS = {"rm -rf /", "rm -rf ~", "sudo rm", "mkfs", "dd if=", ":(){", "fork bomb"}


def _is_safe_path(p: Path) -> bool:
    """Check that *p* is inside an allowed directory."""
    resolved = p.resolve()
    return any(resolved == root or root in resolved.parents for root in _ALLOWED_ROOTS)


def _is_safe_command(cmd: str) -> bool:
    low = cmd.lower().strip()
    return not any(blocked in low for blocked in _BLOCKED_COMMANDS)


class ToolExecutor:
    """
    Central registry and dispatcher for SYNTHOS tools.

    Tools are callables registered under a name. The executor validates
    safety constraints before running anything destructive.
    """

    def __init__(self, workspace: Optional[Path] = None, safe_mode: bool = True):
        self.workspace = Path(workspace) if workspace else _DEFAULT_SANDBOX
        self.safe_mode = safe_mode
        self._tools: Dict[str, Callable[..., ToolResult]] = {}
        self._history: List[ToolResult] = []

        # Ensure workspace exists
        self.workspace.mkdir(parents=True, exist_ok=True)

        # Register built-in tools
        self._register_builtins()

    # ── Registration ───────────────────────────────────────────────────────

    def register(self, name: str, fn: Callable[..., ToolResult]):
        self._tools[name] = fn

    def _register_builtins(self):
        self.register("create_directory", self.tool_create_directory)
        self.register("create_file", self.tool_create_file)
        self.register("write_file", self.tool_create_file)       # alias
        self.register("append_file", self.tool_append_file)
        self.register("read_file", self.tool_read_file)
        self.register("list_directory", self.tool_list_directory)
        self.register("delete_file", self.tool_delete_file)
        self.register("run_shell", self.tool_run_shell)
        self.register("run_python", self.tool_run_python)
        self.register("tree", self.tool_tree)
        self.register("file_exists", self.tool_file_exists)

    @property
    def available_tools(self) -> List[str]:
        return sorted(self._tools.keys())

    @property
    def history(self) -> List[ToolResult]:
        return list(self._history)

    # ── Dispatch ───────────────────────────────────────────────────────────

    def execute(self, call: ToolCall) -> ToolResult:
        """Execute a single tool call."""
        t0 = time.perf_counter()
        fn = self._tools.get(call.name)
        if fn is None:
            result = ToolResult(
                tool=call.name,
                status=ToolStatus.ERROR,
                message=f"Unknown tool '{call.name}'. Available: {', '.join(self.available_tools)}",
            )
        else:
            try:
                result = fn(**call.args)
            except TypeError as exc:
                result = ToolResult(tool=call.name, status=ToolStatus.ERROR, message=f"Bad arguments: {exc}")
            except Exception as exc:
                result = ToolResult(tool=call.name, status=ToolStatus.ERROR, message=f"Tool error: {exc}")

        result.elapsed_ms = (time.perf_counter() - t0) * 1000
        self._history.append(result)
        return result

    def execute_batch(self, calls: List[ToolCall]) -> List[ToolResult]:
        """Execute a list of tool calls sequentially."""
        return [self.execute(c) for c in calls]

    # ── Resolve paths relative to workspace ────────────────────────────────

    def _resolve(self, path: str) -> Path:
        p = Path(path)
        if not p.is_absolute():
            p = self.workspace / p
        return p.resolve()

    def _check_path(self, p: Path) -> Optional[str]:
        """Return error message if path is unsafe, else None."""
        if self.safe_mode and not _is_safe_path(p):
            return f"Path '{p}' is outside allowed directories."
        return None

    # ── Built-in tools ─────────────────────────────────────────────────────

    def tool_create_directory(self, path: str) -> ToolResult:
        p = self._resolve(path)
        err = self._check_path(p)
        if err:
            return ToolResult(tool="create_directory", status=ToolStatus.ERROR, message=err)
        p.mkdir(parents=True, exist_ok=True)
        return ToolResult(tool="create_directory", status=ToolStatus.SUCCESS,
                          message=f"Created directory: {p}", path=str(p))

    def tool_create_file(self, path: str, content: str = "") -> ToolResult:
        p = self._resolve(path)
        err = self._check_path(p)
        if err:
            return ToolResult(tool="create_file", status=ToolStatus.ERROR, message=err)
        p.parent.mkdir(parents=True, exist_ok=True)
        p.write_text(content, encoding="utf-8")
        return ToolResult(tool="create_file", status=ToolStatus.SUCCESS,
                          message=f"Created file: {p} ({len(content)} bytes)", path=str(p))

    def tool_append_file(self, path: str, content: str = "") -> ToolResult:
        p = self._resolve(path)
        err = self._check_path(p)
        if err:
            return ToolResult(tool="append_file", status=ToolStatus.ERROR, message=err)
        with p.open("a", encoding="utf-8") as f:
            f.write(content)
        return ToolResult(tool="append_file", status=ToolStatus.SUCCESS,
                          message=f"Appended {len(content)} bytes to {p}", path=str(p))

    def tool_read_file(self, path: str) -> ToolResult:
        p = self._resolve(path)
        err = self._check_path(p)
        if err:
            return ToolResult(tool="read_file", status=ToolStatus.ERROR, message=err)
        if not p.exists():
            return ToolResult(tool="read_file", status=ToolStatus.ERROR, message=f"File not found: {p}")
        text = p.read_text(encoding="utf-8", errors="replace")
        return ToolResult(tool="read_file", status=ToolStatus.SUCCESS,
                          message=f"Read {len(text)} bytes from {p}", output=text, path=str(p))

    def tool_list_directory(self, path: str = ".") -> ToolResult:
        p = self._resolve(path)
        err = self._check_path(p)
        if err:
            return ToolResult(tool="list_directory", status=ToolStatus.ERROR, message=err)
        if not p.is_dir():
            return ToolResult(tool="list_directory", status=ToolStatus.ERROR, message=f"Not a directory: {p}")
        entries = sorted(p.iterdir())
        listing = []
        for e in entries[:100]:
            kind = "dir" if e.is_dir() else "file"
            size = e.stat().st_size if e.is_file() else sum(1 for _ in e.iterdir()) if e.is_dir() else 0
            listing.append(f"  [{kind}] {e.name}" + (f"  ({size} bytes)" if kind == "file" else f"  ({size} items)"))
        output = f"{p}/\n" + "\n".join(listing)
        return ToolResult(tool="list_directory", status=ToolStatus.SUCCESS,
                          message=f"Listed {len(entries)} entries in {p}", output=output, path=str(p))

    def tool_delete_file(self, path: str) -> ToolResult:
        p = self._resolve(path)
        err = self._check_path(p)
        if err:
            return ToolResult(tool="delete_file", status=ToolStatus.ERROR, message=err)
        if not p.exists():
            return ToolResult(tool="delete_file", status=ToolStatus.ERROR, message=f"File not found: {p}")
        if p.is_dir():
            return ToolResult(tool="delete_file", status=ToolStatus.ERROR, message=f"Use rmdir for directories: {p}")
        p.unlink()
        return ToolResult(tool="delete_file", status=ToolStatus.SUCCESS,
                          message=f"Deleted: {p}", path=str(p))

    def tool_file_exists(self, path: str) -> ToolResult:
        p = self._resolve(path)
        exists = p.exists()
        kind = "directory" if p.is_dir() else "file" if p.is_file() else "absent"
        return ToolResult(tool="file_exists", status=ToolStatus.SUCCESS,
                          message=f"{'Exists' if exists else 'Does not exist'}: {p} ({kind})",
                          output=json.dumps({"exists": exists, "type": kind}), path=str(p))

    def tool_tree(self, path: str = ".", depth: int = 3) -> ToolResult:
        p = self._resolve(path)
        err = self._check_path(p)
        if err:
            return ToolResult(tool="tree", status=ToolStatus.ERROR, message=err)
        lines = [str(p)]
        self._tree_walk(p, "", depth, lines)
        output = "\n".join(lines[:200])
        return ToolResult(tool="tree", status=ToolStatus.SUCCESS,
                          message=f"Tree of {p} (depth={depth})", output=output, path=str(p))

    def _tree_walk(self, directory: Path, prefix: str, depth: int, lines: List[str]):
        if depth <= 0:
            return
        try:
            entries = sorted(directory.iterdir())
        except PermissionError:
            return
        dirs = [e for e in entries if e.is_dir() and not e.name.startswith(".")]
        files = [e for e in entries if e.is_file()]
        for i, f in enumerate(files):
            connector = "└── " if (i == len(files) - 1 and not dirs) else "├── "
            lines.append(f"{prefix}{connector}{f.name}")
        for i, d in enumerate(dirs):
            connector = "└── " if i == len(dirs) - 1 else "├── "
            extension = "    " if i == len(dirs) - 1 else "│   "
            lines.append(f"{prefix}{connector}{d.name}/")
            self._tree_walk(d, prefix + extension, depth - 1, lines)

    def tool_run_shell(self, command: str, timeout: int = 30) -> ToolResult:
        if self.safe_mode and not _is_safe_command(command):
            return ToolResult(tool="run_shell", status=ToolStatus.ERROR,
                              message=f"Command blocked by safety filter: {command}")
        try:
            proc = subprocess.run(
                command, shell=True, capture_output=True, text=True,
                timeout=timeout, cwd=str(self.workspace),
            )
            output = (proc.stdout + proc.stderr).strip()
            status = ToolStatus.SUCCESS if proc.returncode == 0 else ToolStatus.ERROR
            return ToolResult(tool="run_shell", status=status,
                              message=f"Exit code {proc.returncode}", output=output[:2000])
        except subprocess.TimeoutExpired:
            return ToolResult(tool="run_shell", status=ToolStatus.ERROR, message=f"Timeout after {timeout}s")

    def tool_run_python(self, code: str, timeout: int = 30) -> ToolResult:
        """Execute a Python code snippet and capture output."""
        tmp = self.workspace / ".synthos_tmp_exec.py"
        tmp.write_text(code, encoding="utf-8")
        try:
            proc = subprocess.run(
                ["python3", str(tmp)], capture_output=True, text=True,
                timeout=timeout, cwd=str(self.workspace),
            )
            output = (proc.stdout + proc.stderr).strip()
            status = ToolStatus.SUCCESS if proc.returncode == 0 else ToolStatus.ERROR
            return ToolResult(tool="run_python", status=status,
                              message=f"Exit code {proc.returncode}", output=output[:2000])
        except subprocess.TimeoutExpired:
            return ToolResult(tool="run_python", status=ToolStatus.ERROR, message=f"Timeout after {timeout}s")
        finally:
            tmp.unlink(missing_ok=True)
