Spaces:
Sleeping
Sleeping
# File: main.py | |
# (Modified to load embedding model at startup and await async pipeline run) | |
import os | |
import tempfile | |
import asyncio | |
import time | |
from typing import List, Dict, Any, Union | |
from urllib.parse import urlparse, unquote | |
from pathlib import Path | |
import httpx | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel, HttpUrl | |
from groq import AsyncGroq | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
import torch # Import torch to check for CUDA availability | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Import the Pipeline class from the previous file | |
from pipeline import Pipeline | |
# FastAPI application setup | |
app = FastAPI( | |
title="Llama-Index RAG with Groq", | |
description="An API to process a PDF from a URL and answer a list of questions using a Llama-Index RAG pipeline.", | |
) | |
# --- Pydantic Models for API Request and Response --- | |
class RunRequest(BaseModel): | |
documents: HttpUrl | |
questions: List[str] | |
class Answer(BaseModel): | |
question: str | |
answer: str | |
class RunResponse(BaseModel): | |
answers: List[Answer] | |
processing_time: float | |
step_timings: Dict[str, float] | |
# --- Global Configurations --- | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "gsk_...") | |
GROQ_MODEL_NAME = "llama3-70b-8192" | |
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
# Global variable to hold the initialized embedding model | |
embed_model_instance: Union[HuggingFaceEmbedding, None] = None | |
if GROQ_API_KEY == "gsk_...": | |
print("WARNING: GROQ_API_KEY is not set. Please set it in your environment or main.py.") | |
async def startup_event(): | |
""" | |
Loads the embedding model once when the application starts. | |
This prevents re-loading it on every API call. | |
""" | |
global embed_model_instance | |
print(f"Loading embedding model '{EMBEDDING_MODEL_NAME}' at startup...") | |
# Check for GPU availability and use it if possible | |
# Assuming 16GB VRAM, a standard device check is sufficient | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
embed_model_instance = await asyncio.to_thread(HuggingFaceEmbedding, model_name=EMBEDDING_MODEL_NAME, device=device) | |
print("Embedding model loaded successfully.") | |
# --- Async Groq Generation Function --- | |
async def generate_answer_with_groq(query: str, retrieved_results: List[dict], groq_api_key: str) -> str: | |
""" | |
Generates an answer using the Groq API based on the query and retrieved chunks' content. | |
""" | |
if not groq_api_key: | |
return "Error: Groq API key is not set. Cannot generate answer." | |
client = AsyncGroq(api_key=groq_api_key) | |
context_parts = [] | |
for i, res in enumerate(retrieved_results): | |
content = res.get("content", "") | |
metadata = res.get("document_metadata", {}) | |
section_heading = metadata.get("section_heading", metadata.get("file_name", "N/A")) | |
context_parts.append( | |
f"--- Context Chunk {i+1} ---\n" | |
f"Document Part: {section_heading}\n" | |
f"Content: {content}\n" | |
f"-------------------------" | |
) | |
context = "\n\n".join(context_parts) | |
prompt = ( | |
f"You are a specialized document analyzer assistant. Your task is to answer the user's question " | |
f"solely based on the provided context. If the answer cannot be found in the provided context, " | |
f"clearly state that you do not have enough information.\n\n" | |
f"Context:\n{context}\n\n" | |
f"Question: {query}\n\n" | |
f"Answer:" | |
) | |
try: | |
chat_completion = await client.chat.completions.create( | |
messages=[ | |
{ | |
"role": "user", | |
"content": prompt, | |
} | |
], | |
model=GROQ_MODEL_NAME, | |
temperature=0.7, | |
max_tokens=500, | |
) | |
answer = chat_completion.choices[0].message.content | |
return answer | |
except Exception as e: | |
print(f"An error occurred during Groq API call: {e}") | |
return "Could not generate an answer due to an API error." | |
# --- FastAPI Endpoint --- | |
async def health_check(): | |
return {"status": "ok"} | |
async def run_rag_pipeline(request: RunRequest): | |
""" | |
Runs the RAG pipeline for a given PDF document URL and a list of questions. | |
The PDF is downloaded, processed, and then the questions are answered. | |
""" | |
pdf_url = request.documents | |
questions = request.questions | |
local_pdf_path = None | |
step_timings = {} | |
start_time_total = time.perf_counter() | |
if not embed_model_instance: | |
raise HTTPException( | |
status_code=500, | |
detail="Embedding model not loaded. Application startup failed." | |
) | |
if not GROQ_API_KEY or GROQ_API_KEY == "gsk_...": | |
raise HTTPException( | |
status_code=500, | |
detail="Groq API key is not configured. Please set the GROQ_API_KEY environment variable." | |
) | |
try: | |
# 1. Download PDF | |
start_time = time.perf_counter() | |
async with httpx.AsyncClient() as client: | |
try: | |
response = await client.get(str(pdf_url), timeout=30.0, follow_redirects=True) | |
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) | |
doc_bytes = response.content | |
print("Download successful.") | |
except httpx.HTTPStatusError as e: | |
raise HTTPException(status_code=e.response.status_code, detail=f"HTTP error downloading PDF: {e.response.status_code} - {e.response.text}") | |
except httpx.RequestError as e: | |
raise HTTPException(status_code=400, detail=f"Network error downloading PDF: {e}") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"An unexpected error occurred during download: {e}") | |
# Determine a temporary local filename | |
parsed_path = urlparse(str(pdf_url)).path | |
filename = unquote(os.path.basename(parsed_path)) | |
if not filename or not filename.lower().endswith(".pdf"): | |
# If the URL doesn't provide a valid PDF filename, create a generic one. | |
filename = "downloaded_document.pdf" | |
# Use tempfile to create a secure temporary file | |
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_pdf_file: | |
temp_pdf_file.write(doc_bytes) | |
local_pdf_path = temp_pdf_file.name | |
end_time = time.perf_counter() | |
step_timings["download_pdf"] = end_time - start_time | |
print(f"PDF download took {step_timings['download_pdf']:.2f} seconds.") | |
# 2. Initialize and Run the Pipeline (Parsing, Node Creation, Embeddings) | |
start_time = time.perf_counter() | |
# The Pipeline's run() method is now async, so await it directly | |
pipeline = Pipeline(groq_api_key=GROQ_API_KEY, pdf_path=local_pdf_path, embed_model=embed_model_instance) | |
await pipeline.run() # Changed from asyncio.to_thread(pipeline.run) | |
end_time = time.perf_counter() | |
step_timings["pipeline_setup"] = end_time - start_time | |
print(f"Pipeline setup took {step_timings['pipeline_setup']:.2f} seconds.") | |
# 3. Concurrent Retrieval Phase | |
start_time_retrieval = time.perf_counter() | |
print(f"\nStarting concurrent retrieval for {len(questions)} questions...") | |
retrieval_tasks = [asyncio.to_thread(pipeline.retrieve_nodes, q) for q in questions] | |
all_retrieved_results = await asyncio.gather(*retrieval_tasks) | |
end_time_retrieval = time.perf_counter() | |
step_timings["retrieval"] = end_time_retrieval - start_time_retrieval | |
print(f"Retrieval phase completed in {step_timings['retrieval']:.2f} seconds.") | |
# 4. Concurrent Generation Phase | |
start_time_generation = time.perf_counter() | |
print(f"\nStarting concurrent answer generation for {len(questions)} questions...") | |
generation_tasks = [ | |
generate_answer_with_groq(q, retrieved_results, GROQ_API_KEY) | |
for q, retrieved_results in zip(questions, all_retrieved_results) | |
] | |
all_answer_texts = await asyncio.gather(*generation_tasks) | |
end_time_generation = time.perf_counter() | |
step_timings["generation"] = end_time_generation - start_time_generation | |
print(f"Generation phase completed in {step_timings['generation']:.2f} seconds.") | |
end_time_total = time.perf_counter() | |
total_processing_time = end_time_total - start_time_total | |
answers = [Answer(question=q, answer=a) for q, a in zip(questions, all_answer_texts)] | |
return RunResponse( | |
answers=answers, | |
processing_time=total_processing_time, | |
step_timings=step_timings, | |
) | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
print(f"An unhandled error occurred: {e}") | |
raise HTTPException( | |
status_code=500, detail=f"An internal server error occurred: {e}" | |
) | |
finally: | |
if local_pdf_path and os.path.exists(local_pdf_path): | |
os.unlink(local_pdf_path) | |
print(f"Cleaned up temporary PDF file: {local_pdf_path}") |