jpangas's picture
Track history of chatbots conversations and add lint (#2)
46fbfbe verified
raw
history blame
4.69 kB
import time
import gradio as gr
import xmltodict
from grobid_client.grobid_client import GrobidClient
from langchain import hub
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers import GrobidParser
from langchain_core.documents import Document
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict
from utils import create_scheduler, save_feedback
class PaperQA:
def __init__(self):
self.qa_graph = None
self.current_file = None
self.scheduler, self.feedback_file = create_scheduler()
class State(TypedDict):
question: str
context: List[Document]
answer: str
def get_extra_docs(self, file_name):
# TODO: Add the code to extract the title, authors, and abstract from the PDF file
client = GrobidClient(config_path="./config.json")
information = client.process_pdf(
"processHeaderDocument",
file_name,
generateIDs=False,
consolidate_header=False,
consolidate_citations=False,
include_raw_citations=False,
include_raw_affiliations=False,
tei_coordinates=False,
segment_sentences=False,
)
dict_information = xmltodict.parse(information[2])
title = dict_information["tei"]["teiHeader"]["fileDesc"]["titleStmt"]["title"]
return title
def initiate_graph(self, file):
if self.current_file != file.name:
self.qa_graph = None
self.current_file = file.name
loader = GenericLoader.from_filesystem(
file.name,
parser=GrobidParser(
segment_sentences=False,
grobid_server="https://jpangas-grobid-paper-extractor.hf.space/api/processFulltextDocument",
),
)
docs = loader.load()
embeddings = OpenAIEmbeddings()
vector_store = InMemoryVectorStore(embeddings)
llm = ChatOpenAI(model="gpt-4o-mini")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200, add_start_index=True
)
all_splits = text_splitter.split_documents(docs)
vector_store.add_documents(documents=all_splits)
prompt = hub.pull("rlm/rag-prompt")
def retrieve(state: self.State):
retrieved_docs = vector_store.similarity_search(state["question"])
return {"context": retrieved_docs}
def generate(state: self.State):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = prompt.invoke(
{"question": state["question"], "context": docs_content}
)
response = llm.invoke(messages)
return {"answer": response.content}
graph_builder = StateGraph(self.State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
self.qa_graph = graph_builder.compile()
name = file.name.split("/")[-1]
return f"The paper {name} has been loaded and is ready for questions!"
def answer_question(self, question, history):
if self.qa_graph is None:
return "Please upload a PDF file first and wait for it to be loaded!"
response = self.qa_graph.invoke({"question": question})
if response["answer"] != "Please upload a PDF file first!":
save_feedback(
self.scheduler, self.feedback_file, question, response["answer"]
)
return response["answer"]
def slow_echo(self, message, history):
answer = self.answer_question(message, history)
if answer == "Please upload a PDF file first!":
yield answer
return
for i in range(len(answer)):
time.sleep(0.01)
yield answer[: i + 1]
def main():
qa_app = PaperQA()
with gr.Blocks() as demo:
file_input = gr.File(
label="Upload a research paper as a PDF file and wait for it to be loaded",
file_types=[".pdf"],
)
textbox = gr.Textbox(
label="Status of Upload", value="No Paper Uploaded", interactive=False
)
gr.ChatInterface(qa_app.slow_echo, type="messages")
file_input.upload(fn=qa_app.initiate_graph, inputs=file_input, outputs=textbox)
demo.launch()
if __name__ == "__main__":
main()