Spaces:
Sleeping
Sleeping
# file: main.py | |
import time | |
import os | |
import asyncio | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel, HttpUrl | |
from typing import List, Dict, Any | |
from dotenv import load_dotenv | |
# Import functions and classes from the new modular files | |
from document_processor import ingest_and_parse_document | |
from chunking import process_and_chunk | |
from embedding import EmbeddingClient | |
from retrieval import Retriever, generate_hypothetical_document | |
from generation import generate_answer | |
load_dotenv() | |
# --- FastAPI App Initialization --- | |
app = FastAPI( | |
title="Modular RAG API", | |
description="A modular API for Retrieval-Augmented Generation from documents.", | |
version="2.0.0", | |
) | |
# --- Global Clients and API Keys --- | |
GROQ_API_KEY = os.environ.get("GROQ_API_KEY") | |
embedding_client = EmbeddingClient() | |
retriever = Retriever(embedding_client=embedding_client) | |
# --- Pydantic Models --- | |
class RunRequest(BaseModel): | |
document_url: HttpUrl | |
questions: List[str] | |
class RunResponse(BaseModel): | |
answers: List[str] | |
class TestRequest(BaseModel): | |
document_url: HttpUrl | |
#Endpoints | |
# --- NEW: Test Endpoint for Parsing --- | |
async def test_parsing_endpoint(request: TestRequest): | |
""" | |
Tests the document ingestion and parsing phase. | |
Returns the full markdown content and the time taken. | |
""" | |
print("--- Running Parsing Test ---") | |
start_time = time.perf_counter() | |
try: | |
markdown_content = await ingest_and_parse_document(request.document_url) | |
end_time = time.perf_counter() | |
duration = end_time - start_time | |
print(f"--- Parsing took {duration:.2f} seconds ---") | |
return { | |
"parsing_time_seconds": duration, | |
"character_count": len(markdown_content), | |
"content": markdown_content | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An error occurred during parsing: {str(e)}") | |
async def run_rag_pipeline(request: RunRequest): | |
""" | |
Runs the full RAG pipeline for a given document URL and a list of questions. | |
""" | |
try: | |
# --- STAGE 1 & 2: DOCUMENT INGESTION AND CHUNKING --- | |
print("--- Kicking off RAG Pipeline ---") | |
markdown_content = await ingest_and_parse_document(request.document_url) | |
documents = process_and_chunk(markdown_content) | |
if not documents: | |
raise HTTPException(status_code=400, detail="Document could not be processed into chunks.") | |
# --- STAGE 3: INDEXING (Embedding + BM25) --- | |
# This step builds the search index for the current document. | |
retriever.index(documents) | |
# --- CONCURRENT WORKFLOW FOR ALL QUESTIONS --- | |
# Step A: Concurrently generate hypothetical documents for all questions | |
hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in request.questions] | |
all_hyde_docs = await asyncio.gather(*hyde_tasks) | |
# Step B: Concurrently retrieve relevant chunks for all questions | |
retrieval_tasks = [ | |
retriever.retrieve(q, hyde_doc) | |
for q, hyde_doc in zip(request.questions, all_hyde_docs) | |
] | |
all_retrieved_chunks = await asyncio.gather(*retrieval_tasks) | |
# Step C: Concurrently generate final answers for all questions | |
answer_tasks = [ | |
generate_answer(q, chunks, GROQ_API_KEY) | |
for q, chunks in zip(request.questions, all_retrieved_chunks) | |
] | |
final_answers = await asyncio.gather(*answer_tasks) | |
print("--- RAG Pipeline Completed Successfully ---") | |
return RunResponse(answers=final_answers) | |
except Exception as e: | |
print(f"An unhandled error occurred in the pipeline: {e}") | |
# Re-raising as a 500 error for the client | |
raise HTTPException( | |
status_code=500, detail=f"An internal server error occurred: {str(e)}" | |
) | |
async def test_chunking_endpoint(request: TestRequest): | |
""" | |
Tests both the parsing and chunking phases together. | |
Returns the final list of chunks and the total time taken. | |
""" | |
print("--- Running Parsing and Chunking Test ---") | |
start_time = time.perf_counter() | |
try: | |
# Step 1: Parse the document | |
markdown_content = await ingest_and_parse_document(request.document_url) | |
# Step 2: Chunk the parsed content | |
documents = process_and_chunk(markdown_content) | |
end_time = time.perf_counter() | |
duration = end_time - start_time | |
print(f"--- Parsing and Chunking took {duration:.2f} seconds ---") | |
# Convert Document objects to a JSON-serializable list | |
chunk_results = [ | |
{"page_content": doc.page_content, "metadata": doc.metadata} | |
for doc in documents | |
] | |
return { | |
"total_time_seconds": duration, | |
"chunk_count": len(chunk_results), | |
"chunks": chunk_results | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An error occurred during chunking: {str(e)}") |