from typing import Any, Callable, Dict, List, Optional, Tuple, Union from langchain.schema import BaseMessage, BaseRetriever, Document from langchain.chains.conversational_retrieval.base import _get_chat_history from langchain.chains import ConversationalRetrievalChain from langchain.chains.llm import LLMChain from langchain.prompts.chat import ( ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate) from config import DEPLOYMENT_ID from prompts.custom_chain import SYSTEM_PROMPT_TEMPLATE, HUMAN_PROMPT_TEMPLATE from config import OPENAI_API_TYPE, OPENAI_API_VERSION, OPENAI_API_KEY, OPENAI_API_BASE from chains.azure_openai import CustomAzureOpenAI class MultiQueriesChain(LLMChain): llm = CustomAzureOpenAI(deployment_name=DEPLOYMENT_ID, openai_api_type=OPENAI_API_TYPE, openai_api_base=OPENAI_API_BASE, openai_api_version=OPENAI_API_VERSION, openai_api_key=OPENAI_API_KEY, temperature=0.0) prompt = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE), HumanMessagePromptTemplate.from_template(HUMAN_PROMPT_TEMPLATE) ]) llm_chain = MultiQueriesChain() class CustomConversationalRetrievalChain(ConversationalRetrievalChain): retriever: BaseRetriever """Index to connect to.""" max_tokens_limit: Optional[int] = None def _get_docs( self, question: str, inputs: Dict[str, Any] ) -> List[Document]: """Get docs.""" docs = self.retriever.get_relevant_documents( question ) # Add attribute to docs call docs.citation for (idx, d) in enumerate(docs): item = [d.page_content.strip("�"), d.metadata["source"]] d.page_content = f'[{idx+1}] {item[0]}' d.metadata["source"] = f'{item[1]}' return self._reduce_tokens_below_limit(docs) # def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: # results = llm_chain.predict(question=question) + "\n" # print(results) # queries = list(map(lambda x: x.strip(), results.split(', '))) # docs = [] # print(queries) # for query in queries[:3]: # self.retriever.search_kwargs = {"k": 3} # doc = self.retriever.get_relevant_documents(query) # docs.extend(doc) # unique_documents_dict = { # (doc.page_content, tuple(sorted(doc.metadata.items()))): doc # for doc in docs # } # unique_documents = list(unique_documents_dict.values()) # return self._reduce_tokens_below_limit(unique_documents)