import gradio as gr
import io
import logging

from llm_profiler import *
import sys
from contextlib import redirect_stdout

# 模型列表
model_names = [
    "opt-1.3b",
    "opt-6.7b",
    "opt-13b",
    "opt-66b",
    "opt-175b",
    "gpt2",
    "gpt2-medium",
    "gpt2-large",
    "gpt2-xl",
    "bloom-560m",
    "bloom-7b",
    "bloom-175b",
    "llama-7b",
    "llama-13b",
    "llama-30b",
    "llama-65b",
    "llama2-13b",
    "llama2-70b",
    "internlm-20b",
    "baichuan2-13b",
]
# GPU 列表
gpu_names = [
    "t4-pcie-15gb",
    "v100-pcie-32gb",
    "v100-sxm-32gb",
    "br104p",
    "a100-pcie-40gb",
    "a100-sxm-40gb",
    "a100-pcie-80gb",
    "a100-sxm-80gb",
    "910b-64gb",
    "h100-sxm-80gb",
    "h100-pcie-80gb",
    "a30-pcie-24gb",
    "a30-sxm-24gb",
    "a40-pcie-48gb",
]


# 创建一个日志处理器,将日志消息写入 StringIO 对象
class StringHandler(logging.Handler):
    def __init__(self):
        super().__init__()
        self.stream = io.StringIO()
        self.setFormatter(logging.Formatter("%(message)s"))

    def emit(self, record):
        self.stream.write(self.format(record) + "\n")

    def get_value(self):
        return self.stream.getvalue()


# 创建一个日志记录器并添加 StringHandler
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
string_handler = StringHandler()
logger.addHandler(string_handler)


def gradio_interface(
    model_name="llama2-70b",
    gpu_name: str = "t4-pcie-15gb",
    bytes_per_param: int = BYTES_FP16,
    batch_size_per_gpu: int = 2,
    seq_len: int = 300,
    generate_len: int = 40,
    ds_zero: int = 0,
    dp_size: int = 1,
    tp_size: int = 4,
    pp_size: int = 1,
    sp_size: int = 1,
    use_kv_cache: bool = True,
    layernorm_dtype_bytes: int = BYTES_FP16,
    kv_cache_dtype_bytes: int = BYTES_FP16,
    flops_efficiency: float = FLOPS_EFFICIENCY,
    hbm_memory_efficiency: float = HBM_MEMORY_EFFICIENCY,
    intra_node_memory_efficiency: float = INTRA_NODE_MEMORY_EFFICIENCY,
    inter_node_memory_efficiency: float = INTER_NODE_MEMORY_EFFICIENCY,
    mode: str = "inference",
    print_flag: bool = True,
) -> list:
    # 清空 StringIO 对象
    string_handler.stream.seek(0)
    string_handler.stream.truncate()

    # 重定向 sys.stdout 到 StringHandler
    original_stdout = sys.stdout
    sys.stdout = string_handler.stream

    # 调用你的推理函数
    results = llm_profile_infer(
        model_name,
        gpu_name,
        bytes_per_param,
        batch_size_per_gpu,
        seq_len,
        generate_len,
        ds_zero,
        dp_size,
        tp_size,
        pp_size,
        sp_size,
        use_kv_cache,
        layernorm_dtype_bytes,
        kv_cache_dtype_bytes,
        flops_efficiency,
        hbm_memory_efficiency,
        intra_node_memory_efficiency,
        inter_node_memory_efficiency,
        mode,
        print_flag,
    )

    # 恢复 sys.stdout
    sys.stdout = original_stdout

    # 获取日志消息
    log_output = string_handler.get_value()

    # 返回推理结果和日志输出
    return results, log_output


# 创建 Gradio 界面
iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Dropdown(choices=model_names, label="Model Name", value="llama2-70b"),
        gr.Dropdown(choices=gpu_names, label="GPU Name", value="a100-sxm-80gb"),
        gr.Number(label="Bytes per Param", value=BYTES_FP16),
        gr.Number(label="Batch Size per GPU", value=2),
        gr.Number(label="Sequence Length", value=300),
        gr.Number(label="Generate Length", value=40),
        gr.Number(label="DS Zero", value=0),
        gr.Number(label="DP Size", value=1),
        gr.Number(label="TP Size", value=4),
        gr.Number(label="PP Size", value=1),
        gr.Number(label="SP Size", value=1),
        gr.Checkbox(label="Use KV Cache", value=True),
        gr.Number(label="Layernorm dtype Bytes", value=BYTES_FP16),
        gr.Number(label="KV Cache dtype Bytes", value=BYTES_FP16),
        gr.Number(label="FLOPS Efficiency", value=FLOPS_EFFICIENCY),
        gr.Number(label="HBM Memory Efficiency", value=HBM_MEMORY_EFFICIENCY),
        gr.Number(
            label="Intra Node Memory Efficiency", value=INTRA_NODE_MEMORY_EFFICIENCY
        ),
        gr.Number(
            label="Inter Node Memory Efficiency", value=INTER_NODE_MEMORY_EFFICIENCY
        ),
        gr.Radio(choices=["inference", "other_mode"], label="Mode", value="inference"),
        gr.Checkbox(label="Print Flag", value=True),
    ],
    outputs=[
        gr.Textbox(label="Inference Results"),  # 推理结果输出,带标签
        gr.Textbox(label="Detailed Analysis"),  # 日志输出,带标签
    ],
    title="LLM Profiler",
    description="Input parameters to profile your LLM.",
)

# 启动 Gradio 界面
iface.launch(auth=("xtrt-llm", "xtrt-llm"), share=False)
# iface.launch()