File size: 1,956 Bytes
b966bc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Project      : Python.
# @File         : 991_streamlit_apex_charts
# @Time         : 2022/10/17 上午10:48
# @Author       : yuanjie
# @WeChat       : meutils
# @Software     : PyCharm
# @Description  : 
import streamlit as st
from streamlit_chat import message


def reply(input, history=None, reply_func=lambda input: f'{input}的答案', max_turns=3, container=None):
    if history is None:
        history = []  # [(query, response)]

    if container is None:
        container = st.container()

    with container:
        if len(history) > 0:
            for i, (query, response) in enumerate(history[-max_turns + 1:]):
                message(query, avatar_style="big-smile", is_user=True, key=str(i) + "_user")
                message(response, avatar_style="bottts", is_user=False, key=str(i))

        message(input, avatar_style="big-smile", is_user=True, key=str(len(history)) + "_user")
        # st.write("AI正在回复:")
        with st.empty():
            response = reply_func(input)
            message(response, avatar_style="bottts", is_user=False)

    history.append((input, response))
    return history


if __name__ == '__main__':
    def display_previous_message(texts=None):
        if texts:
            for msg in texts:
                message(msg, avatar_style="bottts")  # display all the previous message


    display_previous_message(["你好!我是你的电影小助手,很高兴为您服务。", "你可以向我提问。"])

    container = st.container()  # 占位符
    text = st.text_area(label="用户输入", height=100, placeholder="请在这儿输入您的问题")

    if st.button("发送", key="predict"):
        with st.spinner("AI正在思考,请稍等........"):
            history = st.session_state.get('state')
            st.session_state["state"] = reply(text, history, container=container)
            print(st.session_state['state'])