from operator import itemgetter from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda from langchain.agents.output_parsers.tools import ToolsAgentOutputParser from langchain_anthropic import ChatAnthropic from langchain_openai import ChatOpenAI from utils import _combine_documents, Retriever from prompts import _ANSWERER_SYSTEM_TEMPLATE, _AGENT_SYSTEM_TEMPLATE from tools import Retrieve # AGENT class Agent(): def __init__( self, model_name:str = "gpt-4-turbo", system_template: str = _AGENT_SYSTEM_TEMPLATE, temperature: float = 0.0, ) -> None: self.model_name = model_name self.system_template = system_template self.temperature = temperature self.runnable = self._create_runnable() pass def _create_runnable(self) -> RunnableParallel: prompt = ChatPromptTemplate.from_messages( [ ("system", self.system_template), MessagesPlaceholder("chat_history", optional = True), ("human", "{query}"), ] ) if "gpt" in self.model_name: model = ChatOpenAI( name = "agent", streaming = True, model = self.model_name, temperature = self.temperature, ).bind_tools([Retrieve]) elif "claude" in self.model_name: model = ChatAnthropic( name = "agent", streaming = True, model = self.model_name, temperature = self.temperature, ).bind_tools([Retrieve]) agent_runnable = ( prompt | model | ToolsAgentOutputParser() ) return agent_runnable # ANSWERER class Answerer(): def __init__( self, model_name:str = "gpt-4-turbo", collection_index:int = 0, use_doctrines:bool = True, rewrite:bool = True, search_type:str = "similarity", similarity_threshold:float = 0.0, k:int = 15, temperature:float = 0.0, system_template:str = _ANSWERER_SYSTEM_TEMPLATE, ) -> None: self.model_name = model_name self.collection_index = collection_index self.use_doctrines = use_doctrines self.rewrite = rewrite self.search_type = search_type self.similarity_threshold = similarity_threshold self.k = k self.temperature = temperature self.system_template = system_template self.runnable = self._create_runnable() pass def _create_runnable(self) -> RunnableParallel: vectorstore = Retriever( collection_index = self.collection_index, use_doctrines = self.use_doctrines, search_type = self.search_type, k = self.k, similarity_threshold = self.similarity_threshold, ) _retrieved_docs = RunnablePassthrough.assign( docs = itemgetter("query") | RunnableLambda(vectorstore._retrieve), ) ANSWER_PROMPT = ChatPromptTemplate.from_messages( [ ("system", self.system_template), ("human", "{query}"), ] ) if "gpt" in self.model_name: model = ChatOpenAI( name = "answerer", streaming = True, model = self.model_name, temperature = self.temperature, ) elif "claude" in self.model_name: model = ChatAnthropic( name = "answerer", streaming = True, model = self.model_name, temperature = self.temperature, ) _answer = { "answer": RunnablePassthrough.assign( context = lambda x: _combine_documents(x["docs"]), ) | ANSWER_PROMPT | model, "docs": itemgetter("docs"), "standalone_question": itemgetter("query"), } chain = _retrieved_docs | _answer return chain