test-interface / runnables.py
tommasodelorenzo's picture
Upload folder using huggingface_hub
d46cc41 verified
from operator import itemgetter
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from utils import _combine_documents, Retriever
from prompts import _ANSWERER_SYSTEM_TEMPLATE, _AGENT_SYSTEM_TEMPLATE
from tools import Retrieve
# AGENT
class Agent():
def __init__(
self,
model_name:str = "gpt-4-turbo",
system_template: str = _AGENT_SYSTEM_TEMPLATE,
temperature: float = 0.0,
) -> None:
self.model_name = model_name
self.system_template = system_template
self.temperature = temperature
self.runnable = self._create_runnable()
pass
def _create_runnable(self) -> RunnableParallel:
prompt = ChatPromptTemplate.from_messages(
[
("system", self.system_template),
MessagesPlaceholder("chat_history", optional = True),
("human", "{query}"),
]
)
if "gpt" in self.model_name:
model = ChatOpenAI(
name = "agent",
streaming = True,
model = self.model_name,
temperature = self.temperature,
).bind_tools([Retrieve])
elif "claude" in self.model_name:
model = ChatAnthropic(
name = "agent",
streaming = True,
model = self.model_name,
temperature = self.temperature,
).bind_tools([Retrieve])
agent_runnable = (
prompt
| model
| ToolsAgentOutputParser()
)
return agent_runnable
# ANSWERER
class Answerer():
def __init__(
self,
model_name:str = "gpt-4-turbo",
collection_index:int = 0,
use_doctrines:bool = True,
rewrite:bool = True,
search_type:str = "similarity",
similarity_threshold:float = 0.0,
k:int = 15,
temperature:float = 0.0,
system_template:str = _ANSWERER_SYSTEM_TEMPLATE,
) -> None:
self.model_name = model_name
self.collection_index = collection_index
self.use_doctrines = use_doctrines
self.rewrite = rewrite
self.search_type = search_type
self.similarity_threshold = similarity_threshold
self.k = k
self.temperature = temperature
self.system_template = system_template
self.runnable = self._create_runnable()
pass
def _create_runnable(self) -> RunnableParallel:
vectorstore = Retriever(
collection_index = self.collection_index,
use_doctrines = self.use_doctrines,
search_type = self.search_type,
k = self.k,
similarity_threshold = self.similarity_threshold,
)
_retrieved_docs = RunnablePassthrough.assign(
docs = itemgetter("query") | RunnableLambda(vectorstore._retrieve),
)
ANSWER_PROMPT = ChatPromptTemplate.from_messages(
[
("system", self.system_template),
("human", "{query}"),
]
)
if "gpt" in self.model_name:
model = ChatOpenAI(
name = "answerer",
streaming = True,
model = self.model_name,
temperature = self.temperature,
)
elif "claude" in self.model_name:
model = ChatAnthropic(
name = "answerer",
streaming = True,
model = self.model_name,
temperature = self.temperature,
)
_answer = {
"answer":
RunnablePassthrough.assign(
context = lambda x: _combine_documents(x["docs"]),
)
| ANSWER_PROMPT
| model,
"docs": itemgetter("docs"),
"standalone_question": itemgetter("query"),
}
chain = _retrieved_docs | _answer
return chain