saashley commited on
Commit
0152504
·
verified ·
1 Parent(s): e07150a

Update dspy_wrapper.py

Browse files
Files changed (1) hide show
  1. dspy_wrapper.py +19 -11
dspy_wrapper.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import dspy
2
  from typing import List, Dict
3
  import os
@@ -8,7 +9,6 @@ if not OPENAI_API_KEY:
8
  raise RuntimeError("Missing OPENAI_API_KEY env var")
9
 
10
  gpt_4o_mini = dspy.LM('openai/gpt-4o-mini', api_key=OPENAI_API_KEY)
11
- # using unimib credentials, switch to PeS if needed!
12
  dspy.configure(lm=gpt_4o_mini)
13
 
14
 
@@ -19,7 +19,7 @@ class DSPyHybridRetriever(dspy.Module):
19
  self.retriever = retriever
20
 
21
  def forward(self, query: str, municipality: str = "", top_k: int = 5):
22
- results = self.retriever.rerank(query, top_k=top_k, municipality_filter=municipality) # remember to change to rerank
23
  return {"retrieved_chunks": results}
24
 
25
  class RetrieveChunks(dspy.Signature):
@@ -37,7 +37,7 @@ class RetrieveChunks(dspy.Signature):
37
 
38
  class AnswerWithEvidence(dspy.Signature):
39
  """Answer the query using reasoning and retrieved chunks as context."""
40
- query = dspy.InputField(desc="User's question")
41
  retrieved_chunks = dspy.InputField(desc="Retrieved text chunks (List[dict])")
42
  answer = dspy.OutputField(desc="Final answer")
43
  rationale = dspy.OutputField(desc="Chain-of-thought reasoning")
@@ -50,21 +50,29 @@ class RAGChain(dspy.Module):
50
  self.retriever = retriever
51
  self.answerer = answerer
52
 
53
- def forward(self, query: str, municipality: str = ""):
54
- # retrieve full dicts
55
- retrieved = self.retriever(query=query, municipality=municipality)
 
 
 
 
 
56
  chunks = retrieved["retrieved_chunks"]
57
 
58
- # feed only the raw text into the CoT module
59
  answer_result = self.answerer(
60
- query=query,
61
  retrieved_chunks=[c["chunk_text"] for c in chunks]
62
  )
63
 
64
- # return both the metadata and the LLM answer
65
  return {
66
- "query": query,
67
- "municipality": municipality,
 
 
 
68
  "retrieved_chunks": chunks,
69
  "chain_of_thought": answer_result.rationale,
70
  "final_answer": answer_result.answer,
 
1
+ from query_preprocessing import preprocess_query
2
  import dspy
3
  from typing import List, Dict
4
  import os
 
9
  raise RuntimeError("Missing OPENAI_API_KEY env var")
10
 
11
  gpt_4o_mini = dspy.LM('openai/gpt-4o-mini', api_key=OPENAI_API_KEY)
 
12
  dspy.configure(lm=gpt_4o_mini)
13
 
14
 
 
19
  self.retriever = retriever
20
 
21
  def forward(self, query: str, municipality: str = "", top_k: int = 5):
22
+ results = self.retriever.rerank(query, top_k=top_k, municipality_filter=municipality)
23
  return {"retrieved_chunks": results}
24
 
25
  class RetrieveChunks(dspy.Signature):
 
37
 
38
  class AnswerWithEvidence(dspy.Signature):
39
  """Answer the query using reasoning and retrieved chunks as context."""
40
+ query = dspy.InputField(desc="Rewritten question")
41
  retrieved_chunks = dspy.InputField(desc="Retrieved text chunks (List[dict])")
42
  answer = dspy.OutputField(desc="Final answer")
43
  rationale = dspy.OutputField(desc="Chain-of-thought reasoning")
 
50
  self.retriever = retriever
51
  self.answerer = answerer
52
 
53
+ def forward(self, raw_query: str, municipality: str = "", feedback: str = ""):
54
+ pre = preprocess_query(raw_query, feedback)
55
+ rewritten = pre["rewritten_query"] or raw_query
56
+ extracted_muni = pre["municipality"] or ""
57
+ intent = pre["intent"]
58
+ muni = municipality if municipality.strip() else extracted_muni
59
+
60
+ retrieved = self.retriever(query=rewritten, municipality=muni)
61
  chunks = retrieved["retrieved_chunks"]
62
 
63
+ # Answer + CoT using the rewritten query
64
  answer_result = self.answerer(
65
+ query=rewritten,
66
  retrieved_chunks=[c["chunk_text"] for c in chunks]
67
  )
68
 
69
+ # Return everything for transparency & downstream use
70
  return {
71
+ "original_query": raw_query,
72
+ "intent": intent,
73
+ "rewritten_query": rewritten,
74
+ "llm_municipality": extracted_muni,
75
+ "municipality": muni,
76
  "retrieved_chunks": chunks,
77
  "chain_of_thought": answer_result.rationale,
78
  "final_answer": answer_result.answer,