# Ref: Ouyang, A. (2023). Understanding the Performance of Transformer Inference (Doctoral dissertation, Massachusetts Institute of Technology).

import streamlit as st
import pandas as pd
from model_util import fetch_dictionary_content, load_parameter, get_model, classify_module, get_module_tensors
from calc_util import *
from render_util import create_table, header4, header5


st.set_page_config(layout='wide')
if 'model_config' not in st.session_state:
    st.session_state['model_config'] = {}


def load_model_config(model_id, access_token):
    if 'model_id' in st.session_state['model_config'] and st.session_state['model_config']['model_id'] == model_id:
        return st.session_state['model_config']
    if 'parameter_count' in st.session_state:
        st.session_state.pop('parameter_count')

    model_config = {}
    dictionary_content = fetch_dictionary_content(model_id)
    if dictionary_content:
        model_config['model_id'] = model_id
        model_config['hidden_size'] = dictionary_content['hidden_size']
        model_config['num_attention_heads'] = dictionary_content['num_attention_heads']
        model_config['num_hidden_layers'] = dictionary_content['num_hidden_layers']
        model_config['intermediate_size'] = load_parameter(dictionary_content, ['intermediate_size', 'ffn_dim'])
        model_config['vocab_size'] = dictionary_content['vocab_size']
        model_config['max_position_embeddings'] = dictionary_content['max_position_embeddings']
        model_config['layernorm_operation'] = 2
    else:
        st.warning("Fetching information failed! Maybe model info is not public!")
        model_config['model_id'] = 'opt-1.3b'
        model_config['hidden_size'] = 2048
        model_config['num_attention_heads'] = 32
        model_config['num_hidden_layers'] = 24
        model_config['intermediate_size'] = 8192
        model_config['vocab_size'] = 50272
        model_config['max_position_embeddings'] = 2048
        model_config['layernorm_operation'] = 2

    try:
        model_config['model'] = get_model(model_id, None, access_token=access_token if len(access_token)>0 else None)
        module_tensors = get_module_tensors(model_config['model'])
        model_config['module_classes'] = classify_module(module_tensors)
    except Exception as e:
        st.warning(e)
        model_config['model'] = None
        model_config['module_classes'] = None

    st.session_state['model_config'] = model_config
    return model_config


subtotal_parameters = [
    'embedding_weights',
    'attention_weights',
    'mlp_weights',
]

subtotal_operations = [
    'embeddings',
    'attention',
    'mlp',
    'total',
]



col1, col2, col3, col4, col5 = st.columns([0.8, 2, 2.5, 2.5, 0.01])

inference_config = {}
parameter_count = {}
cached_parameter_count = {}

prefilling_operation_count = {}
generation_operation_count = {}
prefilling_memory_count = {}
generation_memory_count = {}

gpu_config = {}
inference_info = {}

with col1:
    header4("Model")
    model_id = st.text_input("huggingface model id", 'ArthurZ/opt-13b')
    access_token = st.text_input("access token", '')
    model_config = load_model_config(model_id, access_token)
    model_config['hidden_size'] = st.number_input('hidden size', value=model_config['hidden_size'], format ="%d")
    model_config['num_attention_heads'] = st.number_input('num attention heads', value=model_config['num_attention_heads'], format ="%d")
    model_config['num_hidden_layers'] = st.number_input('num hidden layers', value=model_config['num_hidden_layers'], format ="%d")
    model_config['intermediate_size'] = st.number_input('intermediate size', value=model_config['intermediate_size'], format ="%d")
    model_config['vocab_size'] = st.number_input('vocab size', value= model_config['vocab_size'], format ="%d")
    model_config['max_position_embeddings'] = st.number_input('max position embeddings', value=model_config['max_position_embeddings'], format ="%d")
    model_config['hidden_size_per_head'] = model_config['hidden_size']/model_config['num_attention_heads']

    header4("Inference Setting")
    inference_config['batchsize'] = st.number_input('batchsize', value=1, format ="%d")
    inference_config['input_seq_length'] = st.number_input('input seq length', value=1, format ="%d")
    inference_config['output_seq_length'] = st.number_input('output seq length', value=1, format ="%d")
    inference_config['byte_per_parameter'] = st.number_input('byte per parameter', value=2, format ="%d")
    inference_config['KV_cache'] = st.checkbox("Use KV cache", value=True)

    header4("GPU Setting")
    gpu_config['Name'] = st.text_input('GPU Type', value="A6000")
    gpu_config['TFLOP'] = st.number_input('TFLOP', value=38.7, format ="%2f")
    gpu_config['memory_bandwidth'] = st.number_input('memory bandwidth (GB/s)', value=768, format ="%2d")
    gpu_config['arithmetic_intensity'] = gpu_config['TFLOP']*10**12/gpu_config['memory_bandwidth']/1024**3
    st.write(f"arithmetic_intensity: {gpu_config['arithmetic_intensity']:.3f}")

