Spaces:
Sleeping
Sleeping
import os | |
import re | |
import json | |
import uuid | |
import random | |
from dotenv import load_dotenv | |
import chainlit as cl | |
from langchain.docstore.document import Document | |
from bs4 import BeautifulSoup | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import VectorParams, Distance | |
from qdrant_client.http.models import PointStruct | |
from langchain.storage import LocalFileStore | |
from langchain.embeddings import CacheBackedEmbeddings | |
from qdrant_client.http.models import Filter, FieldCondition, MatchValue, MatchAny | |
from langchain.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnableLambda | |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage | |
from langchain_core.tools import tool | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from functools import partial | |
from typing import Optional, TypedDict | |
from langchain_core.messages import AnyMessage | |
from langgraph.graph.message import add_messages | |
from typing import TypedDict, Annotated | |
from langgraph.prebuilt import ToolNode | |
from langgraph.graph import StateGraph, END | |
from langchain.chat_models import init_chat_model | |
from langchain_core.rate_limiters import InMemoryRateLimiter | |
from langchain_core.globals import set_llm_cache | |
from langchain_core.caches import InMemoryCache | |
# Load API Keys | |
load_dotenv() | |
os.environ["LANGCHAIN_PROJECT"] = f"AIE5- Bible Study Tool - {uuid.uuid4().hex[0:8]}" | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
print(os.environ["LANGCHAIN_PROJECT"]) | |
path = "data/" | |
book = "Genesis" | |
collection_name = "genesis_study" | |
# Load Genesis documents (unchanged from original) | |
def load_genesis_documents(path, book_name): | |
documents = [] | |
for file in os.listdir(path): | |
if file.endswith(".html"): | |
file_path = os.path.join(path, file) | |
with open(file_path, "r", encoding="utf-8") as f: | |
soup = BeautifulSoup(f, "html.parser") | |
p_tags = soup.find_all("p", align="left") | |
for p_tag in p_tags: | |
verse_texts = [content.strip() for content in p_tag.contents | |
if isinstance(content, str) and content.strip()] | |
for verse in verse_texts: | |
match = re.match(r"\[(\d+):(\d+)\]\s*(.*)", verse) | |
if match: | |
chapter = int(match.group(1)) | |
verse_num = int(match.group(2)) | |
text = match.group(3) | |
doc = Document( | |
page_content=text, | |
metadata={"book": book_name, "chapter": chapter, "verse": verse_num} | |
) | |
documents.append(doc) | |
return documents | |
documents = load_genesis_documents(path, book) | |
# Initialize embeddings | |
huggingface_embeddings = HuggingFaceEmbeddings(model_name="kcheng0816/finetuned_arctic_genesis") | |
dimension = len(huggingface_embeddings.embed_query("test")) | |
# Set up Qdrant client and collection | |
client = QdrantClient(":memory:") | |
client.create_collection( | |
collection_name=collection_name, | |
vectors_config=VectorParams(size=dimension, distance=Distance.COSINE) | |
) | |
# Generate and upload embeddings | |
embeddings = huggingface_embeddings.embed_documents([doc.page_content for doc in documents]) | |
points = [ | |
PointStruct( | |
id=str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{doc.metadata['chapter']}_{doc.metadata['verse']}")), | |
vector=embedding, | |
payload={ | |
"text": doc.page_content, | |
"book": doc.metadata["book"], | |
"chapter": doc.metadata["chapter"], | |
"verse": doc.metadata["verse"] | |
} | |
) | |
for embedding, doc in zip(embeddings, documents) | |
] | |
client.upsert(collection_name=collection_name, points=points) | |
# Cached embedder | |
#safe_namespace = "AIE5_BibleStudyTool" | |
#store = LocalFileStore("./cache/") | |
#cached_embedder = CacheBackedEmbeddings.from_bytes_store( | |
# huggingface_embeddings, store, namespace=safe_namespace, batch_size=32 | |
#) | |
# Retrieval functions | |
def parse_verse_reference(ref: str): | |
""" | |
Parse a verse reference string into book, chapter, and a list of verse numbers. | |
Args: | |
ref (str): The verse reference, e.g., "Genesis 1:1-10". | |
Returns: | |
tuple: (book, chapter, verses) where verses is a list of integers, or None if invalid. | |
""" | |
match = re.match(r"(\w+(?:\s\w+)?)\s(\d+):([\d,-]+)", ref) | |
if not match: | |
return None | |
book, chapter, verse_part = match.groups() | |
chapter = int(chapter) | |
verses = [] | |
for part in verse_part.split(','): | |
if '-' in part: | |
start, end = map(int, part.split('-')) | |
verses.extend(range(start, end + 1)) | |
else: | |
verses.append(int(part)) | |
return book, chapter, verses | |
def retrieve_verse_content(verse_range: str, client: QdrantClient): | |
""" | |
Retrieve Bible verses from Qdrant based on the specified verse range. | |
Parameters: | |
- verse_range (str): The verse range in the format "Book Chapter:Verses", e.g., "Genesis 1:1-5". | |
- client (QdrantClient): The Qdrant client to query the database. | |
Returns: | |
- list[Document]: A list of Document objects containing the verse text and metadata. | |
- str: An error message if the verse range is invalid or no verses are found. | |
""" | |
# Parse the verse range into book, chapter, and verses | |
parsed = parse_verse_reference(verse_range) | |
if not parsed: | |
return "Invalid verse range format." | |
book, chapter, verses = parsed | |
filter = Filter( | |
must=[ | |
FieldCondition(key="book", match=MatchValue(value=book)), | |
FieldCondition(key="chapter", match=MatchValue(value=chapter)), | |
FieldCondition(key="verse", match=MatchAny(any=verses)) | |
] | |
) | |
search_result = client.scroll( | |
collection_name=collection_name, | |
scroll_filter=filter, | |
limit=len(verses) | |
) | |
if not search_result[0]: | |
return "No verses found for the specified range." | |
sorted_points = sorted(search_result[0], key=lambda p: p.payload["verse"]) | |
docs = [ | |
Document( | |
page_content=p.payload["text"], | |
metadata=p.payload | |
) | |
for p in sorted_points | |
] | |
return docs | |
def retrieve_documents(question: str, collection_name: str, client: QdrantClient): | |
""" | |
Retrieve documents from a Qdrant collection based on the input question. | |
This function first checks if the question contains a specific Bible verse reference | |
(e.g., "Genesis 1:1-5"). If a reference is found, it retrieves the exact verses using | |
`retrieve_verse_content`. If no reference is found, it performs a semantic search | |
using embeddings to find the most relevant documents. | |
Parameters: | |
- question (str): The input question or query string. | |
- collection_name (str): The name of the Qdrant collection to search in. | |
- client (QdrantClient): The Qdrant client object used to interact with the database. | |
Returns: | |
- list[Document]: A list of Document objects containing the relevant verse text and metadata. | |
- str: An error message if no relevant documents are found or if the verse reference is invalid. | |
""" | |
reference_match = re.search(r"(\w+)\s?(\d+):\s?([\d,-]+)", question) | |
if reference_match: | |
verse_range = reference_match.group(1) + ' ' + reference_match.group(2) + ':' + reference_match.group(3) | |
return retrieve_verse_content(verse_range, client) | |
else: | |
query_vector = huggingface_embeddings.embed_query(question) | |
search_result = client.query_points( | |
collection_name=collection_name, | |
query=query_vector, | |
limit=5, | |
with_payload=True | |
).points | |
if search_result: | |
return [ | |
Document( | |
page_content=point.payload["text"], | |
metadata=point.payload | |
) | |
for point in search_result | |
] | |
return "No relevant documents found." | |
# RAG setup | |
RAG_PROMPT = """\ | |
You are a helpful assistant who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge. | |
### Question | |
{question} | |
### Context | |
{context} | |
""" | |
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) | |
rate_limiter = InMemoryRateLimiter( | |
requests_per_second=1, | |
check_every_n_seconds=0.1, | |
max_bucket_size=10, | |
) | |
chat_model = init_chat_model("gpt-4o-mini", rate_limiter=rate_limiter) | |
set_llm_cache(InMemoryCache()) | |
def create_retriever_runnable(collection_name: str, client: QdrantClient) -> RunnableLambda: | |
return RunnableLambda(lambda question: retrieve_documents(question, collection_name, client)) | |
retrieval_runnable = create_retriever_runnable(collection_name, client) | |
def format_docs(docs): | |
if isinstance(docs, str): | |
return docs | |
return "\n\n".join(f"Genesis {doc.metadata['chapter']}:{doc.metadata['verse']} - {doc.page_content}" for doc in docs) | |
rag_chain = ( | |
{"context": retrieval_runnable | RunnableLambda(format_docs), "question": RunnablePassthrough()} | |
| RunnablePassthrough.assign(response=rag_prompt | chat_model | StrOutputParser()) | |
) | |
# Tools | |
def format_contexts(docs): | |
return "\n\n".join(docs) if isinstance(docs, list) else docs | |
def ai_rag_tool(question: str): | |
"""Useful for when you need to answer questions about Bible""" | |
response = rag_chain.invoke(question) | |
return { | |
"message": [HumanMessage(content=response["response"])], | |
"context": format_contexts(response["context"]) | |
} | |
tavily_tool = TavilySearchResults(max_results=5) | |
def _generate_quiz_question(verse_range: str, client: QdrantClient): | |
docs = retrieve_verse_content(verse_range, client) | |
if isinstance(docs, str): | |
return {"error": docs} | |
# Randomly select a subset of verses if the range has more than 3 verses | |
num_verses = len(docs) | |
if num_verses > 3: | |
subset_size = random.randint(1, 3) | |
start_idx = random.randint(0, num_verses - subset_size) | |
selected_docs = docs[start_idx : start_idx + subset_size] | |
else: | |
selected_docs = docs | |
verse_content = "\n".join( | |
f"{doc.metadata['book']} {doc.metadata['chapter']}:{doc.metadata['verse']} - {doc.page_content}" | |
for doc in selected_docs | |
) | |
quiz_prompt = ChatPromptTemplate.from_template( | |
"Based on the following Bible verse(s), generate a multiple-choice quiz question with 4 options (A, B, C, D) " | |
"and indicate the correct answer:\n\n" | |
"{verse_content}\n\n" | |
"Format your response as follows:\n" | |
"Question: [Your question here]\n" | |
"A: [Option A]\n" | |
"B: [Option B]\n" | |
"C: [Option C]\n" | |
"D: [Option D]\n" | |
"Correct Answer: [Letter of correct answer]\n" | |
"Explanation: [Brief explanation of why the answer is correct]\n" | |
) | |
# Use a higher temperature for more diverse question generation | |
chat_model_with_temp = chat_model.bind(temperature=0.8) | |
response = (quiz_prompt | chat_model_with_temp).invoke({"verse_content": verse_content}) | |
response_text = response.content.strip() | |
lines = response_text.split("\n") | |
question = "" | |
options = {} | |
correct_answer = "" | |
explanation = "" | |
for line in lines: | |
line = line.strip() | |
if line.startswith("Question:"): | |
question = line[len("Question:"):].strip() | |
elif line.startswith(("A:", "B:", "C:", "D:")): | |
key, value = line.split(":", 1) | |
options[key.strip()] = value.strip() | |
elif line.startswith("Correct Answer:"): | |
correct_answer = line[len("Correct Answer:"):].strip() | |
elif line.startswith("Explanation:"): | |
explanation = line[len("Explanation:"):].strip() | |
return { | |
"quiz_question": question, | |
"options": options, | |
"correct_answer": correct_answer, | |
"explanation": explanation, | |
"verse_range": verse_range, | |
"verse_content": verse_content | |
} | |
generate_quiz_question_tool = partial(_generate_quiz_question, client=client) | |
def quiz_question_generator(verse_range: str): | |
"""Generate a quiz question based on the content of the specified verse range.""" | |
quiz_data = generate_quiz_question_tool(verse_range) | |
return json.dumps(quiz_data) | |
tool_belt = [ai_rag_tool, tavily_tool, quiz_question_generator] | |
# LLM for agent reasoning | |
llm = init_chat_model("gpt-4o", temperature=0, rate_limiter=rate_limiter) | |
llm_with_tools = llm.bind_tools(tool_belt) | |
set_llm_cache(InMemoryCache()) | |
# Define the state | |
class AgentState(TypedDict): | |
messages: Annotated[list[AnyMessage], add_messages] | |
in_quiz: bool | |
quiz_question: Optional[dict] | |
verse_range: Optional[str] | |
quiz_score: int | |
quiz_total: int | |
waiting_for_answer: bool | |
# System message | |
system_message = SystemMessage(content="""You are a Bible study assistant. You can answer questions about the Bible, search the internet for related information, or generate quiz questions based on specific verse ranges. | |
- Use the 'ai_rag_tool' to answer questions about the Bible. | |
- Use the 'tavily_tool' to search the internet for additional information. | |
- Use the 'quiz_question_generator' tool when the user requests to start a quiz on a specific verse range, such as 'start quiz on Genesis 1:1-10'. | |
When the user requests a quiz, extract the verse range from their message and pass it to the 'quiz_question_generator' tool.""") | |
from typing import Optional | |
from typing_extensions import TypedDict | |
from langgraph.graph.message import AnyMessage, add_messages | |
from typing import Annotated | |
#Agent function | |
def call_mode(state): | |
""" | |
Manage the conversation flow of the Bible Study Tool, focusing on quiz mode and regular interactions. | |
This function determines the next action in the conversation based on the user's input and the current state. | |
It handles quiz mode (processing answers, continuing or ending the quiz) and transitions to or from regular | |
question-answering mode. It also processes tool calls, such as starting a quiz, and delegates non-quiz queries | |
to a language model. | |
Parameters: | |
- state (dict): The current state of the conversation, containing messages, quiz status, and other data. | |
Returns: | |
- dict: An updated state dictionary with new messages and modified quiz-related fields as needed. | |
""" | |
last_message = state["messages"][-1] | |
if state.get("in_quiz", False): | |
if state.get("waiting_for_answer", False): | |
# Process the user's answer | |
quiz_data = state["quiz_question"] | |
user_answer = last_message.content.strip().upper() | |
correct_answer = quiz_data["correct_answer"] | |
new_quiz_total = state["quiz_total"] + 1 | |
if user_answer == correct_answer: | |
new_quiz_score = state["quiz_score"] + 1 | |
feedback = f"Correct! {quiz_data['explanation']}" | |
else: | |
new_quiz_score = state["quiz_score"] | |
feedback = f"Incorrect. The correct answer is {correct_answer}. {quiz_data['explanation']}" | |
return { | |
"messages": [ | |
AIMessage(content=feedback), | |
AIMessage(content="Would you like another question? Type 'Yes' to continue or 'No' to end the quiz.") | |
], | |
"quiz_total": new_quiz_total, | |
"quiz_score": new_quiz_score, | |
"waiting_for_answer": False, | |
"quiz_question": state["quiz_question"], | |
"in_quiz": True, | |
"verse_range": state["verse_range"] | |
} | |
else: | |
# Handle the user's decision to continue or stop the quiz | |
user_input = last_message.content.strip().lower() | |
if user_input == "yes": | |
# Generate a new quiz question | |
verse_range = state["verse_range"] | |
quiz_data_str = quiz_question_generator(verse_range) | |
quiz_data = json.loads(quiz_data_str) | |
question = quiz_data["quiz_question"] | |
options = "\n".join([f"{k}: {v}" for k, v in quiz_data["options"].items()]) | |
verse_content = quiz_data["verse_content"] | |
message_to_user = ( | |
f"Based on the following verse(s):\n\n{verse_content}\n\n" | |
f"Here's your quiz question:\n\n{question}\n\n{options}\n\n" | |
"Please select your answer (A, B, C, or D)." | |
) | |
return { | |
"messages": [AIMessage(content=message_to_user)], | |
"quiz_question": quiz_data, | |
"waiting_for_answer": True, | |
"quiz_total": state["quiz_total"], | |
"quiz_score": state["quiz_score"], | |
"in_quiz": True, | |
"verse_range": state["verse_range"] | |
} | |
elif user_input == "no": | |
# End the quiz and provide a summary | |
score = state["quiz_score"] | |
total = state["quiz_total"] | |
continue_message = "Ask me anything about Genesis or type 'start quiz on <verse range>' (e.g., 'start quiz on Genesis 1:1-5') for a trivia challenge." | |
if total > 0: | |
percentage = (score / total) * 100 | |
if percentage == 100: | |
feedback = "Excellent! You got all questions correct. Please continue your Bible study!" | |
elif percentage >= 80: | |
feedback = "Great job! You have a strong understanding. Please continue your Bible study!" | |
elif percentage >= 50: | |
feedback = "Good effort! Keep practicing to improve. Please continue your Bible study!" | |
else: | |
feedback = "Don’t worry, keep your Bible studying and you’ll get better!" | |
summary = f"You got {score} out of {total} questions correct. {feedback} \n\n {continue_message}" | |
else: | |
summary = "No questions were attempted." | |
return { | |
"messages": [AIMessage(content=summary)], | |
"in_quiz": False, | |
"quiz_question": None, | |
"verse_range": None, | |
"quiz_score": 0, | |
"quiz_total": 0, | |
"waiting_for_answer": False | |
} | |
else: | |
# Handle invalid input | |
return { | |
"messages": [AIMessage(content="Please type 'Yes' to continue or 'No' to end the quiz.")], | |
"quiz_total": state["quiz_total"], | |
"quiz_score": state["quiz_score"], | |
"waiting_for_answer": False, | |
"quiz_question": state["quiz_question"], | |
"in_quiz": True, | |
"verse_range": state["verse_range"] | |
} | |
# Handle starting the quiz or other tool calls | |
if len(state["messages"]) >= 2 and isinstance(last_message, ToolMessage): | |
prev_message = state["messages"][-2] | |
if isinstance(prev_message, AIMessage) and prev_message.tool_calls: | |
tool_call = prev_message.tool_calls[0] | |
if tool_call["name"] == "quiz_question_generator": | |
# Start the quiz | |
quiz_data_str = last_message.content | |
quiz_data = json.loads(quiz_data_str) | |
verse_range = quiz_data["verse_range"] | |
question = quiz_data["quiz_question"] | |
options = "\n".join([f"{k}: {v}" for k, v in quiz_data["options"].items()]) | |
verse_content = quiz_data["verse_content"] | |
message_to_user = ( | |
f"Based on the following verse(s):\n\n{verse_content}\n\n" | |
f"Here's your quiz question:\n\n{question}\n\n{options}\n\n" | |
"Please select your answer (A, B, C, or D)." | |
) | |
return { | |
"messages": [AIMessage(content=message_to_user)], | |
"in_quiz": True, | |
"verse_range": verse_range, | |
"quiz_score": 0, | |
"quiz_total": 0, | |
"quiz_question": quiz_data, | |
"waiting_for_answer": True | |
} | |
# Process regular questions or commands | |
messages = [system_message] + state["messages"] | |
response = llm_with_tools.invoke(messages) | |
return {"messages": [response]} | |
tool_node = ToolNode(tool_belt) | |
def should_continue(state): | |
last_message = state["messages"][-1] | |
if last_message.tool_calls: | |
return "action" | |
return END | |
# Build the graph | |
uncompiled_graph = StateGraph(AgentState) | |
uncompiled_graph.add_node("agent", call_mode) | |
uncompiled_graph.add_node("action", tool_node) | |
uncompiled_graph.set_entry_point("agent") | |
uncompiled_graph.add_conditional_edges("agent", should_continue) | |
uncompiled_graph.add_edge("action", "agent") | |
compiled_graph = uncompiled_graph.compile() | |
# Chainlit integration | |
import chainlit as cl | |
from langchain_core.messages import SystemMessage | |
async def start(): | |
system_message = SystemMessage(content="Welcome to the Bible Study Tool!") | |
initial_state = { | |
"messages": [system_message], | |
"in_quiz": False, | |
"quiz_question": None, | |
"verse_range": None, | |
"quiz_score": 0, | |
"quiz_total": 0, | |
"waiting_for_answer": False | |
} | |
cl.user_session.set("state", initial_state) | |
await cl.Message(content="Welcome to the Bible Study Tool! Ask me anything about Genesis or type 'start quiz on <verse range>' (e.g., 'start quiz on Genesis 1:1-5') for a trivia challenge.").send() | |
async def main(message: cl.Message): | |
state = cl.user_session.get("state") | |
current_messages = len(state["messages"]) | |
state["messages"].append(HumanMessage(content=message.content)) | |
result = compiled_graph.invoke(state) | |
cl.user_session.set("state", result) | |
new_messages = result["messages"][current_messages + 1:] | |
for msg in new_messages: | |
if isinstance(msg, AIMessage): | |
await cl.Message(content=msg.content).send() |