File size: 1,588 Bytes
85e172b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import argparse

from tqdm import tqdm

import torch

from util import utils
from util import extraction


def cache_norms(
        model,
        tok,
        hparams,
        cache_norm_file
    ):
    """ Cache learable parameters in RMSNorm and LayerNorm layers
    """
    layers = hparams['v_loss_layer']+1

    for i in range(layers):
        norm_learnables = extraction.load_norm_learnables(model, hparams, i)
        
        if i == 0: results = {k:[] for k in norm_learnables}
        for key in norm_learnables:
            results[key].append(norm_learnables[key])

    for key in results:
        results[key] = torch.stack(results[key])

    utils.savepickle(cache_norm_file, results)
    print('Saved to ', cache_norm_file)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    
    parser.add_argument(
        '--model', default="gpt-j-6b", type=str, help='model to edit')

    parser.add_argument(
        '--cache_path', type=str, default='./cache/', help='output directory')

    args = parser.parse_args()

    # loading hyperparameters
    hparams_path = f'./hparams/SE/{args.model}.json'
    hparams = utils.loadjson(hparams_path)

    cache_norm_file = os.path.join(
        args.cache_path, f'norm_learnables_{args.model}.pickle'
    )
    if os.path.exists(cache_norm_file):
        print(f'File exists: {cache_norm_file}')
        exit()

    # load model and tokenizer
    model, tok = utils.load_model_tok(args.model)

    # cache norms
    cache_norms(
        model,
        tok,
        hparams,
        cache_norm_file
    )