#!/usr/bin/env python3
"""
AI Code Review — Pre-Commit Advisory Hook

Uses an AI CLI tool to review staged changes for common issues.
Supports multiple backends: codex (default), gemini, claude, ollama.

Behavior:
  - Small changes (<=30 LOC, docs/tests only): SKIP
  - Normal changes: Run AI review, WARN on findings
  - Advisory only — never blocks (exit 0 always)
  - Graceful degradation if backend unavailable

Usage:
  python3 scripts/ai_review.py              # Review staged changes
  python3 scripts/ai_review.py --skip-small  # Skip if <=30 LOC (default)
  python3 scripts/ai_review.py --force       # Review even small changes
  python3 scripts/ai_review.py --backend claude  # Use claude CLI

Environment:
  AI_REVIEW_BACKEND  — Override default backend (codex|gemini|claude|ollama|none)

Exit codes:
  0 — Always (advisory only, never blocks)
"""

import os
import re
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]

# Files to always skip in review
SKIP_PATTERNS = [
    r"\.lock$",
    r"package-lock\.json$",
    r"pnpm-lock\.yaml$",
    r"\.min\.(js|css)$",
    r"\.map$",
    r"__pycache__",
    r"node_modules/",
    r"\.pyc$",
]

# Max diff size to send to the AI backend (chars)
DIFF_LIMIT = 15000

# Timeout for AI backend calls (seconds)
REVIEW_TIMEOUT = 90

REVIEW_PROMPT = """Review this git diff for common code issues.

Focus ONLY on these categories (in priority order):
1. SECURITY: SQL injection, XSS, path traversal, auth bypass, secrets exposure, unsafe deserialization
2. DATA_INTEGRITY: Missing validation, unsafe type coercion, null/undefined handling, race conditions
3. ERROR_HANDLING: Swallowed exceptions, missing error propagation, bare except/catch blocks
4. PERFORMANCE: N+1 queries, missing indexes, unbounded loops, memory leaks, blocking I/O in async
5. BEST_PRACTICES: Hardcoded values, missing cleanup, resource leaks, deprecated API usage

For each finding, output exactly this format (one per line):
SEVERITY|CATEGORY|FILE:LINE|DESCRIPTION

Where SEVERITY is: critical, high, medium, low
Where CATEGORY is: security, data_integrity, error_handling, performance, best_practices

If no issues found, output: NO_ISSUES_FOUND

Be concise. Only report real issues, not style preferences."""

# Supported backends and their command patterns
BACKENDS = {
    "codex": {
        "check": "codex",
        "needs_file": True,
    },
    "gemini": {
        "check": "gemini",
        "needs_file": False,
    },
    "claude": {
        "check": "claude",
        "needs_file": False,
    },
    "ollama": {
        "check": "ollama",
        "needs_file": False,
    },
}


def get_staged_diff() -> str:
    """Get the staged diff."""
    result = subprocess.run(
        ["git", "diff", "--cached", "--unified=3"],
        capture_output=True, text=True, cwd=ROOT,
    )
    return result.stdout


def get_staged_files() -> list[str]:
    """Get staged file names."""
    result = subprocess.run(
        ["git", "diff", "--cached", "--name-only"],
        capture_output=True, text=True, cwd=ROOT,
    )
    return [f for f in result.stdout.strip().split("\n") if f.strip()]


def should_skip(files: list[str], diff: str, force: bool = False) -> tuple[bool, str]:
    """Determine if review should be skipped."""
    if force:
        return False, ""

    # Filter out always-skip files
    relevant = []
    for f in files:
        if not any(re.search(pat, f) for pat in SKIP_PATTERNS):
            relevant.append(f)

    if not relevant:
        return True, "Only generated/lock files changed"

    # Count LOC changed
    added = len(re.findall(r'^\+[^+]', diff, re.MULTILINE))
    removed = len(re.findall(r'^-[^-]', diff, re.MULTILINE))
    total_loc = added + removed

    if total_loc <= 30:
        # Check if only docs/tests/config
        safe_indicators = [
            "docs/", "tests/", "test/", "spec/", "__tests__/",
            ".md", ".txt", ".json", ".yml", ".yaml", ".toml",
            "CHANGELOG", "BACKLOG", "VERSION", "LICENSE", "README",
        ]
        all_safe = all(
            any(p in f for p in safe_indicators)
            for f in relevant
        )
        if all_safe:
            return True, f"Small safe change ({total_loc} LOC, docs/tests/config only)"

    return False, ""


def detect_backend(preferred: str | None = None) -> str | None:
    """Detect available AI backend. Returns backend name or None."""
    # Environment override
    env_backend = os.environ.get("AI_REVIEW_BACKEND", "").lower().strip()
    if env_backend == "none":
        return None
    if env_backend and env_backend in BACKENDS:
        if shutil.which(BACKENDS[env_backend]["check"]):
            return env_backend

    # Explicit preference
    if preferred and preferred in BACKENDS:
        if shutil.which(BACKENDS[preferred]["check"]):
            return preferred

    # Auto-detect in priority order
    for name, cfg in BACKENDS.items():
        if shutil.which(cfg["check"]):
            return name

    return None


