Spaces:
Runtime error
Runtime error
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
from langchain.schema import BaseMessage, BaseRetriever, Document | |
from langchain.chains.conversational_retrieval.base import _get_chat_history | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.chains.llm import LLMChain | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate) | |
from config import DEPLOYMENT_ID | |
from prompts.custom_chain import SYSTEM_PROMPT_TEMPLATE, HUMAN_PROMPT_TEMPLATE | |
from config import OPENAI_API_TYPE, OPENAI_API_VERSION, OPENAI_API_KEY, OPENAI_API_BASE | |
from chains.azure_openai import CustomAzureOpenAI | |
class MultiQueriesChain(LLMChain): | |
llm = CustomAzureOpenAI(deployment_name=DEPLOYMENT_ID, | |
openai_api_type=OPENAI_API_TYPE, | |
openai_api_base=OPENAI_API_BASE, | |
openai_api_version=OPENAI_API_VERSION, | |
openai_api_key=OPENAI_API_KEY, | |
temperature=0.0) | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE), | |
HumanMessagePromptTemplate.from_template(HUMAN_PROMPT_TEMPLATE) | |
]) | |
llm_chain = MultiQueriesChain() | |
class CustomConversationalRetrievalChain(ConversationalRetrievalChain): | |
retriever: BaseRetriever | |
"""Index to connect to.""" | |
max_tokens_limit: Optional[int] = None | |
def _get_docs( | |
self, | |
question: str, | |
inputs: Dict[str, Any] | |
) -> List[Document]: | |
"""Get docs.""" | |
docs = self.retriever.get_relevant_documents( | |
question | |
) | |
# Add attribute to docs call docs.citation | |
for (idx, d) in enumerate(docs): | |
item = [d.page_content.strip("�"), d.metadata["source"]] | |
d.page_content = f'[{idx+1}] {item[0]}' | |
d.metadata["source"] = f'{item[1]}' | |
return self._reduce_tokens_below_limit(docs) | |
# def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: | |
# results = llm_chain.predict(question=question) + "\n" | |
# print(results) | |
# queries = list(map(lambda x: x.strip(), results.split(', '))) | |
# docs = [] | |
# print(queries) | |
# for query in queries[:3]: | |
# self.retriever.search_kwargs = {"k": 3} | |
# doc = self.retriever.get_relevant_documents(query) | |
# docs.extend(doc) | |
# unique_documents_dict = { | |
# (doc.page_content, tuple(sorted(doc.metadata.items()))): doc | |
# for doc in docs | |
# } | |
# unique_documents = list(unique_documents_dict.values()) | |
# return self._reduce_tokens_below_limit(unique_documents) | |