Spaces:
Running
Running
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
from pydantic import BaseModel | |
from typing import List | |
import os | |
from dotenv import load_dotenv | |
# Import our local utilities instead of aimakerspace | |
from text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader | |
from openai_utils import SystemRolePrompt, UserRolePrompt, ChatOpenAI | |
from vector_store import VectorDatabase | |
load_dotenv() | |
app = FastAPI() | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Prompt templates | |
system_template = """\ | |
Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer.""" | |
system_role_prompt = SystemRolePrompt(system_template) | |
user_prompt_template = """\ | |
Context: | |
{context} | |
Question: | |
{question} | |
""" | |
user_role_prompt = UserRolePrompt(user_prompt_template) | |
# Initialize components | |
text_splitter = CharacterTextSplitter() | |
chat_openai = ChatOpenAI() | |
# Store vector databases for each session | |
vector_dbs = {} | |
class QueryRequest(BaseModel): | |
session_id: str | |
query: str | |
async def upload_file(file: UploadFile = File(...)): | |
try: | |
# Save file temporarily | |
file_path = f"temp_{file.filename}" | |
with open(file_path, "wb") as f: | |
content = await file.read() | |
f.write(content) | |
# Process file | |
loader = PDFLoader(file_path) if file.filename.lower().endswith('.pdf') else TextFileLoader(file_path) | |
documents = loader.load_documents() | |
texts = text_splitter.split_texts(documents) | |
# Create vector database | |
vector_db = VectorDatabase() | |
vector_db = await vector_db.abuild_from_list(texts) | |
# Generate session ID and store vector_db | |
import uuid | |
session_id = str(uuid.uuid4()) | |
vector_dbs[session_id] = vector_db | |
# Cleanup | |
os.remove(file_path) | |
return {"session_id": session_id, "message": "File processed successfully"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def query(request: QueryRequest): | |
try: | |
vector_db = vector_dbs.get(request.session_id) | |
if not vector_db: | |
raise HTTPException(status_code=404, detail="Session not found") | |
# Retrieve context | |
context_list = await vector_db.search_by_text(request.query, k=4) | |
context_prompt = "\n".join([str(context[0]) for context in context_list]) | |
# Generate prompts | |
formatted_system_prompt = system_role_prompt.create_message() | |
formatted_user_prompt = user_role_prompt.create_message( | |
question=request.query, | |
context=context_prompt | |
) | |
# Get response | |
response = await chat_openai.acomplete( | |
[formatted_system_prompt, formatted_user_prompt] | |
) | |
return { | |
"answer": str(response), | |
"context": [str(context[0]) for context in context_list] | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Optional: Cleanup endpoint | |
async def cleanup_session(session_id: str): | |
if session_id in vector_dbs: | |
del vector_dbs[session_id] | |
return {"message": "Session cleaned up successfully"} | |
raise HTTPException(status_code=404, detail="Session not found") | |
# Serve static files from static directory | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
async def read_root(): | |
return FileResponse('static/index.html') |