# src/memory.py
import sqlite3
from datetime import datetime, timedelta
import json
from typing import List, Dict, Any, Tuple
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class MemoryManager:
    def __init__(self, db_path: str):
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
        self.create_tables()
        self.vectorizer = TfidfVectorizer(stop_words='english')
        logging.info("MemoryManager initialized and tables created.")

    def create_tables(self):
        # Create tables if they don't exist
        self.cursor.execute('''CREATE TABLE IF NOT EXISTS semantic_memory
                           (id INTEGER PRIMARY KEY, concept TEXT, description TEXT, last_accessed DATETIME, tags TEXT, importance REAL DEFAULT 0.5)''')

        # Add tags and importance columns if they don't exist
        self.cursor.execute("PRAGMA table_info(semantic_memory)")
        columns = [column[1] for column in self.cursor.fetchall()]
        if 'tags' not in columns:
            self.cursor.execute("ALTER TABLE semantic_memory ADD COLUMN tags TEXT")
        if 'importance' not in columns:
            self.cursor.execute("ALTER TABLE semantic_memory ADD COLUMN importance REAL DEFAULT 0.5")

        self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_semantic_concept ON semantic_memory (concept)''')
        self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_semantic_last_accessed ON semantic_memory (last_accessed)''')
        self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_semantic_tags ON semantic_memory (tags)''')

        # Create table for user interactions
        self.cursor.execute('''CREATE TABLE IF NOT EXISTS user_interactions
                           (user_id TEXT, query TEXT, response TEXT, timestamp DATETIME)''')

        self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_user_interactions_timestamp ON user_interactions (timestamp)''')
        self.conn.commit()
        logging.info("Tables and indexes created successfully.")

    def add_semantic_memory(self, concept: str, description: str, tags: List[str] = None):
        if tags is None:
            tags = []
        tags_str = json.dumps(tags)
        self.cursor.execute("INSERT INTO semantic_memory (concept, description, last_accessed, tags) VALUES (?, ?, ?, ?)",
                            (concept, description, datetime.now().isoformat(), tags_str))
        self.conn.commit()
        logging.info("Semantic memory added.")

    def retrieve_relevant_memories(self, query: str, limit: int = 30) -> List[Dict[str, Any]]:
        all_memories = self._get_all_memories()

        # Handle empty or stop-word-only query
        if not query.strip() or self.vectorizer.stop_words and all(word in self.vectorizer.stop_words for word in query.split()):
            return []

        scored_memories = self._score_memories(query, all_memories)
        return [memory for memory, score in sorted(scored_memories, key=lambda x: x[1], reverse=True)[:limit]]

    def _get_all_memories(self) -> List[Tuple[Dict[str, Any], datetime]]:
        self.cursor.execute("SELECT concept, description, importance, last_accessed, tags FROM semantic_memory ORDER BY importance DESC, last_accessed DESC")
        semantic_memories = self.cursor.fetchall()

        all_memories = [({"concept": concept, "description": description, "importance": importance},
                         datetime.fromisoformat(last_accessed), json.loads(tags) if tags else None) for concept, description, importance, last_accessed, tags in semantic_memories]

        return all_memories

    def _score_memories(self, query: str, memories: List[Tuple[Dict[str, Any], datetime, List[str]]]) -> List[Tuple[Dict[str, Any], float]]:
        query_vector = self.vectorizer.fit_transform([query])

        scored_memories = []
        for memory, timestamp, tags in memories:
            text = f"{memory['concept']} {memory['description']}"
            importance = memory.get('importance', 0.5)

            memory_vector = self.vectorizer.transform([text])
            similarity = cosine_similarity(query_vector, memory_vector)[0][0]

            if timestamp:
                recency = 1 / (1 + (datetime.now() - timestamp).total_seconds() / 60)  # Favor recent memories
            else:
                recency = 0.5  # Neutral recency for semantic memories

            score = (similarity + importance + recency) / 3
            scored_memories.append((memory, score))

        return scored_memories

    def section_exists(self, concept: str) -> bool:
        # Normalize the concept to lowercase
        concept = concept.lower()
        self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE ?", (f"{concept}%",))
        count = self.cursor.fetchone()[0]
        return count > 0

    def add_user_interaction(self, user_id: str, query: str, response: str):
        self.cursor.execute("INSERT INTO user_interactions (user_id, query, response, timestamp) VALUES (?, ?, ?, ?)",
                            (user_id, query, response, datetime.now().isoformat()))
        self.conn.commit()
        logging.info(f"User interaction added: User ID: {user_id}, Query: {query}, Response: {response}")

    def get_user_interactions(self, user_id: str) -> List[Dict[str, Any]]:
        self.cursor.execute("SELECT query, response, timestamp FROM user_interactions WHERE user_id = ?", (user_id,))
        interactions = self.cursor.fetchall()
        return [{"query": query, "response": response, "timestamp": timestamp} for query, response, timestamp in interactions]

    def cleanup_expired_interactions(self):
        cutoff_time = datetime.now() - timedelta(minutes=5)
        self.cursor.execute("DELETE FROM user_interactions WHERE timestamp < ?", (cutoff_time.isoformat(),))
        self.conn.commit()
        logging.info(f"Expired user interactions cleaned up. Cutoff time: {cutoff_time}")

    def get_section_description(self, section_name: str) -> str:
        # Normalize the section name to lowercase
        section_name = section_name.lower()

        # Retrieve the specific section from the database
        self.cursor.execute("SELECT description FROM semantic_memory WHERE concept LIKE ?", (f"{section_name}%",))
        result = self.cursor.fetchone()
        if result:
            logging.info(f"Found section: {section_name}")
            return result[0]
        else:
            logging.warning(f"Section not found: {section_name}")
            return ""

    def count_chroniques(self) -> int:
        # Count the number of chroniques in the database
        self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE 'chronique #%'")
        count = self.cursor.fetchone()[0]
        logging.info(f"Number of chroniques: {count}")
        return count

    def count_flash_infos(self) -> int:
        # Count the number of flash infos in the database
        self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE 'flash info fl-%'")
        count = self.cursor.fetchone()[0]
        logging.info(f"Number of flash infos: {count}")
        return count

    def count_chronique_faqs(self) -> int:
        # Count the number of chronique-faqs in the database
        self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE 'chronique-faq #%'")
        count = self.cursor.fetchone()[0]
        logging.info(f"Number of chronique-faqs: {count}")
        return count

if __name__ == "__main__":
    db_path = "agent.db"
    memory_manager = MemoryManager(db_path)
    memory_manager.cleanup_expired_interactions()