from __future__ import annotations
import hashlib
import json
import sqlite3
from pathlib import Path
from typing import Any
from datetime import datetime

class CacheManager:
    def __init__(self, cache_dir: str | Path = "cache"):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        # Create SQLite database for structured results
        self.db_path = self.cache_dir / "extraction_cache.db"
        self._init_db()
    
    def _init_db(self):
        """Initialize the SQLite database with necessary tables."""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute("""
                CREATE TABLE IF NOT EXISTS extractions (
                    input_hash TEXT,
                    form_type TEXT,
                    result TEXT,
                    model_name TEXT,
                    timestamp DATETIME,
                    PRIMARY KEY (input_hash, form_type)
                )
            """)
            
            conn.execute("""
                CREATE TABLE IF NOT EXISTS transcripts (
                    video_id TEXT PRIMARY KEY,
                    transcript TEXT,
                    timestamp DATETIME
                )
            """)
    
    def _hash_content(self, content: str) -> str:
        """Generate a stable hash for input content."""
        return hashlib.sha256(content.encode('utf-8')).hexdigest()
    
    def get_transcript(self, video_id: str) -> str | None:
        """Retrieve a cached transcript if it exists."""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute(
                "SELECT transcript FROM transcripts WHERE video_id = ?",
                (video_id,)
            )
            result = cursor.fetchone()
            return result[0] if result else None
    
    def store_transcript(self, video_id: str, transcript: str):
        """Store a transcript in the cache."""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute(
                """
                INSERT OR REPLACE INTO transcripts (video_id, transcript, timestamp)
                VALUES (?, ?, ?)
                """,
                (video_id, transcript, datetime.now())
            )
    
    def get_extraction(
        self,
        input_content: str,
        form_type: str,
        model_name: str
    ) -> dict | None:
        """Retrieve cached extraction results if they exist."""
        input_hash = self._hash_content(input_content)
        
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute(
                """
                SELECT result FROM extractions 
                WHERE input_hash = ? AND form_type = ? AND model_name = ?
                """,
                (input_hash, form_type, model_name)
            )
            result = cursor.fetchone()
            
            if result:
                return json.loads(result[0])
        return None
    
    def store_extraction(
        self,
        input_content: str,
        form_type: str,
        result: dict,
        model_name: str
    ):
        """Store extraction results in the cache."""
        input_hash = self._hash_content(input_content)
        
        with sqlite3.connect(self.db_path) as conn:
            conn.execute(
                """
                INSERT OR REPLACE INTO extractions 
                (input_hash, form_type, result, model_name, timestamp)
                VALUES (?, ?, ?, ?, ?)
                """,
                (
                    input_hash,
                    form_type,
                    json.dumps(result),
                    model_name,
                    datetime.now()
                )
            )
    
    def clear_cache(self, older_than_days: int | None = None):
        """Clear the cache, optionally only entries older than specified days."""
        with sqlite3.connect(self.db_path) as conn:
            if older_than_days is not None:
                conn.execute(
                    """
                    DELETE FROM extractions 
                    WHERE timestamp < datetime('now', ?)
                    """,
                    (f'-{older_than_days} days',)
                )
                conn.execute(
                    """
                    DELETE FROM transcripts 
                    WHERE timestamp < datetime('now', ?)
                    """,
                    (f'-{older_than_days} days',)
                )
            else:
                conn.execute("DELETE FROM extractions")
                conn.execute("DELETE FROM transcripts") 
    
    def cleanup_gradio_cache(self):
        """Clean up Gradio's example cache directory."""
        gradio_cache = Path(".gradio")
        if gradio_cache.exists():
            import shutil
            shutil.rmtree(gradio_cache)
            print("Cleaned up Gradio cache")