import gradio as gr from langchain.chains import ( ConversationalRetrievalChain, LLMChain, MapReduceDocumentsChain, ReduceDocumentsChain, StuffDocumentsChain, ) from langchain.embeddings import OpenAIEmbeddings from langchain.memory import ConversationBufferMemory from langchain.prompts import PromptTemplate from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from langchain_community.chat_models import ChatOpenAI from langchain_community.document_loaders import WebBaseLoader def wait_for_summarization(url): return [(None, f"Please wait while I summarize the contents of {url}...")] def load_page(url, api_key, history): global docs, summary, llm loader = WebBaseLoader(url) docs = loader.load() llm = ChatOpenAI( model_name="gpt-3.5-turbo-1106", temperature=0, openai_api_key=api_key ) map_template = """The following is a set of snippets from a web page: {docs} Based on this list of snippets, please identify the main themes Helpful Answer:""" map_prompt = PromptTemplate.from_template(map_template) map_chain = LLMChain(llm=llm, prompt=map_prompt) # Reduce reduce_template = """The following is set of summaries of a web page: {docs} Take these and distill it into a final, consolidated summary of the main themes. Helpful Answer:""" reduce_prompt = PromptTemplate.from_template(reduce_template) reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt) # Takes a list of documents, combines them into a single string, and passes this to an LLMChain combine_documents_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name="docs" ) # Combines and iteratively reduces the mapped documents reduce_documents_chain = ReduceDocumentsChain( # This is final chain that is called. combine_documents_chain=combine_documents_chain, # If documents exceed context for `StuffDocumentsChain` collapse_documents_chain=combine_documents_chain, # The maximum number of tokens to group documents into. token_max=4000, ) # Combining documents by mapping a chain over them, then combining results map_reduce_chain = MapReduceDocumentsChain( # Map chain llm_chain=map_chain, # Reduce chain reduce_documents_chain=reduce_documents_chain, # The variable name in the llm_chain to put the documents in document_variable_name="docs", # Return the results of the map steps in the output return_intermediate_steps=False, ) text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=1000, chunk_overlap=0 ) split_docs = text_splitter.split_documents(docs) summary = map_reduce_chain.run(split_docs) return history + [(None, summary)] def prepare_chat(api_key, history): global docs, summary, llm, qa text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=128) documents = text_splitter.split_documents(docs) embeddings = OpenAIEmbeddings(openai_api_key=api_key) vectorstore = Chroma.from_documents(documents, embeddings) retriever = vectorstore.as_retriever( search_type="similarity", search_kwargs={"k": 6} ) qa_prompt_template = ( """As an AI assistant you help in answering questions about the contents of a web page. The summary of the current web page is this: """ + summary + """ Also, consider this additional context that may be relevant for the user's question: {context} Please answer following question: {question}""" ) qa_prompt = PromptTemplate( template=qa_prompt_template, input_variables=["context", "question"] ) memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True, output_key="answer" ) qa = ConversationalRetrievalChain.from_llm( llm=llm, memory=memory, retriever=retriever, combine_docs_chain_kwargs={"prompt": qa_prompt}, ) return history + [(None, "You can now ask me specific questions about the page.")] def chatbot_function(message, history): global qa return "", history + [(message, qa.run(message))] def build_demo(): with gr.Blocks(theme=gr.themes.Default()) as demo: with gr.Row() as config_row: with gr.Column(): api_key_box = gr.Textbox( show_label=False, placeholder="OpenAI API Key", container=False, autofocus=True, ) url_box = gr.Textbox( show_label=False, placeholder="URL", container=False, ) load_btn = gr.Button(value="Load", variant="primary") with gr.Row(visible=False) as chat_row: with gr.Column(): with gr.Row(): chatbot = gr.Chatbot( elem_id="chatbot", label="Web Chat", height=550, ) with gr.Row(visible=False) as inputs_row: with gr.Column(scale=8): text_box = gr.Textbox( show_label=False, placeholder="Enter text and press ENTER", autofocus=True, container=False, ) with gr.Column(scale=1, min_width=50): submit_btn = gr.Button( value="Send", variant="primary", ) load_btn.click( lambda: gr.update(visible=False), outputs=[config_row], ).then( lambda: gr.update(visible=True), outputs=[chat_row], ).then( wait_for_summarization, inputs=[url_box], outputs=[chatbot], ).then( load_page, inputs=[url_box, api_key_box, chatbot], outputs=[chatbot], ).then( prepare_chat, inputs=[api_key_box, chatbot], outputs=[chatbot], ).then( lambda: gr.update(visible=True), outputs=[inputs_row], ) text_box.submit( chatbot_function, [text_box, chatbot], [text_box, chatbot], ) submit_btn.click( chatbot_function, [text_box, chatbot], [text_box, chatbot], ) return demo if __name__ == "__main__": demo = build_demo() demo.launch()