import chainlit as cl
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.vectorstores import FAISS
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage import LocalFileStore
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
from langchain.document_loaders import WikipediaLoader, CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import ChatPromptTemplate
from langchain.agents import Tool
from langchain.agents import ZeroShotAgent, AgentExecutor
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
@cl.author_rename
def rename(orig_author: str):
rename_dict = {"RetrievalQA": "Consulting The Barbenheimer"}
return rename_dict.get(orig_author, orig_author)
@cl.on_chat_start
async def init():
msg = cl.Message(content=f"Building Index...")
await msg.send()
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature = 0)
# set up text splitters
wikipedia_text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 1024,
chunk_overlap = 512,
length_function = len,
is_separator_regex= False,
separators = ["\n==", "\n", " "] # keep headings, then paragraphs, then sentences
)
csv_text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 1024,
chunk_overlap = 512,
length_function = len,
is_separator_regex= False,
separators = ["\n", " "] # keep paragraphs, then sentences
)
# set up cached embeddings store
store = LocalFileStore("./.cache/")
core_embeddings_model = OpenAIEmbeddings()
embedder = CacheBackedEmbeddings.from_bytes_store(core_embeddings_model,
store,
namespace=core_embeddings_model.model)
# Barbie retrieval system (Wikipedia, CSV)
# load the multiple source documents for Barbie and build FAISS index
barbie_wikipedia_docs = WikipediaLoader(
query="Barbie (film)",
load_max_docs= 1, # YOUR CODE HERE,
doc_content_chars_max=10000000
).load()
barbie_csv_docs = CSVLoader(
file_path= "./barbie_data/barbie.csv",
source_column="Review"
).load()
# chunk the loaded documents using the text splitters
chunked_barbie_wikipedia_docs = wikipedia_text_splitter.transform_documents(barbie_wikipedia_docs)
chunked_barbie_csv_docs = csv_text_splitter.transform_documents(barbie_csv_docs)
# set up FAISS vector store and create retriever for CSV docs
barbie_csv_faiss_retriever = FAISS.from_documents(chunked_barbie_csv_docs, embedder)
# set up BM25 retriever
barbie_wikipedia_bm25_retriever = BM25Retriever.from_documents(
chunked_barbie_wikipedia_docs
)
barbie_wikipedia_bm25_retriever.k = 1
# set up FAISS vector store and create retriever
barbie_wikipedia_faiss_store = FAISS.from_documents(
chunked_barbie_wikipedia_docs,
embedder
)
barbie_wikipedia_faiss_retriever = barbie_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1})
# set up ensemble retriever
barbie_ensemble_retriever = EnsembleRetriever(
retrievers=[barbie_wikipedia_bm25_retriever, barbie_wikipedia_faiss_retriever],
weights= [0.25, 0.75] # should sum to 1
)
# create retriever tools
barbie_wikipedia_retrieval_tool = create_retriever_tool(
retriever=barbie_ensemble_retriever,
name='Search_Wikipedia',
description='Useful for when you need to answer questions about plot, cast, production, release, music, marketing, reception, themes and analysis of the Barbie movie.'
)
barbie_csv_retrieval_tool = create_retriever_tool(
retriever=barbie_csv_faiss_retriever.as_retriever(),
name='Search_Reviews',
description='Useful for when you need to answer questions about public reviews of the Barbie movie.'
)
barbie_retriever_tools = [barbie_wikipedia_retrieval_tool, barbie_csv_retrieval_tool]
# retrieval agent
barbie_retriever_agent_executor = create_conversational_retrieval_agent(llm=llm, tools=barbie_retriever_tools, verbose=True)
# Oppenheimer retrieval system (Wikipedia, CSV)
# load the multiple source documents for Oppenheimer and build FAISS index
oppenheimer_wikipedia_docs = WikipediaLoader(
query="Oppenheimer",
load_max_docs=1,
doc_content_chars_max=10000000
).load()
oppenheimer_csv_docs = CSVLoader(
file_path="./oppenheimer_data/oppenheimer.csv",
source_column="Review"
).load()
# chunk the loaded documents using the text splitters
chunked_opp_wikipedia_docs = wikipedia_text_splitter.transform_documents(oppenheimer_wikipedia_docs)
chunked_opp_csv_docs = csv_text_splitter.transform_documents(oppenheimer_csv_docs)
# set up FAISS vector store and create retriever for CSV docs
opp_csv_faiss_retriever = FAISS.from_documents(chunked_opp_csv_docs, embedder).as_retriever()
# set up BM25 retriever
opp_wikipedia_bm25_retriever = BM25Retriever.from_documents(chunked_opp_wikipedia_docs)
opp_wikipedia_bm25_retriever.k = 1
# set up FAISS vector store and create retriever
opp_wikipedia_faiss_store = FAISS.from_documents(
chunked_opp_wikipedia_docs,
embedder
)
opp_wikipedia_faiss_retriever = opp_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1})
# set up ensemble retriever
opp_ensemble_retriever = EnsembleRetriever(
retrievers=[opp_wikipedia_bm25_retriever, opp_wikipedia_faiss_retriever],
weights= [0.25, 0.75] # should sum to 1
)
# setup prompt
system_message = """Use the information from the below two sources to answer any questions.
Source 1: public user reviews about the Oppenheimer movie
{source1}
Source 2: the wikipedia page for the Oppenheimer movie including the plot summary, cast, and production information
{source2}
"""
prompt = ChatPromptTemplate.from_messages([("system", system_message), ("human", "{question}")])
# build multi-source chain
oppenheimer_multisource_chain = {
"source1": (lambda x: x["question"]) | opp_ensemble_retriever,
"source2": (lambda x: x["question"]) | opp_csv_faiss_retriever,
"question": lambda x: x["question"],
} | prompt | llm
# Agent creation
# set up tools
def query_barbie(input):
return barbie_retriever_agent_executor({"input" : input})
def query_oppenheimer(input):
return oppenheimer_multisource_chain.invoke({"question" : input})
tools = [
Tool(
name="BarbieInfo",
func=query_barbie,
description='Useful when you need to answer questions about the Barbie movie'
),
Tool(
name="OppenheimerInfo",
func=query_oppenheimer,
description='Useful when you need to answer questions about the Oppenheimer movie'
),
]
# create prompt
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
suffix = """Begin!"
Question: {input}
{agent_scratchpad}"""
prompt = ZeroShotAgent.create_prompt(
tools=tools,
prefix=prefix,
suffix=suffix,
input_variables=['input', 'agent_scratchpad']
)
# chain llm with prompt
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=True)
# create reasoning agent
barbenheimer_agent = ZeroShotAgent(
llm_chain=llm_chain,
tools=tools,
verbose=True )
# create execution agent
barbenheimer_agent_chain = AgentExecutor.from_agent_and_tools(
agent=barbenheimer_agent,
tools=tools,
verbose=True )
cl.user_session.set("chain", barbenheimer_agent_chain)
msg.content = f"Agent ready!"
await msg.send()
@cl.on_message
async def main(message):
# msg = cl.Message(content=f"Thinking...")
# await msg.send()
chain = cl.user_session.get("chain")
cb = cl.LangchainCallbackHandler(
stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
res = chain.__call__(message, callbacks=[cb], )
# print(res.keys()) # keys are "input" and "output"
answer = res["output"]
source_elements = []
# visited_sources = set()
# # Get the documents from the user session
# docs = res["source_documents"]
# metadatas = [doc.metadata for doc in docs]
# all_sources = [m["source"] for m in metadatas]
# for source in all_sources:
# if source in visited_sources:
# continue
# visited_sources.add(source)
# # Create the text element referenced in the message
# source_elements.append(
# cl.Text(content="https://www.imdb.com" + source, name="Review URL")
# )
# if source_elements:
# answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
# else:
# answer += "\nNo sources found"
await cl.Message(content=answer, elements=source_elements).send()