Anuttama Chakraborty commited on
Commit
4d8af9a
·
1 Parent(s): d006092
Files changed (1) hide show
  1. RagWithConfidenceScore.py +15 -34
RagWithConfidenceScore.py CHANGED
@@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor
16
  class RagWithScore:
17
  def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2",
18
  cross_encoder_name="cross-encoder/ms-marco-TinyBERT-L-2-v2",
19
- llm_name="gpt2",
20
  documents_dir="financial_docs"):
21
  """
22
  Initialize the Financial RAG system
@@ -39,12 +39,12 @@ class RagWithScore:
39
  "text-generation",
40
  model=llm_name,
41
  tokenizer=self.tokenizer,
42
- torch_dtype=torch.bfloat16, # Use float16 if bfloat16 is not supported
43
  device_map="auto",
44
- max_new_tokens=512, # Adjust based on your needs
45
- do_sample=False, # Set to False for deterministic outputs
46
- temperature=0.2, # Reduce randomness
47
- top_p=1.0 # No nucleus sampling
48
  )
49
 
50
  # Store paths
@@ -71,6 +71,7 @@ class RagWithScore:
71
 
72
  import os
73
 
 
74
  def load_and_process_documents(self):
75
  """Load, split and process financial documents"""
76
 
@@ -91,19 +92,8 @@ class RagWithScore:
91
 
92
  return self.vector_store
93
 
94
- def load_or_create_vector_store(self):
95
- try:
96
- print("Loading existing FAISS index...")
97
- self.vector_store = FAISS.load_local("faiss_index", self.embedding_model)
98
- print("FAISS index loaded successfully")
99
- except Exception as e:
100
- print(f"Error loading FAISS index: {e}")
101
- print("Creating new FAISS index...")
102
- # Code to create a new vector store
103
- documents = self.load_and_process_documents() # Make sure this method exists
104
- print("New FAISS index created and saved")
105
-
106
 
 
107
  def generate_answer(self, query, context):
108
  """Generate answer and calculate confidence score concurrently."""
109
  # Format context into a single string
@@ -142,22 +132,8 @@ class RagWithScore:
142
  return answer
143
 
144
 
145
- # def calculate_confidence_score(self, query, retrieved_docs, answer):
146
- # """A simpler confidence score calculation focused on consistency and LLM confidence"""
147
-
148
- # # Get LLM confidence
149
- # llm_confidence = self._get_llm_confidence(query, retrieved_docs, answer)
150
-
151
- # # Get consistency score
152
- # consistency_score = self._measure_answer_consistency(query, retrieved_docs, answer)
153
-
154
- # # Simple weighted average
155
- # confidence_score = (0.6 * consistency_score) + (0.4 * llm_confidence)
156
-
157
- # print(f"confidence score : {confidence_score}")
158
-
159
- # return confidence_score
160
 
 
161
  def calculate_confidence_score(self, query, retrieved_docs, answer):
162
  """
163
  Calculate confidence score using embedding similarity (parallelized).
@@ -175,6 +151,7 @@ class RagWithScore:
175
  return similarity
176
 
177
 
 
178
  def get_confidence_level(self, confidence_score):
179
  """
180
  Convert numerical confidence score to a level (high, medium, low)
@@ -194,6 +171,7 @@ class RagWithScore:
194
  else:
195
  return "very low"
196
 
 
197
  def apply_input_guardrail(self, query):
198
  """Check if query violates input guardrails"""
199
  query_lower = query.lower()
@@ -204,6 +182,7 @@ class RagWithScore:
204
 
205
  return False, ""
206
 
 
207
  def retrieve_with_reranking(self, query, top_k=5, rerank_top_k=3):
208
 
209
  print("retrieve_with_reranking start")
@@ -241,6 +220,7 @@ class RagWithScore:
241
 
242
  return [doc for (doc, _), _ in reranked_results[:rerank_top_k]]
243
 
 
244
  def is_financial_question(self,query):
245
  financial_keywords = [
246
  "finance", "financial", "revenue", "profit", "loss", "ebitda", "cash flow",
@@ -251,7 +231,8 @@ class RagWithScore:
251
  ]
252
  query_lower = query.lower()
253
  return any(keyword in query_lower for keyword in financial_keywords)
254
-
 
255
  def answer_question(self, query):
256
  """End-to-end pipeline to answer a question with confidence score"""
257
 
 
16
  class RagWithScore:
17
  def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2",
18
  cross_encoder_name="cross-encoder/ms-marco-TinyBERT-L-2-v2",
19
+ llm_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
20
  documents_dir="financial_docs"):
21
  """
22
  Initialize the Financial RAG system
 
39
  "text-generation",
40
  model=llm_name,
41
  tokenizer=self.tokenizer,
42
+ torch_dtype=torch.bfloat16,
43
  device_map="auto",
44
+ max_new_tokens=512,
45
+ do_sample=False, # Set to False for deterministic outputs
46
+ temperature=0.2, # Reduce randomness
47
+ top_p=1.0 # No nucleus sampling
48
  )
49
 
50
  # Store paths
 
71
 
72
  import os
73
 
74
+ ## Loadung document and creating vector index at the start of the application
75
  def load_and_process_documents(self):
76
  """Load, split and process financial documents"""
77
 
 
92
 
93
  return self.vector_store
94
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ ## generating response with the query and context by the help of the prompt and calling the slm with the prompt
97
  def generate_answer(self, query, context):
98
  """Generate answer and calculate confidence score concurrently."""
99
  # Format context into a single string
 
132
  return answer
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ ## for confidence score cosine similarity is calculated between the query embedding and answer embedding
137
  def calculate_confidence_score(self, query, retrieved_docs, answer):
138
  """
139
  Calculate confidence score using embedding similarity (parallelized).
 
151
  return similarity
152
 
153
 
154
+ ## confidence level is determined from the confidence score
155
  def get_confidence_level(self, confidence_score):
156
  """
157
  Convert numerical confidence score to a level (high, medium, low)
 
171
  else:
172
  return "very low"
173
 
174
+ ## guardrail is applied to filter harmful user queries
175
  def apply_input_guardrail(self, query):
176
  """Check if query violates input guardrails"""
177
  query_lower = query.lower()
 
182
 
183
  return False, ""
184
 
185
+ ## first the to 5 chunks are retrieved. then after reranking with cross encoder top 2 are rerieved
186
  def retrieve_with_reranking(self, query, top_k=5, rerank_top_k=3):
187
 
188
  print("retrieve_with_reranking start")
 
220
 
221
  return [doc for (doc, _), _ in reranked_results[:rerank_top_k]]
222
 
223
+ ## to handle irrerelevant questions, a rule based claasifier is bein used to classify the questions
224
  def is_financial_question(self,query):
225
  financial_keywords = [
226
  "finance", "financial", "revenue", "profit", "loss", "ebitda", "cash flow",
 
231
  ]
232
  query_lower = query.lower()
233
  return any(keyword in query_lower for keyword in financial_keywords)
234
+
235
+ ##the pipeline of answer and confidence score generation from the query
236
  def answer_question(self, query):
237
  """End-to-end pipeline to answer a question with confidence score"""
238