Spaces:
Sleeping
Sleeping
import os | |
import json | |
import tempfile | |
import requests | |
from fastapi import FastAPI, HTTPException, Depends, status | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from pydantic import BaseModel | |
from typing import List, Dict, Union, Any, Optional | |
from dotenv import load_dotenv | |
import asyncio | |
import httpx | |
import time | |
from urllib.parse import urlparse, unquote | |
import uuid | |
import re | |
# Import LangChain Document and text splitter | |
from langchain_core.documents import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from processing_utility import ( | |
extract_schema_from_file, | |
process_document, | |
download_and_parse_document_using_llama_index, | |
) | |
# Import the new classes and functions from rag_utils | |
from rag_utils import ( | |
process_markdown_with_recursive_chunking, | |
generate_answer_with_groq, | |
generate_hypothetical_document, | |
HybridSearchManager, | |
EmbeddingClient, | |
CHUNK_SIZE, | |
CHUNK_OVERLAP, | |
TOP_K_CHUNKS, | |
GROQ_MODEL_NAME, | |
) | |
load_dotenv() | |
# --- FastAPI App Initialization --- | |
app = FastAPI( | |
title="HackRX RAG API", | |
description="API for Retrieval-Augmented Generation from PDF documents.", | |
version="1.0.0", | |
) | |
# --- Global instance for the HybridSearchManager --- | |
hybrid_search_manager: Optional[HybridSearchManager] = None | |
async def startup_event(): | |
global hybrid_search_manager | |
hybrid_search_manager = HybridSearchManager() | |
#initialize_llama_extract_agent() | |
print("Application startup complete. HybridSearchManager is ready.") | |
# --- Groq API Key Setup --- | |
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "NOT_FOUND") | |
if GROQ_API_KEY == "NOT_FOUND": | |
print( | |
"WARNING: GROQ_API_KEY is using a placeholder or hardcoded value. Please set GROQ_API_KEY environment variable for production." | |
) | |
# --- Pydantic Models for Request and Response --- | |
class RunRequest(BaseModel): | |
documents: str | |
questions: List[str] | |
class Answer(BaseModel): | |
answer: str | |
class RunResponse(BaseModel): | |
answers: List[str] | |
#step_timings: Dict[str, float] | |
#hypothetical_documents: List[str] | |
async def run_rag_pipeline( | |
request: RunRequest | |
): | |
""" | |
Runs the RAG pipeline for a given PDF document (converted to Markdown internally) | |
and a list of questions. | |
""" | |
pdf_url = request.documents | |
questions = request.questions | |
local_markdown_path = None | |
step_timings = {} | |
start_time_total = time.perf_counter() | |
try: | |
if hybrid_search_manager is None: | |
raise HTTPException( | |
status_code=500, detail="HybridSearchManager not initialized." | |
) | |
# 1. Parsing: Download PDF and parse to Markdown | |
start_time = time.perf_counter() | |
markdown_content = await download_and_parse_document_using_llama_index(pdf_url) | |
with tempfile.NamedTemporaryFile( | |
mode="w", delete=False, encoding="utf-8", suffix=".md" | |
) as temp_md_file: | |
temp_md_file.write(markdown_content) | |
local_markdown_path = temp_md_file.name | |
end_time = time.perf_counter() | |
step_timings["parsing_to_markdown"] = end_time - start_time | |
print( | |
f"Parsing to Markdown took {step_timings['parsing_to_markdown']:.2f} seconds." | |
) | |
# 2. Chunk Generation: Process Markdown into chunks | |
start_time = time.perf_counter() | |
processed_documents = process_markdown_with_recursive_chunking( | |
local_markdown_path, | |
CHUNK_SIZE, | |
CHUNK_OVERLAP, | |
) | |
if not processed_documents: | |
raise HTTPException( | |
status_code=500, detail="Failed to process document into chunks." | |
) | |
end_time = time.perf_counter() | |
step_timings["chunk_generation"] = end_time - start_time | |
print( | |
f"Chunk Generation took {step_timings['chunk_generation']:.2f} seconds." | |
) | |
# 3. Model Initialization and Embeddings Pre-computation | |
start_time = time.perf_counter() | |
await hybrid_search_manager.initialize_models(processed_documents) | |
end_time = time.perf_counter() | |
step_timings["model_initialization"] = end_time - start_time | |
print( | |
f"Model initialization took {step_timings['model_initialization']:.2f} seconds." | |
) | |
# --- NEW CONCURRENT WORKFLOW --- | |
# 4. Concurrently generate all hypothetical documents | |
start_time_hyde = time.perf_counter() | |
hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in questions] | |
all_hyde_docs = await asyncio.gather(*hyde_tasks) | |
end_time_hyde = time.perf_counter() | |
step_timings["hyde_generation_total_time"] = end_time_hyde - start_time_hyde | |
step_timings["hyde_generation_avg_time_per_query"] = (end_time_hyde - start_time_hyde) / len(questions) | |
# 5. Concurrently perform initial hybrid search to get candidates for ALL queries | |
start_time_search = time.perf_counter() | |
candidate_retrieval_tasks = [ | |
hybrid_search_manager.retrieve_candidates(q, hyde_doc) | |
for q, hyde_doc in zip(questions, all_hyde_docs) | |
] | |
all_candidates = await asyncio.gather(*candidate_retrieval_tasks) | |
end_time_search = time.perf_counter() | |
step_timings["candidate_retrieval_total_time"] = end_time_search - start_time_search | |
# 6. Concurrently rerank the candidates for ALL queries | |
start_time_rerank = time.perf_counter() | |
rerank_tasks = [ | |
hybrid_search_manager.rerank_results(q, candidates, TOP_K_CHUNKS) | |
for q, candidates in zip(questions, all_candidates) | |
] | |
reranked_results_and_times = await asyncio.gather(*rerank_tasks) | |
end_time_rerank = time.perf_counter() | |
step_timings["reranking_total_time"] = end_time_rerank - start_time_rerank | |
# Unpack reranked results and timings | |
all_retrieved_results = [item[0] for item in reranked_results_and_times] | |
all_rerank_times = [item[1] for item in reranked_results_and_times] | |
step_timings["reranking_avg_time_per_query"] = (end_time_rerank - start_time_rerank) / len(questions) | |
# 7. Concurrently generate final answers | |
start_time_generation = time.perf_counter() | |
generation_tasks = [] | |
for question, retrieved_results in zip(questions, all_retrieved_results): | |
if retrieved_results: | |
generation_tasks.append( | |
generate_answer_with_groq( | |
question, retrieved_results, GROQ_API_KEY | |
) | |
) | |
else: | |
no_info_future = asyncio.Future() | |
no_info_future.set_result( | |
"No relevant information found in the document to answer this question." | |
) | |
generation_tasks.append(no_info_future) | |
all_answer_texts = await asyncio.gather(*generation_tasks) | |
end_time_generation = time.perf_counter() | |
step_timings["generation_total_time"] = end_time_generation - start_time_generation | |
step_timings["generation_avg_time_per_query"] = (end_time_generation - start_time_generation) / len(questions) | |
end_time_total = time.perf_counter() | |
total_processing_time = end_time_total - start_time_total | |
step_timings["total_processing_time"] = total_processing_time | |
print("All questions processed.") | |
all_answers = [answer_text for answer_text in all_answer_texts] | |
return RunResponse( | |
answers=all_answers, | |
#step_timings=step_timings, | |
#hypothetical_documents=all_hyde_docs | |
) | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
print(f"An unhandled error occurred: {e}") | |
raise HTTPException( | |
status_code=500, detail=f"An internal server error occurred: {e}" | |
) | |
finally: | |
if local_markdown_path and os.path.exists(local_markdown_path): | |
os.unlink(local_markdown_path) | |
print(f"Cleaned up temporary markdown file: {local_markdown_path}") | |