File size: 4,686 Bytes
9807ccc
46fbfbe
9807ccc
46fbfbe
 
 
9807ccc
 
 
46fbfbe
 
9807ccc
 
 
46fbfbe
 
9807ccc
15f76b2
 
 
 
 
46fbfbe
15f76b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9807ccc
 
15f76b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9807ccc
15f76b2
 
 
9807ccc
15f76b2
 
9807ccc
15f76b2
 
 
9807ccc
15f76b2
46fbfbe
 
 
 
15f76b2
9807ccc
15f76b2
 
 
 
 
9807ccc
15f76b2
 
 
9807ccc
d789b69
688c931
15f76b2
 
688c931
 
15f76b2
688c931
 
17c8aa2
688c931
 
 
d789b69
46fbfbe
15f76b2
 
9807ccc
15f76b2
9807ccc
17c8aa2
688c931
46fbfbe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import time

import gradio as gr
import xmltodict
from grobid_client.grobid_client import GrobidClient
from langchain import hub
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers import GrobidParser
from langchain_core.documents import Document
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict

from utils import create_scheduler, save_feedback


class PaperQA:
    def __init__(self):
        self.qa_graph = None
        self.current_file = None
        self.scheduler, self.feedback_file = create_scheduler()

    class State(TypedDict):
        question: str
        context: List[Document]
        answer: str

    def get_extra_docs(self, file_name):
        # TODO: Add the code to extract the title, authors, and abstract from the PDF file
        client = GrobidClient(config_path="./config.json")
        information = client.process_pdf(
            "processHeaderDocument",
            file_name,
            generateIDs=False,
            consolidate_header=False,
            consolidate_citations=False,
            include_raw_citations=False,
            include_raw_affiliations=False,
            tei_coordinates=False,
            segment_sentences=False,
        )
        dict_information = xmltodict.parse(information[2])
        title = dict_information["tei"]["teiHeader"]["fileDesc"]["titleStmt"]["title"]
        return title

    def initiate_graph(self, file):
        if self.current_file != file.name:
            self.qa_graph = None
            self.current_file = file.name

        loader = GenericLoader.from_filesystem(
            file.name,
            parser=GrobidParser(
                segment_sentences=False,
                grobid_server="https://jpangas-grobid-paper-extractor.hf.space/api/processFulltextDocument",
            ),
        )

        docs = loader.load()

        embeddings = OpenAIEmbeddings()
        vector_store = InMemoryVectorStore(embeddings)

        llm = ChatOpenAI(model="gpt-4o-mini")
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000, chunk_overlap=200, add_start_index=True
        )
        all_splits = text_splitter.split_documents(docs)
        vector_store.add_documents(documents=all_splits)
        prompt = hub.pull("rlm/rag-prompt")

        def retrieve(state: self.State):
            retrieved_docs = vector_store.similarity_search(state["question"])
            return {"context": retrieved_docs}

        def generate(state: self.State):
            docs_content = "\n\n".join(doc.page_content for doc in state["context"])
            messages = prompt.invoke(
                {"question": state["question"], "context": docs_content}
            )
            response = llm.invoke(messages)
            return {"answer": response.content}

        graph_builder = StateGraph(self.State).add_sequence([retrieve, generate])
        graph_builder.add_edge(START, "retrieve")
        self.qa_graph = graph_builder.compile()

        name = file.name.split("/")[-1]
        return f"The paper {name} has been loaded and is ready for questions!"

    def answer_question(self, question, history):
        if self.qa_graph is None:
            return "Please upload a PDF file first and wait for it to be loaded!"

        response = self.qa_graph.invoke({"question": question})
        if response["answer"] != "Please upload a PDF file first!":
            save_feedback(
                self.scheduler, self.feedback_file, question, response["answer"]
            )
        return response["answer"]

    def slow_echo(self, message, history):
        answer = self.answer_question(message, history)
        if answer == "Please upload a PDF file first!":
            yield answer
            return

        for i in range(len(answer)):
            time.sleep(0.01)
            yield answer[: i + 1]


def main():
    qa_app = PaperQA()

    with gr.Blocks() as demo:
        file_input = gr.File(
            label="Upload a research paper as a PDF file and wait for it to be loaded",
            file_types=[".pdf"],
        )

        textbox = gr.Textbox(
            label="Status of Upload", value="No Paper Uploaded", interactive=False
        )

        gr.ChatInterface(qa_app.slow_echo, type="messages")

        file_input.upload(fn=qa_app.initiate_graph, inputs=file_input, outputs=textbox)

    demo.launch()


if __name__ == "__main__":
    main()