Spaces:
Running
Running
File size: 4,341 Bytes
d46cc41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from operator import itemgetter
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from utils import _combine_documents, Retriever
from prompts import _ANSWERER_SYSTEM_TEMPLATE, _AGENT_SYSTEM_TEMPLATE
from tools import Retrieve
# AGENT
class Agent():
def __init__(
self,
model_name:str = "gpt-4-turbo",
system_template: str = _AGENT_SYSTEM_TEMPLATE,
temperature: float = 0.0,
) -> None:
self.model_name = model_name
self.system_template = system_template
self.temperature = temperature
self.runnable = self._create_runnable()
pass
def _create_runnable(self) -> RunnableParallel:
prompt = ChatPromptTemplate.from_messages(
[
("system", self.system_template),
MessagesPlaceholder("chat_history", optional = True),
("human", "{query}"),
]
)
if "gpt" in self.model_name:
model = ChatOpenAI(
name = "agent",
streaming = True,
model = self.model_name,
temperature = self.temperature,
).bind_tools([Retrieve])
elif "claude" in self.model_name:
model = ChatAnthropic(
name = "agent",
streaming = True,
model = self.model_name,
temperature = self.temperature,
).bind_tools([Retrieve])
agent_runnable = (
prompt
| model
| ToolsAgentOutputParser()
)
return agent_runnable
# ANSWERER
class Answerer():
def __init__(
self,
model_name:str = "gpt-4-turbo",
collection_index:int = 0,
use_doctrines:bool = True,
rewrite:bool = True,
search_type:str = "similarity",
similarity_threshold:float = 0.0,
k:int = 15,
temperature:float = 0.0,
system_template:str = _ANSWERER_SYSTEM_TEMPLATE,
) -> None:
self.model_name = model_name
self.collection_index = collection_index
self.use_doctrines = use_doctrines
self.rewrite = rewrite
self.search_type = search_type
self.similarity_threshold = similarity_threshold
self.k = k
self.temperature = temperature
self.system_template = system_template
self.runnable = self._create_runnable()
pass
def _create_runnable(self) -> RunnableParallel:
vectorstore = Retriever(
collection_index = self.collection_index,
use_doctrines = self.use_doctrines,
search_type = self.search_type,
k = self.k,
similarity_threshold = self.similarity_threshold,
)
_retrieved_docs = RunnablePassthrough.assign(
docs = itemgetter("query") | RunnableLambda(vectorstore._retrieve),
)
ANSWER_PROMPT = ChatPromptTemplate.from_messages(
[
("system", self.system_template),
("human", "{query}"),
]
)
if "gpt" in self.model_name:
model = ChatOpenAI(
name = "answerer",
streaming = True,
model = self.model_name,
temperature = self.temperature,
)
elif "claude" in self.model_name:
model = ChatAnthropic(
name = "answerer",
streaming = True,
model = self.model_name,
temperature = self.temperature,
)
_answer = {
"answer":
RunnablePassthrough.assign(
context = lambda x: _combine_documents(x["docs"]),
)
| ANSWER_PROMPT
| model,
"docs": itemgetter("docs"),
"standalone_question": itemgetter("query"),
}
chain = _retrieved_docs | _answer
return chain |