import os, sys
from pathlib import Path
import aiosqlite
import asyncio
from typing import Optional, Tuple, Dict
from contextlib import asynccontextmanager
import logging
import json  # Added for serialization/deserialization
from .utils import ensure_content_dirs, generate_content_hash
from .models import CrawlResult, MarkdownGenerationResult
import xxhash
import aiofiles
from .config import NEED_MIGRATION
from .version_manager import VersionManager
from .async_logger import AsyncLogger
from .utils import get_error_context, create_box_message
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

base_directory = DB_PATH = os.path.join(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home()), ".crawl4ai")
os.makedirs(DB_PATH, exist_ok=True)
DB_PATH = os.path.join(base_directory, "crawl4ai.db")

class AsyncDatabaseManager:
    def __init__(self, pool_size: int = 10, max_retries: int = 3):
        self.db_path = DB_PATH
        self.content_paths = ensure_content_dirs(os.path.dirname(DB_PATH))
        self.pool_size = pool_size
        self.max_retries = max_retries
        self.connection_pool: Dict[int, aiosqlite.Connection] = {}
        self.pool_lock = asyncio.Lock()
        self.init_lock = asyncio.Lock()
        self.connection_semaphore = asyncio.Semaphore(pool_size)
        self._initialized = False  
        self.version_manager = VersionManager()
        self.logger = AsyncLogger(
            log_file=os.path.join(base_directory, ".crawl4ai", "crawler_db.log"),
            verbose=False,
            tag_width=10
        )
        
        
    async def initialize(self):
        """Initialize the database and connection pool"""
        try:
            self.logger.info("Initializing database", tag="INIT")
            # Ensure the database file exists
            os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
            
            # Check if version update is needed
            needs_update = self.version_manager.needs_update()
            
            # Always ensure base table exists
            await self.ainit_db()
            
            # Verify the table exists
            async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
                async with db.execute(
                    "SELECT name FROM sqlite_master WHERE type='table' AND name='crawled_data'"
                ) as cursor:
                    result = await cursor.fetchone()
                    if not result:
                        raise Exception("crawled_data table was not created")
            
            # If version changed or fresh install, run updates
            if needs_update:
                self.logger.info("New version detected, running updates", tag="INIT")
                await self.update_db_schema()
                from .migrations import run_migration  # Import here to avoid circular imports
                await run_migration()
                self.version_manager.update_version()  # Update stored version after successful migration
                self.logger.success("Version update completed successfully", tag="COMPLETE")
            else:
                self.logger.success("Database initialization completed successfully", tag="COMPLETE")

                
        except Exception as e:
            self.logger.error(
                message="Database initialization error: {error}",
                tag="ERROR",
                params={"error": str(e)}
            )
            self.logger.info(
                message="Database will be initialized on first use",
                tag="INIT"
            )
                        
            raise

            
    async def cleanup(self):
        """Cleanup connections when shutting down"""
        async with self.pool_lock:
            for conn in self.connection_pool.values():
                await conn.close()
            self.connection_pool.clear()

    @asynccontextmanager
    async def get_connection(self):
        """Connection pool manager with enhanced error handling"""
        if not self._initialized:
            async with self.init_lock:
                if not self._initialized:
                    try:
                        await self.initialize()
                        self._initialized = True
                    except Exception as e:
                        import sys
                        error_context = get_error_context(sys.exc_info())
                        self.logger.error(
                            message="Database initialization failed:\n{error}\n\nContext:\n{context}\n\nTraceback:\n{traceback}",
                            tag="ERROR",
                            force_verbose=True,
                            params={
                                "error": str(e),
                                "context": error_context["code_context"],
                                "traceback": error_context["full_traceback"]
                            }
                        )
                        raise

        await self.connection_semaphore.acquire()
        task_id = id(asyncio.current_task())
        
        try:
            async with self.pool_lock:
                if task_id not in self.connection_pool:
                    try:
                        conn = await aiosqlite.connect(
                            self.db_path,
                            timeout=30.0
                        )
                        await conn.execute('PRAGMA journal_mode = WAL')
                        await conn.execute('PRAGMA busy_timeout = 5000')
                        
                        # Verify database structure
                        async with conn.execute("PRAGMA table_info(crawled_data)") as cursor:
                            columns = await cursor.fetchall()
                            column_names = [col[1] for col in columns]
                            expected_columns = {
                                'url', 'html', 'cleaned_html', 'markdown', 'extracted_content',
                                'success', 'media', 'links', 'metadata', 'screenshot',
                                'response_headers', 'downloaded_files'
                            }
                            missing_columns = expected_columns - set(column_names)
                            if missing_columns:
                                raise ValueError(f"Database missing columns: {missing_columns}")
                        
                        self.connection_pool[task_id] = conn
                    except Exception as e:
                        import sys
                        error_context = get_error_context(sys.exc_info())
                        error_message = (
                            f"Unexpected error in db get_connection at line {error_context['line_no']} "
                            f"in {error_context['function']} ({error_context['filename']}):\n"
                            f"Error: {str(e)}\n\n"
                            f"Code context:\n{error_context['code_context']}"
                        )
                        self.logger.error(
                            message=create_box_message(error_message, type= "error"),
                        )

                        raise

            yield self.connection_pool[task_id]

        except Exception as e:
            import sys
            error_context = get_error_context(sys.exc_info())
            error_message = (
                f"Unexpected error in db get_connection at line {error_context['line_no']} "
                f"in {error_context['function']} ({error_context['filename']}):\n"
                f"Error: {str(e)}\n\n"
                f"Code context:\n{error_context['code_context']}"
            )
            self.logger.error(
                message=create_box_message(error_message, type= "error"),
            )
            raise
        finally:
            async with self.pool_lock:
                if task_id in self.connection_pool:
                    await self.connection_pool[task_id].close()
                    del self.connection_pool[task_id]
            self.connection_semaphore.release()


    async def execute_with_retry(self, operation, *args):
        """Execute database operations with retry logic"""
        for attempt in range(self.max_retries):
            try:
                async with self.get_connection() as db:
                    result = await operation(db, *args)
                    await db.commit()
                    return result
            except Exception as e:
                if attempt == self.max_retries - 1:
                    self.logger.error(
                        message="Operation failed after {retries} attempts: {error}",
                        tag="ERROR",
                        force_verbose=True,
                        params={
                            "retries": self.max_retries,
                            "error": str(e)
                        }
                    )                    
                    raise
                await asyncio.sleep(1 * (attempt + 1))  # Exponential backoff

    async def ainit_db(self):
        """Initialize database schema"""
        async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
            await db.execute('''
                CREATE TABLE IF NOT EXISTS crawled_data (
                    url TEXT PRIMARY KEY,
                    html TEXT,
                    cleaned_html TEXT,
                    markdown TEXT,
                    extracted_content TEXT,
                    success BOOLEAN,
                    media TEXT DEFAULT "{}",
                    links TEXT DEFAULT "{}",
                    metadata TEXT DEFAULT "{}",
                    screenshot TEXT DEFAULT "",
                    response_headers TEXT DEFAULT "{}",
                    downloaded_files TEXT DEFAULT "{}"  -- New column added
                )
            ''')
            await db.commit()

        

    async def update_db_schema(self):
        """Update database schema if needed"""
        async with aiosqlite.connect(self.db_path, timeout=30.0) as db:
            cursor = await db.execute("PRAGMA table_info(crawled_data)")
            columns = await cursor.fetchall()
            column_names = [column[1] for column in columns]
            
            # List of new columns to add
            new_columns = ['media', 'links', 'metadata', 'screenshot', 'response_headers', 'downloaded_files']
            
            for column in new_columns:
                if column not in column_names:
                    await self.aalter_db_add_column(column, db)
            await db.commit()

    async def aalter_db_add_column(self, new_column: str, db):
        """Add new column to the database"""
        if new_column == 'response_headers':
            await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT "{{}}"')
        else:
            await db.execute(f'ALTER TABLE crawled_data ADD COLUMN {new_column} TEXT DEFAULT ""')
        self.logger.info(
            message="Added column '{column}' to the database",
            tag="INIT",
            params={"column": new_column}
        )        


    async def aget_cached_url(self, url: str) -> Optional[CrawlResult]:
        """Retrieve cached URL data as CrawlResult"""
        async def _get(db):
            async with db.execute(
                'SELECT * FROM crawled_data WHERE url = ?', (url,)
            ) as cursor:
                row = await cursor.fetchone()
                if not row:
                    return None
                    
                # Get column names
                columns = [description[0] for description in cursor.description]
                # Create dict from row data
                row_dict = dict(zip(columns, row))
                
                # Load content from files using stored hashes
                content_fields = {
                    'html': row_dict['html'],
                    'cleaned_html': row_dict['cleaned_html'],
                    'markdown': row_dict['markdown'],
                    'extracted_content': row_dict['extracted_content'],
                    'screenshot': row_dict['screenshot'],
                    'screenshots': row_dict['screenshot'],
                }
                
                for field, hash_value in content_fields.items():
                    if hash_value:
                        content = await self._load_content(
                            hash_value, 
                            field.split('_')[0]  # Get content type from field name
                        )
                        row_dict[field] = content or ""
                    else:
                        row_dict[field] = ""

                # Parse JSON fields
                json_fields = ['media', 'links', 'metadata', 'response_headers', 'markdown']
                for field in json_fields:
                    try:
                        row_dict[field] = json.loads(row_dict[field]) if row_dict[field] else {}
                    except json.JSONDecodeError:
                        row_dict[field] = {}

                if isinstance(row_dict['markdown'], Dict):
                    row_dict['markdown_v2'] = row_dict['markdown']
                    if row_dict['markdown'].get('raw_markdown'):
                        row_dict['markdown'] = row_dict['markdown']['raw_markdown']
                
                # Parse downloaded_files
                try:
                    row_dict['downloaded_files'] = json.loads(row_dict['downloaded_files']) if row_dict['downloaded_files'] else []
                except json.JSONDecodeError:
                    row_dict['downloaded_files'] = []

                # Remove any fields not in CrawlResult model
                valid_fields = CrawlResult.__annotations__.keys()
                filtered_dict = {k: v for k, v in row_dict.items() if k in valid_fields}
                
                return CrawlResult(**filtered_dict)

        try:
            return await self.execute_with_retry(_get)
        except Exception as e:
            self.logger.error(
                message="Error retrieving cached URL: {error}",
                tag="ERROR",
                force_verbose=True,
                params={"error": str(e)}
            )
            return None

    async def acache_url(self, result: CrawlResult):
        """Cache CrawlResult data"""
        # Store content files and get hashes
        content_map = {
            'html': (result.html, 'html'),
            'cleaned_html': (result.cleaned_html or "", 'cleaned'),
            'markdown': None,
            'extracted_content': (result.extracted_content or "", 'extracted'),
            'screenshot': (result.screenshot or "", 'screenshots')
        }

        try:
            if isinstance(result.markdown, MarkdownGenerationResult):
                content_map['markdown'] = (result.markdown.model_dump_json(), 'markdown')
            elif hasattr(result, 'markdown_v2'):
                content_map['markdown'] = (result.markdown_v2.model_dump_json(), 'markdown')
            elif isinstance(result.markdown, str):
                markdown_result = MarkdownGenerationResult(raw_markdown=result.markdown)
                content_map['markdown'] = (markdown_result.model_dump_json(), 'markdown')
            else:
                content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown')
        except Exception as e:
            self.logger.warning(
                message=f"Error processing markdown content: {str(e)}",
                tag="WARNING"
            )
            # Fallback to empty markdown result
            content_map['markdown'] = (MarkdownGenerationResult().model_dump_json(), 'markdown')
        
        content_hashes = {}
        for field, (content, content_type) in content_map.items():
            content_hashes[field] = await self._store_content(content, content_type)

        async def _cache(db):
            await db.execute('''
                INSERT INTO crawled_data (
                    url, html, cleaned_html, markdown,
                    extracted_content, success, media, links, metadata,
                    screenshot, response_headers, downloaded_files
                )
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                ON CONFLICT(url) DO UPDATE SET
                    html = excluded.html,
                    cleaned_html = excluded.cleaned_html,
                    markdown = excluded.markdown,
                    extracted_content = excluded.extracted_content,
                    success = excluded.success,
                    media = excluded.media,
                    links = excluded.links,
                    metadata = excluded.metadata,
                    screenshot = excluded.screenshot,
                    response_headers = excluded.response_headers,
                    downloaded_files = excluded.downloaded_files
            ''', (
                result.url,
                content_hashes['html'],
                content_hashes['cleaned_html'],
                content_hashes['markdown'],
                content_hashes['extracted_content'],
                result.success,
                json.dumps(result.media),
                json.dumps(result.links),
                json.dumps(result.metadata or {}),
                content_hashes['screenshot'],
                json.dumps(result.response_headers or {}),
                json.dumps(result.downloaded_files or [])
            ))

        try:
            await self.execute_with_retry(_cache)
        except Exception as e:
            self.logger.error(
                message="Error caching URL: {error}",
                tag="ERROR",
                force_verbose=True,
                params={"error": str(e)}
            )
            

    async def aget_total_count(self) -> int:
        """Get total number of cached URLs"""
        async def _count(db):
            async with db.execute('SELECT COUNT(*) FROM crawled_data') as cursor:
                result = await cursor.fetchone()
                return result[0] if result else 0

        try:
            return await self.execute_with_retry(_count)
        except Exception as e:
            self.logger.error(
                message="Error getting total count: {error}",
                tag="ERROR",
                force_verbose=True,
                params={"error": str(e)}
            )
            return 0

    async def aclear_db(self):
        """Clear all data from the database"""
        async def _clear(db):
            await db.execute('DELETE FROM crawled_data')

        try:
            await self.execute_with_retry(_clear)
        except Exception as e:
            self.logger.error(
                message="Error clearing database: {error}",
                tag="ERROR",
                force_verbose=True,
                params={"error": str(e)}
            )

    async def aflush_db(self):
        """Drop the entire table"""
        async def _flush(db):
            await db.execute('DROP TABLE IF EXISTS crawled_data')

        try:
            await self.execute_with_retry(_flush)
        except Exception as e:
            self.logger.error(
                message="Error flushing database: {error}",
                tag="ERROR",
                force_verbose=True,
                params={"error": str(e)}
            )
            
                
    async def _store_content(self, content: str, content_type: str) -> str:
        """Store content in filesystem and return hash"""
        if not content:
            return ""
            
        content_hash = generate_content_hash(content)
        file_path = os.path.join(self.content_paths[content_type], content_hash)
        
        # Only write if file doesn't exist
        if not os.path.exists(file_path):
            async with aiofiles.open(file_path, 'w', encoding='utf-8') as f:
                await f.write(content)
                
        return content_hash

    async def _load_content(self, content_hash: str, content_type: str) -> Optional[str]:
        """Load content from filesystem by hash"""
        if not content_hash:
            return None
            
        file_path = os.path.join(self.content_paths[content_type], content_hash)
        try:
            async with aiofiles.open(file_path, 'r', encoding='utf-8') as f:
                return await f.read()
        except:
            self.logger.error(
                message="Failed to load content: {file_path}",
                tag="ERROR",
                force_verbose=True,
                params={"file_path": file_path}
            )
            return None

# Create a singleton instance
async_db_manager = AsyncDatabaseManager()