import json
import os
import re
import uuid
import streamlit as st
import pandas as pd
from custom import *
import copy
import io


def get_history_chats(path: str) -> list:
    if "apikey" in st.secrets:
        if not os.path.exists(path):
            os.makedirs(path)
        files = [f for f in os.listdir(f'./{path}') if f.endswith('.json')]
        files_with_time = [(f, os.stat(f'./{path}/' + f).st_ctime) for f in files]
        sorted_files = sorted(files_with_time, key=lambda x: x[1], reverse=True)
        chat_names = [os.path.splitext(f[0])[0] for f in sorted_files]
        if len(chat_names) == 0:
            chat_names.append('New Chat_' + str(uuid.uuid4()))
    else:
        chat_names = ['New Chat_' + str(uuid.uuid4())]
    return chat_names


def save_data(path: str, file_name: str, history: list, paras: dict, contexts: dict, **kwargs):
    if not os.path.exists(path):
        os.makedirs(path)
    with open(f"./{path}/{file_name}.json", 'w', encoding='utf-8') as f:
        json.dump({"history": history, "paras": paras, "contexts": contexts, **kwargs}, f)


def remove_data(path: str, chat_name: str):
    try:
        os.remove(f"./{path}/{chat_name}.json")
    except FileNotFoundError:
        pass
    # 清除缓存
    try:
        st.session_state.pop('history' + chat_name)
        for item in ["context_select", "context_input", "context_level", *initial_content_all['paras']]:
            st.session_state.pop(item + chat_name + "value")
    except KeyError:
        pass


def load_data(path: str, file_name: str) -> dict:
    try:
        with open(f"./{path}/{file_name}.json", 'r', encoding='utf-8') as f:
            data = json.load(f)
            return data
    except FileNotFoundError:
        content = copy.deepcopy(initial_content_all)
        if "apikey" in st.secrets:
            with open(f"./{path}/{file_name}.json", 'w', encoding='utf-8') as f:
                f.write(json.dumps(content))
        return content


def show_each_message(message: str, role: str, area=None):
    if area is None:
        area = [st.markdown] * 2
    if role == 'user':
        icon = user_svg
        name = user_name
        background_color = user_background_color
    else:
        icon = gpt_svg
        name = gpt_name
        background_color = gpt_background_color
    message = colon_correction(
        url_correction(message)
    )
    area[0](f"\n<div class='avatar'>{icon}<h2>{name}:</h2></div>", unsafe_allow_html=True)
    area[1](f"""<div class='content-div' style='background-color: {background_color};'>\n\n{message}""",
            unsafe_allow_html=True)


def show_messages(messages: list):
    for each in messages:
        if (each["role"] == "user") or (each["role"] == "assistant"):
            show_each_message(each["content"], each["role"])
        if each["role"] == "assistant":
            st.write("---")


# 根据context_level提取history
def get_history_input(history: list, level: int) -> list:
    if level != 0:
        df_history = pd.DataFrame(history)
        df_system = df_history.query('role=="system"')
        df_input = df_history.query('role!="system"')
        df_input = df_input[-level * 2:]
        res = pd.concat([df_system, df_input], ignore_index=True).to_dict('records')
    else:
        res = []
    return res


# 去除#号右边的空格
# def remove_hashtag_right__space(text: str) -> str:
#     text = re.sub(r"(#+)\s*", r"\1", text)
#     return text


# 提取文本
def extract_chars(text: str, num: int) -> str:
    char_num = 0
    chars = ''
    for char in text:
        # 汉字算两个字符
        if '\u4e00' <= char <= '\u9fff':
            char_num += 2
        else:
            char_num += 1
        chars += char
        if char_num >= num:
            break
    return chars


@st.cache_data(max_entries=20, show_spinner=False)
def download_history(history: list):
    md_text = ""
    for msg in history:
        if msg['role'] == 'user':
            md_text += f'## {user_name}:\n{msg["content"]}\n'
        elif msg['role'] == 'assistant':
            md_text += f'## {gpt_name}:\n{msg["content"]}\n'
    output = io.BytesIO()
    output.write(md_text.encode('utf-8'))
    output.seek(0)
    return output


def filename_correction(filename: str) -> str:
    pattern = r'[^\w\.-]'
    filename = re.sub(pattern, '', filename)
    return filename


def url_correction(text: str) -> str:
    pattern = r'((?:http[s]?://|www\.)(?:[a-zA-Z0-9]|[$-_\~#!])+)'
    text = re.sub(pattern, r' \g<1> ', text)
    return text


# st的markdown会错误渲染英文引号加英文字符,例如 :abc
def colon_correction(text):
    pattern = r':[a-zA-Z]'
    if re.search(pattern, text):
        text = text.replace(":", "&#58;")
        pattern = r'`([^`]*)&#58;([^`]*)`|```([^`]*)&#58;([^`]*)```'
        text = re.sub(pattern, lambda m: m.group(0).replace('&#58;', ':') if '&#58;' in m.group(0) else m.group(0),
                      text)
    return text