import numpy as np
from collections import defaultdict
from functools import partial
from typing import List
from model_util import get_module_tensors_matched

def calc_model_size_from_model(model_config, inference_config):
    get_module_tensors_matched_partial = partial(get_module_tensors_matched, module_classes_dict = model_config['module_classes'])

    parameter_count = defaultdict(float)
    parameter_count['word_embedding'] = sum([v.numel() for v in get_module_tensors_matched_partial(lambda x: 'embed' in x and 'pos' not in x)])
    parameter_count['positional_embedding'] = sum([v.numel() for v in get_module_tensors_matched_partial(lambda x: 'embed' in x and 'pos' in x)])

    parameter_count['attention_Q'] = sum([v.numel() for v in get_module_tensors_matched_partial(lambda x: 'att' in x and 'q' in x)])
    parameter_count['attention_K'] = sum([v.numel() for v in get_module_tensors_matched_partial(lambda x: 'att' in x and 'k' in x)])
    parameter_count['attention_V'] = sum([v.numel() for v in get_module_tensors_matched_partial(lambda x: 'att' in x and 'v' in x)])
    parameter_count['attention_out'] = sum([v.numel() for v in get_module_tensors_matched_partial(lambda x: 'att' in x and ('out_' in x or 'o_' in x))])
    
    parameter_count['layernorm'] = sum([v.numel() for v in get_module_tensors_matched_partial(lambda x: 'norm' in x)])
    parameter_count['mlp_weights'] = sum([v.numel() for v in get_module_tensors_matched_partial(lambda x: 'fc' in x or 'mlp' in x)])
       
    parameter_count['embedding_weights'] = parameter_count['word_embedding'] + parameter_count['positional_embedding']
    parameter_count['attention_weights'] = parameter_count['attention_out'] + parameter_count['attention_Q'] + parameter_count['attention_K'] + parameter_count['attention_V']
        
    return parameter_count

def model_size_estimate(model_config, inference_config):
    parameter_count = {}
    parameter_count['word_embedding'] = model_config['vocab_size']*model_config['hidden_size']
    parameter_count['positional_embedding'] = model_config['max_position_embeddings']*model_config['hidden_size']
    
    parameter_count['attention_Q']   = model_config['num_hidden_layers']*model_config['hidden_size']*model_config['hidden_size']/model_config['num_attention_heads']*model_config['num_attention_heads']
    parameter_count['attention_K']   = model_config['num_hidden_layers']*model_config['hidden_size']*model_config['hidden_size']/model_config['num_attention_heads']*model_config['num_attention_heads']
    parameter_count['attention_V']   = model_config['num_hidden_layers']*model_config['hidden_size']*model_config['hidden_size']/model_config['num_attention_heads']*model_config['num_attention_heads']
    parameter_count['attention_out'] = model_config['num_hidden_layers']*model_config['hidden_size']*model_config['hidden_size']/model_config['num_attention_heads']*model_config['num_attention_heads']
    
    parameter_count['layernorm'] = 2*model_config['layernorm_operation']*model_config['num_hidden_layers']*model_config['hidden_size']
    parameter_count['mlp1'] = model_config['num_hidden_layers']*model_config['hidden_size']*model_config['intermediate_size']
    parameter_count['mlp2'] = model_config['num_hidden_layers']*model_config['hidden_size']*model_config['intermediate_size']
    parameter_count['embedding_weights'] = parameter_count['word_embedding'] + parameter_count['positional_embedding']
    parameter_count['attention_weights'] = parameter_count['attention_out'] + parameter_count['attention_Q'] + parameter_count['attention_K'] + parameter_count['attention_V']
    parameter_count['mlp_weights'] = parameter_count['mlp1'] + parameter_count['mlp2']
    
    return parameter_count

def multiplication_in_int64(array):
    return np.cumprod(np.array(array, dtype=np.int64))[-1]

def matrix_operation(shapeA, shapeB):
    assert(shapeA[-1] == shapeB[0])
    op = np.cumprod(np.array(shapeA[:-1], np.float64))    
    return multiplication_in_int64([2, op[-1], shapeA[-1], shapeB[-1]])

