harisyammnv
feat: added new files
1156d26
raw
history blame
2.52 kB
from operator import itemgetter
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableLambda
from langchain.schema.runnable.config import RunnableConfig
from langchain.memory import ConversationBufferMemory
from chainlit.client.base import ConversationDict
import chainlit as cl
def setup_runnable():
memory = cl.user_session.get("memory") # type: ConversationBufferMemory
model = ChatOpenAI(streaming=True)
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful chatbot"),
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
]
)
runnable = (
RunnablePassthrough.assign(
history=RunnableLambda(memory.load_memory_variables) | itemgetter("history")
)
| prompt
| model
| StrOutputParser()
)
cl.user_session.set("runnable", runnable)
@cl.on_chat_start
async def on_chat_start():
await cl.Avatar(
name="Chatbot",
path="icon/chainlit.png"
).send()
await cl.Avatar(
name="User",
path="icon/avatar.png",
).send()
cl.user_session.set("memory", ConversationBufferMemory(return_messages=True))
setup_runnable()
@cl.on_chat_resume
async def on_chat_resume(conversation: ConversationDict):
memory = ConversationBufferMemory(return_messages=True)
root_messages = [m for m in conversation["messages"] if m["parentId"] == None]
for message in root_messages:
if message["authorIsUser"]:
memory.chat_memory.add_user_message(message["content"])
else:
memory.chat_memory.add_ai_message(message["content"])
cl.user_session.set("memory", memory)
setup_runnable()
@cl.on_message
async def on_message(message: cl.Message):
memory = cl.user_session.get("memory") # type: ConversationBufferMemory
runnable = cl.user_session.get("runnable") # type: Runnable
res = cl.Message(content="")
async for chunk in runnable.astream(
{"question": message.content},
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
await res.stream_token(chunk)
await res.send()
memory.chat_memory.add_user_message(message.content)
memory.chat_memory.add_ai_message(res.content)