|
|
"""This is a public module. It should have a docstring.""" |
|
|
import itertools |
|
|
import os |
|
|
import random |
|
|
from typing import Any, List, Tuple |
|
|
|
|
|
import streamlit as st |
|
|
from langchain.agents import AgentExecutor, OpenAIFunctionsAgent |
|
|
from langchain.agents.agent_toolkits import create_retriever_tool |
|
|
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import ( |
|
|
AgentTokenBufferMemory, |
|
|
) |
|
|
from langchain.callbacks import StreamlitCallbackHandler |
|
|
from langchain.chains import QAGenerationChain |
|
|
from langchain.chat_models import ChatOpenAI |
|
|
from langchain.document_loaders import PyPDFLoader |
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
from langchain.prompts import MessagesPlaceholder |
|
|
from langchain.schema import AIMessage, HumanMessage, SystemMessage |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.vectorstores import FAISS |
|
|
|
|
|
st.set_page_config(page_title="InQuest", page_icon="๐") |
|
|
|
|
|
starter_message = "Ask me anything about the Doc!" |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def create_prompt(openai_api_key: str) -> Tuple[SystemMessage, ChatOpenAI]: |
|
|
"""Create prompt.""" |
|
|
|
|
|
llm = ChatOpenAI( |
|
|
temperature=0, |
|
|
model_name="gpt-3.5-turbo", |
|
|
streaming=True, |
|
|
openai_api_key=openai_api_key, |
|
|
) |
|
|
|
|
|
message = SystemMessage( |
|
|
content=( |
|
|
"You are a helpful chatbot who is tasked with answering questions about context given through uploaded documents." |
|
|
"Unless otherwise explicitly stated, it is probably fair to assume that questions are about the context given." |
|
|
"If there is any ambiguity, you probably assume they are about that." |
|
|
) |
|
|
) |
|
|
|
|
|
prompt = OpenAIFunctionsAgent.create_prompt( |
|
|
system_message=message, |
|
|
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")], |
|
|
) |
|
|
|
|
|
return prompt, llm |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def save_file_locally(file: Any) -> str: |
|
|
"""Save uploaded files locally.""" |
|
|
doc_path = os.path.join("tempdir", file.name) |
|
|
with open(doc_path, "wb") as f: |
|
|
f.write(file.getbuffer()) |
|
|
|
|
|
return doc_path |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def load_docs(files: List[Any], url: bool = False) -> str: |
|
|
"""Load and process the uploaded PDF files.""" |
|
|
if not url: |
|
|
st.info("`Reading doc ...`") |
|
|
documents = [] |
|
|
for file in files: |
|
|
doc_path = save_file_locally(file) |
|
|
pages = PyPDFLoader(doc_path) |
|
|
documents.extend(pages.load()) |
|
|
|
|
|
return ",".join([doc.page_content for doc in documents]) |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def gen_embeddings() -> HuggingFaceEmbeddings: |
|
|
"""Generate embeddings for given model.""" |
|
|
embeddings = HuggingFaceEmbeddings( |
|
|
cache_folder="hf_model" |
|
|
) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def process_corpus(corpus: str, chunk_size: int = 1000, overlap: int = 50) -> List: |
|
|
"""Process text for Semantic Search.""" |
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=chunk_size, chunk_overlap=overlap |
|
|
) |
|
|
|
|
|
texts = text_splitter.split_text(corpus) |
|
|
|
|
|
|
|
|
num_chunks = len(texts) |
|
|
st.write(f"Number of text chunks: {num_chunks}") |
|
|
|
|
|
|
|
|
embeddings = gen_embeddings() |
|
|
|
|
|
|
|
|
vectorstore = FAISS.from_texts(texts, embeddings).as_retriever( |
|
|
search_kwargs={"k": 4} |
|
|
) |
|
|
|
|
|
|
|
|
tool = create_retriever_tool( |
|
|
vectorstore, |
|
|
"search_docs", |
|
|
"Searches and returns documents using the context provided as a source, relevant to the user input question.", |
|
|
) |
|
|
|
|
|
tools = [tool] |
|
|
return tools |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def generate_agent_executer(text: str) -> List[AgentExecutor]: |
|
|
"""Generate the memory functionality.""" |
|
|
tools = process_corpus(text) |
|
|
|
|
|
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt) |
|
|
|
|
|
|
|
|
agent_executor = AgentExecutor( |
|
|
agent=agent, |
|
|
tools=tools, |
|
|
verbose=True, |
|
|
return_intermediate_steps=True, |
|
|
) |
|
|
return agent_executor |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def generate_eval(raw_text: str, N: int, chunk: int) -> List: |
|
|
"""Generate the focusing functionality.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
update = st.empty() |
|
|
ques_update = st.empty() |
|
|
update.info("`Generating sample questions ...`") |
|
|
n = len(raw_text) |
|
|
starting_indices = [random.randint(0, n - chunk) for _ in range(N)] |
|
|
sub_sequences = [raw_text[i : i + chunk] for i in starting_indices] |
|
|
chain = QAGenerationChain.from_llm(llm) |
|
|
eval_set = [] |
|
|
for i, b in enumerate(sub_sequences): |
|
|
try: |
|
|
qa = chain.run(b) |
|
|
eval_set.append(qa) |
|
|
ques_update.info(f"Creating Question: {i+1}") |
|
|
except ValueError: |
|
|
st.warning(f"Error in generating Question: {i+1}...", icon="โ ๏ธ") |
|
|
continue |
|
|
|
|
|
eval_set_full = list(itertools.chain.from_iterable(eval_set)) |
|
|
|
|
|
update.empty() |
|
|
ques_update.empty() |
|
|
|
|
|
return eval_set_full |
|
|
|
|
|
|
|
|
@st.cache_resource() |
|
|
def gen_side_bar_qa(text: str) -> None: |
|
|
"""Generate responses from query.""" |
|
|
if text: |
|
|
|
|
|
if "eval_set" not in st.session_state: |
|
|
|
|
|
num_eval_questions = 5 |
|
|
st.session_state.eval_set = generate_eval(text, num_eval_questions, 3000) |
|
|
|
|
|
|
|
|
for i, qa_pair in enumerate(st.session_state.eval_set): |
|
|
st.sidebar.markdown( |
|
|
f""" |
|
|
<div class="css-card"> |
|
|
<span class="card-tag">Question {i + 1}</span> |
|
|
<p style="font-size: 12px;">{qa_pair['question']}</p> |
|
|
<p style="font-size: 12px;">{qa_pair['answer']}</p> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
st.write("Ready to answer your questions.") |
|
|
|
|
|
|
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
#MainMenu {visibility: hidden; |
|
|
# } |
|
|
footer {visibility: hidden; |
|
|
} |
|
|
.css-card { |
|
|
border-radius: 0px; |
|
|
padding: 30px 10px 10px 10px; |
|
|
background-color: black; |
|
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
|
|
margin-bottom: 10px; |
|
|
font-family: "IBM Plex Sans", sans-serif; |
|
|
} |
|
|
.card-tag { |
|
|
border-radius: 0px; |
|
|
padding: 1px 5px 1px 5px; |
|
|
margin-bottom: 10px; |
|
|
position: absolute; |
|
|
left: 0px; |
|
|
top: 0px; |
|
|
font-size: 0.6rem; |
|
|
font-family: "IBM Plex Sans", sans-serif; |
|
|
color: white; |
|
|
background-color: green; |
|
|
} |
|
|
.css-zt5igj {left:0; |
|
|
} |
|
|
span.css-10trblm {margin-left:0; |
|
|
} |
|
|
div.css-1kyxreq {margin-top: -40px; |
|
|
} |
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
st.write( |
|
|
""" |
|
|
<div style="display: flex; align-items: center; margin-left: 0;"> |
|
|
<h1 style="display: inline-block;">InQuest</h1> |
|
|
<sup style="margin-left:5px;font-size:small; color: green;">beta</sup> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
openai_api_key = st.text_input( |
|
|
"OpenAI API Key", key="api_key_openai", type="password" |
|
|
) |
|
|
if openai_api_key and openai_api_key.startswith("sk-"): |
|
|
prompt, llm = create_prompt(openai_api_key) |
|
|
memory = AgentTokenBufferMemory(llm=llm) |
|
|
"[here OpenAI API key](https://platform.openai.com/account/api-keys)" |
|
|
else: |
|
|
st.info("Please add your correct OpenAI API key in the sidebar.") |
|
|
|
|
|
|
|
|
if not openai_api_key: |
|
|
st.info("Please add your OpenAI API key in the sidebar.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
splitter_type = "RecursiveCharacterTextSplitter" |
|
|
|
|
|
uploaded_files = st.file_uploader( |
|
|
"Upload a PDF Document", type=["pdf"], accept_multiple_files=True |
|
|
) |
|
|
|
|
|
if uploaded_files: |
|
|
|
|
|
|
|
|
if ( |
|
|
"last_uploaded_files" not in st.session_state |
|
|
or st.session_state.last_uploaded_files != uploaded_files |
|
|
): |
|
|
st.session_state.last_uploaded_files = uploaded_files |
|
|
if "eval_set" in st.session_state: |
|
|
del st.session_state["eval_set"] |
|
|
|
|
|
|
|
|
raw_pdf_text = load_docs(uploaded_files) |
|
|
st.success("Documents uploaded and processed.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent_executor = generate_agent_executer(raw_pdf_text) |
|
|
|
|
|
if "messages" not in st.session_state or st.sidebar.button("Clear message history"): |
|
|
st.session_state["messages"] = [AIMessage(content=starter_message)] |
|
|
|
|
|
for msg in st.session_state.messages: |
|
|
if isinstance(msg, AIMessage): |
|
|
st.chat_message("assistant").write(msg.content) |
|
|
elif isinstance(msg, HumanMessage): |
|
|
st.chat_message("user").write(msg.content) |
|
|
memory.chat_memory.add_message(msg) |
|
|
|
|
|
if user_question := st.chat_input(placeholder=starter_message): |
|
|
st.chat_message("user").write(user_question) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
st_callback = StreamlitCallbackHandler( |
|
|
st.container(), |
|
|
expand_new_thoughts=True, |
|
|
collapse_completed_thoughts=True, |
|
|
thought_labeler=None, |
|
|
) |
|
|
|
|
|
response = agent_executor( |
|
|
{"input": user_question, "history": st.session_state.messages}, |
|
|
callbacks=[st_callback], |
|
|
include_run_info=True, |
|
|
) |
|
|
st.session_state.messages.append(AIMessage(content=response["output"])) |
|
|
|
|
|
st.write(response["output"]) |
|
|
|
|
|
memory.save_context({"input": user_question}, response) |
|
|
|
|
|
st.session_state["messages"] = memory.buffer |
|
|
|
|
|
run_id = response["__run"].run_id |
|
|
|
|
|
col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1]) |
|
|
|