def word_embedding_operation(model_config, inference_config):
    #Given:
    #\begin{itemize}
    #    \item Matrix \( X \) of size \( B \times s \) (representing the batch size and sequence length respectively).
    #    \item Embedding matrix \( W_e \) of size \( n_{vocab} \times d_{model} \).
    #\end{itemize}
    
    #The resultant matrix after the multiplication will be of size \( B \times s \times d_{model} \).
    #For each element in this resultant matrix, the number of FLOPs required is \( 2 \times n_{vocab} \). This is because for a single element in the output matrix, we have \( 2N \) FLOPs (with \( N \) being the common dimension), leading to the matrix multiplication FLOP count as:
    #\begin{equation}
    #2 \times B \times s \times n_{v ocab} \times d_{model}
    #\end{equation}
    if model_config['module_classes']:
        modules = get_module_tensors_matched(lambda x: 'embed' in x and 'pos' not in x, model_config['module_classes'])
        if len(modules) > 0:
            A = [inference_config['batchsize'], inference_config['input_seq_length'], modules[0][0]]
            B = modules[0]
            op_count = matrix_operation(A, B)
            return op_count

    A = [inference_config['batchsize'], inference_config['input_seq_length'], model_config['vocab_size']]
    B = [model_config['vocab_size'], model_config['hidden_size']]
    op_count = matrix_operation(A, B)
    return op_count


def positional_embedding_operation(model_config, inference_config):
    if model_config['module_classes']:
        modules = get_module_tensors_matched(lambda x: 'embed' in x and 'pos' in x, model_config['module_classes'])
        if len(modules) > 0:
            return multiplication_in_int64([inference_config['batchsize'], inference_config['input_seq_length'], modules[0][-1]])

    return multiplication_in_int64([inference_config['batchsize'], inference_config['input_seq_length'], model_config['hidden_size']])

### Below three are the same
def attention_K_operation(model_config, inference_config, seq_length):
    if model_config['module_classes']:
        modules = get_module_tensors_matched(lambda x: 'att' in x and 'k' in x , model_config['module_classes'])
        if len(modules) > 0:
            total = 0
            for module in modules:
                if len(module) > 1:
                    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
                    B = [model_config['hidden_size'], model_config['hidden_size_per_head']]
                    total += model_config['num_attention_heads']*matrix_operation(A, B)
                else:
                    total += model_config['hidden_size']
            return total

    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
    B = [model_config['hidden_size'], model_config['hidden_size_per_head']]
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * matrix_operation(A, B)

def attention_Q_operation(model_config, inference_config, seq_length):
    if model_config['module_classes']:
        modules = get_module_tensors_matched(lambda x: 'att' in x and 'q' in x , model_config['module_classes'])
        if len(modules) > 0:
            total = 0
            for module in modules:
                if len(module) > 1:
                    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
                    B = [model_config['hidden_size'], model_config['hidden_size_per_head']]
                    total += model_config['num_attention_heads']*matrix_operation(A, B)
                else:
                    total += model_config['hidden_size']
            return total
     
    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
    B = [model_config['hidden_size'], model_config['hidden_size_per_head']]
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * matrix_operation(A, B)

def attention_V_operation(model_config, inference_config, seq_length):
    if model_config['module_classes']:
        modules = get_module_tensors_matched(lambda x: 'att' in x and 'v' in x , model_config['module_classes'])
        if len(modules) > 0:
            total = 0
            for module in modules:
                if len(module) > 1:
                    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
                    B = [model_config['hidden_size'], model_config['hidden_size_per_head']]
                    total += model_config['num_attention_heads']*matrix_operation(A, B)
                else:
                    total += model_config['hidden_size']
            return total
     
    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
    B = [model_config['hidden_size'], model_config['hidden_size_per_head']]
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * matrix_operation(A, B)

## 
def attention_QK_operation(model_config, inference_config, seq_length_Q, seq_length_K):
    A = [inference_config['batchsize'], seq_length_Q, model_config['hidden_size_per_head']]
    B = [model_config['hidden_size_per_head'], seq_length_K]
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * matrix_operation(A, B)

def attention_softmax_operation(model_config, inference_config,seq_length):
    # Ref: Ouyang, A. (2023). Understanding the Performance of Transformer Inference (Doctoral dissertation, Massachusetts Institute of Technology).
    # 3 is a modeled value
    softmax_operation = (3*inference_config['batchsize']*seq_length*seq_length)
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * softmax_operation

