Spaces:
Sleeping
Sleeping
import chainlit as cl | |
from langchain.retrievers import BM25Retriever, EnsembleRetriever | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.embeddings import CacheBackedEmbeddings | |
from langchain.storage import LocalFileStore | |
from langchain.agents.agent_toolkits import create_retriever_tool | |
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent | |
from langchain.document_loaders import WikipediaLoader, CSVLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.agents import Tool | |
from langchain.agents import ZeroShotAgent, AgentExecutor | |
from langchain.chat_models import ChatOpenAI | |
from langchain import LLMChain | |
import os | |
os.environ["OPENAI_API_KEY"] = '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' | |
def rename(orig_author: str): | |
rename_dict = {"RetrievalQA": "Consulting The Barbenheimer"} | |
return rename_dict.get(orig_author, orig_author) | |
async def init(): | |
msg = cl.Message(content=f"Building Index...") | |
await msg.send() | |
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature = 0) | |
# set up text splitters | |
wikipedia_text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size = 1024, | |
chunk_overlap = 512, | |
length_function = len, | |
is_separator_regex= False, | |
separators = ["\n==", "\n", " "] # keep headings, then paragraphs, then sentences | |
) | |
csv_text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size = 1024, | |
chunk_overlap = 512, | |
length_function = len, | |
is_separator_regex= False, | |
separators = ["\n", " "] # keep paragraphs, then sentences | |
) | |
# set up cached embeddings store | |
store = LocalFileStore("./.cache/") | |
core_embeddings_model = OpenAIEmbeddings() | |
embedder = CacheBackedEmbeddings.from_bytes_store(core_embeddings_model, | |
store, | |
namespace=core_embeddings_model.model) | |
# Barbie retrieval system (Wikipedia, CSV) | |
# load the multiple source documents for Barbie and build FAISS index | |
barbie_wikipedia_docs = WikipediaLoader( | |
query="Barbie (film)", | |
load_max_docs= 1, # YOUR CODE HERE, | |
doc_content_chars_max=10000000 | |
).load() | |
barbie_csv_docs = CSVLoader( | |
file_path= "./barbie_data/barbie.csv", | |
source_column="Review" | |
).load() | |
# chunk the loaded documents using the text splitters | |
chunked_barbie_wikipedia_docs = wikipedia_text_splitter.transform_documents(barbie_wikipedia_docs) | |
chunked_barbie_csv_docs = csv_text_splitter.transform_documents(barbie_csv_docs) | |
# set up FAISS vector store and create retriever for CSV docs | |
barbie_csv_faiss_retriever = FAISS.from_documents(chunked_barbie_csv_docs, embedder) | |
# set up BM25 retriever | |
barbie_wikipedia_bm25_retriever = BM25Retriever.from_documents( | |
chunked_barbie_wikipedia_docs | |
) | |
barbie_wikipedia_bm25_retriever.k = 1 | |
# set up FAISS vector store and create retriever | |
barbie_wikipedia_faiss_store = FAISS.from_documents( | |
chunked_barbie_wikipedia_docs, | |
embedder | |
) | |
barbie_wikipedia_faiss_retriever = barbie_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1}) | |
# set up ensemble retriever | |
barbie_ensemble_retriever = EnsembleRetriever( | |
retrievers=[barbie_wikipedia_bm25_retriever, barbie_wikipedia_faiss_retriever], | |
weights= [0.25, 0.75] # should sum to 1 | |
) | |
# create retriever tools | |
barbie_wikipedia_retrieval_tool = create_retriever_tool( | |
retriever=barbie_ensemble_retriever, | |
name='Search_Wikipedia', | |
description='Useful for when you need to answer questions about plot, cast, production, release, music, marketing, reception, themes and analysis of the Barbie movie.' | |
) | |
barbie_csv_retrieval_tool = create_retriever_tool( | |
retriever=barbie_csv_faiss_retriever.as_retriever(), | |
name='Search_Reviews', | |
description='Useful for when you need to answer questions about public reviews of the Barbie movie.' | |
) | |
barbie_retriever_tools = [barbie_wikipedia_retrieval_tool, barbie_csv_retrieval_tool] | |
# retrieval agent | |
barbie_retriever_agent_executor = create_conversational_retrieval_agent(llm=llm, tools=barbie_retriever_tools, verbose=True) | |
# Oppenheimer retrieval system (Wikipedia, CSV) | |
# load the multiple source documents for Oppenheimer and build FAISS index | |
oppenheimer_wikipedia_docs = WikipediaLoader( | |
query="Oppenheimer", | |
load_max_docs=1, | |
doc_content_chars_max=10000000 | |
).load() | |
oppenheimer_csv_docs = CSVLoader( | |
file_path="./oppenheimer_data/oppenheimer.csv", | |
source_column="Review" | |
).load() | |
# chunk the loaded documents using the text splitters | |
chunked_opp_wikipedia_docs = wikipedia_text_splitter.transform_documents(oppenheimer_wikipedia_docs) | |
chunked_opp_csv_docs = csv_text_splitter.transform_documents(oppenheimer_csv_docs) | |
# set up FAISS vector store and create retriever for CSV docs | |
opp_csv_faiss_retriever = FAISS.from_documents(chunked_opp_csv_docs, embedder).as_retriever() | |
# set up BM25 retriever | |
opp_wikipedia_bm25_retriever = BM25Retriever.from_documents(chunked_opp_wikipedia_docs) | |
opp_wikipedia_bm25_retriever.k = 1 | |
# set up FAISS vector store and create retriever | |
opp_wikipedia_faiss_store = FAISS.from_documents( | |
chunked_opp_wikipedia_docs, | |
embedder | |
) | |
opp_wikipedia_faiss_retriever = opp_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1}) | |
# set up ensemble retriever | |
opp_ensemble_retriever = EnsembleRetriever( | |
retrievers=[opp_wikipedia_bm25_retriever, opp_wikipedia_faiss_retriever], | |
weights= [0.25, 0.75] # should sum to 1 | |
) | |
# setup prompt | |
system_message = """Use the information from the below two sources to answer any questions. | |
Source 1: public user reviews about the Oppenheimer movie | |
<source1> | |
{source1} | |
</source1> | |
Source 2: the wikipedia page for the Oppenheimer movie including the plot summary, cast, and production information | |
<source2> | |
{source2} | |
</source2> | |
""" | |
prompt = ChatPromptTemplate.from_messages([("system", system_message), ("human", "{question}")]) | |
# build multi-source chain | |
oppenheimer_multisource_chain = { | |
"source1": (lambda x: x["question"]) | opp_ensemble_retriever, | |
"source2": (lambda x: x["question"]) | opp_csv_faiss_retriever, | |
"question": lambda x: x["question"], | |
} | prompt | llm | |
# Agent creation | |
# set up tools | |
def query_barbie(input): | |
return barbie_retriever_agent_executor({"input" : input}) | |
def query_oppenheimer(input): | |
return oppenheimer_multisource_chain.invoke({"question" : input}) | |
tools = [ | |
Tool( | |
name="BarbieInfo", | |
func=query_barbie, | |
description='Useful when you need to answer questions about the Barbie movie' | |
), | |
Tool( | |
name="OppenheimerInfo", | |
func=query_oppenheimer, | |
description='Useful when you need to answer questions about the Oppenheimer movie' | |
), | |
] | |
# create prompt | |
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:""" | |
suffix = """Begin!" | |
Question: {input} | |
{agent_scratchpad}""" | |
prompt = ZeroShotAgent.create_prompt( | |
tools=tools, | |
prefix=prefix, | |
suffix=suffix, | |
input_variables=['input', 'agent_scratchpad'] | |
) | |
# chain llm with prompt | |
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=True) | |
# create reasoning agent | |
barbenheimer_agent = ZeroShotAgent( | |
llm_chain=llm_chain, | |
tools=tools, | |
verbose=True ) | |
# create execution agent | |
barbenheimer_agent_chain = await cl.make_async(AgentExecutor.from_agent_and_tools)( | |
agent=barbenheimer_agent, | |
tools=tools, | |
verbose=True ) | |
cl.user_session.set("chain", barbenheimer_agent_chain) | |
msg.content = f"Agent ready!" | |
await msg.send() | |
async def main(message): | |
# msg = cl.Message(content=f"Thinking...") | |
# await msg.send() | |
chain = cl.user_session.get("chain") | |
cb = cl.AsyncLangchainCallbackHandler( | |
stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"] | |
) | |
cb.answer_reached = True | |
res = await chain.acall(message, callbacks=[cb], ) | |
# msg = cl.Message(content=f"{res}") | |
# await msg.send() | |
answer = res["result"] | |
source_elements = [] | |
visited_sources = set() | |
# Get the documents from the user session | |
docs = res["source_documents"] | |
metadatas = [doc.metadata for doc in docs] | |
all_sources = [m["source"] for m in metadatas] | |
for source in all_sources: | |
if source in visited_sources: | |
continue | |
visited_sources.add(source) | |
# Create the text element referenced in the message | |
source_elements.append( | |
cl.Text(content="https://www.imdb.com" + source, name="Review URL") | |
) | |
if source_elements: | |
answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}" | |
else: | |
answer += "\nNo sources found" | |
await cl.Message(content=answer, elements=source_elements).send() | |