import os import json import tempfile import requests from fastapi import FastAPI, HTTPException, Depends, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel from typing import List, Dict, Union, Any, Optional from dotenv import load_dotenv import asyncio import httpx import time from urllib.parse import urlparse, unquote import uuid import re # Import LangChain Document and text splitter from langchain_core.documents import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from processing_utility import ( extract_schema_from_file, process_document, download_and_parse_document_using_llama_index, ) # Import the new classes and functions from rag_utils from rag_utils import ( process_markdown_with_recursive_chunking, generate_answer_with_groq, generate_hypothetical_document, HybridSearchManager, EmbeddingClient, CHUNK_SIZE, CHUNK_OVERLAP, TOP_K_CHUNKS, GROQ_MODEL_NAME, ) load_dotenv() # --- FastAPI App Initialization --- app = FastAPI( title="HackRX RAG API", description="API for Retrieval-Augmented Generation from PDF documents.", version="1.0.0", ) # --- Global instance for the HybridSearchManager --- hybrid_search_manager: Optional[HybridSearchManager] = None @app.on_event("startup") async def startup_event(): global hybrid_search_manager hybrid_search_manager = HybridSearchManager() #initialize_llama_extract_agent() print("Application startup complete. HybridSearchManager is ready.") # --- Groq API Key Setup --- GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "NOT_FOUND") if GROQ_API_KEY == "NOT_FOUND": print( "WARNING: GROQ_API_KEY is using a placeholder or hardcoded value. Please set GROQ_API_KEY environment variable for production." ) # --- Pydantic Models for Request and Response --- class RunRequest(BaseModel): documents: str questions: List[str] class Answer(BaseModel): answer: str class RunResponse(BaseModel): answers: List[str] #step_timings: Dict[str, float] #hypothetical_documents: List[str] @app.post("/hackrx/run", response_model=RunResponse) async def run_rag_pipeline( request: RunRequest ): """ Runs the RAG pipeline for a given PDF document (converted to Markdown internally) and a list of questions. """ pdf_url = request.documents questions = request.questions local_markdown_path = None step_timings = {} start_time_total = time.perf_counter() try: if hybrid_search_manager is None: raise HTTPException( status_code=500, detail="HybridSearchManager not initialized." ) # 1. Parsing: Download PDF and parse to Markdown start_time = time.perf_counter() markdown_content = await download_and_parse_document_using_llama_index(pdf_url) with tempfile.NamedTemporaryFile( mode="w", delete=False, encoding="utf-8", suffix=".md" ) as temp_md_file: temp_md_file.write(markdown_content) local_markdown_path = temp_md_file.name end_time = time.perf_counter() step_timings["parsing_to_markdown"] = end_time - start_time print( f"Parsing to Markdown took {step_timings['parsing_to_markdown']:.2f} seconds." ) # 2. Chunk Generation: Process Markdown into chunks start_time = time.perf_counter() processed_documents = process_markdown_with_recursive_chunking( local_markdown_path, CHUNK_SIZE, CHUNK_OVERLAP, ) if not processed_documents: raise HTTPException( status_code=500, detail="Failed to process document into chunks." ) end_time = time.perf_counter() step_timings["chunk_generation"] = end_time - start_time print( f"Chunk Generation took {step_timings['chunk_generation']:.2f} seconds." ) # 3. Model Initialization and Embeddings Pre-computation start_time = time.perf_counter() await hybrid_search_manager.initialize_models(processed_documents) end_time = time.perf_counter() step_timings["model_initialization"] = end_time - start_time print( f"Model initialization took {step_timings['model_initialization']:.2f} seconds." ) # --- NEW CONCURRENT WORKFLOW --- # 4. Concurrently generate all hypothetical documents start_time_hyde = time.perf_counter() hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in questions] all_hyde_docs = await asyncio.gather(*hyde_tasks) end_time_hyde = time.perf_counter() step_timings["hyde_generation_total_time"] = end_time_hyde - start_time_hyde step_timings["hyde_generation_avg_time_per_query"] = (end_time_hyde - start_time_hyde) / len(questions) # 5. Concurrently perform initial hybrid search to get candidates for ALL queries start_time_search = time.perf_counter() candidate_retrieval_tasks = [ hybrid_search_manager.retrieve_candidates(q, hyde_doc) for q, hyde_doc in zip(questions, all_hyde_docs) ] all_candidates = await asyncio.gather(*candidate_retrieval_tasks) end_time_search = time.perf_counter() step_timings["candidate_retrieval_total_time"] = end_time_search - start_time_search # 6. Concurrently rerank the candidates for ALL queries start_time_rerank = time.perf_counter() rerank_tasks = [ hybrid_search_manager.rerank_results(q, candidates, TOP_K_CHUNKS) for q, candidates in zip(questions, all_candidates) ] reranked_results_and_times = await asyncio.gather(*rerank_tasks) end_time_rerank = time.perf_counter() step_timings["reranking_total_time"] = end_time_rerank - start_time_rerank # Unpack reranked results and timings all_retrieved_results = [item[0] for item in reranked_results_and_times] all_rerank_times = [item[1] for item in reranked_results_and_times] step_timings["reranking_avg_time_per_query"] = (end_time_rerank - start_time_rerank) / len(questions) # 7. Concurrently generate final answers start_time_generation = time.perf_counter() generation_tasks = [] for question, retrieved_results in zip(questions, all_retrieved_results): if retrieved_results: generation_tasks.append( generate_answer_with_groq( question, retrieved_results, GROQ_API_KEY ) ) else: no_info_future = asyncio.Future() no_info_future.set_result( "No relevant information found in the document to answer this question." ) generation_tasks.append(no_info_future) all_answer_texts = await asyncio.gather(*generation_tasks) end_time_generation = time.perf_counter() step_timings["generation_total_time"] = end_time_generation - start_time_generation step_timings["generation_avg_time_per_query"] = (end_time_generation - start_time_generation) / len(questions) end_time_total = time.perf_counter() total_processing_time = end_time_total - start_time_total step_timings["total_processing_time"] = total_processing_time print("All questions processed.") all_answers = [answer_text for answer_text in all_answer_texts] return RunResponse( answers=all_answers, #step_timings=step_timings, #hypothetical_documents=all_hyde_docs ) except HTTPException as e: raise e except Exception as e: print(f"An unhandled error occurred: {e}") raise HTTPException( status_code=500, detail=f"An internal server error occurred: {e}" ) finally: if local_markdown_path and os.path.exists(local_markdown_path): os.unlink(local_markdown_path) print(f"Cleaned up temporary markdown file: {local_markdown_path}")