def attention_multV_operation(model_config, inference_config, seq_length_Q, seq_length_V):
    A = [inference_config['batchsize'], seq_length_Q, seq_length_V]
    B = [seq_length_V, model_config['hidden_size_per_head']]
    return model_config['num_hidden_layers'] * model_config['num_attention_heads']* matrix_operation(A, B)

def attention_out_operation(model_config, inference_config, seq_length):
    if model_config['module_classes']:
        modules = get_module_tensors_matched(lambda x: 'att' in x and 'k' in x , model_config['module_classes'])
        if len(modules) > 0:
            total = 0
            for module in modules:
                if len(module) > 1:
                    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
                    B = [model_config['hidden_size'], model_config['hidden_size']]
                    total += matrix_operation(A, B)
                else:
                    total += model_config['hidden_size']
            return total
    
    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
    B = [model_config['hidden_size'], model_config['hidden_size']]
    return model_config['num_hidden_layers'] * matrix_operation(A, B)

def layernorm_operation(model_config, inference_config, seq_length):
    # Ref: Ouyang, A. (2023). Understanding the Performance of Transformer Inference (Doctoral dissertation, Massachusetts Institute of Technology).
    # 5 is a modeled value
    if model_config['module_classes']:
        modules = get_module_tensors_matched(lambda x: 'norm' in x, model_config['module_classes'])
        if len(modules) > 0:
            total = 0
            for module in modules:
                total += model_config['hidden_size']
            return 5*total
        
    layernorm_operation = (5*inference_config['batchsize']*seq_length*model_config['hidden_size'])
    return model_config['num_hidden_layers'] * model_config['layernorm_operation'] * layernorm_operation


def mlp_operation(model_config, inference_config, seq_length):
    if model_config['module_classes']:
        modules = get_module_tensors_matched(lambda x: 'fc' in x or 'mlp' in x, model_config['module_classes'])
        if len(modules) > 0:
            total = 0
            for module in modules:
                if len(module) > 1:
                    A = [inference_config['batchsize'], seq_length, module[1]]
                    B = [module[1], module[0]]
                    total += matrix_operation(A, B)
                else:
                    total += modules[-1][0]
            return total
    
    A = [inference_config['batchsize'], seq_length, model_config['hidden_size']]
    B = [model_config['hidden_size'], model_config['intermediate_size']]
    return model_config['num_hidden_layers'] * (2*matrix_operation(A, B))


def prefilling_operation(model_config, inference_config):
    prefilling_operation_count = {}
    prefilling_operation_count['word_embedding'] = word_embedding_operation(model_config, inference_config)
    prefilling_operation_count['positional_embedding'] = positional_embedding_operation(model_config, inference_config)
    
    prefilling_operation_count['attention_Q'] = attention_Q_operation(model_config, inference_config, inference_config['input_seq_length'])
    prefilling_operation_count['attention_K'] = attention_K_operation(model_config, inference_config, inference_config['input_seq_length'])
    prefilling_operation_count['attention_V'] = attention_V_operation(model_config, inference_config, inference_config['input_seq_length'])
    prefilling_operation_count['attention_QK'] = attention_QK_operation(model_config, inference_config, inference_config['input_seq_length'], inference_config['input_seq_length'])
    prefilling_operation_count['attention_softmax'] = attention_softmax_operation(model_config, inference_config, inference_config['input_seq_length'])
    prefilling_operation_count['attention_multV'] = attention_multV_operation(model_config, inference_config, inference_config['input_seq_length'], inference_config['input_seq_length'])
    prefilling_operation_count['attention_out'] = attention_out_operation(model_config, inference_config, inference_config['input_seq_length'])

    prefilling_operation_count['layernorm'] =layernorm_operation(model_config, inference_config, inference_config['input_seq_length'])

    prefilling_operation_count['mlp'] = mlp_operation(model_config, inference_config, inference_config['input_seq_length'])
    
    prefilling_operation_count['embeddings'] = prefilling_operation_count['word_embedding'] + prefilling_operation_count['positional_embedding']
    prefilling_operation_count['attention'] = sum([v for k,v in prefilling_operation_count.items() if 'attention' in k])
    prefilling_operation_count['total'] = (prefilling_operation_count['embeddings'] + prefilling_operation_count['attention'] + prefilling_operation_count['mlp'] + prefilling_operation_count['layernorm'])
    
    return prefilling_operation_count

