File size: 2,969 Bytes
d037cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from langchain.schema import BaseMessage, BaseRetriever, Document
from langchain.chains.conversational_retrieval.base import _get_chat_history
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate)
from config import DEPLOYMENT_ID
from prompts.custom_chain import SYSTEM_PROMPT_TEMPLATE, HUMAN_PROMPT_TEMPLATE
from config import OPENAI_API_TYPE, OPENAI_API_VERSION, OPENAI_API_KEY, OPENAI_API_BASE
from chains.azure_openai import CustomAzureOpenAI
from langchain.callbacks.manager import CallbackManagerForChainRun
from utils import add_source_numbers
import os

class MultiQueriesChain(LLMChain):
    llm = CustomAzureOpenAI(deployment_name=DEPLOYMENT_ID,
                    openai_api_type=OPENAI_API_TYPE,
                    openai_api_base=OPENAI_API_BASE,
                    openai_api_version=OPENAI_API_VERSION,
                    openai_api_key=OPENAI_API_KEY,
                    temperature=0.0)
    prompt = ChatPromptTemplate.from_messages(
        [
            SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE),
            HumanMessagePromptTemplate.from_template(HUMAN_PROMPT_TEMPLATE)
        ])

llm_chain = MultiQueriesChain()

class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
    retriever: BaseRetriever
    """Index to connect to."""
    max_tokens_limit: Optional[int] = None
    def _get_docs(
        self,
        question: str,
        inputs: Dict[str, Any]
    ) -> List[Document]:
        """Get docs."""
        docs = self.retriever.get_relevant_documents(
            question
        )
        for (idx, d) in enumerate(docs):
            if "https:" in d.metadata["source"]:
                item = [d.page_content.strip("�"), d.metadata["source"]]
            else:
                item = [d.page_content.strip("�"), os.path.basename(d.metadata["source"])]
            d.page_content = f'[{idx+1}]\t "{item[0]}"\nSource: {item[1]}'
        return self._reduce_tokens_below_limit(docs)
    # def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
    #     results = llm_chain.predict(question=question) + "\n"
    #     print(results)
    #     queries = list(map(lambda x: x.strip(), results.split(', ')))
    #     docs = []
    #     print(queries)
    #     for query in queries[:3]:
    #         self.retriever.search_kwargs = {"k": 3}
    #         doc = self.retriever.get_relevant_documents(query)
    #         docs.extend(doc)
    #     unique_documents_dict = {
    #         (doc.page_content, tuple(sorted(doc.metadata.items()))): doc
    #         for doc in docs
    #     }
    #     unique_documents = list(unique_documents_dict.values())
    #     return self._reduce_tokens_below_limit(unique_documents)