File size: 4,341 Bytes
d46cc41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI

from utils import _combine_documents, Retriever
from prompts import _ANSWERER_SYSTEM_TEMPLATE, _AGENT_SYSTEM_TEMPLATE
from tools import Retrieve

# AGENT

class Agent():
    def __init__(
        self,
        model_name:str = "gpt-4-turbo", 
        system_template: str = _AGENT_SYSTEM_TEMPLATE,
        temperature: float = 0.0,
    ) -> None:
        self.model_name = model_name
        self.system_template = system_template
        self.temperature = temperature

        self.runnable = self._create_runnable()
        pass    
    
    def _create_runnable(self) -> RunnableParallel:
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", self.system_template),
                MessagesPlaceholder("chat_history", optional = True),
                ("human", "{query}"),
            ]
        )

        if "gpt" in self.model_name:
            model = ChatOpenAI(
                name = "agent", 
                streaming = True,
                model = self.model_name,
                temperature = self.temperature,
            ).bind_tools([Retrieve])

        elif "claude" in self.model_name:
            model = ChatAnthropic(
                name = "agent", 
                streaming = True,
                model = self.model_name,
                temperature = self.temperature,
            ).bind_tools([Retrieve])

        agent_runnable = (
            prompt
            | model
            | ToolsAgentOutputParser()
        )
        
        return agent_runnable

# ANSWERER
class Answerer():
    def __init__(
            self,
            model_name:str = "gpt-4-turbo",
            collection_index:int = 0,
            use_doctrines:bool = True,
            rewrite:bool = True,
            search_type:str = "similarity",
            similarity_threshold:float = 0.0,
            k:int = 15,
            temperature:float = 0.0,
            system_template:str = _ANSWERER_SYSTEM_TEMPLATE,
        ) -> None:

        self.model_name = model_name
        self.collection_index = collection_index
        self.use_doctrines = use_doctrines
        self.rewrite = rewrite
        self.search_type = search_type
        self.similarity_threshold = similarity_threshold
        self.k = k
        self.temperature = temperature
        self.system_template = system_template

        self.runnable = self._create_runnable()
        
        pass
    
    def _create_runnable(self) -> RunnableParallel:

        vectorstore = Retriever(
            collection_index = self.collection_index,
            use_doctrines = self.use_doctrines,
            search_type = self.search_type,
            k = self.k,
            similarity_threshold = self.similarity_threshold,
        )
        
        _retrieved_docs = RunnablePassthrough.assign(
            docs = itemgetter("query") | RunnableLambda(vectorstore._retrieve),
        )

        ANSWER_PROMPT = ChatPromptTemplate.from_messages(
            [
                ("system", self.system_template),
                ("human", "{query}"),
            ]
        )

        if "gpt" in self.model_name:
            model = ChatOpenAI(
                name = "answerer", 
                streaming = True,
                model = self.model_name,
                temperature = self.temperature,
            )
        elif "claude" in self.model_name:
            model = ChatAnthropic(
                name = "answerer", 
                streaming = True,
                model = self.model_name,
                temperature = self.temperature,
            )

        _answer = {
            "answer":
                RunnablePassthrough.assign(
                    context = lambda x: _combine_documents(x["docs"]),
                ) 
                | ANSWER_PROMPT 
                | model,
            "docs": itemgetter("docs"),
            "standalone_question": itemgetter("query"),
        }

        chain = _retrieved_docs | _answer

        return chain