import platform
import os
import time
from threading import Thread

from rich.text import Text
from rich.live import Live

from model.infer import ChatBot
from config import InferConfig

infer_config = InferConfig()
chat_bot = ChatBot(infer_config=infer_config)

clear_cmd = 'cls' if platform.system().lower() == 'windows' else 'clear'

welcome_txt = '欢迎使用ChatBot,输入`exit`退出,输入`cls`清屏。\n'
print(welcome_txt)

def build_prompt(history: list[list[str]]) -> str:
    prompt = welcome_txt
    for query, response in history:
        prompt += '\n\033[0;33;40m用户:\033[0m{}'.format(query)
        prompt += '\n\033[0;32;40mChatBot:\033[0m\n{}\n'.format(response)
    return prompt

STOP_CIRCLE: bool=False
def circle_print(total_time: int=60) -> None:
    global STOP_CIRCLE
    '''非stream chat打印忙碌状态
    '''
    list_circle = ["\\", "|", "/", "—"]
    for i in range(total_time * 4):
        time.sleep(0.25)
        print("\r{}".format(list_circle[i % 4]), end="", flush=True)

        if STOP_CIRCLE: break

    print("\r", end='', flush=True)


def chat(stream: bool=True) -> None:
    global  STOP_CIRCLE
    history = []
    turn_count = 0

    while True:
        print('\r\033[0;33;40m用户:\033[0m', end='', flush=True)
        input_txt = input()

        if len(input_txt) == 0:
            print('请输入问题')
            continue
        
        # 退出
        if input_txt.lower() == 'exit':
            break
        
        # 清屏
        if input_txt.lower() == 'cls':
            history = []
            turn_count = 0
            os.system(clear_cmd)
            print(welcome_txt)
            continue
        
        if not stream:
            STOP_CIRCLE = False
            thread = Thread(target=circle_print)
            thread.start()

            outs = chat_bot.chat(input_txt)

            STOP_CIRCLE = True
            thread.join()
            
            print("\r\033[0;32;40mChatBot:\033[0m\n{}\n\n".format(outs), end='')
           
            continue

        history.append([input_txt, ''])
        stream_txt = []
        streamer = chat_bot.stream_chat(input_txt)
        rich_text = Text()

        print("\r\033[0;32;40mChatBot:\033[0m\n", end='')

        with Live(rich_text, refresh_per_second=15) as live: 
            for i, word in enumerate(streamer):
                rich_text.append(word)
                stream_txt.append(word)

        stream_txt = ''.join(stream_txt)

        if len(stream_txt) == 0:
            stream_txt = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋"

        history[turn_count][1] = stream_txt
        
        os.system(clear_cmd)
        print(build_prompt(history), flush=True)
        turn_count += 1

if __name__ == '__main__':
    chat(stream=True)