File size: 910 Bytes
465a7e3
 
 
 
 
 
95b0fa1
465a7e3
 
 
 
 
95b0fa1
465a7e3
 
 
 
 
 
 
95b0fa1
465a7e3
95b0fa1
 
465a7e3
 
95b0fa1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import chainlit as cl
from qdrant_client import QdrantClient
from rag_graph import RagGraph

@cl.cache
def get_qdrant_client():
  """Create a QdrantClient instance and cache it for restarts during development."""
  from qdrant_client import QdrantClient
  return QdrantClient(path='data/vectors')

@cl.on_chat_start
async def on_chat_start():
  """Create the RAG graph and store it in the user session."""
  qdrant_client = get_qdrant_client()
  rag_graph = RagGraph(qdrant_client)

  cl.user_session.set("rag_graph", rag_graph)

@cl.on_message
async def on_message(question: cl.Message):
  """Stream the response to the user."""
  msg = cl.Message(content="")
  # Send a message to the user to indicate that the response is being generated
  await msg.send()
  
  rag_graph = cl.user_session.get("rag_graph")
   # Update the message when streaming is complete
  await rag_graph.stream(question.content, msg)