with col2:
    if 'parameter_count' not in st.session_state:
        if model_config['model']:
            st.info("Model info fetcted!")
            parameter_count = calc_model_size_from_model(model_config, inference_config)
        else:
            st.info("Fail to fetch model info. Using estimation!")
            parameter_count = model_size_estimate(model_config, inference_config)
        st.session_state.parameter_count = parameter_count 
    else:
        parameter_count = st.session_state.parameter_count

    parameters_items = {key: "{:,}".format(int(parameter_count[key])) for key in parameter_count if key not in subtotal_parameters}
    subtotal_parameters_items = {key: "{:,}".format(int(parameter_count[key])) for key in parameter_count if key in subtotal_parameters}

    # Convert dictionaries to pandas dataframes for table display
    df_parameters_items = pd.DataFrame(list(parameters_items.items()), columns=["Parameter", "Count"])
    df_subtotal_parameters_items = pd.DataFrame(list(subtotal_parameters_items.items()), columns=["Parameter", "Count"])

    header4("Model Parameters")
    st.markdown(create_table(df_parameters_items))

    header4("Parameters Summary")
    st.markdown(create_table(df_subtotal_parameters_items))

    model_total_size_in_byte = inference_config['byte_per_parameter'] * (
                                                                            parameter_count['embedding_weights'] + 
                                                                            parameter_count['attention_weights'] + 
                                                                            parameter_count['mlp_weights'] + 
                                                                            parameter_count['layernorm']
                                                                        )
    st.write(f'model_total_size (Byte): {model_total_size_in_byte:,}')


    # add parameter viewer
    if model_config['model']:
        header4("Parameters Viewer")
        weight_generic = st.selectbox('Select weight:', options=model_config['module_classes'])
        modules = {}
        for module in model_config['module_classes'][weight_generic]:
            modules.update(module)
        modules = {k: list(v) for k, v in modules.items()}
        modules = pd.DataFrame(list(modules.items()), columns=["Parameter", "Shape"])
        st.markdown(create_table(modules))

with col3: # Prefilling
    prefilling_operation_count = prefilling_operation(model_config, inference_config)
    prefilling_activation_memory_count = prefilling_activation_memory(model_config, inference_config)
    inference_info['inference_prefilling_time'] = prefilling_operation_count['total'] / (gpu_config['TFLOP']*1024**4)
    inference_info['prefilling_memory_latency'] = prefilling_activation_memory_count['total'] / (gpu_config['memory_bandwidth']*1024**3)
    calc_prefilling_throughput(model_config, inference_config, inference_info)
    
    cached_parameter_count['kv_cache'] = 2 * (inference_config['batchsize'] * (model_config['hidden_size'] * model_config['num_hidden_layers'] * inference_config['input_seq_length']))

    operation_items = {key: "{:,}".format(int(prefilling_operation_count[key])) for key in prefilling_operation_count if key not in subtotal_operations}
    subtotal_operation_items = {key: "{:,}".format(int(prefilling_operation_count[key])) for key in prefilling_operation_count if key in subtotal_operations}
    prefilling_arithmetic_intensity = {key: "{:.3f}".format(prefilling_operation_count[key]/prefilling_activation_memory_count[key] if prefilling_activation_memory_count[key]>0 else float('inf'))  for key in prefilling_activation_memory_count}
    prefilling_activation_memory_count = {key: "{:,}".format(int(value)) for key, value in prefilling_activation_memory_count.items()}


    ## Convert dictionaries to pandas dataframes for table display
    df_operation_count = pd.DataFrame(list(operation_items.items()), columns=["Operation", "FLOPS"])
    df_subtotal_operation_count = pd.DataFrame(list(subtotal_operation_items.items()), columns=["Operation", "FLOPS"])
    
    df_operation_count["Activation (Byte)"] = df_operation_count["Operation"].map(prefilling_activation_memory_count)
    df_operation_count["Arithmetic Intensity"] = df_operation_count["Operation"].map(prefilling_arithmetic_intensity)
    df_subtotal_operation_count["Activation (Byte)"] = df_subtotal_operation_count["Operation"].map(prefilling_activation_memory_count)
    df_subtotal_operation_count["Arithmetic Intensity"] = df_subtotal_operation_count["Operation"].map(prefilling_arithmetic_intensity)
    
    header4("Inference Ops: Prefilling")
    st.markdown(create_table(df_operation_count))

    header5("Summary: Prefilling")
    st.markdown(create_table(df_subtotal_operation_count))
    st.write(f"FLOPS latency: {inference_info['inference_prefilling_time']}")
    st.write(f"Memory latency: {inference_info['prefilling_memory_latency']}")
    st.write(f"Prefillng throughput (tokens/s): {inference_info['prefilling_throughput']:.2f} ({inference_info['prefilling_bound_type']}-bound)")

    if inference_config['KV_cache']:
        st.write(f"kv cache (Byte): {cached_parameter_count['kv_cache']:,}")
        


