File size: 7,016 Bytes
9ccb3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os

from langchain import FAISS, OpenAI, HuggingFaceHub, Cohere, PromptTemplate
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings, CohereEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter, NLTKTextSplitter, \
    SpacyTextSplitter
from langchain.vectorstores import Chroma, ElasticVectorSearch
from pypdf import PdfReader

from schema import EmbeddingTypes, IndexerType, TransformType, BotType


class QnASystem:

    def read_and_load_pdf(self, f_data):
        pdf_data = PdfReader(f_data)
        documents = []
        for idx, page in enumerate(pdf_data.pages):
            documents.append(Document(page_content=page.extract_text(),
                                      metadata={"page_no": idx, "source": f_data.name}))

        self.documents = documents

    def document_transformer(self, transform_type: TransformType):
        match transform_type:
            case TransformType.CharacterTransform:
                t_type = CharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
            case TransformType.RecursiveTransform:
                t_type = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
            case TransformType.NLTKTransform:
                t_type = NLTKTextSplitter()
            case TransformType.SpacyTransform:
                t_type = SpacyTextSplitter()

            case _:
                raise IndexError("Invalid Transformer Type")

        self.transformed_documents = t_type.split_documents(documents=self.documents)

    def generate_embeddings(self, embedding_type: EmbeddingTypes = EmbeddingTypes.OPENAI,
                            indexer_type: IndexerType = IndexerType.FAISS, **kwargs):
        temperature = kwargs.get("temperature", 0)
        max_tokens = kwargs.get("max_tokens", 512)
        match embedding_type:
            case EmbeddingTypes.OPENAI:
                os.environ["OPENAI_API_KEY"] = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
                embeddings = OpenAIEmbeddings()
                llm = OpenAI(temperature=temperature, max_tokens=max_tokens)
            case EmbeddingTypes.HUGGING_FACE:
                embeddings = HuggingFaceEmbeddings(model_name=kwargs.get("model_name"))
                llm = HuggingFaceHub(repo_id=kwargs.get("model_name"),
                                     model_kwargs={"temperature": temperature, "max_tokens": max_tokens})
            case EmbeddingTypes.COHERE:
                embeddings = CohereEmbeddings(model=kwargs.get("model_name"), cohere_api_key=kwargs.get("api_key"))
                llm = Cohere(model=kwargs.get("model_name"), cohere_api_key=kwargs.get("api_key"),
                             model_kwargs={"temperature": temperature,
                                           "max_tokens": max_tokens})
            case _:
                raise IndexError("Invalid Embedding Type")

        match indexer_type:
            case IndexerType.FAISS:
                indexer = FAISS
            case IndexerType.CHROMA:
                indexer = Chroma()

            case IndexerType.ELASTICSEARCH:
                indexer = ElasticVectorSearch(elasticsearch_url=kwargs.get("elasticsearch_url"))
            case _:
                raise IndexError("Invalid Indexer Function")

        self.llm = llm
        self.indexer = indexer
        self.vector_store = indexer.from_documents(documents=self.transformed_documents, embedding=embeddings)

    def get_retriever(self, search_type="similarity", top_k=5, **kwargs):
        retriever = self.vector_store.as_retriever(search_type=search_type, search_kwargs={"k": top_k})
        self.retriever = retriever

    def get_prompt(self, bot_type: BotType, **kwargs):
        match bot_type:
            case BotType.qna:
                prompt = """
                You are a smart and helpful AI assistant, who answer the question given context
                {context}
                Question: {question}
                """
            case BotType.conversational:
                prompt = """
                Given the following conversation and a follow up question, 
                rephrase the follow up question to be a standalone question, in its original language.
                \nChat History:\n{chat_history}\nFollow Up Input: {question}\nStandalone question:
                """
        return PromptTemplate(input_variables=["context", "question", "chat_history"], template=prompt)

    def build_qa(self, qa_type: BotType, chain_type="stuff",
                 return_documents: bool = True, **kwargs):
        match qa_type:
            case BotType.qna:
                self.chain = RetrievalQA.from_chain_type(llm=self.llm, retriever=self.retriever, chain_type=chain_type,
                                                         return_source_documents=return_documents, verbose=True)

            case BotType.conversational:
                self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True,
                                                       output_key="answer")
                self.chain = ConversationalRetrievalChain.from_llm(llm=self.llm, retriever=self.retriever,
                                                                   chain_type=chain_type,
                                                                   return_source_documents=return_documents,
                                                                   memory=self.memory, verbose=True)

            case _:
                raise IndexError("Invalid QA Type")

    def ask_question(self, query):
        if type(self.chain) == RetrievalQA:
            data = {"query": query}
        else:
            data = {"question": query}
        return self.chain(data)

    def build_chain(self, transform_type, embedding_type, indexer_type, **kwargs):
        if hasattr(self, "llm"):
            return self.chain
        self.document_transformer(transform_type)
        self.generate_embeddings(embedding_type=embedding_type,
                                 indexer_type=indexer_type, **kwargs)
        self.get_retriever(**kwargs)
        qa = self.build_qa(qa_type=kwargs.get("bot_type"), **kwargs)
        return qa


if __name__ == "__main__":
    qna = QnASystem()
    with open("../docs/Doc A.pdf", "rb") as f:
        qna.read_and_load_pdf(f)
        chain = qna.build_chain(
            transform_type=TransformType.RecursiveTransform,
            embedding_type=EmbeddingTypes.OPENAI, indexer_type=IndexerType.FAISS,
            chain_type="map_reduce", bot_type=BotType.conversational, return_documents=True
        )
        question = qna.ask_question(query="Hi! Summarize the document.")
        question = qna.ask_question(query="What happened from June 1984 to September 1996")
        print(question)