capitolati-rag / dspy_wrapper.py
saashley's picture
Update dspy_wrapper.py
0152504 verified
from query_preprocessing import preprocess_query
import dspy
from typing import List, Dict
import os
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise RuntimeError("Missing OPENAI_API_KEY env var")
gpt_4o_mini = dspy.LM('openai/gpt-4o-mini', api_key=OPENAI_API_KEY)
dspy.configure(lm=gpt_4o_mini)
# == Building Blocks ==
class DSPyHybridRetriever(dspy.Module):
def __init__(self, retriever):
super().__init__()
self.retriever = retriever
def forward(self, query: str, municipality: str = "", top_k: int = 5):
results = self.retriever.rerank(query, top_k=top_k, municipality_filter=municipality)
return {"retrieved_chunks": results}
class RetrieveChunks(dspy.Signature):
"""Given a user query and optional municipality, retrieve relevant text chunks."""
query = dspy.InputField(desc="User's question")
municipality = dspy.InputField(desc="Optional municipality filter")
retrieved_chunks = dspy.OutputField(
desc=(
"List of retrieved chunks, each as a dict with keys: "
"`chunk`, `document_id`, `section`, `level`, `page`, "
"`dense_score`, `sparse_score`, `graph_score`, `final_score`"
),
type=List[Dict] # each item is a dict carrying all those fields
)
class AnswerWithEvidence(dspy.Signature):
"""Answer the query using reasoning and retrieved chunks as context."""
query = dspy.InputField(desc="Rewritten question")
retrieved_chunks = dspy.InputField(desc="Retrieved text chunks (List[dict])")
answer = dspy.OutputField(desc="Final answer")
rationale = dspy.OutputField(desc="Chain-of-thought reasoning")
# == RAG Pipeline ==
class RAGChain(dspy.Module):
def __init__(self, retriever, answerer):
super().__init__()
self.retriever = retriever
self.answerer = answerer
def forward(self, raw_query: str, municipality: str = "", feedback: str = ""):
pre = preprocess_query(raw_query, feedback)
rewritten = pre["rewritten_query"] or raw_query
extracted_muni = pre["municipality"] or ""
intent = pre["intent"]
muni = municipality if municipality.strip() else extracted_muni
retrieved = self.retriever(query=rewritten, municipality=muni)
chunks = retrieved["retrieved_chunks"]
# Answer + CoT using the rewritten query
answer_result = self.answerer(
query=rewritten,
retrieved_chunks=[c["chunk_text"] for c in chunks]
)
# Return everything for transparency & downstream use
return {
"original_query": raw_query,
"intent": intent,
"rewritten_query": rewritten,
"llm_municipality": extracted_muni,
"municipality": muni,
"retrieved_chunks": chunks,
"chain_of_thought": answer_result.rationale,
"final_answer": answer_result.answer,
}