File size: 2,969 Bytes
0152504
2fc692a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0152504
2fc692a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0152504
2fc692a
 
 
 
 
 
 
 
 
 
 
 
0152504
 
 
 
 
 
 
 
2fc692a
 
0152504
2fc692a
0152504
2fc692a
 
 
0152504
2fc692a
0152504
 
 
 
 
2fc692a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from query_preprocessing import preprocess_query
import dspy 
from typing import List, Dict
import os


OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    raise RuntimeError("Missing OPENAI_API_KEY env var")

gpt_4o_mini = dspy.LM('openai/gpt-4o-mini', api_key=OPENAI_API_KEY)
dspy.configure(lm=gpt_4o_mini)


# == Building Blocks ==
class DSPyHybridRetriever(dspy.Module):
    def __init__(self, retriever):
        super().__init__()
        self.retriever = retriever

    def forward(self, query: str, municipality: str = "", top_k: int = 5):
        results = self.retriever.rerank(query, top_k=top_k, municipality_filter=municipality)
        return {"retrieved_chunks": results}
    
class RetrieveChunks(dspy.Signature):
    """Given a user query and optional municipality, retrieve relevant text chunks."""
    query = dspy.InputField(desc="User's question")
    municipality = dspy.InputField(desc="Optional municipality filter")
    retrieved_chunks = dspy.OutputField(
        desc=(
            "List of retrieved chunks, each as a dict with keys: "
            "`chunk`, `document_id`, `section`, `level`, `page`, "
            "`dense_score`, `sparse_score`, `graph_score`, `final_score`"
        ),
        type=List[Dict]  # each item is a dict carrying all those fields
    )

class AnswerWithEvidence(dspy.Signature):
    """Answer the query using reasoning and retrieved chunks as context."""
    query = dspy.InputField(desc="Rewritten question")
    retrieved_chunks = dspy.InputField(desc="Retrieved text chunks (List[dict])")
    answer = dspy.OutputField(desc="Final answer")
    rationale = dspy.OutputField(desc="Chain-of-thought reasoning")


# == RAG Pipeline ==
class RAGChain(dspy.Module):
    def __init__(self, retriever, answerer):
        super().__init__()
        self.retriever = retriever
        self.answerer = answerer

    def forward(self, raw_query: str, municipality: str = "", feedback: str = ""):
        pre = preprocess_query(raw_query, feedback)
        rewritten = pre["rewritten_query"] or raw_query
        extracted_muni = pre["municipality"] or ""
        intent = pre["intent"]
        muni = municipality if municipality.strip() else extracted_muni

        retrieved = self.retriever(query=rewritten, municipality=muni)
        chunks = retrieved["retrieved_chunks"]

        # Answer + CoT using the rewritten query
        answer_result = self.answerer(
            query=rewritten,
            retrieved_chunks=[c["chunk_text"] for c in chunks]
        )

        # Return everything for transparency & downstream use
        return {
            "original_query": raw_query,
            "intent": intent,
            "rewritten_query": rewritten,
            "llm_municipality": extracted_muni,
            "municipality": muni,
            "retrieved_chunks": chunks,
            "chain_of_thought": answer_result.rationale,
            "final_answer": answer_result.answer,
        }