Spaces:
Running
Running
from operator import itemgetter | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda | |
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser | |
from langchain_anthropic import ChatAnthropic | |
from langchain_openai import ChatOpenAI | |
from utils import _combine_documents, Retriever | |
from prompts import _ANSWERER_SYSTEM_TEMPLATE, _AGENT_SYSTEM_TEMPLATE | |
from tools import Retrieve | |
# AGENT | |
class Agent(): | |
def __init__( | |
self, | |
model_name:str = "gpt-4-turbo", | |
system_template: str = _AGENT_SYSTEM_TEMPLATE, | |
temperature: float = 0.0, | |
) -> None: | |
self.model_name = model_name | |
self.system_template = system_template | |
self.temperature = temperature | |
self.runnable = self._create_runnable() | |
pass | |
def _create_runnable(self) -> RunnableParallel: | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", self.system_template), | |
MessagesPlaceholder("chat_history", optional = True), | |
("human", "{query}"), | |
] | |
) | |
if "gpt" in self.model_name: | |
model = ChatOpenAI( | |
name = "agent", | |
streaming = True, | |
model = self.model_name, | |
temperature = self.temperature, | |
).bind_tools([Retrieve]) | |
elif "claude" in self.model_name: | |
model = ChatAnthropic( | |
name = "agent", | |
streaming = True, | |
model = self.model_name, | |
temperature = self.temperature, | |
).bind_tools([Retrieve]) | |
agent_runnable = ( | |
prompt | |
| model | |
| ToolsAgentOutputParser() | |
) | |
return agent_runnable | |
# ANSWERER | |
class Answerer(): | |
def __init__( | |
self, | |
model_name:str = "gpt-4-turbo", | |
collection_index:int = 0, | |
use_doctrines:bool = True, | |
rewrite:bool = True, | |
search_type:str = "similarity", | |
similarity_threshold:float = 0.0, | |
k:int = 15, | |
temperature:float = 0.0, | |
system_template:str = _ANSWERER_SYSTEM_TEMPLATE, | |
) -> None: | |
self.model_name = model_name | |
self.collection_index = collection_index | |
self.use_doctrines = use_doctrines | |
self.rewrite = rewrite | |
self.search_type = search_type | |
self.similarity_threshold = similarity_threshold | |
self.k = k | |
self.temperature = temperature | |
self.system_template = system_template | |
self.runnable = self._create_runnable() | |
pass | |
def _create_runnable(self) -> RunnableParallel: | |
vectorstore = Retriever( | |
collection_index = self.collection_index, | |
use_doctrines = self.use_doctrines, | |
search_type = self.search_type, | |
k = self.k, | |
similarity_threshold = self.similarity_threshold, | |
) | |
_retrieved_docs = RunnablePassthrough.assign( | |
docs = itemgetter("query") | RunnableLambda(vectorstore._retrieve), | |
) | |
ANSWER_PROMPT = ChatPromptTemplate.from_messages( | |
[ | |
("system", self.system_template), | |
("human", "{query}"), | |
] | |
) | |
if "gpt" in self.model_name: | |
model = ChatOpenAI( | |
name = "answerer", | |
streaming = True, | |
model = self.model_name, | |
temperature = self.temperature, | |
) | |
elif "claude" in self.model_name: | |
model = ChatAnthropic( | |
name = "answerer", | |
streaming = True, | |
model = self.model_name, | |
temperature = self.temperature, | |
) | |
_answer = { | |
"answer": | |
RunnablePassthrough.assign( | |
context = lambda x: _combine_documents(x["docs"]), | |
) | |
| ANSWER_PROMPT | |
| model, | |
"docs": itemgetter("docs"), | |
"standalone_question": itemgetter("query"), | |
} | |
chain = _retrieved_docs | _answer | |
return chain |