Spaces:
Sleeping
Sleeping
import time | |
import gradio as gr | |
import xmltodict | |
from grobid_client.grobid_client import GrobidClient | |
from langchain import hub | |
from langchain_community.document_loaders.generic import GenericLoader | |
from langchain_community.document_loaders.parsers import GrobidParser | |
from langchain_core.documents import Document | |
from langchain_core.vectorstores import InMemoryVectorStore | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langgraph.graph import START, StateGraph | |
from typing_extensions import List, TypedDict | |
from utils import create_scheduler, save_feedback | |
class PaperQA: | |
def __init__(self): | |
self.qa_graph = None | |
self.current_file = None | |
self.scheduler, self.feedback_file = create_scheduler() | |
class State(TypedDict): | |
question: str | |
context: List[Document] | |
answer: str | |
def get_extra_docs(self, file_name): | |
# TODO: Add the code to extract the title, authors, and abstract from the PDF file | |
client = GrobidClient(config_path="./config.json") | |
information = client.process_pdf( | |
"processHeaderDocument", | |
file_name, | |
generateIDs=False, | |
consolidate_header=False, | |
consolidate_citations=False, | |
include_raw_citations=False, | |
include_raw_affiliations=False, | |
tei_coordinates=False, | |
segment_sentences=False, | |
) | |
dict_information = xmltodict.parse(information[2]) | |
title = dict_information["tei"]["teiHeader"]["fileDesc"]["titleStmt"]["title"] | |
return title | |
def initiate_graph(self, file): | |
if self.current_file != file.name: | |
self.qa_graph = None | |
self.current_file = file.name | |
loader = GenericLoader.from_filesystem( | |
file.name, | |
parser=GrobidParser( | |
segment_sentences=False, | |
grobid_server="https://jpangas-grobid-paper-extractor.hf.space/api/processFulltextDocument", | |
), | |
) | |
docs = loader.load() | |
embeddings = OpenAIEmbeddings() | |
vector_store = InMemoryVectorStore(embeddings) | |
llm = ChatOpenAI(model="gpt-4o-mini") | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, chunk_overlap=200, add_start_index=True | |
) | |
all_splits = text_splitter.split_documents(docs) | |
vector_store.add_documents(documents=all_splits) | |
prompt = hub.pull("rlm/rag-prompt") | |
def retrieve(state: self.State): | |
retrieved_docs = vector_store.similarity_search(state["question"]) | |
return {"context": retrieved_docs} | |
def generate(state: self.State): | |
docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | |
messages = prompt.invoke( | |
{"question": state["question"], "context": docs_content} | |
) | |
response = llm.invoke(messages) | |
return {"answer": response.content} | |
graph_builder = StateGraph(self.State).add_sequence([retrieve, generate]) | |
graph_builder.add_edge(START, "retrieve") | |
self.qa_graph = graph_builder.compile() | |
name = file.name.split("/")[-1] | |
return f"The paper {name} has been loaded and is ready for questions!" | |
def answer_question(self, question, history): | |
if self.qa_graph is None: | |
return "Please upload a PDF file first and wait for it to be loaded!" | |
response = self.qa_graph.invoke({"question": question}) | |
if response["answer"] != "Please upload a PDF file first!": | |
save_feedback( | |
self.scheduler, self.feedback_file, question, response["answer"] | |
) | |
return response["answer"] | |
def slow_echo(self, message, history): | |
answer = self.answer_question(message, history) | |
if answer == "Please upload a PDF file first!": | |
yield answer | |
return | |
for i in range(len(answer)): | |
time.sleep(0.01) | |
yield answer[: i + 1] | |
def main(): | |
qa_app = PaperQA() | |
with gr.Blocks() as demo: | |
file_input = gr.File( | |
label="Upload a research paper as a PDF file and wait for it to be loaded", | |
file_types=[".pdf"], | |
) | |
textbox = gr.Textbox( | |
label="Status of Upload", value="No Paper Uploaded", interactive=False | |
) | |
gr.ChatInterface(qa_app.slow_echo, type="messages") | |
file_input.upload(fn=qa_app.initiate_graph, inputs=file_input, outputs=textbox) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |