"""This is a public module. It should have a docstring.""" import itertools import os import random from typing import Any, List, Tuple import streamlit as st from langchain.agents import AgentExecutor, OpenAIFunctionsAgent from langchain.agents.agent_toolkits import create_retriever_tool from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( AgentTokenBufferMemory, ) from langchain.callbacks import StreamlitCallbackHandler from langchain.chains import QAGenerationChain from langchain.chat_models import ChatOpenAI from langchain.document_loaders import PyPDFLoader from langchain.embeddings import HuggingFaceEmbeddings from langchain.prompts import MessagesPlaceholder from langchain.schema import AIMessage, HumanMessage, SystemMessage from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import FAISS st.set_page_config(page_title="InQuest", page_icon="📚") starter_message = "Ask me anything about the Doc!" @st.cache_resource def create_prompt(openai_api_key: str) -> Tuple[SystemMessage, ChatOpenAI]: """Create prompt.""" # Make your OpenAI API request here llm = ChatOpenAI( temperature=0, model_name="gpt-3.5-turbo", streaming=True, openai_api_key=openai_api_key, ) message = SystemMessage( content=( "You are a helpful chatbot who is tasked with answering questions about context given through uploaded documents." # noqa: E501 comment "Unless otherwise explicitly stated, it is probably fair to assume that questions are about the context given." # noqa: E501 comment "If there is any ambiguity, you probably assume they are about that." # noqa: E501 comment ) ) prompt = OpenAIFunctionsAgent.create_prompt( system_message=message, extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], ) return prompt, llm @st.cache_data def save_file_locally(file: Any) -> str: """Save uploaded files locally.""" doc_path = os.path.join("tempdir", file.name) with open(doc_path, "wb") as f: f.write(file.getbuffer()) return doc_path @st.cache_data def load_docs(files: List[Any], url: bool = False) -> str: """Load and process the uploaded PDF files.""" if not url: st.info("`Reading doc ...`") documents = [] for file in files: doc_path = save_file_locally(file) pages = PyPDFLoader(doc_path) documents.extend(pages.load()) return ",".join([doc.page_content for doc in documents]) @st.cache_data def gen_embeddings() -> HuggingFaceEmbeddings: """Generate embeddings for given model.""" embeddings = HuggingFaceEmbeddings( cache_folder="hf_model" ) # https://github.com/UKPLab/sentence-transformers/issues/1828 return embeddings @st.cache_resource def process_corpus(corpus: str, chunk_size: int = 1000, overlap: int = 50) -> List: """Process text for Semantic Search.""" text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=overlap ) texts = text_splitter.split_text(corpus) # Display the number of text chunks num_chunks = len(texts) st.write(f"Number of text chunks: {num_chunks}") # select embedding model embeddings = gen_embeddings() # create vectorstore vectorstore = FAISS.from_texts(texts, embeddings).as_retriever( search_kwargs={"k": 4} ) # create retriever tool tool = create_retriever_tool( vectorstore, "search_docs", "Searches and returns documents using the context provided as a source, relevant to the user input question.", # noqa: E501 comment ) tools = [tool] return tools @st.cache_data def generate_agent_executer(text: str) -> List[AgentExecutor]: """Generate the memory functionality.""" tools = process_corpus(text) agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) # Synthwave agent_executor = AgentExecutor( agent=agent, tools=tools, verbose=True, return_intermediate_steps=True, ) return agent_executor @st.cache_data def generate_eval(raw_text: str, N: int, chunk: int) -> List: """Generate the focusing functionality.""" # Generate N questions from context of chunk chars # IN: text, N questions, chunk size to draw question from in the doc # OUT: eval set as JSON list # raw_text = ','.join(raw_text) update = st.empty() ques_update = st.empty() update.info("`Generating sample questions ...`") n = len(raw_text) starting_indices = [random.randint(0, n - chunk) for _ in range(N)] sub_sequences = [raw_text[i : i + chunk] for i in starting_indices] chain = QAGenerationChain.from_llm(llm) eval_set = [] for i, b in enumerate(sub_sequences): try: qa = chain.run(b) eval_set.append(qa) ques_update.info(f"Creating Question: {i+1}") except ValueError: st.warning(f"Error in generating Question: {i+1}...", icon="⚠️") continue eval_set_full = list(itertools.chain.from_iterable(eval_set)) update.empty() ques_update.empty() return eval_set_full @st.cache_resource() def gen_side_bar_qa(text: str) -> None: """Generate responses from query.""" if text: # Check if there are no generated question-answer pairs in the session state if "eval_set" not in st.session_state: # Use the generate_eval function to generate question-answer pairs num_eval_questions = 5 # Number of question-answer pairs to generate st.session_state.eval_set = generate_eval(text, num_eval_questions, 3000) # Display the question-answer pairs in the sidebar with smaller text for i, qa_pair in enumerate(st.session_state.eval_set): st.sidebar.markdown( f"""
{qa_pair['question']}
{qa_pair['answer']}