mtyrrell commited on
Commit
c997974
·
1 Parent(s): a1e5650

feedback functionality completed (using some test elements)

Browse files
.gitignore CHANGED
@@ -1,2 +1,8 @@
1
  .DS_store
2
- /testing/
 
 
 
 
 
 
 
1
  .DS_store
2
+ .env
3
+ /testing/
4
+ /logs/
5
+ logging_config.py
6
+ /data/
7
+ app_interactions.jsonl
8
+ auditqa/__pycache__/
app.py CHANGED
@@ -14,25 +14,57 @@ from auditqa.retriever import get_context
14
  from auditqa.reader import nvidia_client, dedicated_endpoint
15
  from auditqa.utils import make_html_source, parse_output_llm_with_sources, save_logs, get_message_template
16
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
17
  load_dotenv()
18
 
19
- # fetch tokens and model config params
20
- SPACES_LOG = os.environ["SPACES_LOG"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  model_config = getconfig("model_params.cfg")
22
 
23
- # create the local logs repo
24
- JSON_DATASET_DIR = Path("json_dataset")
25
- JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
26
- JSON_DATASET_PATH = JSON_DATASET_DIR / f"logs-{uuid4()}.json"
 
 
 
 
 
 
 
 
 
 
27
 
28
  # the logs are written to dataset repo periodically from local logs
29
  # https://huggingface.co/spaces/Wauplin/space_to_dataset_saver
30
- scheduler = CommitScheduler(
31
- repo_id="GIZ/spaces_logs",
32
- repo_type="dataset",
33
- folder_path=JSON_DATASET_DIR,
34
- path_in_repo="audit_chatbot",
35
- token=SPACES_LOG )
36
 
37
  #####--------------- VECTOR STORE -------------------------------------------------
38
  # reports contain the already created chunks from Markdown version of pdf reports
@@ -40,7 +72,7 @@ scheduler = CommitScheduler(
40
  # We need to create the local vectorstore collection once using load_chunks
41
  # vectorestore colection are stored on persistent storage so this needs to be run only once
42
  # hence, comment out line below when creating for first time
43
- #vectorstores = load_new_chunks()
44
  # once the vectore embeddings are created we will use qdrant client to access these
45
  vectorstores = get_local_qdrant()
46
 
@@ -53,6 +85,20 @@ def start_chat(query,history):
53
  def finish_chat():
54
  return (gr.update(interactive = True,value = ""))
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  async def chat(query,history,sources,reports,subtype,year):
57
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering)
58
  to yield a tuple of:(messages in gradio format/messages in langchain format, source documents)
@@ -71,6 +117,7 @@ async def chat(query,history,sources,reports,subtype,year):
71
  vectorstore = vectorstores["allreports"]
72
 
73
  ##------------------------------get context----------------------------------------------
 
74
  context_retrieved = get_context(vectorstore=vectorstore,query=query,reports=reports,
75
  sources=sources,subtype=subtype,year=year)
76
  context_retrieved_formatted = "||".join(doc.page_content for doc in context_retrieved)
@@ -111,6 +158,23 @@ async def chat(query,history,sources,reports,subtype,year):
111
 
112
  ##-----------------------get answer from endpoints------------------------------
113
  answer_yet = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  if model_config.get('reader','TYPE') == 'NVIDIA':
115
  chat_model = nvidia_client()
116
  async def process_stream():
@@ -130,49 +194,53 @@ async def chat(query,history,sources,reports,subtype,year):
130
  answer_yet += token
131
  parsed_answer = parse_output_llm_with_sources(answer_yet)
132
  history[-1] = (query, parsed_answer)
133
- yield [tuple(x) for x in history], docs_html
 
 
134
 
135
  # Stream the response updates
136
  async for update in process_stream():
137
  yield update
138
 
139
  else:
140
- chat_model = dedicated_endpoint()
141
  async def process_stream():
142
- # Without nonlocal, Python would create a new local variable answer_yet inside process_stream(),
143
- # instead of modifying the one from the outer scope.
144
- nonlocal answer_yet # Use the outer scope's answer_yet variable
145
- # Iterate over the streaming response chunks
146
- async for chunk in chat_model.astream(messages):
147
- token = chunk.content
148
- answer_yet += token
149
- parsed_answer = parse_output_llm_with_sources(answer_yet)
150
- history[-1] = (query, parsed_answer)
151
- yield [tuple(x) for x in history], docs_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- # Stream the response updates
154
  async for update in process_stream():
155
  yield update
156
 
157
- # logging the event
158
  try:
159
- timestamp = str(datetime.now().timestamp())
160
- logs = {
161
- "system_prompt": SYSTEM_PROMPT,
162
- "sources":sources,
163
- "reports":reports,
164
- "subtype":subtype,
165
- "year":year,
166
- "question":query,
167
- "sources":sources,
168
- "retriever":model_config.get('retriever','MODEL'),
169
- "endpoint_type":model_config.get('reader','TYPE'),
170
- "raeder":model_config.get('reader','NVIDIA_MODEL'),
171
- "docs":[doc.page_content for doc in context_retrieved],
172
- "answer": history[-1][1],
173
- "time": timestamp,
174
- }
175
- save_logs(scheduler,JSON_DATASET_PATH,logs)
176
  except Exception as e:
177
  logging.error(e)
178
 
@@ -378,21 +446,61 @@ with gr.Blocks(title="Audit Q&A", css= "style.css", theme=theme,elem_id = "main-
378
 
379
 
380
 
381
- # using event listeners for 1. query box 2. click on example question
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  # https://www.gradio.app/docs/gradio/textbox#event-listeners-arguments
383
  (textbox
384
- .submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
385
- # queue must be set as False (default) so the process is not waiting for another to be finished
386
- .then(chat, [textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], queue=True, concurrency_limit=8, api_name="chat_textbox")
387
- .then(finish_chat, None, [textbox], api_name="finish_chat_textbox"))
 
388
 
389
  (examples_hidden
390
  .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
391
  # queue must be set as False (default) so the process is not waiting for another to be finished
392
- .then(chat, [examples_hidden, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox], concurrency_limit=8, api_name="chat_examples")
393
- .then(finish_chat, None, [textbox], api_name="finish_chat_examples")
394
- )
395
-
 
 
396
  demo.queue()
397
 
398
- demo.launch()
 
 
14
  from auditqa.reader import nvidia_client, dedicated_endpoint
15
  from auditqa.utils import make_html_source, parse_output_llm_with_sources, save_logs, get_message_template
16
  from dotenv import load_dotenv
17
+ from threading import Lock
18
+ import json
19
+ from functools import partial
20
+
21
+ # TESTING DEBUG LOG
22
+ from auditqa.logging_config import setup_logging
23
+ setup_logging()
24
+ import logging
25
+ logger = logging.getLogger(__name__)
26
+
27
  load_dotenv()
28
 
29
+ # # fetch tokens and model config params
30
+ # SPACES_LOG = os.environ["SPACES_LOG"]
31
+ # model_config = getconfig("model_params.cfg")
32
+
33
+ # # create the local logs repo
34
+ # JSON_DATASET_DIR = Path("json_dataset")
35
+ # JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
36
+ # JSON_DATASET_PATH = JSON_DATASET_DIR / f"logs-{uuid4()}.json"
37
+
38
+ # # the logs are written to dataset repo periodically from local logs
39
+ # # https://huggingface.co/spaces/Wauplin/space_to_dataset_saver
40
+ # scheduler = CommitScheduler(
41
+ # repo_id="GIZ/spaces_logs",
42
+ # repo_type="dataset",
43
+ # folder_path=JSON_DATASET_DIR,
44
+ # path_in_repo="audit_chatbot",
45
+ # token=SPACES_LOG )
46
+
47
+ # Make logging optional
48
+ SPACES_LOG = os.getenv("SPACES_LOG", "app_interactions.jsonl") # TESTING: local logging setup
49
  model_config = getconfig("model_params.cfg")
50
 
51
+ # TESTING: local logging setup
52
+ class LocalScheduler:
53
+ def __init__(self, filepath):
54
+ self.filepath = Path(filepath)
55
+ self.lock = Lock()
56
+
57
+ # Create the file if it doesn't exist
58
+ if not self.filepath.exists():
59
+ with self.filepath.open('w') as f:
60
+ f.write('')
61
+
62
+ # Instead of HuggingFace CommitScheduler, use local scheduler
63
+ scheduler = LocalScheduler(SPACES_LOG) # TESTING: local logging setup
64
+ JSON_DATASET_PATH = Path(SPACES_LOG) # TESTING: local logging setup
65
 
66
  # the logs are written to dataset repo periodically from local logs
67
  # https://huggingface.co/spaces/Wauplin/space_to_dataset_saver
 
 
 
 
 
 
68
 
69
  #####--------------- VECTOR STORE -------------------------------------------------
70
  # reports contain the already created chunks from Markdown version of pdf reports
 
72
  # We need to create the local vectorstore collection once using load_chunks
73
  # vectorestore colection are stored on persistent storage so this needs to be run only once
74
  # hence, comment out line below when creating for first time
75
+ # vectorstores = load_new_chunks()
76
  # once the vectore embeddings are created we will use qdrant client to access these
77
  vectorstores = get_local_qdrant()
78
 
 
85
  def finish_chat():
86
  return (gr.update(interactive = True,value = ""))
87
 
88
+ def submit_feedback(feedback, logs_data):
89
+ """Handle feedback submission"""
90
+ try:
91
+ if logs_data is None:
92
+ logger.error("No logs data available for feedback")
93
+ return gr.update(visible=False), gr.update(visible=True)
94
+
95
+ save_logs(scheduler, JSON_DATASET_PATH, logs_data, feedback)
96
+ return gr.update(visible=False), gr.update(visible=True)
97
+ except Exception as e:
98
+ logger.error(f"Error saving feedback: {e}")
99
+ # Still need to return the expected outputs even on error
100
+ return gr.update(visible=False), gr.update(visible=True)
101
+
102
  async def chat(query,history,sources,reports,subtype,year):
103
  """taking a query and a message history, use a pipeline (reformulation, retriever, answering)
104
  to yield a tuple of:(messages in gradio format/messages in langchain format, source documents)
 
117
  vectorstore = vectorstores["allreports"]
118
 
119
  ##------------------------------get context----------------------------------------------
120
+
121
  context_retrieved = get_context(vectorstore=vectorstore,query=query,reports=reports,
122
  sources=sources,subtype=subtype,year=year)
123
  context_retrieved_formatted = "||".join(doc.page_content for doc in context_retrieved)
 
158
 
159
  ##-----------------------get answer from endpoints------------------------------
160
  answer_yet = ""
161
+ # Create logs data structure at the beginning (so that feedback can be saved after streaming
162
+ timestamp = str(datetime.now().timestamp())
163
+ logs_data = {
164
+ "system_prompt": SYSTEM_PROMPT,
165
+ "sources": sources,
166
+ "reports": reports,
167
+ "subtype": subtype,
168
+ "year": year,
169
+ "question": query,
170
+ "retriever": model_config.get('retriever','MODEL'),
171
+ "endpoint_type": model_config.get('reader','TYPE'),
172
+ "reader": model_config.get('reader','NVIDIA_MODEL'),
173
+ "docs": [doc.page_content for doc in context_retrieved],
174
+ "answer": "", # Updated after streaming
175
+ "time": timestamp,
176
+ }
177
+
178
  if model_config.get('reader','TYPE') == 'NVIDIA':
179
  chat_model = nvidia_client()
180
  async def process_stream():
 
194
  answer_yet += token
195
  parsed_answer = parse_output_llm_with_sources(answer_yet)
196
  history[-1] = (query, parsed_answer)
197
+ # Update logs_data with current answer
198
+ logs_data["answer"] = parsed_answer
199
+ yield [tuple(x) for x in history], docs_html, logs_data
200
 
201
  # Stream the response updates
202
  async for update in process_stream():
203
  yield update
204
 
205
  else:
206
+ chat_model = dedicated_endpoint() # TESTING: ADAPTED FOR HF INFERENCE API
207
  async def process_stream():
208
+ nonlocal answer_yet
209
+ try:
210
+ formatted_messages = [
211
+ {
212
+ "role": msg.type if hasattr(msg, 'type') else msg.role,
213
+ "content": msg.content
214
+ }
215
+ for msg in messages
216
+ ]
217
+
218
+ response = chat_model.chat_completion(
219
+ messages=formatted_messages,
220
+ max_tokens=int(model_config.get('reader', 'MAX_TOKENS'))
221
+ )
222
+
223
+ response_text = response.choices[0].message.content
224
+ words = response_text.split()
225
+ for word in words:
226
+ answer_yet += word + " "
227
+ parsed_answer = parse_output_llm_with_sources(answer_yet)
228
+ history[-1] = (query, parsed_answer)
229
+ # Update logs_data with current answer
230
+ logs_data["answer"] = parsed_answer
231
+ yield [tuple(x) for x in history], docs_html, logs_data
232
+ await asyncio.sleep(0.05)
233
+
234
+ except Exception as e:
235
+ logger.error(f"Error in process_stream: {str(e)}")
236
+ raise
237
 
 
238
  async for update in process_stream():
239
  yield update
240
 
 
241
  try:
242
+ # Save log after streaming is complete
243
+ save_logs(scheduler, JSON_DATASET_PATH, logs_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  except Exception as e:
245
  logging.error(e)
246
 
 
446
 
447
 
448
 
449
+ #-------------------- Feedback UI elements + state management -------------------------
450
+ with gr.Row(visible=False) as feedback_row:
451
+ gr.Markdown("Was this response helpful?")
452
+ with gr.Row():
453
+ okay_btn = gr.Button("👍 Okay", elem_classes="feedback-button")
454
+ not_okay_btn = gr.Button("👎 Not to expectations", elem_classes="feedback-button")
455
+
456
+ feedback_thanks = gr.Markdown("Thanks for the feedback!", visible=False)
457
+ feedback_state = gr.State() # Add state to store logs data
458
+
459
+ def show_feedback(logs):
460
+ """Show feedback buttons and store logs in state"""
461
+ return gr.update(visible=True), gr.update(visible=False), logs
462
+
463
+ def submit_feedback_okay(logs_data):
464
+ """Handle 'okay' feedback submission"""
465
+ return submit_feedback("okay", logs_data)
466
+
467
+ def submit_feedback_not_okay(logs_data):
468
+ """Handle 'not okay' feedback submission"""
469
+ return submit_feedback("not_okay", logs_data)
470
+
471
+ okay_btn.click(
472
+ submit_feedback_okay,
473
+ [feedback_state],
474
+ [feedback_row, feedback_thanks]
475
+ )
476
+
477
+ not_okay_btn.click(
478
+ submit_feedback_not_okay,
479
+ [feedback_state],
480
+ [feedback_row, feedback_thanks]
481
+ )
482
+
483
+ #-------------------- Gradio voodoo continued -------------------------
484
+
485
+ # Using event listeners for 1. query box 2. click on example question
486
  # https://www.gradio.app/docs/gradio/textbox#event-listeners-arguments
487
  (textbox
488
+ .submit(start_chat, [textbox, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_textbox")
489
+ # queue must be set as False (default) so the process is not waiting for another to be finished
490
+ .then(chat, [textbox, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox, feedback_state], queue=True, concurrency_limit=8, api_name="chat_textbox")
491
+ .then(show_feedback, [feedback_state], [feedback_row, feedback_thanks, feedback_state], api_name="show_feedback_textbox")
492
+ .then(finish_chat, None, [textbox], api_name="finish_chat_textbox"))
493
 
494
  (examples_hidden
495
  .change(start_chat, [examples_hidden, chatbot], [textbox, tabs, chatbot], queue=False, api_name="start_chat_examples")
496
  # queue must be set as False (default) so the process is not waiting for another to be finished
497
+ .then(chat, [examples_hidden, chatbot, dropdown_sources, dropdown_reports, dropdown_category, dropdown_year], [chatbot, sources_textbox, feedback_state], queue=True, concurrency_limit=8, api_name="chat_examples")
498
+ .then(show_feedback, [feedback_state], [feedback_row, feedback_thanks, feedback_state], api_name="show_feedback_examples")
499
+ .then(finish_chat, None, [textbox], api_name="finish_chat_examples"))
500
+
501
+
502
+
503
  demo.queue()
504
 
505
+ demo.launch()
506
+ logger.info("App launched")
auditqa/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/auditqa/__pycache__/__init__.cpython-310.pyc and b/auditqa/__pycache__/__init__.cpython-310.pyc differ
 
auditqa/__pycache__/process_chunks.cpython-310.pyc CHANGED
Binary files a/auditqa/__pycache__/process_chunks.cpython-310.pyc and b/auditqa/__pycache__/process_chunks.cpython-310.pyc differ
 
auditqa/__pycache__/reader.cpython-310.pyc CHANGED
Binary files a/auditqa/__pycache__/reader.cpython-310.pyc and b/auditqa/__pycache__/reader.cpython-310.pyc differ
 
auditqa/__pycache__/reports.cpython-310.pyc CHANGED
Binary files a/auditqa/__pycache__/reports.cpython-310.pyc and b/auditqa/__pycache__/reports.cpython-310.pyc differ
 
auditqa/__pycache__/retriever.cpython-310.pyc CHANGED
Binary files a/auditqa/__pycache__/retriever.cpython-310.pyc and b/auditqa/__pycache__/retriever.cpython-310.pyc differ
 
auditqa/__pycache__/sample_questions.cpython-310.pyc CHANGED
Binary files a/auditqa/__pycache__/sample_questions.cpython-310.pyc and b/auditqa/__pycache__/sample_questions.cpython-310.pyc differ
 
auditqa/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/auditqa/__pycache__/utils.cpython-310.pyc and b/auditqa/__pycache__/utils.cpython-310.pyc differ
 
auditqa/process_chunks.py CHANGED
@@ -11,10 +11,17 @@ from qdrant_client import QdrantClient
11
  from auditqa.reports import files, report_list
12
  from langchain.docstore.document import Document
13
  import configparser
 
14
 
15
  # read all the necessary variables
16
  device = 'cuda' if cuda.is_available() else 'cpu'
17
- path_to_data = "./reports/"
 
 
 
 
 
 
18
 
19
 
20
  ##---------------------functions -------------------------------------------##
@@ -118,7 +125,7 @@ def load_new_chunks():
118
  """
119
  this method reads through the files and report_list to create the vector database
120
  """
121
-
122
  # we iterate through the files which contain information about its
123
  # 'source'=='category', 'subtype', these are used in UI for document selection
124
  # which will be used later for filtering database
@@ -161,7 +168,7 @@ def load_new_chunks():
161
  qdrant_collections['allreports'] = Qdrant.from_documents(
162
  all_documents,
163
  embeddings,
164
- path="/data/local_qdrant",
165
  collection_name='allreports',
166
  )
167
  print(qdrant_collections)
@@ -169,14 +176,16 @@ def load_new_chunks():
169
  return qdrant_collections
170
 
171
  def get_local_qdrant():
172
- """once the local qdrant server is created this is used to make the connection to exisitng server"""
173
  config = getconfig("./model_params.cfg")
174
  qdrant_collections = {}
175
  embeddings = HuggingFaceEmbeddings(
176
  model_kwargs = {'device': device},
177
  encode_kwargs = {'normalize_embeddings': True},
178
  model_name=config.get('retriever','MODEL'))
179
- client = QdrantClient(path="/data/local_qdrant")
180
- print("Collections in local Qdrant:",client.get_collections())
 
 
181
  qdrant_collections['allreports'] = Qdrant(client=client, collection_name='allreports', embeddings=embeddings, )
182
  return qdrant_collections
 
11
  from auditqa.reports import files, report_list
12
  from langchain.docstore.document import Document
13
  import configparser
14
+ from pathlib import Path
15
 
16
  # read all the necessary variables
17
  device = 'cuda' if cuda.is_available() else 'cpu'
18
+ path_to_data = "./reports/"
19
+
20
+ # TESTING DEBUG LOG
21
+ from auditqa.logging_config import setup_logging
22
+ setup_logging()
23
+ import logging
24
+ logger = logging.getLogger(__name__)
25
 
26
 
27
  ##---------------------functions -------------------------------------------##
 
125
  """
126
  this method reads through the files and report_list to create the vector database
127
  """
128
+ logger.info("Loading new chunks")
129
  # we iterate through the files which contain information about its
130
  # 'source'=='category', 'subtype', these are used in UI for document selection
131
  # which will be used later for filtering database
 
168
  qdrant_collections['allreports'] = Qdrant.from_documents(
169
  all_documents,
170
  embeddings,
171
+ path="./data/local_qdrant",
172
  collection_name='allreports',
173
  )
174
  print(qdrant_collections)
 
176
  return qdrant_collections
177
 
178
  def get_local_qdrant():
179
+ """once the local qdrant server is created this is used to make the connection to existing server"""
180
  config = getconfig("./model_params.cfg")
181
  qdrant_collections = {}
182
  embeddings = HuggingFaceEmbeddings(
183
  model_kwargs = {'device': device},
184
  encode_kwargs = {'normalize_embeddings': True},
185
  model_name=config.get('retriever','MODEL'))
186
+
187
+ # Change the path to a local directory
188
+ client = QdrantClient(path="./data/local_qdrant")
189
+ print("Collections in local Qdrant:", client.get_collections())
190
  qdrant_collections['allreports'] = Qdrant(client=client, collection_name='allreports', embeddings=embeddings, )
191
  return qdrant_collections
auditqa/reader.py CHANGED
@@ -7,38 +7,49 @@ import os
7
  from dotenv import load_dotenv
8
  load_dotenv()
9
 
10
- model_config = getconfig("model_params.cfg")
11
- NVIDIA_SERVER = os.environ["NVIDIA_SERVERLESS"]
12
- HF_token = os.environ["LLAMA_3_1"]
 
 
13
 
 
 
 
 
14
 
15
  def nvidia_client():
 
16
  """ returns the nvidia server client """
17
- client = InferenceClient(
18
- base_url=model_config.get('reader','NVIDIA_ENDPOINT'),
19
- api_key=NVIDIA_SERVER)
20
- print("getting nvidia client")
21
-
22
- return client
 
 
 
23
 
 
24
  def dedicated_endpoint():
25
- """ returns the dedicated server endpoint"""
 
 
 
 
 
 
 
26
 
27
- # Set up the streaming callback handler
28
- callback = StreamingStdOutCallbackHandler()
 
 
 
 
 
29
 
30
- # Initialize the HuggingFaceEndpoint with streaming enabled
31
- llm_qa = HuggingFaceEndpoint(
32
- endpoint_url=model_config.get('reader', 'DEDICATED_ENDPOINT'),
33
- max_new_tokens=int(model_config.get('reader','MAX_TOKENS')),
34
- repetition_penalty=1.03,
35
- timeout=70,
36
- huggingfacehub_api_token=HF_token,
37
- streaming=True, # Enable streaming for real-time token generation
38
- callbacks=[callback] # Add the streaming callback handler
39
- )
40
-
41
- # Create a ChatHuggingFace instance with the streaming-enabled endpoint
42
- chat_model = ChatHuggingFace(llm=llm_qa)
43
- print("getting dedicated endpoint wrapped in ChathuggingFace ")
44
- return chat_model
 
7
  from dotenv import load_dotenv
8
  load_dotenv()
9
 
10
+ # TESTING DEBUG LOG
11
+ from auditqa.logging_config import setup_logging
12
+ setup_logging()
13
+ import logging
14
+ logger = logging.getLogger(__name__)
15
 
16
+ model_config = getconfig("model_params.cfg")
17
+ # NVIDIA_SERVER = os.environ["NVIDIA_SERVERLESS"]
18
+ # HF_token = os.environ["LLAMA_3_1"]
19
+ HF_token = os.getenv('LLAMA_3_1') # TESTING
20
 
21
  def nvidia_client():
22
+ logger.info("NVIDIA client activated")
23
  """ returns the nvidia server client """
24
+ try:
25
+ NVIDIA_SERVER = os.environ["NVIDIA_SERVERLESS"]
26
+ client = InferenceClient(
27
+ base_url=model_config.get('reader','NVIDIA_ENDPOINT'),
28
+ api_key=NVIDIA_SERVER)
29
+ print("getting nvidia client")
30
+ return client
31
+ except KeyError:
32
+ raise KeyError("NVIDIA_SERVERLESS environment variable not set. Required for NVIDIA endpoint.")
33
 
34
+ # TESTING VERSION
35
  def dedicated_endpoint():
36
+ logger.info("Serverless endpoint activated")
37
+ try:
38
+ hf_api_key = os.environ["LLAMA_3_1"]
39
+ if not hf_api_key:
40
+ raise ValueError("LLAMA_3_1 environment variable is empty")
41
+
42
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
43
+ logger.info(f"Initializing InferenceClient with model: {model_id}")
44
 
45
+ client = InferenceClient(
46
+ model=model_id,
47
+ api_key=hf_api_key,
48
+ )
49
+
50
+ logger.info("Serverless InferenceClient initialization successful")
51
+ return client
52
 
53
+ except Exception as e:
54
+ logger.error(f"Error initializing dedicated endpoint: {str(e)}")
55
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
auditqa/retriever.py CHANGED
@@ -4,6 +4,12 @@ from langchain.retrievers import ContextualCompressionRetriever
4
  from langchain.retrievers.document_compressors import CrossEncoderReranker
5
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
6
 
 
 
 
 
 
 
7
  model_config = getconfig("model_params.cfg")
8
 
9
  def create_filter(reports:list = [],sources:str =None,
@@ -35,9 +41,12 @@ def create_filter(reports:list = [],sources:str =None,
35
  return filter
36
 
37
 
 
38
  def get_context(vectorstore,query,reports,sources,subtype,year):
 
39
  # create metadata filter
40
- filter = create_filter(reports=reports,sources=sources,subtype=subtype,year=year)
 
41
 
42
  # getting context
43
  retriever = vectorstore.as_retriever(search_type="similarity_score_threshold",
@@ -50,7 +59,9 @@ def get_context(vectorstore,query,reports,sources,subtype,year):
50
  compression_retriever = ContextualCompressionRetriever(
51
  base_compressor=compressor, base_retriever=retriever
52
  )
 
53
  context_retrieved = compression_retriever.invoke(query)
 
54
  print(f"retrieved paragraphs:{len(context_retrieved)}")
55
 
56
  return context_retrieved
 
4
  from langchain.retrievers.document_compressors import CrossEncoderReranker
5
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
6
 
7
+ # TESTING DEBUG LOG
8
+ from auditqa.logging_config import setup_logging
9
+ setup_logging()
10
+ import logging
11
+ logger = logging.getLogger(__name__)
12
+
13
  model_config = getconfig("model_params.cfg")
14
 
15
  def create_filter(reports:list = [],sources:str =None,
 
41
  return filter
42
 
43
 
44
+
45
  def get_context(vectorstore,query,reports,sources,subtype,year):
46
+ logger.info("Retriever activated")
47
  # create metadata filter
48
+ # filter = create_filter(reports=reports,sources=sources,subtype=subtype,year=year)
49
+ filter = None
50
 
51
  # getting context
52
  retriever = vectorstore.as_retriever(search_type="similarity_score_threshold",
 
59
  compression_retriever = ContextualCompressionRetriever(
60
  base_compressor=compressor, base_retriever=retriever
61
  )
62
+
63
  context_retrieved = compression_retriever.invoke(query)
64
+ logger.info(f"retrieved paragraphs:{len(context_retrieved)}")
65
  print(f"retrieved paragraphs:{len(context_retrieved)}")
66
 
67
  return context_retrieved
auditqa/utils.py CHANGED
@@ -6,10 +6,14 @@ from langchain.schema import (
6
  SystemMessage,
7
  )
8
 
9
- def save_logs(scheduler, JSON_DATASET_PATH, logs) -> None:
10
  """ Every interaction with app saves the log of question and answer,
11
- this is to get the usage statistics of app and evaluate model performances
 
12
  """
 
 
 
13
  with scheduler.lock:
14
  with JSON_DATASET_PATH.open("a") as f:
15
  json.dump(logs, f)
 
6
  SystemMessage,
7
  )
8
 
9
+ def save_logs(scheduler, JSON_DATASET_PATH, logs, feedback=None) -> None:
10
  """ Every interaction with app saves the log of question and answer,
11
+ this is to get the usage statistics of app and evaluate model performances.
12
+ Also saves user feedback (when provided).
13
  """
14
+ if feedback:
15
+ logs["feedback"] = feedback #optional
16
+
17
  with scheduler.lock:
18
  with JSON_DATASET_PATH.open("a") as f:
19
  json.dump(logs, f)
model_params.cfg CHANGED
@@ -6,9 +6,9 @@ TOP_K = 20
6
  MODEL = BAAI/bge-reranker-base
7
  TOP_K = 3
8
  [reader]
9
- TYPE = NVIDIA
10
  DEDICATED_MODEL = meta-llama/Llama-3.1-8B-Instruct
11
  DEDICATED_ENDPOINT = https://qu2d8m6dmsollhly.us-east-1.aws.endpoints.huggingface.cloud
12
  NVIDIA_MODEL = meta-llama/Llama-3.1-8B-Instruct
13
  NVIDIA_ENDPOINT = https://huggingface.co/api/integrations/dgx/v1
14
- MAX_TOKENS = 512
 
6
  MODEL = BAAI/bge-reranker-base
7
  TOP_K = 3
8
  [reader]
9
+ TYPE = DEDICATED
10
  DEDICATED_MODEL = meta-llama/Llama-3.1-8B-Instruct
11
  DEDICATED_ENDPOINT = https://qu2d8m6dmsollhly.us-east-1.aws.endpoints.huggingface.cloud
12
  NVIDIA_MODEL = meta-llama/Llama-3.1-8B-Instruct
13
  NVIDIA_ENDPOINT = https://huggingface.co/api/integrations/dgx/v1
14
+ MAX_TOKENS = 256
style.css CHANGED
@@ -1,4 +1,3 @@
1
-
2
  /* :root {
3
  --user-image: url('https://ih1.redbubble.net/image.4776899543.6215/st,small,507x507-pad,600x600,f8f8f8.jpg');
4
  } */
@@ -360,3 +359,15 @@ span.chatbot > p > img{
360
  .a-doc-ref{
361
  text-decoration: none !important;
362
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  /* :root {
2
  --user-image: url('https://ih1.redbubble.net/image.4776899543.6215/st,small,507x507-pad,600x600,f8f8f8.jpg');
3
  } */
 
359
  .a-doc-ref{
360
  text-decoration: none !important;
361
  }
362
+
363
+ .feedback-button {
364
+ border: none;
365
+ padding: 8px 16px;
366
+ border-radius: 4px;
367
+ cursor: pointer;
368
+ transition: background-color 0.3s;
369
+ }
370
+
371
+ .feedback-button:hover {
372
+ opacity: 0.8;
373
+ }