with col4: # Generation
    generation_operation_count = generation_operation(model_config, inference_config)
    generation_activation_memory_count = generation_activation_memory(model_config, inference_config)
    inference_info['inference_generation_time'] = generation_operation_count['total'] / (gpu_config['TFLOP']*1024**4)
    inference_info['generation_memory_latency'] = generation_activation_memory_count['total'] / (gpu_config['memory_bandwidth']*1024**3)
    calc_generation_throughput(model_config, inference_config, inference_info)
   
    cached_parameter_count['kv_cache'] = 2 * (inference_config['batchsize'] * (model_config['hidden_size'] * model_config['num_hidden_layers'] * (inference_config['input_seq_length']+inference_config['output_seq_length'])))

    operation_items = {key: "{:,}".format(int(generation_operation_count[key])) for key in generation_operation_count if key not in subtotal_operations}
    subtotal_operation_items = {key: "{:,}".format(int(generation_operation_count[key])) for key in generation_operation_count if key in subtotal_operations}
    generation_arithmetic_intensity = {key: "{:.3f}".format(generation_operation_count[key]/generation_activation_memory_count[key] if generation_activation_memory_count[key]>0 else float('inf')) for key in generation_activation_memory_count}
    generation_activation_memory_count = {key: "{:,}".format(int(value)) for key, value in generation_activation_memory_count.items()}

    ## Convert dictionaries to pandas dataframes for table display
    df_operation_count = pd.DataFrame(list(operation_items.items()), columns=["Operation", "FLOPS"])
    df_subtotal_operation_count = pd.DataFrame(list(subtotal_operation_items.items()), columns=["Operation", "FLOPS"])
   
    df_operation_count["Activation (Byte)"] = df_operation_count["Operation"].map(generation_activation_memory_count)
    df_operation_count["Arithmetic Intensity"] = df_operation_count["Operation"].map(generation_arithmetic_intensity)
    df_subtotal_operation_count["Activation (Byte)"] = df_subtotal_operation_count["Operation"].map(generation_activation_memory_count)
    df_subtotal_operation_count["Arithmetic Intensity"] = df_subtotal_operation_count["Operation"].map(generation_arithmetic_intensity)
    
    header4("Inference Ops: Generation")
    st.markdown(create_table(df_operation_count))

    header5("Summary: Generation")
    st.markdown(create_table(df_subtotal_operation_count))
    #st.write(f"Generation-only throughput (tokens/s): {inference_info['inference_generation_throughput']:.2f}")
    #st.write(f"(Client) Generation throughput (tokens/s): {inference_info['inference_client_generation_throughput']:.2f}")
    st.write(f"FLOPS latency: {inference_info['inference_generation_time']}")
    st.write(f"Memory latency: {inference_info['generation_memory_latency']}")
    st.write(f"Generation-only throughput (tokens/s): {inference_info['generation_throughput']:.2f} ({inference_info['generation_bound_type']}-bound)")
    st.write(f"(Client) Generation throughput (tokens/s): {inference_info['client_generation_throughput']:.2f}")

    if inference_config['KV_cache']:
        st.write(f"kv cache (Byte): {cached_parameter_count['kv_cache']:,}")