Spaces:
Running
Running
| from fastapi import FastAPI, Depends, HTTPException, UploadFile, File | |
| import pandas as pd | |
| import lancedb | |
| from functools import cached_property, lru_cache | |
| from pydantic import Field, BaseModel | |
| from typing import Optional, Dict, List, Annotated, Any | |
| from fastapi import APIRouter | |
| import uuid | |
| import io | |
| from io import BytesIO | |
| import csv | |
| # LlamaIndex imports | |
| from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex | |
| from llama_index.vector_stores.lancedb import LanceDBVectorStore | |
| from llama_index.embeddings.fastembed import FastEmbedEmbedding | |
| from llama_index.core.schema import TextNode | |
| from llama_index.core import StorageContext, load_index_from_storage | |
| import json | |
| import os | |
| import shutil | |
| router = APIRouter( | |
| prefix="/rag", | |
| tags=["rag"] | |
| ) | |
| # Configure global LlamaIndex settings | |
| Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5") | |
| tables_file_path = './data/tables.json' | |
| # Database connection dependency | |
| def get_db_connection(db_path: str = "./lancedb/dev"): | |
| return lancedb.connect(db_path) | |
| # Pydantic models | |
| class CreateTableResponse(BaseModel): | |
| table_id: str | |
| message: str | |
| status: str | |
| table_name: str | |
| class QueryTableResponse(BaseModel): | |
| results: Dict[str, Any] | |
| total_results: int | |
| async def create_embedding_table( | |
| user_id: str, | |
| files: List[UploadFile] = File(...), | |
| table_id: Optional[str] = None, | |
| table_name: Optional[str] = None | |
| ) -> CreateTableResponse: | |
| """Create a table and load embeddings from uploaded files using LlamaIndex.""" | |
| allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"} | |
| for file in files: | |
| if file.filename is None: | |
| raise HTTPException(status_code=400, detail="File must have a valid name.") | |
| file_extension = os.path.splitext(file.filename)[1].lower() | |
| if file_extension not in allowed_extensions: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"File type {file_extension} is not allowed. Supported file types are: {', '.join(allowed_extensions)}." | |
| ) | |
| if table_id is None: | |
| table_id = str(uuid.uuid4()) | |
| table_name = f"knowledge-base-{str(uuid.uuid4())[:4]}" if not table_name else table_name | |
| #table_name = table_id #f"{user_id}__table__{table_id}" | |
| # Create a directory for the uploaded files | |
| directory_path = f"./data/{table_id}" | |
| os.makedirs(directory_path, exist_ok=True) | |
| # Save each uploaded file to the data directory | |
| for file in files: | |
| file_path = os.path.join(directory_path, file.filename) | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| try: | |
| # Setup LanceDB vector store | |
| vector_store = LanceDBVectorStore( | |
| uri="./lancedb/dev", | |
| table_name=table_id, | |
| mode="overwrite", | |
| query_type="hybrid" | |
| ) | |
| # Load documents using SimpleDirectoryReader | |
| documents = SimpleDirectoryReader(directory_path).load_data() | |
| # Create the index | |
| index = VectorStoreIndex.from_documents( | |
| documents, | |
| vector_store=vector_store | |
| ) | |
| index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}") | |
| # Store user_id and table_name in a JSON file | |
| try: | |
| tables_file_path = './data/tables.json' | |
| os.makedirs(os.path.dirname(tables_file_path), exist_ok=True) | |
| # Load existing tables or create a new file if it doesn't exist | |
| try: | |
| with open(tables_file_path, 'r') as f: | |
| tables = json.load(f) | |
| except (FileNotFoundError, json.JSONDecodeError): | |
| tables = {} | |
| # Update the tables dictionary | |
| if user_id not in tables: | |
| tables[user_id] = [] | |
| if table_id not in [table['table_id'] for table in tables[user_id]]: | |
| tables[user_id].append({"table_id": table_id, "table_name": table_name}) | |
| # Write the updated tables back to the JSON file | |
| with open(tables_file_path, 'w') as f: | |
| json.dump(tables, f) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to update tables file: {str(e)}") | |
| return CreateTableResponse( | |
| table_id=table_id, | |
| message="Table created and documents indexed successfully", | |
| status="success", | |
| table_name=table_name | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Table creation failed: {str(e)}") | |
| async def query_table( | |
| table_id: str, | |
| query: str, | |
| user_id: str, | |
| #db: Annotated[Any, Depends(get_db_connection)], | |
| limit: Optional[int] = 10 | |
| ) -> QueryTableResponse: | |
| """Query the database table using LlamaIndex.""" | |
| try: | |
| table_name = table_id #f"{user_id}__table__{table_id}" | |
| # load index and retriever | |
| storage_context = StorageContext.from_defaults(persist_dir=f"./lancedb/index/{table_name}") | |
| index = load_index_from_storage(storage_context) | |
| retriever = index.as_retriever(similarity_top_k=limit) | |
| # Get response | |
| response = retriever.retrieve(query) | |
| # Format results | |
| results = [{ | |
| 'text': node.text, | |
| 'score': node.score | |
| } for node in response] | |
| return QueryTableResponse( | |
| results={'data': results}, | |
| total_results=len(results) | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}") | |
| async def get_tables(user_id: str): | |
| """Get all tables for a user.""" | |
| tables_file_path = './data/tables.json' | |
| try: | |
| # Load existing tables from the JSON file | |
| with open(tables_file_path, 'r') as f: | |
| tables = json.load(f) | |
| # Retrieve tables for the specified user | |
| user_tables = tables.get(user_id, []) | |
| return user_tables | |
| except (FileNotFoundError, json.JSONDecodeError): | |
| return [] # Return an empty list if the file doesn't exist or is invalid | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to retrieve tables: {str(e)}") | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| async def startup(): | |
| print("RAG Router started") | |
| from llama_index.core.schema import TextNode | |
| table_name = "digiyatra" | |
| nodes = [] | |
| vector_store = LanceDBVectorStore( | |
| uri="./lancedb/dev", | |
| table_name=table_name, | |
| mode="overwrite", | |
| query_type="hybrid" | |
| ) | |
| # load digiyatra csv and create node for each row using csv.reader | |
| with open('combined_digi_yatra.csv', newline='') as f: | |
| reader = csv.reader(f) | |
| data = list(reader) | |
| for row in data[1:]: | |
| node = TextNode(text=str(row), id_=str(uuid.uuid4())) | |
| nodes.append(node) | |
| index = VectorStoreIndex(nodes, vector_store=vector_store) | |
| index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}") | |
| # Create tables dictionary | |
| tables = {} | |
| user_id = "digiyatra" | |
| tables[user_id] = [ | |
| { | |
| "table_id": table_name, | |
| "table_name": table_name | |
| } | |
| ] | |
| with open(tables_file_path, 'w') as f: | |
| json.dump(tables, f) | |
| async def shutdown(): | |
| print("RAG Router shutdown") |