File size: 3,303 Bytes
0af1d97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46a68f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from utils import MainState, generate_uuid, llm
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, START, END
import re

def get_graph(retriever):
    def retriever_node(state: MainState):
        return {
            'question': state['question'],
            'scratchpad': state['scratchpad'] + [ToolMessage(content=retriever.invoke(state['question'].content),
                                                             tool_call_id=state['scratchpad'][-1].tool_call_id)],
            'answer': state['answer'],
            'next_node': 'model_node',
            'history': state['history']
        }

    import re

    def model_node(state: MainState):
        prompt = ChatPromptTemplate.from_template(
            """
            Você é um assistente de IA chamado DocAI. Responda à pergunta abaixo da forma mais precisa possível.

            Caso não tenha informações para responder à pergunte **retorne apenas** uma resposta no seguinte formato:
            <tool>retriever</tool>,
            ao fazer isso a task será repassada para um agente que irá complementar as informações.
            Se a pergunta puder ser respondida sem acessar documentos enviados, forneça uma resposta **concisa e objetiva**, com no máximo três sentenças.

            ### Contexto:
            - Bloco de Notas: {scratchpad}
            - Histórico de Conversas: {chat_history}

            **Pergunta:** {question}
            """
        )

        if isinstance(state['question'], str):
            state['question'] = HumanMessage(content=state['question'])

        qa_chain = prompt | llm

        response = qa_chain.invoke({'question': state['question'].content,
                                    'scratchpad': state['scratchpad'],
                                    'chat_history': [
                                        f'AI: {msg.content}' if isinstance(msg, AIMessage) else f'Human: {msg.content}'
                                        for msg in state['history']],
                                    })

        if '<tool>' in response.content:
            return {
                'question': state['question'],
                'scratchpad': state['scratchpad'] + [AIMessage(content='', tool_call_id=generate_uuid())] if state[
                    'scratchpad'] else [AIMessage(content='', tool_call_id=generate_uuid())],
                'answer': state['answer'],
                'next_node': 'retriever',
                'history': state['history']
            }

        # print(state['scratchpad'])
        return {
            'question': state['question'],
            'scratchpad': state['scratchpad'],
            'answer': response,
            'next_node': END,
            'history': state['history'] + [HumanMessage(content=state['question'].content), response]
        }

    def next_node(state: MainState):
        return state['next_node']

    graph = StateGraph(MainState)
    graph.add_node('model', model_node)
    graph.add_node('retriever', retriever_node)
    graph.add_edge(START, 'model')
    graph.add_edge('retriever', 'model')
    graph.add_conditional_edges('model', next_node)

    chain = graph.compile()

    return chain