kcheng0816 commited on
Commit
ed05693
·
1 Parent(s): 8e6e753

update app.py for repeated quiz question

Browse files
Files changed (1) hide show
  1. app.py +95 -23
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  import re
3
- import random
4
  import uuid
 
5
  from dotenv import load_dotenv
6
  import chainlit as cl
7
  from langchain.docstore.document import Document
@@ -21,18 +22,22 @@ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, Tool
21
  from langchain_core.tools import tool
22
  from langchain_community.tools.tavily_search import TavilySearchResults
23
  from functools import partial
24
- from typing import Any, Callable, List, Optional, TypedDict, Union
25
  from langchain_core.messages import AnyMessage
26
  from langgraph.graph.message import add_messages
27
  from typing import TypedDict, Annotated
28
  from langgraph.prebuilt import ToolNode
29
  from langgraph.graph import StateGraph, END
30
- import json
 
 
 
31
 
32
  # Load API Keys
33
  load_dotenv()
34
  os.environ["LANGCHAIN_PROJECT"] = f"AIE5- Bible Study Tool - {uuid.uuid4().hex[0:8]}"
35
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
 
36
 
37
  path = "data/"
38
  book = "Genesis"
@@ -94,14 +99,23 @@ points = [
94
  client.upsert(collection_name=collection_name, points=points)
95
 
96
  # Cached embedder
97
- safe_namespace = "AIE5_BibleStudyTool"
98
- store = LocalFileStore("./cache/")
99
- cached_embedder = CacheBackedEmbeddings.from_bytes_store(
100
- huggingface_embeddings, store, namespace=safe_namespace, batch_size=32
101
- )
102
 
103
- # Retrieval functions (unchanged from original)
104
  def parse_verse_reference(ref: str):
 
 
 
 
 
 
 
 
 
105
  match = re.match(r"(\w+(?:\s\w+)?)\s(\d+):([\d,-]+)", ref)
106
  if not match:
107
  return None
@@ -117,6 +131,18 @@ def parse_verse_reference(ref: str):
117
  return book, chapter, verses
118
 
119
  def retrieve_verse_content(verse_range: str, client: QdrantClient):
 
 
 
 
 
 
 
 
 
 
 
 
120
  parsed = parse_verse_reference(verse_range)
121
  if not parsed:
122
  return "Invalid verse range format."
@@ -146,12 +172,29 @@ def retrieve_verse_content(verse_range: str, client: QdrantClient):
146
  return docs
147
 
148
  def retrieve_documents(question: str, collection_name: str, client: QdrantClient):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  reference_match = re.search(r"(\w+)\s?(\d+):\s?([\d,-]+)", question)
150
  if reference_match:
151
  verse_range = reference_match.group(1) + ' ' + reference_match.group(2) + ':' + reference_match.group(3)
152
  return retrieve_verse_content(verse_range, client)
153
  else:
154
- query_vector = cached_embedder.embed_query(question)
155
  search_result = client.query_points(
156
  collection_name=collection_name,
157
  query=query_vector,
@@ -168,7 +211,7 @@ def retrieve_documents(question: str, collection_name: str, client: QdrantClient
168
  ]
169
  return "No relevant documents found."
170
 
171
- # RAG setup (unchanged from original)
172
  RAG_PROMPT = """\
173
  You are a helpful assistant who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge.
174
 
@@ -180,10 +223,6 @@ You are a helpful assistant who answers questions based on provided context. You
180
  """
181
  rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
182
 
183
- from langchain_openai import ChatOpenAI
184
- from langchain.chat_models import init_chat_model
185
- from langchain_core.rate_limiters import InMemoryRateLimiter
186
-
187
  rate_limiter = InMemoryRateLimiter(
188
  requests_per_second=1,
189
  check_every_n_seconds=0.1,
@@ -191,6 +230,7 @@ rate_limiter = InMemoryRateLimiter(
191
  )
192
 
193
  chat_model = init_chat_model("gpt-4o-mini", rate_limiter=rate_limiter)
 
194
 
195
  def create_retriever_runnable(collection_name: str, client: QdrantClient) -> RunnableLambda:
196
  return RunnableLambda(lambda question: retrieve_documents(question, collection_name, client))
@@ -226,10 +266,21 @@ def _generate_quiz_question(verse_range: str, client: QdrantClient):
226
  docs = retrieve_verse_content(verse_range, client)
227
  if isinstance(docs, str):
228
  return {"error": docs}
 
 
 
 
 
 
 
 
 
 
229
  verse_content = "\n".join(
230
  f"{doc.metadata['book']} {doc.metadata['chapter']}:{doc.metadata['verse']} - {doc.page_content}"
231
- for doc in docs
232
  )
 
233
  quiz_prompt = ChatPromptTemplate.from_template(
234
  "Based on the following Bible verse(s), generate a multiple-choice quiz question with 4 options (A, B, C, D) "
235
  "and indicate the correct answer:\n\n"
@@ -243,7 +294,11 @@ def _generate_quiz_question(verse_range: str, client: QdrantClient):
243
  "Correct Answer: [Letter of correct answer]\n"
244
  "Explanation: [Brief explanation of why the answer is correct]\n"
245
  )
246
- response = (quiz_prompt | chat_model).invoke({"verse_content": verse_content})
 
 
 
 
247
  response_text = response.content.strip()
248
  lines = response_text.split("\n")
249
  question = ""
@@ -261,6 +316,7 @@ def _generate_quiz_question(verse_range: str, client: QdrantClient):
261
  correct_answer = line[len("Correct Answer:"):].strip()
262
  elif line.startswith("Explanation:"):
263
  explanation = line[len("Explanation:"):].strip()
 
264
  return {
265
  "quiz_question": question,
266
  "options": options,
@@ -273,16 +329,17 @@ def _generate_quiz_question(verse_range: str, client: QdrantClient):
273
  generate_quiz_question_tool = partial(_generate_quiz_question, client=client)
274
 
275
  @tool
276
- def generate_quiz_question(verse_range: str):
277
  """Generate a quiz question based on the content of the specified verse range."""
278
  quiz_data = generate_quiz_question_tool(verse_range)
279
  return json.dumps(quiz_data)
280
 
281
- tool_belt = [ai_rag_tool, tavily_tool, generate_quiz_question]
282
 
283
  # LLM for agent reasoning
284
  llm = init_chat_model("gpt-4o", temperature=0, rate_limiter=rate_limiter)
285
  llm_with_tools = llm.bind_tools(tool_belt)
 
286
 
287
  # Define the state
288
  class AgentState(TypedDict):
@@ -299,9 +356,9 @@ system_message = SystemMessage(content="""You are a Bible study assistant. You c
299
 
300
  - Use the 'ai_rag_tool' to answer questions about the Bible.
301
  - Use the 'tavily_tool' to search the internet for additional information.
302
- - Use the 'generate_quiz_question' tool when the user requests to start a quiz on a specific verse range, such as 'start quiz on Genesis 1:1-10'.
303
 
304
- When the user requests a quiz, extract the verse range from their message and pass it to the 'generate_quiz_question' tool.""")
305
 
306
 
307
  from typing import Optional
@@ -310,7 +367,22 @@ from langgraph.graph.message import AnyMessage, add_messages
310
  from typing import Annotated
311
 
312
 
 
313
  def call_mode(state):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  last_message = state["messages"][-1]
315
 
316
  if state.get("in_quiz", False):
@@ -344,7 +416,7 @@ def call_mode(state):
344
  if user_input == "yes":
345
  # Generate a new quiz question
346
  verse_range = state["verse_range"]
347
- quiz_data_str = generate_quiz_question(verse_range)
348
  quiz_data = json.loads(quiz_data_str)
349
  question = quiz_data["quiz_question"]
350
  options = "\n".join([f"{k}: {v}" for k, v in quiz_data["options"].items()])
@@ -407,7 +479,7 @@ def call_mode(state):
407
  prev_message = state["messages"][-2]
408
  if isinstance(prev_message, AIMessage) and prev_message.tool_calls:
409
  tool_call = prev_message.tool_calls[0]
410
- if tool_call["name"] == "generate_quiz_question":
411
  # Start the quiz
412
  quiz_data_str = last_message.content
413
  quiz_data = json.loads(quiz_data_str)
 
1
  import os
2
  import re
3
+ import json
4
  import uuid
5
+ import random
6
  from dotenv import load_dotenv
7
  import chainlit as cl
8
  from langchain.docstore.document import Document
 
22
  from langchain_core.tools import tool
23
  from langchain_community.tools.tavily_search import TavilySearchResults
24
  from functools import partial
25
+ from typing import Optional, TypedDict
26
  from langchain_core.messages import AnyMessage
27
  from langgraph.graph.message import add_messages
28
  from typing import TypedDict, Annotated
29
  from langgraph.prebuilt import ToolNode
30
  from langgraph.graph import StateGraph, END
31
+ from langchain.chat_models import init_chat_model
32
+ from langchain_core.rate_limiters import InMemoryRateLimiter
33
+ from langchain_core.globals import set_llm_cache
34
+ from langchain_core.caches import InMemoryCache
35
 
36
  # Load API Keys
37
  load_dotenv()
38
  os.environ["LANGCHAIN_PROJECT"] = f"AIE5- Bible Study Tool - {uuid.uuid4().hex[0:8]}"
39
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
40
+ print(os.environ["LANGCHAIN_PROJECT"])
41
 
42
  path = "data/"
43
  book = "Genesis"
 
99
  client.upsert(collection_name=collection_name, points=points)
100
 
101
  # Cached embedder
102
+ #safe_namespace = "AIE5_BibleStudyTool"
103
+ #store = LocalFileStore("./cache/")
104
+ #cached_embedder = CacheBackedEmbeddings.from_bytes_store(
105
+ # huggingface_embeddings, store, namespace=safe_namespace, batch_size=32
106
+ #)
107
 
108
+ # Retrieval functions
109
  def parse_verse_reference(ref: str):
110
+ """
111
+ Parse a verse reference string into book, chapter, and a list of verse numbers.
112
+
113
+ Args:
114
+ ref (str): The verse reference, e.g., "Genesis 1:1-10".
115
+
116
+ Returns:
117
+ tuple: (book, chapter, verses) where verses is a list of integers, or None if invalid.
118
+ """
119
  match = re.match(r"(\w+(?:\s\w+)?)\s(\d+):([\d,-]+)", ref)
120
  if not match:
121
  return None
 
131
  return book, chapter, verses
132
 
133
  def retrieve_verse_content(verse_range: str, client: QdrantClient):
134
+ """
135
+ Retrieve Bible verses from Qdrant based on the specified verse range.
136
+
137
+ Parameters:
138
+ - verse_range (str): The verse range in the format "Book Chapter:Verses", e.g., "Genesis 1:1-5".
139
+ - client (QdrantClient): The Qdrant client to query the database.
140
+
141
+ Returns:
142
+ - list[Document]: A list of Document objects containing the verse text and metadata.
143
+ - str: An error message if the verse range is invalid or no verses are found.
144
+ """
145
+ # Parse the verse range into book, chapter, and verses
146
  parsed = parse_verse_reference(verse_range)
147
  if not parsed:
148
  return "Invalid verse range format."
 
172
  return docs
173
 
174
  def retrieve_documents(question: str, collection_name: str, client: QdrantClient):
175
+ """
176
+ Retrieve documents from a Qdrant collection based on the input question.
177
+
178
+ This function first checks if the question contains a specific Bible verse reference
179
+ (e.g., "Genesis 1:1-5"). If a reference is found, it retrieves the exact verses using
180
+ `retrieve_verse_content`. If no reference is found, it performs a semantic search
181
+ using embeddings to find the most relevant documents.
182
+
183
+ Parameters:
184
+ - question (str): The input question or query string.
185
+ - collection_name (str): The name of the Qdrant collection to search in.
186
+ - client (QdrantClient): The Qdrant client object used to interact with the database.
187
+
188
+ Returns:
189
+ - list[Document]: A list of Document objects containing the relevant verse text and metadata.
190
+ - str: An error message if no relevant documents are found or if the verse reference is invalid.
191
+ """
192
  reference_match = re.search(r"(\w+)\s?(\d+):\s?([\d,-]+)", question)
193
  if reference_match:
194
  verse_range = reference_match.group(1) + ' ' + reference_match.group(2) + ':' + reference_match.group(3)
195
  return retrieve_verse_content(verse_range, client)
196
  else:
197
+ query_vector = huggingface_embeddings.embed_query(question)
198
  search_result = client.query_points(
199
  collection_name=collection_name,
200
  query=query_vector,
 
211
  ]
212
  return "No relevant documents found."
213
 
214
+ # RAG setup
215
  RAG_PROMPT = """\
216
  You are a helpful assistant who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge.
217
 
 
223
  """
224
  rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
225
 
 
 
 
 
226
  rate_limiter = InMemoryRateLimiter(
227
  requests_per_second=1,
228
  check_every_n_seconds=0.1,
 
230
  )
231
 
232
  chat_model = init_chat_model("gpt-4o-mini", rate_limiter=rate_limiter)
233
+ set_llm_cache(InMemoryCache())
234
 
235
  def create_retriever_runnable(collection_name: str, client: QdrantClient) -> RunnableLambda:
236
  return RunnableLambda(lambda question: retrieve_documents(question, collection_name, client))
 
266
  docs = retrieve_verse_content(verse_range, client)
267
  if isinstance(docs, str):
268
  return {"error": docs}
269
+
270
+ # Randomly select a subset of verses if the range has more than 3 verses
271
+ num_verses = len(docs)
272
+ if num_verses > 3:
273
+ subset_size = random.randint(1, 3)
274
+ start_idx = random.randint(0, num_verses - subset_size)
275
+ selected_docs = docs[start_idx : start_idx + subset_size]
276
+ else:
277
+ selected_docs = docs
278
+
279
  verse_content = "\n".join(
280
  f"{doc.metadata['book']} {doc.metadata['chapter']}:{doc.metadata['verse']} - {doc.page_content}"
281
+ for doc in selected_docs
282
  )
283
+
284
  quiz_prompt = ChatPromptTemplate.from_template(
285
  "Based on the following Bible verse(s), generate a multiple-choice quiz question with 4 options (A, B, C, D) "
286
  "and indicate the correct answer:\n\n"
 
294
  "Correct Answer: [Letter of correct answer]\n"
295
  "Explanation: [Brief explanation of why the answer is correct]\n"
296
  )
297
+
298
+ # Use a higher temperature for more diverse question generation
299
+ chat_model_with_temp = chat_model.bind(temperature=0.8)
300
+ response = (quiz_prompt | chat_model_with_temp).invoke({"verse_content": verse_content})
301
+
302
  response_text = response.content.strip()
303
  lines = response_text.split("\n")
304
  question = ""
 
316
  correct_answer = line[len("Correct Answer:"):].strip()
317
  elif line.startswith("Explanation:"):
318
  explanation = line[len("Explanation:"):].strip()
319
+
320
  return {
321
  "quiz_question": question,
322
  "options": options,
 
329
  generate_quiz_question_tool = partial(_generate_quiz_question, client=client)
330
 
331
  @tool
332
+ def quiz_question_generator(verse_range: str):
333
  """Generate a quiz question based on the content of the specified verse range."""
334
  quiz_data = generate_quiz_question_tool(verse_range)
335
  return json.dumps(quiz_data)
336
 
337
+ tool_belt = [ai_rag_tool, tavily_tool, quiz_question_generator]
338
 
339
  # LLM for agent reasoning
340
  llm = init_chat_model("gpt-4o", temperature=0, rate_limiter=rate_limiter)
341
  llm_with_tools = llm.bind_tools(tool_belt)
342
+ set_llm_cache(InMemoryCache())
343
 
344
  # Define the state
345
  class AgentState(TypedDict):
 
356
 
357
  - Use the 'ai_rag_tool' to answer questions about the Bible.
358
  - Use the 'tavily_tool' to search the internet for additional information.
359
+ - Use the 'quiz_question_generator' tool when the user requests to start a quiz on a specific verse range, such as 'start quiz on Genesis 1:1-10'.
360
 
361
+ When the user requests a quiz, extract the verse range from their message and pass it to the 'quiz_question_generator' tool.""")
362
 
363
 
364
  from typing import Optional
 
367
  from typing import Annotated
368
 
369
 
370
+ #Agent function
371
  def call_mode(state):
372
+ """
373
+ Manage the conversation flow of the Bible Study Tool, focusing on quiz mode and regular interactions.
374
+
375
+ This function determines the next action in the conversation based on the user's input and the current state.
376
+ It handles quiz mode (processing answers, continuing or ending the quiz) and transitions to or from regular
377
+ question-answering mode. It also processes tool calls, such as starting a quiz, and delegates non-quiz queries
378
+ to a language model.
379
+
380
+ Parameters:
381
+ - state (dict): The current state of the conversation, containing messages, quiz status, and other data.
382
+
383
+ Returns:
384
+ - dict: An updated state dictionary with new messages and modified quiz-related fields as needed.
385
+ """
386
  last_message = state["messages"][-1]
387
 
388
  if state.get("in_quiz", False):
 
416
  if user_input == "yes":
417
  # Generate a new quiz question
418
  verse_range = state["verse_range"]
419
+ quiz_data_str = quiz_question_generator(verse_range)
420
  quiz_data = json.loads(quiz_data_str)
421
  question = quiz_data["quiz_question"]
422
  options = "\n".join([f"{k}: {v}" for k, v in quiz_data["options"].items()])
 
479
  prev_message = state["messages"][-2]
480
  if isinstance(prev_message, AIMessage) and prev_message.tool_calls:
481
  tool_call = prev_message.tool_calls[0]
482
+ if tool_call["name"] == "quiz_question_generator":
483
  # Start the quiz
484
  quiz_data_str = last_message.content
485
  quiz_data = json.loads(quiz_data_str)