Chat_QnA_v2 / chains /custom_chain.py
binh99's picture
update cosmos db
a4b89be
raw
history blame
2.77 kB
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from langchain.schema import BaseMessage, BaseRetriever, Document
from langchain.chains.conversational_retrieval.base import _get_chat_history
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate)
from config import DEPLOYMENT_ID
from prompts.custom_chain import SYSTEM_PROMPT_TEMPLATE, HUMAN_PROMPT_TEMPLATE
from config import OPENAI_API_TYPE, OPENAI_API_VERSION, OPENAI_API_KEY, OPENAI_API_BASE
from chains.azure_openai import CustomAzureOpenAI
class MultiQueriesChain(LLMChain):
llm = CustomAzureOpenAI(deployment_name=DEPLOYMENT_ID,
openai_api_type=OPENAI_API_TYPE,
openai_api_base=OPENAI_API_BASE,
openai_api_version=OPENAI_API_VERSION,
openai_api_key=OPENAI_API_KEY,
temperature=0.0)
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE),
HumanMessagePromptTemplate.from_template(HUMAN_PROMPT_TEMPLATE)
])
llm_chain = MultiQueriesChain()
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
retriever: BaseRetriever
"""Index to connect to."""
max_tokens_limit: Optional[int] = None
def _get_docs(
self,
question: str,
inputs: Dict[str, Any]
) -> List[Document]:
"""Get docs."""
docs = self.retriever.get_relevant_documents(
question
)
# Add attribute to docs call docs.citation
for (idx, d) in enumerate(docs):
item = [d.page_content.strip("�"), d.metadata["source"]]
d.page_content = f'[{idx+1}] {item[0]}'
d.metadata["source"] = f'{item[1]}'
return self._reduce_tokens_below_limit(docs)
# def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
# results = llm_chain.predict(question=question) + "\n"
# print(results)
# queries = list(map(lambda x: x.strip(), results.split(', ')))
# docs = []
# print(queries)
# for query in queries[:3]:
# self.retriever.search_kwargs = {"k": 3}
# doc = self.retriever.get_relevant_documents(query)
# docs.extend(doc)
# unique_documents_dict = {
# (doc.page_content, tuple(sorted(doc.metadata.items()))): doc
# for doc in docs
# }
# unique_documents = list(unique_documents_dict.values())
# return self._reduce_tokens_below_limit(unique_documents)