# File: llm_observability.py
import sqlite3
import json
from datetime import datetime
from typing import Dict, Any, List, Optional, Callable
import logging
import functools

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def log_execution(func: Callable) -> Callable:
    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        logger.info(f"Executing {func.__name__}")
        try:
            result = func(*args, **kwargs)
            logger.info(f"{func.__name__} completed successfully")
            return result
        except Exception as e:
            logger.error(f"Error in {func.__name__}: {e}")
            raise
    return wrapper


class LLMObservabilityManager:
    def __init__(self, db_path: str = "/data/llm_observability_v2.db"):
        self.db_path = db_path
        self.create_table()

    def create_table(self):
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute('''
                CREATE TABLE IF NOT EXISTS llm_observations (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    conversation_id TEXT,
                    created_at DATETIME,
                    status TEXT,
                    request TEXT,
                    response TEXT,
                    model TEXT,
                    prompt_tokens INTEGER,
                    completion_tokens INTEGER,
                    total_tokens INTEGER,
                    cost FLOAT,
                    latency FLOAT,
                    user TEXT
                )
            ''')

    def insert_observation(self, response: str, conversation_id: str, status: str, request: str, model: str, prompt_tokens: int,completion_tokens: int, total_tokens: int, cost: float, latency: float, user: str):
        created_at = datetime.now()
        
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            cursor.execute('''
                INSERT INTO llm_observations 
                (conversation_id, created_at, status, request, response, model, prompt_tokens, completion_tokens,total_tokens, cost, latency, user)
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            ''', (
                conversation_id,
                created_at,
                status,
                request,
                response,
                model,
                prompt_tokens,
                completion_tokens,
                total_tokens,
                cost,
                latency,
                user
            ))

    def get_observations(self, conversation_id: Optional[str] = None) -> List[Dict[str, Any]]:
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            if conversation_id:
                cursor.execute('SELECT * FROM llm_observations WHERE conversation_id = ? ORDER BY created_at', (conversation_id,))
            else:
                cursor.execute('SELECT * FROM llm_observations ORDER BY created_at')
            rows = cursor.fetchall()

            column_names = [description[0] for description in cursor.description]
            return [dict(zip(column_names, row)) for row in rows]

    def get_all_observations(self) -> List[Dict[str, Any]]:
        return self.get_observations()
    
    def get_all_unique_conversation_observations(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.cursor()
            # Get the latest observation for each unique conversation_id
            query = '''
                SELECT * FROM llm_observations o1
                WHERE created_at = (
                    SELECT MAX(created_at) 
                    FROM llm_observations o2 
                    WHERE o2.conversation_id = o1.conversation_id
                )
                ORDER BY created_at DESC
            '''
            if limit is not None:
                query += f' LIMIT {limit}'
                
            cursor.execute(query)
            rows = cursor.fetchall()
            
            column_names = [description[0] for description in cursor.description]
            return [dict(zip(column_names, row)) for row in rows]

    def get_dashboard_statistics(self, days: Optional[int] = None, time_series_interval: str = 'day') -> Dict[str, Any]:
        """
        Get statistical metrics for LLM usage dashboard with time series data.
        
        Args:
            days (int, optional): Number of days to look back. If None, returns all-time statistics
            time_series_interval (str): Interval for time series data ('hour', 'day', 'week', 'month')
            
        Returns:
            Dict containing dashboard statistics and time series data
        """
        def safe_round(value: Any, decimals: int = 2) -> float:
            """Safely round a value, returning 0 if the value is None or invalid."""
            try:
                return round(float(value), decimals) if value is not None else 0.0
            except (TypeError, ValueError):
                return 0.0
    
        def safe_divide(numerator: Any, denominator: Any, decimals: int = 2) -> float:
            """Safely divide two numbers, handling None and zero division."""
            try:
                if not denominator or denominator is None:
                    return 0.0
                return round(float(numerator or 0) / float(denominator), decimals)
            except (TypeError, ValueError):
                return 0.0
    
        try:
            with sqlite3.connect(self.db_path) as conn:
                cursor = conn.cursor()
                
                # Build time filter
                time_filter = ""
                if days is not None:
                    time_filter = f"WHERE created_at >= datetime('now', '-{days} days')"
                
                # Get general statistics
                cursor.execute(f"""
                    SELECT 
                        COUNT(*) as total_requests,
                        COUNT(DISTINCT conversation_id) as unique_conversations,
                        COUNT(DISTINCT user) as unique_users,
                        SUM(total_tokens) as total_tokens,
                        SUM(cost) as total_cost,
                        AVG(latency) as avg_latency,
                        SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as error_count
                    FROM llm_observations
                    {time_filter}
                """)
                row = cursor.fetchone()
                if not row:
                    return self._get_empty_statistics()
                    
                general_stats = dict(zip([col[0] for col in cursor.description], row))
                
                # Get model distribution
                cursor.execute(f"""
                    SELECT model, COUNT(*) as count
                    FROM llm_observations
                    {time_filter}
                    GROUP BY model
                    ORDER BY count DESC
                """)
                model_distribution = {row[0]: row[1] for row in cursor.fetchall()} if cursor.fetchall() else {}
                
                # Get average tokens per request
                cursor.execute(f"""
                    SELECT 
                        AVG(prompt_tokens) as avg_prompt_tokens,
                        AVG(completion_tokens) as avg_completion_tokens
                    FROM llm_observations
                    {time_filter}
                """)
                token_averages = dict(zip([col[0] for col in cursor.description], cursor.fetchone()))
                
                # Get top users by request count
                cursor.execute(f"""
                    SELECT user, COUNT(*) as request_count, 
                           SUM(total_tokens) as total_tokens,
                           SUM(cost) as total_cost
                    FROM llm_observations
                    {time_filter}
                    GROUP BY user
                    ORDER BY request_count DESC
                    LIMIT 5
                """)
                top_users = [
                    {
                        "user": row[0],
                        "request_count": row[1],
                        "total_tokens": row[2] or 0,
                        "total_cost": safe_round(row[3])
                    }
                    for row in cursor.fetchall()
                ]
    
                # Get time series data
                time_series_format = {
                    'hour': "%Y-%m-%d %H:00:00",
                    'day': "%Y-%m-%d",
                    'week': "%Y-%W",
                    'month': "%Y-%m"
                }
                
                format_string = time_series_format.get(time_series_interval, "%Y-%m-%d")
                
                cursor.execute(f"""
                    SELECT 
                        strftime('{format_string}', created_at) as time_bucket,
                        COUNT(*) as request_count,
                        SUM(total_tokens) as total_tokens,
                        SUM(cost) as total_cost,
                        AVG(latency) as avg_latency,
                        COUNT(DISTINCT user) as unique_users,
                        SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as error_count
                    FROM llm_observations
                    {time_filter}
                    GROUP BY time_bucket
                    ORDER BY time_bucket
                """)
                
                time_series = [
                    {
                        "timestamp": row[0],
                        "request_count": row[1] or 0,
                        "total_tokens": row[2] or 0,
                        "total_cost": safe_round(row[3]),
                        "avg_latency": safe_round(row[4]),
                        "unique_users": row[5] or 0,
                        "error_count": row[6] or 0
                    }
                    for row in cursor.fetchall()
                ]
    
                # Calculate trends safely
                trends = self._calculate_trends(time_series)
    
                return {
                    "general_stats": {
                        "total_requests": general_stats["total_requests"] or 0,
                        "unique_conversations": general_stats["unique_conversations"] or 0,
                        "unique_users": general_stats["unique_users"] or 0,
                        "total_tokens": general_stats["total_tokens"] or 0,
                        "total_cost": safe_round(general_stats["total_cost"]),
                        "avg_latency": safe_round(general_stats["avg_latency"]),
                        "error_rate": safe_round(
                            safe_divide(general_stats["error_count"], general_stats["total_requests"]) * 100
                        )
                    },
                    "model_distribution": model_distribution,
                    "token_metrics": {
                        "avg_prompt_tokens": safe_round(token_averages["avg_prompt_tokens"]),
                        "avg_completion_tokens": safe_round(token_averages["avg_completion_tokens"])
                    },
                    "top_users": top_users,
                    "time_series": time_series,
                    "trends": trends
                }
        except sqlite3.Error as e:
            logger.error(f"Database error in get_dashboard_statistics: {e}")
            return self._get_empty_statistics()
        except Exception as e:
            logger.error(f"Error in get_dashboard_statistics: {e}")
            return self._get_empty_statistics()
    
    def _get_empty_statistics(self) -> Dict[str, Any]:
        """Return an empty statistics structure when no data is available."""
        return {
            "general_stats": {
                "total_requests": 0,
                "unique_conversations": 0,
                "unique_users": 0,
                "total_tokens": 0,
                "total_cost": 0.0,
                "avg_latency": 0.0,
                "error_rate": 0.0
            },
            "model_distribution": {},
            "token_metrics": {
                "avg_prompt_tokens": 0.0,
                "avg_completion_tokens": 0.0
            },
            "top_users": [],
            "time_series": [],
            "trends": {
                "request_trend": 0.0,
                "cost_trend": 0.0,
                "token_trend": 0.0
            }
        }
    
    def _calculate_trends(self, time_series: List[Dict[str, Any]]) -> Dict[str, float]:
        """Calculate trends safely from time series data."""
        if len(time_series) >= 2:
            current = time_series[-1]
            previous = time_series[-2]
            return {
                "request_trend": self._calculate_percentage_change(
                    previous["request_count"], current["request_count"]),
                "cost_trend": self._calculate_percentage_change(
                    previous["total_cost"], current["total_cost"]),
                "token_trend": self._calculate_percentage_change(
                    previous["total_tokens"], current["total_tokens"])
            }
        return {
            "request_trend": 0.0,
            "cost_trend": 0.0,
            "token_trend": 0.0
        }
    
    def _calculate_percentage_change(self, old_value: Any, new_value: Any) -> float:
        """Calculate percentage change between two values safely."""
        try:
            old_value = float(old_value or 0)
            new_value = float(new_value or 0)
            if old_value == 0:
                return 100.0 if new_value > 0 else 0.0
            return round(((new_value - old_value) / old_value) * 100, 2)
        except (TypeError, ValueError):
            return 0.0