Spaces:
Sleeping
Sleeping
Update dspy_wrapper.py
Browse files- dspy_wrapper.py +19 -11
dspy_wrapper.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import dspy
|
2 |
from typing import List, Dict
|
3 |
import os
|
@@ -8,7 +9,6 @@ if not OPENAI_API_KEY:
|
|
8 |
raise RuntimeError("Missing OPENAI_API_KEY env var")
|
9 |
|
10 |
gpt_4o_mini = dspy.LM('openai/gpt-4o-mini', api_key=OPENAI_API_KEY)
|
11 |
-
# using unimib credentials, switch to PeS if needed!
|
12 |
dspy.configure(lm=gpt_4o_mini)
|
13 |
|
14 |
|
@@ -19,7 +19,7 @@ class DSPyHybridRetriever(dspy.Module):
|
|
19 |
self.retriever = retriever
|
20 |
|
21 |
def forward(self, query: str, municipality: str = "", top_k: int = 5):
|
22 |
-
results = self.retriever.rerank(query, top_k=top_k, municipality_filter=municipality)
|
23 |
return {"retrieved_chunks": results}
|
24 |
|
25 |
class RetrieveChunks(dspy.Signature):
|
@@ -37,7 +37,7 @@ class RetrieveChunks(dspy.Signature):
|
|
37 |
|
38 |
class AnswerWithEvidence(dspy.Signature):
|
39 |
"""Answer the query using reasoning and retrieved chunks as context."""
|
40 |
-
query = dspy.InputField(desc="
|
41 |
retrieved_chunks = dspy.InputField(desc="Retrieved text chunks (List[dict])")
|
42 |
answer = dspy.OutputField(desc="Final answer")
|
43 |
rationale = dspy.OutputField(desc="Chain-of-thought reasoning")
|
@@ -50,21 +50,29 @@ class RAGChain(dspy.Module):
|
|
50 |
self.retriever = retriever
|
51 |
self.answerer = answerer
|
52 |
|
53 |
-
def forward(self,
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
chunks = retrieved["retrieved_chunks"]
|
57 |
|
58 |
-
#
|
59 |
answer_result = self.answerer(
|
60 |
-
query=
|
61 |
retrieved_chunks=[c["chunk_text"] for c in chunks]
|
62 |
)
|
63 |
|
64 |
-
#
|
65 |
return {
|
66 |
-
"
|
67 |
-
"
|
|
|
|
|
|
|
68 |
"retrieved_chunks": chunks,
|
69 |
"chain_of_thought": answer_result.rationale,
|
70 |
"final_answer": answer_result.answer,
|
|
|
1 |
+
from query_preprocessing import preprocess_query
|
2 |
import dspy
|
3 |
from typing import List, Dict
|
4 |
import os
|
|
|
9 |
raise RuntimeError("Missing OPENAI_API_KEY env var")
|
10 |
|
11 |
gpt_4o_mini = dspy.LM('openai/gpt-4o-mini', api_key=OPENAI_API_KEY)
|
|
|
12 |
dspy.configure(lm=gpt_4o_mini)
|
13 |
|
14 |
|
|
|
19 |
self.retriever = retriever
|
20 |
|
21 |
def forward(self, query: str, municipality: str = "", top_k: int = 5):
|
22 |
+
results = self.retriever.rerank(query, top_k=top_k, municipality_filter=municipality)
|
23 |
return {"retrieved_chunks": results}
|
24 |
|
25 |
class RetrieveChunks(dspy.Signature):
|
|
|
37 |
|
38 |
class AnswerWithEvidence(dspy.Signature):
|
39 |
"""Answer the query using reasoning and retrieved chunks as context."""
|
40 |
+
query = dspy.InputField(desc="Rewritten question")
|
41 |
retrieved_chunks = dspy.InputField(desc="Retrieved text chunks (List[dict])")
|
42 |
answer = dspy.OutputField(desc="Final answer")
|
43 |
rationale = dspy.OutputField(desc="Chain-of-thought reasoning")
|
|
|
50 |
self.retriever = retriever
|
51 |
self.answerer = answerer
|
52 |
|
53 |
+
def forward(self, raw_query: str, municipality: str = "", feedback: str = ""):
|
54 |
+
pre = preprocess_query(raw_query, feedback)
|
55 |
+
rewritten = pre["rewritten_query"] or raw_query
|
56 |
+
extracted_muni = pre["municipality"] or ""
|
57 |
+
intent = pre["intent"]
|
58 |
+
muni = municipality if municipality.strip() else extracted_muni
|
59 |
+
|
60 |
+
retrieved = self.retriever(query=rewritten, municipality=muni)
|
61 |
chunks = retrieved["retrieved_chunks"]
|
62 |
|
63 |
+
# Answer + CoT using the rewritten query
|
64 |
answer_result = self.answerer(
|
65 |
+
query=rewritten,
|
66 |
retrieved_chunks=[c["chunk_text"] for c in chunks]
|
67 |
)
|
68 |
|
69 |
+
# Return everything for transparency & downstream use
|
70 |
return {
|
71 |
+
"original_query": raw_query,
|
72 |
+
"intent": intent,
|
73 |
+
"rewritten_query": rewritten,
|
74 |
+
"llm_municipality": extracted_muni,
|
75 |
+
"municipality": muni,
|
76 |
"retrieved_chunks": chunks,
|
77 |
"chain_of_thought": answer_result.rationale,
|
78 |
"final_answer": answer_result.answer,
|