Spaces:
Sleeping
Sleeping
Update dspy_wrapper.py
Browse files- dspy_wrapper.py +19 -11
dspy_wrapper.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import dspy
|
| 2 |
from typing import List, Dict
|
| 3 |
import os
|
|
@@ -8,7 +9,6 @@ if not OPENAI_API_KEY:
|
|
| 8 |
raise RuntimeError("Missing OPENAI_API_KEY env var")
|
| 9 |
|
| 10 |
gpt_4o_mini = dspy.LM('openai/gpt-4o-mini', api_key=OPENAI_API_KEY)
|
| 11 |
-
# using unimib credentials, switch to PeS if needed!
|
| 12 |
dspy.configure(lm=gpt_4o_mini)
|
| 13 |
|
| 14 |
|
|
@@ -19,7 +19,7 @@ class DSPyHybridRetriever(dspy.Module):
|
|
| 19 |
self.retriever = retriever
|
| 20 |
|
| 21 |
def forward(self, query: str, municipality: str = "", top_k: int = 5):
|
| 22 |
-
results = self.retriever.rerank(query, top_k=top_k, municipality_filter=municipality)
|
| 23 |
return {"retrieved_chunks": results}
|
| 24 |
|
| 25 |
class RetrieveChunks(dspy.Signature):
|
|
@@ -37,7 +37,7 @@ class RetrieveChunks(dspy.Signature):
|
|
| 37 |
|
| 38 |
class AnswerWithEvidence(dspy.Signature):
|
| 39 |
"""Answer the query using reasoning and retrieved chunks as context."""
|
| 40 |
-
query = dspy.InputField(desc="
|
| 41 |
retrieved_chunks = dspy.InputField(desc="Retrieved text chunks (List[dict])")
|
| 42 |
answer = dspy.OutputField(desc="Final answer")
|
| 43 |
rationale = dspy.OutputField(desc="Chain-of-thought reasoning")
|
|
@@ -50,21 +50,29 @@ class RAGChain(dspy.Module):
|
|
| 50 |
self.retriever = retriever
|
| 51 |
self.answerer = answerer
|
| 52 |
|
| 53 |
-
def forward(self,
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
chunks = retrieved["retrieved_chunks"]
|
| 57 |
|
| 58 |
-
#
|
| 59 |
answer_result = self.answerer(
|
| 60 |
-
query=
|
| 61 |
retrieved_chunks=[c["chunk_text"] for c in chunks]
|
| 62 |
)
|
| 63 |
|
| 64 |
-
#
|
| 65 |
return {
|
| 66 |
-
"
|
| 67 |
-
"
|
|
|
|
|
|
|
|
|
|
| 68 |
"retrieved_chunks": chunks,
|
| 69 |
"chain_of_thought": answer_result.rationale,
|
| 70 |
"final_answer": answer_result.answer,
|
|
|
|
| 1 |
+
from query_preprocessing import preprocess_query
|
| 2 |
import dspy
|
| 3 |
from typing import List, Dict
|
| 4 |
import os
|
|
|
|
| 9 |
raise RuntimeError("Missing OPENAI_API_KEY env var")
|
| 10 |
|
| 11 |
gpt_4o_mini = dspy.LM('openai/gpt-4o-mini', api_key=OPENAI_API_KEY)
|
|
|
|
| 12 |
dspy.configure(lm=gpt_4o_mini)
|
| 13 |
|
| 14 |
|
|
|
|
| 19 |
self.retriever = retriever
|
| 20 |
|
| 21 |
def forward(self, query: str, municipality: str = "", top_k: int = 5):
|
| 22 |
+
results = self.retriever.rerank(query, top_k=top_k, municipality_filter=municipality)
|
| 23 |
return {"retrieved_chunks": results}
|
| 24 |
|
| 25 |
class RetrieveChunks(dspy.Signature):
|
|
|
|
| 37 |
|
| 38 |
class AnswerWithEvidence(dspy.Signature):
|
| 39 |
"""Answer the query using reasoning and retrieved chunks as context."""
|
| 40 |
+
query = dspy.InputField(desc="Rewritten question")
|
| 41 |
retrieved_chunks = dspy.InputField(desc="Retrieved text chunks (List[dict])")
|
| 42 |
answer = dspy.OutputField(desc="Final answer")
|
| 43 |
rationale = dspy.OutputField(desc="Chain-of-thought reasoning")
|
|
|
|
| 50 |
self.retriever = retriever
|
| 51 |
self.answerer = answerer
|
| 52 |
|
| 53 |
+
def forward(self, raw_query: str, municipality: str = "", feedback: str = ""):
|
| 54 |
+
pre = preprocess_query(raw_query, feedback)
|
| 55 |
+
rewritten = pre["rewritten_query"] or raw_query
|
| 56 |
+
extracted_muni = pre["municipality"] or ""
|
| 57 |
+
intent = pre["intent"]
|
| 58 |
+
muni = municipality if municipality.strip() else extracted_muni
|
| 59 |
+
|
| 60 |
+
retrieved = self.retriever(query=rewritten, municipality=muni)
|
| 61 |
chunks = retrieved["retrieved_chunks"]
|
| 62 |
|
| 63 |
+
# Answer + CoT using the rewritten query
|
| 64 |
answer_result = self.answerer(
|
| 65 |
+
query=rewritten,
|
| 66 |
retrieved_chunks=[c["chunk_text"] for c in chunks]
|
| 67 |
)
|
| 68 |
|
| 69 |
+
# Return everything for transparency & downstream use
|
| 70 |
return {
|
| 71 |
+
"original_query": raw_query,
|
| 72 |
+
"intent": intent,
|
| 73 |
+
"rewritten_query": rewritten,
|
| 74 |
+
"llm_municipality": extracted_muni,
|
| 75 |
+
"municipality": muni,
|
| 76 |
"retrieved_chunks": chunks,
|
| 77 |
"chain_of_thought": answer_result.rationale,
|
| 78 |
"final_answer": answer_result.answer,
|