File size: 3,982 Bytes
85e172b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131


import os
import argparse

import numpy as np
from tqdm import tqdm

from util import utils
from util import extraction, evaluation

from dsets import wikipedia


def cache_wikipedia(
        model_name,
        model,
        tok,
        max_len,
        exclude_front = 0,
        sample_size = 10000,
        take_single = False,
        exclude_path = None,
        layers = None,
        cache_path = None
    ):
    # load wikipedia dataset
    if max_len is not None:
        raw_ds, tok_ds = wikipedia.get_ds(tok, maxlen=max_len)
    else:
        print('Finding max length of dataset...')
        try:
            raw_ds, tok_ds = wikipedia.get_ds(tok, maxlen=model.config.n_positions)
        except:
            raw_ds, tok_ds = wikipedia.get_ds(tok, maxlen=4096)

    # extract features from each layer
    for l in layers:

        # try:
            print('\n\nExtracting wikipedia token features for model layer:', l)

            output_file = os.path.join(cache_path, f'wikipedia_features_{model_name}_layer{l}_w1.pickle')
            if os.path.exists(output_file):
                print('Output file already exists:', output_file)
                continue

            if exclude_path is not None:
                exclude_file = os.path.join(exclude_path, f'wikipedia_features_{model_name}_layer{l}_w1.pickle')
                exclude_indices = utils.loadpickle(exclude_file)['sampled_indices']
            else:
                exclude_indices = []

            features, params = extraction.extract_tokdataset_features(
                model,
                tok_ds,
                layer = l,
                hparams = hparams,   
                exclude_front = exclude_front,
                sample_size = sample_size,
                take_single = take_single,
                exclude_indices = exclude_indices,
                verbose = True
            )
            # save features
            params['features'] = features.cpu().numpy()
            utils.savepickle(output_file, params)
            print('Features saved:', output_file)

        # except:
        #     print('Error extracting wikipedia features for layer:', l)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    
    parser.add_argument(
        '--model', default="gpt-j-6b", type=str, help='model to edit')
        
    parser.add_argument(
        '--sample_size', type=int, default=10000, help='number of feacture vectors to extract')

    parser.add_argument(
        '--max_len', type=int, default=None, help='maximum token length')
    parser.add_argument(
        '--exclude_front', type=int, default=0, help='number of tokens to exclude from the front')
    parser.add_argument(
        '--take_single', type=int, default=0, help='single vector from single wikipedia sample text')

    parser.add_argument(
        '--layer', type=int, default=None, help='single vector from single wikipedia sample text')

    parser.add_argument(
        '--exclude_path', type=str, default=None, help='output directory')

    parser.add_argument(
        '--cache_path', type=str, default='./cache/wiki_train/', help='output directory')

    args = parser.parse_args()

    # loading hyperparameters
    hparams_path = f'./hparams/SE/{args.model}.json'
    hparams = utils.loadjson(hparams_path)

    # ensure save path exists
    utils.assure_path_exists(args.cache_path)

    # load model 
    model, tok = utils.load_model_tok(args.model)

    if args.layer is not None:
        layers = [args.layer]
    else:
        layers = evaluation.model_layer_indices[args.model]

    # main function
    cache_wikipedia(
        model_name = args.model,
        model = model,
        tok = tok,
        max_len = args.max_len,
        layers = layers,
        exclude_front = args.exclude_front,
        sample_size = args.sample_size,
        take_single = bool(args.take_single),
        cache_path = args.cache_path,
        exclude_path = args.exclude_path,
    )