from fastapi import FastAPI, Depends, HTTPException, UploadFile, File
import pandas as pd
import lancedb
from functools import cached_property, lru_cache
from pydantic import Field, BaseModel
from typing import Optional, Dict, List, Annotated, Any
from fastapi import APIRouter
import uuid
import io
from io import BytesIO
import csv
import sqlite3

# LlamaIndex imports
from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.core.schema import TextNode
from llama_index.core import StorageContext, load_index_from_storage
import json
import os
import shutil

router = APIRouter(
    prefix="/rag",
    tags=["rag"]
)

# Configure global LlamaIndex settings
Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")

# Database connection dependency
@lru_cache()
def get_db_connection(db_path: str = "./lancedb/dev"):
    return lancedb.connect(db_path)

def get_db():
   conn = sqlite3.connect('./data/tablesv2.db')
   conn.row_factory = sqlite3.Row
   return conn

def init_db():
   db = get_db()
   db.execute('''
       CREATE TABLE IF NOT EXISTS tables (
           id INTEGER PRIMARY KEY,
           user_id TEXT NOT NULL,
           table_id TEXT NOT NULL,
           table_name TEXT NOT NULL,
           created_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
       )
   ''')
   db.execute('''
        CREATE TABLE IF NOT EXISTS table_files (
            id INTEGER PRIMARY KEY,
            table_id TEXT NOT NULL,
            filename TEXT NOT NULL,
            file_path TEXT NOT NULL,
            FOREIGN KEY (table_id) REFERENCES tables (table_id),
            UNIQUE(table_id, filename)
        )
   ''')
   db.commit()

# Pydantic models
class CreateTableResponse(BaseModel):
    table_id: str
    message: str
    status: str
    table_name: str

class QueryTableResponse(BaseModel):
    results: Dict[str, Any]
    total_results: int


@router.post("/create_table", response_model=CreateTableResponse)
async def create_embedding_table(
  user_id: str,
  files: List[UploadFile] = File(...),
  table_id: Optional[str] = None,
  table_name: Optional[str] = None
) -> CreateTableResponse:
  try:
      db = get_db()
      table_id = table_id or str(uuid.uuid4())
      table_name = table_name or f"knowledge-base-{str(uuid.uuid4())[:4]}"
      
      # Check if table exists
      existing = db.execute(
          'SELECT id FROM tables WHERE user_id = ? AND table_id = ?', 
          (user_id, table_id)
      ).fetchone()

      directory_path = f"./data/{table_id}"
      os.makedirs(directory_path, exist_ok=True)

      for file in files:
          if not file.filename:
              raise HTTPException(status_code=400, detail="Invalid filename")
          if os.path.splitext(file.filename)[1].lower() not in {".pdf", ".docx", ".csv", ".txt", ".md"}:
              raise HTTPException(status_code=400, detail="Unsupported file type")
              
          file_path = os.path.join(directory_path, file.filename)
          with open(file_path, "wb") as buffer:
              shutil.copyfileobj(file.file, buffer)

      vector_store = LanceDBVectorStore(
          uri="./lancedb/dev",
          table_name=table_id,
          mode="overwrite",
          query_type="hybrid"
      )

      documents = SimpleDirectoryReader(directory_path).load_data()
      index = VectorStoreIndex.from_documents(documents, vector_store=vector_store)
      index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}")

      if not existing:
          db.execute(
              'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)',
              (user_id, table_id, table_name)
          )
      
      for file in files:
          db.execute(
              'INSERT OR REPLACE INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)',
              (table_id, file.filename, f"./data/{table_id}/{file.filename}")
          )
      db.commit()

      return CreateTableResponse(
          table_id=table_id,
          message="Success",
          status="success",
          table_name=table_name
      )

  except Exception as e:
      raise HTTPException(status_code=500, detail=str(e))


@router.post("/query_table/{table_id}", response_model=QueryTableResponse)
async def query_table(
    table_id: str,
    query: str,
    user_id: str,
    #db: Annotated[Any, Depends(get_db_connection)],
    limit: Optional[int] = 10
) -> QueryTableResponse:
    """Query the database table using LlamaIndex."""
    try:
        table_name = table_id  #f"{user_id}__table__{table_id}"
        
        # load index and retriever
        storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}")
        index = load_index_from_storage(storage_context)
        retriever = index.as_retriever(similarity_top_k=limit)
        
        # Get response
        response = retriever.retrieve(query)
        
        # Format results
        results = [{
            'text': node.text,
            'score': node.score
        } for node in response]
        
        return QueryTableResponse(
            results={'data': results},
            total_results=len(results)
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")

@router.get("/get_tables/{user_id}")
async def get_tables(user_id: str):
   db = get_db()
   tables = db.execute('''
       SELECT 
           t.table_id,
           t.table_name,
           t.created_time as created_at,
           GROUP_CONCAT(tf.filename) as filenames
       FROM tables t
       LEFT JOIN table_files tf ON t.table_id = tf.table_id
       WHERE t.user_id = ?
       GROUP BY t.table_id
   ''', (user_id,)).fetchall()
   
   result = []
   for table in tables:
       table_dict = dict(table)
       result.append({
           'table_id': table_dict['table_id'],
           'table_name': table_dict['table_name'],
           'created_at': table_dict['created_at'],
           'documents': [filename for filename in table_dict['filenames'].split(',') if filename] if table_dict['filenames'] else []
       })
   
   return result


@router.delete("/delete_table/{table_id}")
async def delete_table(table_id: str, user_id: str):
   try:
       db = get_db()
       
       # Verify user owns the table
       table = db.execute(
           'SELECT * FROM tables WHERE table_id = ? AND user_id = ?',
           (table_id, user_id)
       ).fetchone()
       
       if not table:
           raise HTTPException(status_code=404, detail="Table not found or unauthorized")

       # Delete files from filesystem
       table_path = f"./data/{table_id}"
       index_path = f"./lancedb/index/{table_id}"
       if os.path.exists(table_path):
           shutil.rmtree(table_path)
       if os.path.exists(index_path):
           shutil.rmtree(index_path)

       # Delete from database
       db.execute('DELETE FROM table_files WHERE table_id = ?', (table_id,))
       db.execute('DELETE FROM tables WHERE table_id = ?', (table_id,))
       db.commit()

       return {"message": "Table deleted successfully"}

   except Exception as e:
       raise HTTPException(status_code=500, detail=str(e))


@router.get("/health")
async def health_check():
    return {"status": "healthy"}

@router.on_event("startup")
async def startup():
    init_db()
    print("RAG Router started")
    
    table_name = "digiyatra"
    user_id = "digiyatra"
    
    db = get_db()
    # Check if table already exists
    existing = db.execute('SELECT id FROM tables WHERE table_id = ?', (table_name,)).fetchone()
    if not existing:
        vector_store = LanceDBVectorStore(
            uri="./lancedb/dev",
            table_name=table_name,
            mode="overwrite", 
            query_type="hybrid"
        )

        with open('combined_digi_yatra.csv', newline='') as f:
            nodes = [TextNode(text=str(row), id_=str(uuid.uuid4())) 
                    for row in list(csv.reader(f))[1:]]

        index = VectorStoreIndex(nodes, vector_store=vector_store)
        index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")

        db.execute(
            'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)',
            (user_id, table_name, table_name)
        )
        db.execute(
            'INSERT INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)',
            (table_name, 'combined_digi_yatra.csv', 'combined_digi_yatra.csv')
        )
        db.commit()

@router.on_event("shutdown")
async def shutdown():
    print("RAG Router shutdown")