Spaces:
Sleeping
Sleeping
import os | |
from langchain import FAISS, OpenAI, HuggingFaceHub, Cohere, PromptTemplate | |
from langchain.chains import RetrievalQA, ConversationalRetrievalChain | |
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings, CohereEmbeddings | |
from langchain.memory import ConversationBufferMemory | |
from langchain.schema import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter, NLTKTextSplitter, \ | |
SpacyTextSplitter | |
from langchain.vectorstores import Chroma, ElasticVectorSearch | |
from pypdf import PdfReader | |
from schema import EmbeddingTypes, IndexerType, TransformType, BotType | |
class QnASystem: | |
def read_and_load_pdf(self, f_data): | |
pdf_data = PdfReader(f_data) | |
documents = [] | |
for idx, page in enumerate(pdf_data.pages): | |
documents.append(Document(page_content=page.extract_text(), | |
metadata={"page_no": idx, "source": f_data.name})) | |
self.documents = documents | |
def document_transformer(self, transform_type: TransformType): | |
match transform_type: | |
case TransformType.CharacterTransform: | |
t_type = CharacterTextSplitter(chunk_size=1000, chunk_overlap=20) | |
case TransformType.RecursiveTransform: | |
t_type = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20) | |
case TransformType.NLTKTransform: | |
t_type = NLTKTextSplitter() | |
case TransformType.SpacyTransform: | |
t_type = SpacyTextSplitter() | |
case _: | |
raise IndexError("Invalid Transformer Type") | |
self.transformed_documents = t_type.split_documents(documents=self.documents) | |
def generate_embeddings(self, embedding_type: EmbeddingTypes = EmbeddingTypes.OPENAI, | |
indexer_type: IndexerType = IndexerType.FAISS, **kwargs): | |
temperature = kwargs.get("temperature", 0) | |
max_tokens = kwargs.get("max_tokens", 512) | |
match embedding_type: | |
case EmbeddingTypes.OPENAI: | |
os.environ["OPENAI_API_KEY"] = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY") | |
embeddings = OpenAIEmbeddings() | |
llm = OpenAI(temperature=temperature, max_tokens=max_tokens) | |
case EmbeddingTypes.HUGGING_FACE: | |
embeddings = HuggingFaceEmbeddings(model_name=kwargs.get("model_name")) | |
llm = HuggingFaceHub(repo_id=kwargs.get("model_name"), | |
model_kwargs={"temperature": temperature, "max_tokens": max_tokens}) | |
case EmbeddingTypes.COHERE: | |
embeddings = CohereEmbeddings(model=kwargs.get("model_name"), cohere_api_key=kwargs.get("api_key")) | |
llm = Cohere(model=kwargs.get("model_name"), cohere_api_key=kwargs.get("api_key"), | |
model_kwargs={"temperature": temperature, | |
"max_tokens": max_tokens}) | |
case _: | |
raise IndexError("Invalid Embedding Type") | |
match indexer_type: | |
case IndexerType.FAISS: | |
indexer = FAISS | |
case IndexerType.CHROMA: | |
indexer = Chroma() | |
case IndexerType.ELASTICSEARCH: | |
indexer = ElasticVectorSearch(elasticsearch_url=kwargs.get("elasticsearch_url")) | |
case _: | |
raise IndexError("Invalid Indexer Function") | |
self.llm = llm | |
self.indexer = indexer | |
self.vector_store = indexer.from_documents(documents=self.transformed_documents, embedding=embeddings) | |
def get_retriever(self, search_type="similarity", top_k=5, **kwargs): | |
retriever = self.vector_store.as_retriever(search_type=search_type, search_kwargs={"k": top_k}) | |
self.retriever = retriever | |
def get_prompt(self, bot_type: BotType, **kwargs): | |
match bot_type: | |
case BotType.qna: | |
prompt = """ | |
You are a smart and helpful AI assistant, who answer the question given context | |
{context} | |
Question: {question} | |
""" | |
case BotType.conversational: | |
prompt = """ | |
Given the following conversation and a follow up question, | |
rephrase the follow up question to be a standalone question, in its original language. | |
\nChat History:\n{chat_history}\nFollow Up Input: {question}\nStandalone question: | |
""" | |
return PromptTemplate(input_variables=["context", "question", "chat_history"], template=prompt) | |
def build_qa(self, qa_type: BotType, chain_type="stuff", | |
return_documents: bool = True, **kwargs): | |
match qa_type: | |
case BotType.qna: | |
self.chain = RetrievalQA.from_chain_type(llm=self.llm, retriever=self.retriever, chain_type=chain_type, | |
return_source_documents=return_documents, verbose=True) | |
case BotType.conversational: | |
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, | |
output_key="answer") | |
self.chain = ConversationalRetrievalChain.from_llm(llm=self.llm, retriever=self.retriever, | |
chain_type=chain_type, | |
return_source_documents=return_documents, | |
memory=self.memory, verbose=True) | |
case _: | |
raise IndexError("Invalid QA Type") | |
def ask_question(self, query): | |
if type(self.chain) == RetrievalQA: | |
data = {"query": query} | |
else: | |
data = {"question": query} | |
return self.chain(data) | |
def build_chain(self, transform_type, embedding_type, indexer_type, **kwargs): | |
if hasattr(self, "llm"): | |
return self.chain | |
self.document_transformer(transform_type) | |
self.generate_embeddings(embedding_type=embedding_type, | |
indexer_type=indexer_type, **kwargs) | |
self.get_retriever(**kwargs) | |
qa = self.build_qa(qa_type=kwargs.get("bot_type"), **kwargs) | |
return qa | |
if __name__ == "__main__": | |
qna = QnASystem() | |
with open("../docs/Doc A.pdf", "rb") as f: | |
qna.read_and_load_pdf(f) | |
chain = qna.build_chain( | |
transform_type=TransformType.RecursiveTransform, | |
embedding_type=EmbeddingTypes.OPENAI, indexer_type=IndexerType.FAISS, | |
chain_type="map_reduce", bot_type=BotType.conversational, return_documents=True | |
) | |
question = qna.ask_question(query="Hi! Summarize the document.") | |
question = qna.ask_question(query="What happened from June 1984 to September 1996") | |
print(question) | |