Spaces:
Runtime error
Runtime error
File size: 2,969 Bytes
d037cdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from langchain.schema import BaseMessage, BaseRetriever, Document
from langchain.chains.conversational_retrieval.base import _get_chat_history
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate)
from config import DEPLOYMENT_ID
from prompts.custom_chain import SYSTEM_PROMPT_TEMPLATE, HUMAN_PROMPT_TEMPLATE
from config import OPENAI_API_TYPE, OPENAI_API_VERSION, OPENAI_API_KEY, OPENAI_API_BASE
from chains.azure_openai import CustomAzureOpenAI
from langchain.callbacks.manager import CallbackManagerForChainRun
from utils import add_source_numbers
import os
class MultiQueriesChain(LLMChain):
llm = CustomAzureOpenAI(deployment_name=DEPLOYMENT_ID,
openai_api_type=OPENAI_API_TYPE,
openai_api_base=OPENAI_API_BASE,
openai_api_version=OPENAI_API_VERSION,
openai_api_key=OPENAI_API_KEY,
temperature=0.0)
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE),
HumanMessagePromptTemplate.from_template(HUMAN_PROMPT_TEMPLATE)
])
llm_chain = MultiQueriesChain()
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
retriever: BaseRetriever
"""Index to connect to."""
max_tokens_limit: Optional[int] = None
def _get_docs(
self,
question: str,
inputs: Dict[str, Any]
) -> List[Document]:
"""Get docs."""
docs = self.retriever.get_relevant_documents(
question
)
for (idx, d) in enumerate(docs):
if "https:" in d.metadata["source"]:
item = [d.page_content.strip("�"), d.metadata["source"]]
else:
item = [d.page_content.strip("�"), os.path.basename(d.metadata["source"])]
d.page_content = f'[{idx+1}]\t "{item[0]}"\nSource: {item[1]}'
return self._reduce_tokens_below_limit(docs)
# def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
# results = llm_chain.predict(question=question) + "\n"
# print(results)
# queries = list(map(lambda x: x.strip(), results.split(', ')))
# docs = []
# print(queries)
# for query in queries[:3]:
# self.retriever.search_kwargs = {"k": 3}
# doc = self.retriever.get_relevant_documents(query)
# docs.extend(doc)
# unique_documents_dict = {
# (doc.page_content, tuple(sorted(doc.metadata.items()))): doc
# for doc in docs
# }
# unique_documents = list(unique_documents_dict.values())
# return self._reduce_tokens_below_limit(unique_documents)
|