Spaces:
Sleeping
Sleeping
File size: 2,328 Bytes
a150274 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from operator import itemgetter
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableLambda
from langchain.schema.runnable.config import RunnableConfig
from langchain.memory import ConversationBufferMemory
from chainlit.client.base import ConversationDict
import chainlit as cl
def setup_runnable():
memory = cl.user_session.get("memory") # type: ConversationBufferMemory
model = ChatOpenAI(streaming=True)
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful chatbot"),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
]
)
runnable = (
RunnablePassthrough.assign(
history=RunnableLambda(memory.load_memory_variables) | itemgetter("history")
)
| prompt
| model
| StrOutputParser()
)
cl.user_session.set("runnable", runnable)
@cl.on_chat_start
async def on_chat_start():
cl.user_session.set("memory", ConversationBufferMemory(return_messages=True))
setup_runnable()
@cl.on_chat_resume
async def on_chat_resume(conversation: ConversationDict):
memory = ConversationBufferMemory(return_messages=True)
root_messages = [m for m in conversation["messages"] if m["parentId"] == None]
for message in root_messages:
if message["authorIsUser"]:
memory.chat_memory.add_user_message(message["content"])
else:
memory.chat_memory.add_ai_message(message["content"])
cl.user_session.set("memory", memory)
setup_runnable()
@cl.on_message
async def on_message(message: cl.Message):
memory = cl.user_session.get("memory") # type: ConversationBufferMemory
runnable = cl.user_session.get("runnable") # type: Runnable
res = cl.Message(content="")
async for chunk in runnable.astream(
{"question": message.content},
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
await res.stream_token(chunk)
await res.send()
memory.chat_memory.add_user_message(message.content)
memory.chat_memory.add_ai_message(res.content) |