def generation_operation(model_config, inference_config):
    generation_operation_count = {}
    generation_operation_count['word_embedding'] = 0
    generation_operation_count['positional_embedding'] = 0
    generation_operation_count['attention_K'] = 0
    generation_operation_count['attention_V'] = 0
    generation_operation_count['attention_Q'] = 0
    generation_operation_count['attention_QK'] = 0
    generation_operation_count['attention_softmax'] = 0
    generation_operation_count['attention_multV'] = 0
    generation_operation_count['attention_out'] = 0
    generation_operation_count['mlp'] = 0
    generation_operation_count['layernorm'] = 0

    for t in range(inference_config['output_seq_length']):
        if inference_config['KV_cache']:
            generation_operation_count['attention_K'] += attention_K_operation(model_config, inference_config, 1)
            generation_operation_count['attention_V'] += attention_V_operation(model_config, inference_config, 1)
            generation_operation_count['attention_Q'] += attention_Q_operation(model_config, inference_config, 1)
            generation_operation_count['attention_QK'] += attention_QK_operation(model_config, inference_config, seq_length_Q=1, seq_length_K=(t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_softmax'] += attention_softmax_operation(model_config, inference_config, 1)
            generation_operation_count['attention_multV'] += attention_multV_operation(model_config, inference_config, seq_length_Q=1, seq_length_V=(t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_out'] += attention_out_operation(model_config, inference_config, 1)
            generation_operation_count['mlp'] += mlp_operation(model_config, inference_config, 1)
        else:
            generation_operation_count['attention_K'] += attention_K_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_V'] += attention_V_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_Q'] += attention_Q_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_QK'] += attention_QK_operation(model_config, inference_config, seq_length_Q=(t+1)+inference_config['input_seq_length'], seq_length_K=(t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_softmax'] += attention_softmax_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_multV'] += attention_multV_operation(model_config, inference_config, seq_length_Q=(t+1)+inference_config['input_seq_length'], seq_length_V=(t+1)+inference_config['input_seq_length'])
            generation_operation_count['attention_out'] += attention_out_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            generation_operation_count['mlp'] += mlp_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])

        generation_operation_count['layernorm'] += layernorm_operation(model_config, inference_config, (t+1)+inference_config['input_seq_length'])

    generation_operation_count['embeddings'] = generation_operation_count['word_embedding'] + generation_operation_count['positional_embedding'] 
    generation_operation_count['attention'] = sum([v for k,v in generation_operation_count.items() if 'attention' in k])
    generation_operation_count['total'] = (generation_operation_count['attention'] + generation_operation_count['mlp'] + generation_operation_count['layernorm'])

    return generation_operation_count


def word_embedding_activation_memory(model_config, inference_config, seq_length):
    return inference_config['batchsize'] * seq_length * (model_config['vocab_size'] + model_config['hidden_size'])

def positional_embedding_activation_memory(model_config, inference_config, seq_length):
    return 2 * inference_config['batchsize'] * seq_length * model_config['hidden_size']

def attention_K_activation_memory(model_config, inference_config, seq_length):
    per_head_per_layer = inference_config['batchsize'] * seq_length * (model_config['hidden_size'] + model_config['hidden_size_per_head'])
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * per_head_per_layer

def attention_V_activation_memory(model_config, inference_config, seq_length):
    per_head_per_layer = inference_config['batchsize'] * seq_length * (model_config['hidden_size'] + model_config['hidden_size_per_head'])
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * per_head_per_layer

def attention_Q_activation_memory(model_config, inference_config, seq_length):
    per_head_per_layer = inference_config['batchsize'] * seq_length * (model_config['hidden_size'] + model_config['hidden_size_per_head'])
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * per_head_per_layer

def attention_QK_activation_memory(model_config, inference_config, seq_length_Q, seq_length_K):
    inputs_Q = inference_config['batchsize'] * seq_length_Q * model_config['hidden_size_per_head']
    inputs_K = inference_config['batchsize'] * seq_length_K * model_config['hidden_size_per_head']
    outputs =  inference_config['batchsize'] * seq_length_Q * seq_length_K
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * (inputs_Q + inputs_K + outputs)

