Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- chunking.py +68 -0
- document_processor.py +88 -0
- embedding.py +40 -0
- generation.py +57 -0
- main.py +149 -0
- retrieval.py +139 -0
chunking.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: chunking.py
|
2 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
+
from langchain_core.documents import Document
|
4 |
+
from typing import List
|
5 |
+
from unstructured.partition.md import partition_md
|
6 |
+
from unstructured.documents.elements import Header, Footer, PageBreak, Table, NarrativeText
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
# --- Configuration ---
|
11 |
+
CHUNK_SIZE = 1000
|
12 |
+
CHUNK_OVERLAP = 200
|
13 |
+
|
14 |
+
def process_and_chunk(raw_text: str) -> List[Document]:
|
15 |
+
"""
|
16 |
+
Partitions raw text from a document processor using 'unstructured',
|
17 |
+
correctly interpreting it as markdown to preserve table structures,
|
18 |
+
and then chunks the remaining text content.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
raw_text: The raw string content of the document (expected to be markdown).
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
A list of Document objects, including structured tables and chunked text.
|
25 |
+
"""
|
26 |
+
if not raw_text:
|
27 |
+
print("Warning: Input text for chunking is empty.")
|
28 |
+
return []
|
29 |
+
|
30 |
+
print(f"Processing raw text of length {len(raw_text)} with 'unstructured' markdown parser.")
|
31 |
+
|
32 |
+
# --- FIX: Change content_type to "text/markdown" ---
|
33 |
+
# This tells unstructured to use its specialized markdown parser, which
|
34 |
+
# correctly handles tables and other structures from your PyMuPDF output.
|
35 |
+
elements = partition_md(text=raw_text)
|
36 |
+
|
37 |
+
documents = []
|
38 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
39 |
+
chunk_size=CHUNK_SIZE,
|
40 |
+
chunk_overlap=CHUNK_OVERLAP,
|
41 |
+
length_function=len,
|
42 |
+
is_separator_regex=False,
|
43 |
+
)
|
44 |
+
|
45 |
+
for element in elements:
|
46 |
+
if isinstance(element, (Header, Footer, PageBreak)):
|
47 |
+
continue
|
48 |
+
# Process tables
|
49 |
+
if "unstructured.documents.elements.Table" in str(type(element)):
|
50 |
+
table_html = element.metadata.text_as_html
|
51 |
+
table_metadata = element.metadata.to_dict()
|
52 |
+
table_metadata['content_type'] = 'table'
|
53 |
+
documents.append(Document(page_content=table_html, metadata=table_metadata))
|
54 |
+
# Process and chunk narrative text
|
55 |
+
elif "unstructured.documents.elements.NarrativeText" in str(type(element)):
|
56 |
+
chunks = text_splitter.split_text(element.text)
|
57 |
+
for chunk in chunks:
|
58 |
+
chunk_metadata = element.metadata.to_dict()
|
59 |
+
chunk_metadata['content_type'] = 'text'
|
60 |
+
documents.append(Document(page_content=chunk, metadata=chunk_metadata))
|
61 |
+
# Handle other elements directly
|
62 |
+
else:
|
63 |
+
general_metadata = element.metadata.to_dict()
|
64 |
+
general_metadata['content_type'] = 'other'
|
65 |
+
documents.append(Document(page_content=element.text, metadata=general_metadata))
|
66 |
+
|
67 |
+
print(f"Created {len(documents)} documents (chunks and tables).")
|
68 |
+
return documents
|
document_processor.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: document_processing.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import httpx
|
6 |
+
from pathlib import Path
|
7 |
+
from urllib.parse import urlparse, unquote
|
8 |
+
from llama_index.readers.file import PyMuPDFReader
|
9 |
+
from llama_index.core import Document as LlamaDocument
|
10 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
11 |
+
from pydantic import HttpUrl
|
12 |
+
from typing import List
|
13 |
+
|
14 |
+
# Define the batch size for parallel processing
|
15 |
+
BATCH_SIZE = 25
|
16 |
+
|
17 |
+
def _process_page_batch(documents_batch: List[LlamaDocument]) -> str:
|
18 |
+
"""
|
19 |
+
Helper function to extract content from a batch of LlamaIndex Document objects
|
20 |
+
and join them into a single string.
|
21 |
+
"""
|
22 |
+
return "\n\n".join([d.get_content() for d in documents_batch])
|
23 |
+
|
24 |
+
async def ingest_and_parse_document(doc_url: HttpUrl) -> str:
|
25 |
+
"""
|
26 |
+
Asynchronously downloads a document, saves it locally, and parses it to
|
27 |
+
Markdown text using PyMuPDFReader with parallel processing.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
doc_url: The Pydantic-validated URL of the document to process.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
A single string containing the document's extracted text.
|
34 |
+
"""
|
35 |
+
print(f"Initiating download from: {doc_url}")
|
36 |
+
LOCAL_STORAGE_DIR = "data/"
|
37 |
+
os.makedirs(LOCAL_STORAGE_DIR, exist_ok=True)
|
38 |
+
|
39 |
+
try:
|
40 |
+
# Asynchronously download the document
|
41 |
+
async with httpx.AsyncClient() as client:
|
42 |
+
response = await client.get(str(doc_url), timeout=30.0, follow_redirects=True)
|
43 |
+
response.raise_for_status()
|
44 |
+
doc_bytes = response.content
|
45 |
+
print("Download successful.")
|
46 |
+
|
47 |
+
# Determine a valid local filename
|
48 |
+
parsed_path = urlparse(str(doc_url)).path
|
49 |
+
filename = unquote(os.path.basename(parsed_path)) or "downloaded_document.pdf"
|
50 |
+
local_file_path = Path(os.path.join(LOCAL_STORAGE_DIR, filename))
|
51 |
+
|
52 |
+
# Save the document locally
|
53 |
+
with open(local_file_path, "wb") as f:
|
54 |
+
f.write(doc_bytes)
|
55 |
+
print(f"Document saved locally at: {local_file_path}")
|
56 |
+
|
57 |
+
# Parse the document using LlamaIndex's PyMuPDFReader
|
58 |
+
print("Parsing document with PyMuPDFReader...")
|
59 |
+
loader = PyMuPDFReader()
|
60 |
+
docs_from_loader = loader.load_data(file_path=local_file_path)
|
61 |
+
|
62 |
+
# Parallelize the extraction of text from loaded pages
|
63 |
+
start_time = time.perf_counter()
|
64 |
+
all_extracted_texts = []
|
65 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as executor:
|
66 |
+
futures = [
|
67 |
+
executor.submit(_process_page_batch, docs_from_loader[i:i + BATCH_SIZE])
|
68 |
+
for i in range(0, len(docs_from_loader), BATCH_SIZE)
|
69 |
+
]
|
70 |
+
for future in as_completed(futures):
|
71 |
+
all_extracted_texts.append(future.result())
|
72 |
+
|
73 |
+
doc_text = "\n\n".join(all_extracted_texts)
|
74 |
+
elapsed_time = time.perf_counter() - start_time
|
75 |
+
print(f"Time taken for parallel text extraction: {elapsed_time:.4f} seconds.")
|
76 |
+
|
77 |
+
if not doc_text:
|
78 |
+
raise ValueError("Document parsing yielded no content.")
|
79 |
+
|
80 |
+
print(f"Parsing complete. Extracted {len(doc_text)} characters.")
|
81 |
+
return doc_text
|
82 |
+
|
83 |
+
except httpx.HTTPStatusError as e:
|
84 |
+
print(f"Error downloading document: {e}")
|
85 |
+
raise
|
86 |
+
except Exception as e:
|
87 |
+
print(f"An unexpected error occurred during document processing: {e}")
|
88 |
+
raise
|
embedding.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: embedding.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
# --- Configuration ---
|
8 |
+
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
|
9 |
+
|
10 |
+
class EmbeddingClient:
|
11 |
+
"""A client for generating text embeddings using a local sentence transformer model."""
|
12 |
+
|
13 |
+
def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
|
14 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
self.model = SentenceTransformer(model_name, device=self.device)
|
16 |
+
print(f"EmbeddingClient initialized with model '{model_name}' on device '{self.device}'.")
|
17 |
+
|
18 |
+
def create_embeddings(self, texts: List[str]) -> torch.Tensor:
|
19 |
+
"""
|
20 |
+
Generates embeddings for a list of text chunks.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
texts: A list of strings to be embedded.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
A torch.Tensor containing the generated embeddings.
|
27 |
+
"""
|
28 |
+
if not texts:
|
29 |
+
return torch.tensor([])
|
30 |
+
|
31 |
+
print(f"Generating embeddings for {len(texts)} text chunks on {self.device}...")
|
32 |
+
try:
|
33 |
+
embeddings = self.model.encode(
|
34 |
+
texts, convert_to_tensor=True, show_progress_bar=False
|
35 |
+
)
|
36 |
+
print("Embeddings generated successfully.")
|
37 |
+
return embeddings
|
38 |
+
except Exception as e:
|
39 |
+
print(f"An error occurred during embedding generation: {e}")
|
40 |
+
raise
|
generation.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: generation.py
|
2 |
+
from groq import AsyncGroq
|
3 |
+
from typing import List, Dict
|
4 |
+
|
5 |
+
# --- Configuration ---
|
6 |
+
GROQ_MODEL_NAME = "llama3-8b-8192"
|
7 |
+
|
8 |
+
async def generate_answer(query: str, context_chunks: List[Dict], groq_api_key: str) -> str:
|
9 |
+
"""
|
10 |
+
Generates a final answer using the Groq API based on the query and retrieved context.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
query: The user's original question.
|
14 |
+
context_chunks: A list of the most relevant, reranked document chunks.
|
15 |
+
groq_api_key: The API key for the Groq service.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
A string containing the generated answer.
|
19 |
+
"""
|
20 |
+
if not groq_api_key:
|
21 |
+
return "Error: Groq API key is not set."
|
22 |
+
if not context_chunks:
|
23 |
+
return "I do not have enough information to answer this question based on the provided document."
|
24 |
+
|
25 |
+
print("Generating final answer with Groq...")
|
26 |
+
client = AsyncGroq(api_key=groq_api_key)
|
27 |
+
|
28 |
+
# Format the context for the prompt
|
29 |
+
context_str = "\n\n---\n\n".join(
|
30 |
+
[f"Context Chunk:\n{chunk['content']}" for chunk in context_chunks]
|
31 |
+
)
|
32 |
+
|
33 |
+
prompt = (
|
34 |
+
"You are an expert Q&A system. Your task is to extract information with 100% accuracy from the provided text. Provide a brief and direct answer."
|
35 |
+
"Do not mention the context in your response. Answer *only* using the information from the provided document."
|
36 |
+
"Do not infer, add, or assume any information that is not explicitly written in the source text. If the answer is not in the document, state that the information is not available."
|
37 |
+
"When the question involves numbers, percentages, or monetary values, extract the exact figures from the text."
|
38 |
+
"Double-check that the value corresponds to the correct plan or condition mentioned in the question."
|
39 |
+
"\n\n"
|
40 |
+
f"CONTEXT:\n{context_str}\n\n"
|
41 |
+
f"QUESTION:\n{query}\n\n"
|
42 |
+
"ANSWER:"
|
43 |
+
)
|
44 |
+
|
45 |
+
try:
|
46 |
+
chat_completion = await client.chat.completions.create(
|
47 |
+
messages=[{"role": "user", "content": prompt}],
|
48 |
+
model=GROQ_MODEL_NAME,
|
49 |
+
temperature=0.2, # Lower temperature for more factual answers
|
50 |
+
max_tokens=500,
|
51 |
+
)
|
52 |
+
answer = chat_completion.choices[0].message.content
|
53 |
+
print("Answer generated successfully.")
|
54 |
+
return answer
|
55 |
+
except Exception as e:
|
56 |
+
print(f"An error occurred during Groq API call: {e}")
|
57 |
+
return "Could not generate an answer due to an API error."
|
main.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: main.py
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
import asyncio
|
5 |
+
from fastapi import FastAPI, HTTPException
|
6 |
+
from pydantic import BaseModel, HttpUrl
|
7 |
+
from typing import List, Dict, Any
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
|
10 |
+
# Import functions and classes from the new modular files
|
11 |
+
from document_processor import ingest_and_parse_document
|
12 |
+
from chunking import process_and_chunk
|
13 |
+
from embedding import EmbeddingClient
|
14 |
+
from retrieval import Retriever, generate_hypothetical_document
|
15 |
+
from generation import generate_answer
|
16 |
+
|
17 |
+
load_dotenv()
|
18 |
+
|
19 |
+
# --- FastAPI App Initialization ---
|
20 |
+
app = FastAPI(
|
21 |
+
title="Modular RAG API",
|
22 |
+
description="A modular API for Retrieval-Augmented Generation from documents.",
|
23 |
+
version="2.0.0",
|
24 |
+
)
|
25 |
+
|
26 |
+
# --- Global Clients and API Keys ---
|
27 |
+
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
|
28 |
+
embedding_client = EmbeddingClient()
|
29 |
+
retriever = Retriever(embedding_client=embedding_client)
|
30 |
+
|
31 |
+
|
32 |
+
# --- Pydantic Models ---
|
33 |
+
class RunRequest(BaseModel):
|
34 |
+
document_url: HttpUrl
|
35 |
+
questions: List[str]
|
36 |
+
|
37 |
+
class RunResponse(BaseModel):
|
38 |
+
answers: List[str]
|
39 |
+
|
40 |
+
class TestRequest(BaseModel):
|
41 |
+
document_url: HttpUrl
|
42 |
+
#Endpoints
|
43 |
+
|
44 |
+
# --- NEW: Test Endpoint for Parsing ---
|
45 |
+
@app.post("/test/parse", response_model=Dict[str, Any], tags=["Testing"])
|
46 |
+
async def test_parsing_endpoint(request: TestRequest):
|
47 |
+
"""
|
48 |
+
Tests the document ingestion and parsing phase.
|
49 |
+
Returns the full markdown content and the time taken.
|
50 |
+
"""
|
51 |
+
print("--- Running Parsing Test ---")
|
52 |
+
start_time = time.perf_counter()
|
53 |
+
|
54 |
+
try:
|
55 |
+
markdown_content = await ingest_and_parse_document(request.document_url)
|
56 |
+
|
57 |
+
end_time = time.perf_counter()
|
58 |
+
duration = end_time - start_time
|
59 |
+
print(f"--- Parsing took {duration:.2f} seconds ---")
|
60 |
+
|
61 |
+
return {
|
62 |
+
"parsing_time_seconds": duration,
|
63 |
+
"character_count": len(markdown_content),
|
64 |
+
"content": markdown_content
|
65 |
+
}
|
66 |
+
except Exception as e:
|
67 |
+
raise HTTPException(status_code=500, detail=f"An error occurred during parsing: {str(e)}")
|
68 |
+
|
69 |
+
@app.post("/hackrx/run", response_model=RunResponse)
|
70 |
+
async def run_rag_pipeline(request: RunRequest):
|
71 |
+
"""
|
72 |
+
Runs the full RAG pipeline for a given document URL and a list of questions.
|
73 |
+
"""
|
74 |
+
try:
|
75 |
+
# --- STAGE 1 & 2: DOCUMENT INGESTION AND CHUNKING ---
|
76 |
+
print("--- Kicking off RAG Pipeline ---")
|
77 |
+
markdown_content = await ingest_and_parse_document(request.document_url)
|
78 |
+
documents = process_and_chunk(markdown_content)
|
79 |
+
|
80 |
+
if not documents:
|
81 |
+
raise HTTPException(status_code=400, detail="Document could not be processed into chunks.")
|
82 |
+
|
83 |
+
# --- STAGE 3: INDEXING (Embedding + BM25) ---
|
84 |
+
# This step builds the search index for the current document.
|
85 |
+
retriever.index(documents)
|
86 |
+
|
87 |
+
# --- CONCURRENT WORKFLOW FOR ALL QUESTIONS ---
|
88 |
+
|
89 |
+
# Step A: Concurrently generate hypothetical documents for all questions
|
90 |
+
hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in request.questions]
|
91 |
+
all_hyde_docs = await asyncio.gather(*hyde_tasks)
|
92 |
+
|
93 |
+
# Step B: Concurrently retrieve relevant chunks for all questions
|
94 |
+
retrieval_tasks = [
|
95 |
+
retriever.retrieve(q, hyde_doc)
|
96 |
+
for q, hyde_doc in zip(request.questions, all_hyde_docs)
|
97 |
+
]
|
98 |
+
all_retrieved_chunks = await asyncio.gather(*retrieval_tasks)
|
99 |
+
|
100 |
+
# Step C: Concurrently generate final answers for all questions
|
101 |
+
answer_tasks = [
|
102 |
+
generate_answer(q, chunks, GROQ_API_KEY)
|
103 |
+
for q, chunks in zip(request.questions, all_retrieved_chunks)
|
104 |
+
]
|
105 |
+
final_answers = await asyncio.gather(*answer_tasks)
|
106 |
+
|
107 |
+
print("--- RAG Pipeline Completed Successfully ---")
|
108 |
+
return RunResponse(answers=final_answers)
|
109 |
+
|
110 |
+
except Exception as e:
|
111 |
+
print(f"An unhandled error occurred in the pipeline: {e}")
|
112 |
+
# Re-raising as a 500 error for the client
|
113 |
+
raise HTTPException(
|
114 |
+
status_code=500, detail=f"An internal server error occurred: {str(e)}"
|
115 |
+
)
|
116 |
+
|
117 |
+
@app.post("/test/chunk", response_model=Dict[str, Any], tags=["Testing"])
|
118 |
+
async def test_chunking_endpoint(request: TestRequest):
|
119 |
+
"""
|
120 |
+
Tests both the parsing and chunking phases together.
|
121 |
+
Returns the final list of chunks and the total time taken.
|
122 |
+
"""
|
123 |
+
print("--- Running Parsing and Chunking Test ---")
|
124 |
+
start_time = time.perf_counter()
|
125 |
+
|
126 |
+
try:
|
127 |
+
# Step 1: Parse the document
|
128 |
+
markdown_content = await ingest_and_parse_document(request.document_url)
|
129 |
+
|
130 |
+
# Step 2: Chunk the parsed content
|
131 |
+
documents = process_and_chunk(markdown_content)
|
132 |
+
|
133 |
+
end_time = time.perf_counter()
|
134 |
+
duration = end_time - start_time
|
135 |
+
print(f"--- Parsing and Chunking took {duration:.2f} seconds ---")
|
136 |
+
|
137 |
+
# Convert Document objects to a JSON-serializable list
|
138 |
+
chunk_results = [
|
139 |
+
{"page_content": doc.page_content, "metadata": doc.metadata}
|
140 |
+
for doc in documents
|
141 |
+
]
|
142 |
+
|
143 |
+
return {
|
144 |
+
"total_time_seconds": duration,
|
145 |
+
"chunk_count": len(chunk_results),
|
146 |
+
"chunks": chunk_results
|
147 |
+
}
|
148 |
+
except Exception as e:
|
149 |
+
raise HTTPException(status_code=500, detail=f"An error occurred during chunking: {str(e)}")
|
retrieval.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: retrieval.py
|
2 |
+
|
3 |
+
import time
|
4 |
+
import asyncio
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from groq import AsyncGroq
|
8 |
+
from rank_bm25 import BM25Okapi
|
9 |
+
from sentence_transformers import CrossEncoder
|
10 |
+
from sklearn.preprocessing import MinMaxScaler
|
11 |
+
from torch.nn.functional import cosine_similarity
|
12 |
+
from typing import List, Dict, Tuple
|
13 |
+
|
14 |
+
from embedding import EmbeddingClient
|
15 |
+
from langchain_core.documents import Document
|
16 |
+
|
17 |
+
# --- Configuration ---
|
18 |
+
HYDE_MODEL = "llama3-8b-8192"
|
19 |
+
RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2'
|
20 |
+
INITIAL_K_CANDIDATES = 20
|
21 |
+
TOP_K_CHUNKS = 10
|
22 |
+
|
23 |
+
async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
|
24 |
+
"""Generates a hypothetical document (HyDE) to enhance search."""
|
25 |
+
if not groq_api_key:
|
26 |
+
print("Groq API key not set. Skipping HyDE generation.")
|
27 |
+
return ""
|
28 |
+
|
29 |
+
print(f"Starting HyDE generation for query: '{query}'...")
|
30 |
+
client = AsyncGroq(api_key=groq_api_key)
|
31 |
+
prompt = (
|
32 |
+
f"Write a brief, formal passage that answers the following question. "
|
33 |
+
f"Use specific terminology as if it were from a larger document. "
|
34 |
+
f"Do not include the question or conversational text.\n\n"
|
35 |
+
f"Question: {query}\n\n"
|
36 |
+
f"Hypothetical Passage:"
|
37 |
+
)
|
38 |
+
|
39 |
+
try:
|
40 |
+
chat_completion = await client.chat.completions.create(
|
41 |
+
messages=[{"role": "user", "content": prompt}],
|
42 |
+
model=HYDE_MODEL,
|
43 |
+
temperature=0.7,
|
44 |
+
max_tokens=500,
|
45 |
+
)
|
46 |
+
return chat_completion.choices[0].message.content
|
47 |
+
except Exception as e:
|
48 |
+
print(f"An error occurred during HyDE generation: {e}")
|
49 |
+
return ""
|
50 |
+
|
51 |
+
class Retriever:
|
52 |
+
"""Manages hybrid search, combining BM25, dense search, and a reranker."""
|
53 |
+
|
54 |
+
def __init__(self, embedding_client: EmbeddingClient):
|
55 |
+
self.embedding_client = embedding_client
|
56 |
+
self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device)
|
57 |
+
self.bm25 = None
|
58 |
+
self.document_chunks = []
|
59 |
+
self.chunk_embeddings = None
|
60 |
+
print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.")
|
61 |
+
|
62 |
+
def index(self, documents: List[Document]):
|
63 |
+
"""Builds the search index from document chunks."""
|
64 |
+
self.document_chunks = documents
|
65 |
+
corpus = [doc.page_content for doc in documents]
|
66 |
+
if not corpus:
|
67 |
+
print("No documents to index.")
|
68 |
+
return
|
69 |
+
|
70 |
+
print("Indexing documents for retrieval...")
|
71 |
+
# 1. Initialize BM25 model
|
72 |
+
tokenized_corpus = [doc.split(" ") for doc in corpus]
|
73 |
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
74 |
+
# 2. Compute and store dense embeddings
|
75 |
+
self.chunk_embeddings = self.embedding_client.create_embeddings(corpus)
|
76 |
+
print("Indexing complete.")
|
77 |
+
|
78 |
+
def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]:
|
79 |
+
"""Performs the initial hybrid search to get candidate chunks."""
|
80 |
+
if self.bm25 is None or self.chunk_embeddings is None:
|
81 |
+
raise ValueError("Retriever has not been indexed. Call index() first.")
|
82 |
+
|
83 |
+
# Enhance query with hypothetical document
|
84 |
+
enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
|
85 |
+
|
86 |
+
# BM25 (keyword) search
|
87 |
+
tokenized_query = query.split(" ")
|
88 |
+
bm25_scores = self.bm25.get_scores(tokenized_query)
|
89 |
+
|
90 |
+
# Dense (semantic) search
|
91 |
+
query_embedding = self.embedding_client.create_embeddings([enhanced_query])
|
92 |
+
dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten()
|
93 |
+
|
94 |
+
# Normalize and combine scores
|
95 |
+
scaler = MinMaxScaler()
|
96 |
+
norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
|
97 |
+
norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
|
98 |
+
combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense
|
99 |
+
|
100 |
+
# Get top initial candidates
|
101 |
+
top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES]
|
102 |
+
return [(idx, combined_scores[idx]) for idx in top_indices]
|
103 |
+
|
104 |
+
async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]:
|
105 |
+
"""Reranks the candidate chunks using a CrossEncoder model."""
|
106 |
+
if not candidates:
|
107 |
+
return []
|
108 |
+
|
109 |
+
print(f"Reranking {len(candidates)} candidates...")
|
110 |
+
rerank_input = [[query, chunk["content"]] for chunk in candidates]
|
111 |
+
|
112 |
+
# Run synchronous prediction in a separate thread
|
113 |
+
rerank_scores = await asyncio.to_thread(
|
114 |
+
self.reranker.predict, rerank_input, show_progress_bar=False
|
115 |
+
)
|
116 |
+
|
117 |
+
# Combine candidates with their new scores and sort
|
118 |
+
for candidate, score in zip(candidates, rerank_scores):
|
119 |
+
candidate['rerank_score'] = score
|
120 |
+
|
121 |
+
candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
|
122 |
+
return candidates[:TOP_K_CHUNKS]
|
123 |
+
|
124 |
+
async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]:
|
125 |
+
"""Executes the full retrieval pipeline: hybrid search followed by reranking."""
|
126 |
+
print(f"Retrieving documents for query: '{query}'")
|
127 |
+
# 1. Get initial candidates from hybrid search
|
128 |
+
initial_candidates_info = self._hybrid_search(query, hyde_doc)
|
129 |
+
|
130 |
+
retrieved_candidates = [{
|
131 |
+
"content": self.document_chunks[idx].page_content,
|
132 |
+
"metadata": self.document_chunks[idx].metadata,
|
133 |
+
"initial_score": score
|
134 |
+
} for idx, score in initial_candidates_info]
|
135 |
+
|
136 |
+
# 2. Rerank the candidates to get the final list
|
137 |
+
final_chunks = await self._rerank(query, retrieved_candidates)
|
138 |
+
print(f"Retrieved and reranked {len(final_chunks)} final chunks.")
|
139 |
+
return final_chunks
|