File size: 2,220 Bytes
			
			| 0af560f f2c15d5 0af560f f2c15d5 0af560f f2c15d5 0af560f f2c15d5 0af560f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | def activation_memory(
    a, # attention heads
    b, # micro batch size
    h, # hidden dimension size
    h_ff, # feedforward dimension size (often h_ff = 4h)
    L, # number of layers
    s, # sequence length
    mixed=True,
    recomputation="none"
    ):
    
    # https://arxiv.org/pdf/2205.05198
    if mixed:
        bytes_per_value = 2 
    else:
        bytes_per_value = 4
    one_layer_attention = s * b * h * (bytes_per_value * 5 + 1) + ((2 * bytes_per_value + 1) * a * s * s * b) # eq (2)
    one_layer_feedforward_mlp = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value)   # inputs of 1st/2nd linear layers
         + s * b * h_ff * bytes_per_value # inputs of activation function (not really necessary for Relu though)
            + s * b * h)  # dropout
    one_layer_feedforward_swiglu = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value)   # inputs of input/output linear layers
         + s * b * h_ff * bytes_per_value * 3 # inputs of activation function
            + s * b * h)  # dropout (note that dropout is lower-precision - boolean)
    if recomputation == "none":
        one_layer = one_layer_attention # eq (2)
    elif recomputation =="selective":
        one_layer = s * b * h * 34 # eq (6)
    elif recomputation =="full":
        one_layer = s * b * h * 2
    else:
        raise ValueError()
    
    input_dropout = 0  # s * b * h # section 4.3
    total = L * one_layer + input_dropout
        
    return total
def param_grads_opt(
    h, # hidden dimension size
    L, # number of layers
    s, # sequence length
    v, # vocab size
    k=8, # parameters for optimizer (Adam: 8 = 4 bytes moments + 4 bytes variance)
    mixed=True # mixed precision training
    ):
    
    # https://michaelwornow.net/2024/01/18/counting-params-in-transformer
    # note: this is without GQA or MQA
    
    emb = h*(v+s)
    one_layer = 12 * h**2 + 13*h
    other = 2*h
    n = emb + L * one_layer + other
    
    # 3.1 https://arxiv.org/pdf/1910.02054
    
    if mixed:
        k += 4 # additional full precision weights
        bytes_per_paramter = 2
    else:
        bytes_per_paramter = 4
    
    return bytes_per_paramter*n, bytes_per_paramter*n, k*n
 | 