def attention_softmax_activation_memory(model_config, inference_config, seq_length):
    per_head_per_layer = (2 * inference_config['batchsize'] * seq_length * seq_length)
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * per_head_per_layer

def attention_multV_activation_memory(model_config, inference_config, seq_length_Q, seq_length_V):
    per_head_per_layer = inference_config['batchsize'] * seq_length_Q * seq_length_V + inference_config['batchsize'] * seq_length_Q * model_config['hidden_size_per_head'] + inference_config['batchsize'] * seq_length_V * model_config['hidden_size_per_head']
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * per_head_per_layer

def attention_out_activation_memory(model_config, inference_config, seq_length):
    per_head_per_layer = 2 * inference_config['batchsize'] * seq_length * model_config['hidden_size']
    return model_config['num_hidden_layers'] * model_config['num_attention_heads'] * per_head_per_layer

def layernorm_activation_memory(model_config, inference_config, seq_length):
    per_layernorm_per_layer = 2 * inference_config['batchsize'] * seq_length * model_config['hidden_size']
    return model_config['num_hidden_layers'] * model_config['layernorm_operation'] * per_layernorm_per_layer

def mlp_activation_memory(model_config, inference_config, seq_length):
    # two mlp layer
    per_layer = 2 * inference_config['batchsize'] * seq_length * (model_config['hidden_size'] + model_config['intermediate_size'])
    return model_config['num_hidden_layers'] * per_layer

def prefilling_activation_memory(model_config, inference_config):
    activation_memory = {}
    
    activation_memory['word_embedding'] = word_embedding_activation_memory(model_config, inference_config, inference_config['input_seq_length'])
    activation_memory['positional_embedding'] = positional_embedding_activation_memory(model_config, inference_config, inference_config['input_seq_length'])
    
    activation_memory['attention_Q'] = attention_Q_activation_memory(model_config, inference_config, inference_config['input_seq_length'])
    activation_memory['attention_K'] = attention_K_activation_memory(model_config, inference_config, inference_config['input_seq_length'])
    activation_memory['attention_V'] = attention_V_activation_memory(model_config, inference_config, inference_config['input_seq_length'])
    activation_memory['attention_QK'] = attention_QK_activation_memory(model_config, inference_config, inference_config['input_seq_length'], inference_config['input_seq_length'])
    activation_memory['attention_softmax'] = attention_softmax_activation_memory(model_config, inference_config, inference_config['input_seq_length'])
    activation_memory['attention_multV'] = attention_multV_activation_memory(model_config, inference_config, inference_config['input_seq_length'], inference_config['input_seq_length'])
    activation_memory['attention_out'] = attention_out_activation_memory(model_config, inference_config, inference_config['input_seq_length'])

    activation_memory['layernorm'] = layernorm_activation_memory(model_config, inference_config, inference_config['input_seq_length'])

    activation_memory['mlp'] = mlp_activation_memory(model_config, inference_config, inference_config['input_seq_length'])
    
    activation_memory['embeddings'] = activation_memory['word_embedding'] + activation_memory['positional_embedding']
    activation_memory['attention'] = (
        activation_memory['attention_Q'] + activation_memory['attention_K'] +
        activation_memory['attention_V'] + activation_memory['attention_QK'] +
        activation_memory['attention_softmax'] + activation_memory['attention_multV'] +
        activation_memory['attention_out']
    )
    activation_memory['total'] = (
        activation_memory['embeddings'] + activation_memory['attention'] +
        activation_memory['mlp'] + activation_memory['layernorm']
    )
    
    activation_memory['embeddings'] = activation_memory['word_embedding'] + activation_memory['positional_embedding'] 
    activation_memory['attention'] = sum([v for k,v in activation_memory.items() if 'attention' in k])
    activation_memory['total'] = (activation_memory['attention'] + activation_memory['mlp'] + activation_memory['layernorm'])

    return activation_memory

