Spaces:
Sleeping
Sleeping
import gradio as gr | |
from langchain.chains import ( | |
ConversationalRetrievalChain, | |
LLMChain, | |
MapReduceDocumentsChain, | |
ReduceDocumentsChain, | |
StuffDocumentsChain, | |
) | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.memory import ConversationBufferMemory | |
from langchain.prompts import PromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
from langchain_community.chat_models import ChatOpenAI | |
from langchain_community.document_loaders import WebBaseLoader | |
def wait_for_summarization(url): | |
return [(None, f"Please wait while I summarize the contents of {url}...")] | |
def load_page(url, api_key, history): | |
global docs, summary, llm | |
loader = WebBaseLoader(url) | |
docs = loader.load() | |
llm = ChatOpenAI( | |
model_name="gpt-3.5-turbo-1106", temperature=0, openai_api_key=api_key | |
) | |
map_template = """The following is a set of snippets from a web page: | |
{docs} | |
Based on this list of snippets, please identify the main themes | |
Helpful Answer:""" | |
map_prompt = PromptTemplate.from_template(map_template) | |
map_chain = LLMChain(llm=llm, prompt=map_prompt) | |
# Reduce | |
reduce_template = """The following is set of summaries of a web page: | |
{docs} | |
Take these and distill it into a final, consolidated summary of the main themes. | |
Helpful Answer:""" | |
reduce_prompt = PromptTemplate.from_template(reduce_template) | |
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt) | |
# Takes a list of documents, combines them into a single string, and passes this to an LLMChain | |
combine_documents_chain = StuffDocumentsChain( | |
llm_chain=reduce_chain, document_variable_name="docs" | |
) | |
# Combines and iteratively reduces the mapped documents | |
reduce_documents_chain = ReduceDocumentsChain( | |
# This is final chain that is called. | |
combine_documents_chain=combine_documents_chain, | |
# If documents exceed context for `StuffDocumentsChain` | |
collapse_documents_chain=combine_documents_chain, | |
# The maximum number of tokens to group documents into. | |
token_max=4000, | |
) | |
# Combining documents by mapping a chain over them, then combining results | |
map_reduce_chain = MapReduceDocumentsChain( | |
# Map chain | |
llm_chain=map_chain, | |
# Reduce chain | |
reduce_documents_chain=reduce_documents_chain, | |
# The variable name in the llm_chain to put the documents in | |
document_variable_name="docs", | |
# Return the results of the map steps in the output | |
return_intermediate_steps=False, | |
) | |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
chunk_size=1000, chunk_overlap=0 | |
) | |
split_docs = text_splitter.split_documents(docs) | |
summary = map_reduce_chain.run(split_docs) | |
return history + [(None, summary)] | |
def prepare_chat(api_key, history): | |
global docs, summary, llm, qa | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=128) | |
documents = text_splitter.split_documents(docs) | |
embeddings = OpenAIEmbeddings(openai_api_key=api_key) | |
vectorstore = Chroma.from_documents(documents, embeddings) | |
retriever = vectorstore.as_retriever( | |
search_type="similarity", search_kwargs={"k": 6} | |
) | |
qa_prompt_template = ( | |
"""As an AI assistant you help in answering questions about the contents of a web page. | |
The summary of the current web page is this: | |
""" | |
+ summary | |
+ """ | |
Also, consider this additional context that may be relevant for the user's question: | |
{context} | |
Please answer following question: {question}""" | |
) | |
qa_prompt = PromptTemplate( | |
template=qa_prompt_template, input_variables=["context", "question"] | |
) | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", return_messages=True, output_key="answer" | |
) | |
qa = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
memory=memory, | |
retriever=retriever, | |
combine_docs_chain_kwargs={"prompt": qa_prompt}, | |
) | |
return history + [(None, "You can now ask me specific questions about the page.")] | |
def chatbot_function(message, history): | |
global qa | |
return "", history + [(message, qa.run(message))] | |
def build_demo(): | |
with gr.Blocks(theme=gr.themes.Default()) as demo: | |
with gr.Row() as config_row: | |
with gr.Column(): | |
api_key_box = gr.Textbox( | |
show_label=False, | |
placeholder="OpenAI API Key", | |
container=False, | |
autofocus=True, | |
) | |
url_box = gr.Textbox( | |
show_label=False, | |
placeholder="URL", | |
container=False, | |
) | |
load_btn = gr.Button(value="Load", variant="primary") | |
with gr.Row(visible=False) as chat_row: | |
with gr.Column(): | |
with gr.Row(): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label="Web Chat", | |
height=550, | |
) | |
with gr.Row(visible=False) as inputs_row: | |
with gr.Column(scale=8): | |
text_box = gr.Textbox( | |
show_label=False, | |
placeholder="Enter text and press ENTER", | |
autofocus=True, | |
container=False, | |
) | |
with gr.Column(scale=1, min_width=50): | |
submit_btn = gr.Button( | |
value="Send", | |
variant="primary", | |
) | |
load_btn.click( | |
lambda: gr.update(visible=False), | |
outputs=[config_row], | |
).then( | |
lambda: gr.update(visible=True), | |
outputs=[chat_row], | |
).then( | |
wait_for_summarization, | |
inputs=[url_box], | |
outputs=[chatbot], | |
).then( | |
load_page, | |
inputs=[url_box, api_key_box, chatbot], | |
outputs=[chatbot], | |
).then( | |
prepare_chat, | |
inputs=[api_key_box, chatbot], | |
outputs=[chatbot], | |
).then( | |
lambda: gr.update(visible=True), | |
outputs=[inputs_row], | |
) | |
text_box.submit( | |
chatbot_function, | |
[text_box, chatbot], | |
[text_box, chatbot], | |
) | |
submit_btn.click( | |
chatbot_function, | |
[text_box, chatbot], | |
[text_box, chatbot], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = build_demo() | |
demo.launch() | |