File size: 6,717 Bytes
7ca61de
704be27
 
 
7ca61de
704be27
7ca61de
704be27
 
 
 
7801115
 
704be27
234eac0
502c3a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234eac0
704be27
502c3a8
 
234eac0
7801115
 
 
 
 
 
704be27
7801115
502c3a8
c5c02af
7801115
234eac0
704be27
 
 
 
 
7801115
 
704be27
234eac0
704be27
502c3a8
704be27
502c3a8
234eac0
704be27
234eac0
704be27
 
7ca61de
 
 
 
234eac0
704be27
 
234eac0
7801115
704be27
 
502c3a8
7801115
 
704be27
 
 
 
 
234eac0
7ca61de
 
502c3a8
7ca61de
 
234eac0
 
 
7ca61de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234eac0
704be27
 
7ca61de
 
 
 
 
 
 
 
 
 
 
704be27
7ca61de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234eac0
704be27
 
 
 
234eac0
03d0c20
 
704be27
 
7ca61de
 
 
 
 
 
704be27
 
502c3a8
 
 
 
704be27
03d0c20
 
 
502c3a8
 
03d0c20
 
 
 
 
704be27
03d0c20
 
 
 
704be27
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import uuid
from typing import Annotated, TypedDict, Literal
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import MessagesState, add_messages
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import MemorySaver

from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain.schema.runnable.config import RunnableConfig
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain.tools import Tool
from langchain_core.tools import tool

import chainlit as cl
from rag import create_rag_pipeline, add_urls_to_vectorstore

# Initialize RAG pipeline
rag_components = create_rag_pipeline(collection_name="london_events")

# Add some initial URLs to the vector store
urls = [
    "https://www.timeout.com/london/things-to-do-in-london-this-weekend",
    "https://www.timeout.com/london/london-events-in-march"
]
add_urls_to_vectorstore(
    rag_components["vector_store"],
    rag_components["text_splitter"],
    urls
)

class AgentState(TypedDict):
    messages: Annotated[list, add_messages]
    context: list  # Store retrieved context

# Create a retrieve tool
@tool
def retrieve_context(query: str) -> list[str]:
    """Searches the knowledge base for relevant information about events and activities. Use this when you need specific details about events."""
    return [doc.page_content for doc in rag_components["retriever"].get_relevant_documents(query)]

tavily_tool = TavilySearchResults(max_results=5)
tool_belt = [tavily_tool, retrieve_context]

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
model = llm.bind_tools(tool_belt)

# Define system prompt
SYSTEM_PROMPT = SystemMessage(content="""
You are a helpful AI assistant that answers questions clearly and concisely.
If you don't know something, simply say you don't know.
Be engaging and professional in your responses.
Use the retrieve_context tool when you need specific information about events and activities.
Use the tavily_search tool for general web searches.
""")

def call_model(state: AgentState):
    messages = [SYSTEM_PROMPT] + state["messages"]
    response = model.invoke(messages)
    return {"messages": [response]}

tool_node = ToolNode(tool_belt)

# Simple flow control - always go to final
def should_continue(state):
    last_message = state["messages"][-1]
    if last_message.tool_calls:
        return "action"
    return END

# Create the graph
builder = StateGraph(AgentState)

# Remove retrieve node and modify graph structure
builder.add_node("agent", call_model)
builder.add_node("action", tool_node)

# Update edges
builder.set_entry_point("agent")
builder.add_conditional_edges(
    "agent",
    should_continue,
)
builder.add_edge("action", "agent")

# Initialize memory saver for conversation persistence
memory = MemorySaver()

# Compile the graph with memory
graph = builder.compile(checkpointer=memory)

@cl.on_chat_start
async def on_chat_start():
    # Generate and store a session ID
    session_id = str(uuid.uuid4())
    cl.user_session.set("session_id", session_id)
    
    # Initialize the conversation state with proper auth
    cl.user_session.set("messages", [])
    
    # Initialize config using stored session ID
    config = RunnableConfig(
        configurable={
            "thread_id": session_id,
            "sessionId": session_id
        }
    )
    
    # Initialize empty state with auth
    try:
        await graph.ainvoke(
            {"messages": [], "context": []},
            config=config
        )
    except Exception as e:
        print(f"Error initializing state: {str(e)}")
    
    await cl.Message(
        content="Hello! I'm your chief joy officer, here to help you with finding fun things to do in London!",
        author="Assistant"
    ).send()

@cl.on_message
async def on_message(message: cl.Message):
    session_id = cl.user_session.get("session_id")
    print(f"Session ID: {session_id}")
    if not session_id:
        session_id = str(uuid.uuid4())
        cl.user_session.set("session_id", session_id)
    
    config = RunnableConfig(
        configurable={
            "thread_id": session_id,
            "checkpoint_ns": "default_namespace",
            "sessionId": session_id
        }
    )
    
    # Try to retrieve previous conversation state
    try:
        previous_state = await graph.aget_state(config)
        if previous_state and previous_state.values:
            previous_messages = previous_state.values.get('messages', [])
            print("Found previous state with messages:", len(previous_messages))
        else:
            print("Previous state empty or invalid")
            previous_messages = []
        current_messages = previous_messages + [HumanMessage(content=message.content)]
    except Exception as e:
        print(f"Error retrieving previous state: {str(e)}")
        current_messages = [HumanMessage(content=message.content)]
    
    # Setup callback handler and final answer message
    cb = cl.LangchainCallbackHandler()
    final_answer = cl.Message(content="")
    await final_answer.send()
    
    loading_msg = None  # Initialize reference to loading message
    
    # Stream the response
    async for chunk in graph.astream(
        {"messages": current_messages, "context": []},
        config=RunnableConfig(
            configurable={
                "thread_id": session_id,
            }
        )
    ):
        for node, values in chunk.items():
            if node == "retrieve":
                loading_msg = cl.Message(content="πŸ” Searching knowledge base...", author="System")
                await loading_msg.send()
            elif values.get("messages"):
                last_message = values["messages"][-1]
                # Check for tool calls in additional_kwargs
                if hasattr(last_message, "additional_kwargs") and last_message.additional_kwargs.get("tool_calls"):
                    tool_name = last_message.additional_kwargs["tool_calls"][0]["function"]["name"]
                    if loading_msg:
                        await loading_msg.remove()
                    loading_msg = cl.Message(
                        content=f"πŸ” Using {tool_name}...",
                        author="Tool"
                    )
                    await loading_msg.send()
                # Only stream AI messages, skip tool outputs
                elif isinstance(last_message, AIMessage):
                    if loading_msg:
                        await loading_msg.remove()
                        loading_msg = None
                    await final_answer.stream_token(last_message.content)

    await final_answer.send()