def generation_activation_memory(model_config, inference_config):
    activation_memory = {}

    activation_memory['word_embedding'] = 0
    activation_memory['positional_embedding'] = 0
    activation_memory['attention_K'] = 0
    activation_memory['attention_V'] = 0
    activation_memory['attention_Q'] = 0
    activation_memory['attention_QK'] = 0
    activation_memory['attention_softmax'] = 0
    activation_memory['attention_multV'] = 0
    activation_memory['attention_out'] = 0
    activation_memory['mlp'] = 0
    activation_memory['layernorm'] = 0

    for t in range(inference_config['output_seq_length']):
        if inference_config['KV_cache']:
            activation_memory['attention_K'] += attention_K_activation_memory(model_config, inference_config, 1)
            activation_memory['attention_V'] += attention_V_activation_memory(model_config, inference_config, 1)
            activation_memory['attention_Q'] += attention_Q_activation_memory(model_config, inference_config, 1)
            activation_memory['attention_QK'] += attention_QK_activation_memory(model_config, inference_config, seq_length_Q=1, seq_length_K=(t+1)+inference_config['input_seq_length'])
            activation_memory['attention_softmax'] += attention_softmax_activation_memory(model_config, inference_config, 1)
            activation_memory['attention_multV'] += attention_multV_activation_memory(model_config, inference_config, seq_length_Q=1, seq_length_V=(t+1)+inference_config['input_seq_length'])
            activation_memory['attention_out'] += attention_out_activation_memory(model_config, inference_config, 1)
            activation_memory['mlp'] += mlp_activation_memory(model_config, inference_config, 1)
        else:
            activation_memory['attention_K'] += attention_K_activation_memory(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            activation_memory['attention_V'] += attention_V_activation_memory(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            activation_memory['attention_Q'] += attention_Q_activation_memory(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            activation_memory['attention_QK'] += attention_QK_activation_memory(model_config, inference_config, seq_length_Q=(t+1)+inference_config['input_seq_length'], seq_length_K=(t+1)+inference_config['input_seq_length'])
            activation_memory['attention_softmax'] += attention_softmax_activation_memory(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            activation_memory['attention_multV'] += attention_multV_activation_memory(model_config, inference_config, seq_length_Q=(t+1)+inference_config['input_seq_length'], seq_length_V=(t+1)+inference_config['input_seq_length'])
            activation_memory['attention_out'] += attention_out_activation_memory(model_config, inference_config, (t+1)+inference_config['input_seq_length'])
            activation_memory['mlp'] += mlp_activation_memory(model_config, inference_config, (t+1)+inference_config['input_seq_length'])

        activation_memory['layernorm'] += layernorm_activation_memory(model_config, inference_config, (t+1)+inference_config['input_seq_length'])

    activation_memory['embeddings'] = activation_memory['word_embedding'] + activation_memory['positional_embedding']
    activation_memory['attention'] = (
        activation_memory['attention_K'] + activation_memory['attention_V'] +
        activation_memory['attention_Q'] + activation_memory['attention_QK'] +
        activation_memory['attention_softmax'] + activation_memory['attention_multV'] +
        activation_memory['attention_out']
    )
    activation_memory['total'] = (
        activation_memory['embeddings'] + activation_memory['attention'] +
        activation_memory['mlp'] + activation_memory['layernorm']
    )

    return activation_memory


def calc_prefilling_throughput(model_config, inference_config, inference_info):
    inference_info['prefilling_throughput'] = inference_config['input_seq_length']*inference_config['batchsize'] / max([inference_info['inference_prefilling_time'], inference_info['prefilling_memory_latency']])
    inference_info['prefilling_bound_type'] = "memory" if inference_info['inference_prefilling_time'] < inference_info['prefilling_memory_latency'] else "arithmetic"

def calc_generation_throughput(model_config, inference_config, inference_info):
    inference_info['generation_throughput'] = inference_config['input_seq_length']*inference_config['batchsize'] / max([inference_info['inference_generation_time'], inference_info['generation_memory_latency']])
    inference_info['generation_bound_type'] = "memory" if inference_info['inference_generation_time'] < inference_info['generation_memory_latency'] else "arithmetic"
    
    total_time = max([inference_info['inference_prefilling_time'], inference_info['prefilling_memory_latency']]) + max([inference_info['inference_generation_time'], inference_info['generation_memory_latency']])
    inference_info['client_generation_throughput'] = inference_config['output_seq_length']*inference_config['batchsize'] / total_time