ChadGPT / main.py
PercivalFletcher's picture
Upload 6 files
a19a241 verified
raw
history blame
5.31 kB
# file: main.py
import time
import os
import asyncio
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, HttpUrl
from typing import List, Dict, Any
from dotenv import load_dotenv
# Import functions and classes from the new modular files
from document_processor import ingest_and_parse_document
from chunking import process_and_chunk
from embedding import EmbeddingClient
from retrieval import Retriever, generate_hypothetical_document
from generation import generate_answer
load_dotenv()
# --- FastAPI App Initialization ---
app = FastAPI(
title="Modular RAG API",
description="A modular API for Retrieval-Augmented Generation from documents.",
version="2.0.0",
)
# --- Global Clients and API Keys ---
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
embedding_client = EmbeddingClient()
retriever = Retriever(embedding_client=embedding_client)
# --- Pydantic Models ---
class RunRequest(BaseModel):
document_url: HttpUrl
questions: List[str]
class RunResponse(BaseModel):
answers: List[str]
class TestRequest(BaseModel):
document_url: HttpUrl
#Endpoints
# --- NEW: Test Endpoint for Parsing ---
@app.post("/test/parse", response_model=Dict[str, Any], tags=["Testing"])
async def test_parsing_endpoint(request: TestRequest):
"""
Tests the document ingestion and parsing phase.
Returns the full markdown content and the time taken.
"""
print("--- Running Parsing Test ---")
start_time = time.perf_counter()
try:
markdown_content = await ingest_and_parse_document(request.document_url)
end_time = time.perf_counter()
duration = end_time - start_time
print(f"--- Parsing took {duration:.2f} seconds ---")
return {
"parsing_time_seconds": duration,
"character_count": len(markdown_content),
"content": markdown_content
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred during parsing: {str(e)}")
@app.post("/hackrx/run", response_model=RunResponse)
async def run_rag_pipeline(request: RunRequest):
"""
Runs the full RAG pipeline for a given document URL and a list of questions.
"""
try:
# --- STAGE 1 & 2: DOCUMENT INGESTION AND CHUNKING ---
print("--- Kicking off RAG Pipeline ---")
markdown_content = await ingest_and_parse_document(request.document_url)
documents = process_and_chunk(markdown_content)
if not documents:
raise HTTPException(status_code=400, detail="Document could not be processed into chunks.")
# --- STAGE 3: INDEXING (Embedding + BM25) ---
# This step builds the search index for the current document.
retriever.index(documents)
# --- CONCURRENT WORKFLOW FOR ALL QUESTIONS ---
# Step A: Concurrently generate hypothetical documents for all questions
hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in request.questions]
all_hyde_docs = await asyncio.gather(*hyde_tasks)
# Step B: Concurrently retrieve relevant chunks for all questions
retrieval_tasks = [
retriever.retrieve(q, hyde_doc)
for q, hyde_doc in zip(request.questions, all_hyde_docs)
]
all_retrieved_chunks = await asyncio.gather(*retrieval_tasks)
# Step C: Concurrently generate final answers for all questions
answer_tasks = [
generate_answer(q, chunks, GROQ_API_KEY)
for q, chunks in zip(request.questions, all_retrieved_chunks)
]
final_answers = await asyncio.gather(*answer_tasks)
print("--- RAG Pipeline Completed Successfully ---")
return RunResponse(answers=final_answers)
except Exception as e:
print(f"An unhandled error occurred in the pipeline: {e}")
# Re-raising as a 500 error for the client
raise HTTPException(
status_code=500, detail=f"An internal server error occurred: {str(e)}"
)
@app.post("/test/chunk", response_model=Dict[str, Any], tags=["Testing"])
async def test_chunking_endpoint(request: TestRequest):
"""
Tests both the parsing and chunking phases together.
Returns the final list of chunks and the total time taken.
"""
print("--- Running Parsing and Chunking Test ---")
start_time = time.perf_counter()
try:
# Step 1: Parse the document
markdown_content = await ingest_and_parse_document(request.document_url)
# Step 2: Chunk the parsed content
documents = process_and_chunk(markdown_content)
end_time = time.perf_counter()
duration = end_time - start_time
print(f"--- Parsing and Chunking took {duration:.2f} seconds ---")
# Convert Document objects to a JSON-serializable list
chunk_results = [
{"page_content": doc.page_content, "metadata": doc.metadata}
for doc in documents
]
return {
"total_time_seconds": duration,
"chunk_count": len(chunk_results),
"chunks": chunk_results
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred during chunking: {str(e)}")