def run_review(backend: str, diff: str) -> str | None:
    """Run AI review using the specified backend. Returns output or None."""
    truncated_diff = diff[:DIFF_LIMIT]
    prompt = f"{REVIEW_PROMPT}\n\nHere is the diff:\n\n```diff\n{truncated_diff}\n```"

    try:
        if backend == "codex":
            return _run_codex(prompt)
        elif backend == "gemini":
            return _run_gemini(prompt)
        elif backend == "claude":
            return _run_claude(prompt)
        elif backend == "ollama":
            return _run_ollama(prompt)
    except (subprocess.TimeoutExpired, FileNotFoundError, OSError) as e:
        print(f"AI REVIEW: Backend '{backend}' failed: {e}", file=sys.stderr)
    return None


def _run_codex(prompt: str) -> str | None:
    """Run review via Codex CLI."""
    codex_path = shutil.which("codex")
    if not codex_path:
        return None

    output_file = tempfile.mktemp(suffix=".md")
    try:
        subprocess.run(
            [codex_path, "exec", "--full-auto", "-o", output_file, prompt],
            capture_output=True, text=True, timeout=REVIEW_TIMEOUT, cwd=ROOT,
        )
        if os.path.exists(output_file):
            return Path(output_file).read_text(encoding="utf-8")
    finally:
        if os.path.exists(output_file):
            os.unlink(output_file)
    return None


def _run_gemini(prompt: str) -> str | None:
    """Run review via Gemini CLI (non-interactive, plan mode)."""
    gemini_path = shutil.which("gemini")
    if not gemini_path:
        return None

    result = subprocess.run(
        [gemini_path, "-p", prompt, "--approval-mode", "plan", "--output-format", "text"],
        capture_output=True, text=True, timeout=REVIEW_TIMEOUT, cwd=ROOT,
    )
    if result.returncode == 0 and result.stdout.strip():
        return result.stdout
    return None


def _run_claude(prompt: str) -> str | None:
    """Run review via Claude CLI."""
    claude_path = shutil.which("claude")
    if not claude_path:
        return None

    result = subprocess.run(
        [claude_path, "-p", prompt],
        capture_output=True, text=True, timeout=REVIEW_TIMEOUT, cwd=ROOT,
    )
    if result.returncode == 0 and result.stdout.strip():
        return result.stdout
    return None


def _run_ollama(prompt: str) -> str | None:
    """Run review via Ollama (local model)."""
    ollama_path = shutil.which("ollama")
    if not ollama_path:
        return None

    result = subprocess.run(
        [ollama_path, "run", "codellama", prompt],
        capture_output=True, text=True, timeout=REVIEW_TIMEOUT, cwd=ROOT,
    )
    if result.returncode == 0 and result.stdout.strip():
        return result.stdout
    return None


def parse_findings(review_output: str) -> list[dict]:
    """Parse structured findings from review output."""
    findings = []
    valid_severities = {"critical", "high", "medium", "low"}
    for line in review_output.split("\n"):
        line = line.strip()
        if "|" not in line:
            continue
        parts = [p.strip() for p in line.split("|")]
        if len(parts) >= 4 and parts[0].lower() in valid_severities:
            findings.append({
                "severity": parts[0].lower(),
                "category": parts[1].lower(),
                "location": parts[2],
                "message": parts[3],
            })
    return findings


def format_findings(findings: list[dict]) -> str:
    """Format findings as readable output."""
    if not findings:
        return ""

    lines = ["AI CODE REVIEW FINDINGS:"]
    by_severity: dict[str, list[dict]] = {"critical": [], "high": [], "medium": [], "low": []}
    for f in findings:
        sev = f["severity"]
        if sev in by_severity:
            by_severity[sev].append(f)

    for sev in ("critical", "high", "medium", "low"):
        for f in by_severity[sev]:
            icon = {"critical": "!!!", "high": "!!", "medium": "!", "low": "."}[sev]
            lines.append(f"  [{icon}] {sev.upper()} [{f['category']}] {f['location']}: {f['message']}")

    return "\n".join(lines)


def main() -> int:
    force = "--force" in sys.argv
    backend_arg = None

    # Parse --backend flag
    for i, arg in enumerate(sys.argv):
        if arg == "--backend" and i + 1 < len(sys.argv):
            backend_arg = sys.argv[i + 1].lower()

    files = get_staged_files()
    if not files:
        return 0

    diff = get_staged_diff()
    if not diff:
        return 0

    skip, reason = should_skip(files, diff, force)
    if skip:
        return 0

    backend = detect_backend(backend_arg)
    if not backend:
        print("AI REVIEW: No AI backend available (codex/gemini/claude/ollama), skipping.", file=sys.stderr)
        return 0

    review_output = run_review(backend, diff)
    if not review_output:
        print(f"AI REVIEW: Backend '{backend}' returned no output, skipping.", file=sys.stderr)
        return 0

    if "NO_ISSUES_FOUND" in review_output:
        print(f"AI REVIEW ({backend}): No issues found.", file=sys.stderr)
        return 0

    findings = parse_findings(review_output)
    if findings:
        output = format_findings(findings)
        print(output, file=sys.stderr)

        critical_count = sum(1 for f in findings if f["severity"] == "critical")
        high_count = sum(1 for f in findings if f["severity"] == "high")
        if critical_count or high_count:
            print(
                f"\nAI REVIEW ({backend}): {critical_count} critical, {high_count} high finding(s). "
                "Review before committing.",
                file=sys.stderr,
            )
    else:
        # Backend returned text but no structured findings — print summary
        summary = review_output.strip()[:500]
        if summary and "no issues" not in summary.lower():
            print(f"AI REVIEW NOTES ({backend}):\n{summary}", file=sys.stderr)

    return 0  # Advisory only — never blocks


if __name__ == "__main__":
    sys.exit(main())
