Spaces:
Running
Running
from utils import MainState, generate_uuid, llm | |
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage | |
from langchain_core.prompts import ChatPromptTemplate | |
from langgraph.graph import StateGraph, START, END | |
import re | |
def get_graph(retriever): | |
def retriever_node(state: MainState): | |
return { | |
'question': state['question'], | |
'scratchpad': state['scratchpad'] + [ToolMessage(content=retriever.invoke(state['question'].content), | |
tool_call_id=state['scratchpad'][-1].tool_call_id)], | |
'answer': state['answer'], | |
'next_node': 'model_node', | |
'history': state['history'] | |
} | |
import re | |
def model_node(state: MainState): | |
prompt = ChatPromptTemplate.from_template( | |
""" | |
Você é um assistente de IA chamado DocAI. Responda à pergunta abaixo da forma mais precisa possível. | |
Caso não tenha informações para responder à pergunte **retorne apenas** uma resposta no seguinte formato: | |
<tool>retriever</tool>, | |
ao fazer isso a task será repassada para um agente que irá complementar as informações. | |
Se a pergunta puder ser respondida sem acessar documentos enviados, forneça uma resposta **concisa e objetiva**, com no máximo três sentenças. | |
### Contexto: | |
- Bloco de Notas: {scratchpad} | |
- Histórico de Conversas: {chat_history} | |
**Pergunta:** {question} | |
""" | |
) | |
if isinstance(state['question'], str): | |
state['question'] = HumanMessage(content=state['question']) | |
qa_chain = prompt | llm | |
response = qa_chain.invoke({'question': state['question'].content, | |
'scratchpad': state['scratchpad'], | |
'chat_history': [ | |
f'AI: {msg.content}' if isinstance(msg, AIMessage) else f'Human: {msg.content}' | |
for msg in state['history']], | |
}) | |
if '<tool>' in response.content: | |
return { | |
'question': state['question'], | |
'scratchpad': state['scratchpad'] + [AIMessage(content='', tool_call_id=generate_uuid())] if state[ | |
'scratchpad'] else [AIMessage(content='', tool_call_id=generate_uuid())], | |
'answer': state['answer'], | |
'next_node': 'retriever', | |
'history': state['history'] | |
} | |
# print(state['scratchpad']) | |
return { | |
'question': state['question'], | |
'scratchpad': state['scratchpad'], | |
'answer': response, | |
'next_node': END, | |
'history': state['history'] + [HumanMessage(content=state['question'].content), response] | |
} | |
def next_node(state: MainState): | |
return state['next_node'] | |
graph = StateGraph(MainState) | |
graph.add_node('model', model_node) | |
graph.add_node('retriever', retriever_node) | |
graph.add_edge(START, 'model') | |
graph.add_edge('retriever', 'model') | |
graph.add_conditional_edges('model', next_node) | |
chain = graph.compile() | |
return chain |