Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,982 Bytes
85e172b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import argparse
import numpy as np
from tqdm import tqdm
from util import utils
from util import extraction, evaluation
from dsets import wikipedia
def cache_wikipedia(
model_name,
model,
tok,
max_len,
exclude_front = 0,
sample_size = 10000,
take_single = False,
exclude_path = None,
layers = None,
cache_path = None
):
# load wikipedia dataset
if max_len is not None:
raw_ds, tok_ds = wikipedia.get_ds(tok, maxlen=max_len)
else:
print('Finding max length of dataset...')
try:
raw_ds, tok_ds = wikipedia.get_ds(tok, maxlen=model.config.n_positions)
except:
raw_ds, tok_ds = wikipedia.get_ds(tok, maxlen=4096)
# extract features from each layer
for l in layers:
# try:
print('\n\nExtracting wikipedia token features for model layer:', l)
output_file = os.path.join(cache_path, f'wikipedia_features_{model_name}_layer{l}_w1.pickle')
if os.path.exists(output_file):
print('Output file already exists:', output_file)
continue
if exclude_path is not None:
exclude_file = os.path.join(exclude_path, f'wikipedia_features_{model_name}_layer{l}_w1.pickle')
exclude_indices = utils.loadpickle(exclude_file)['sampled_indices']
else:
exclude_indices = []
features, params = extraction.extract_tokdataset_features(
model,
tok_ds,
layer = l,
hparams = hparams,
exclude_front = exclude_front,
sample_size = sample_size,
take_single = take_single,
exclude_indices = exclude_indices,
verbose = True
)
# save features
params['features'] = features.cpu().numpy()
utils.savepickle(output_file, params)
print('Features saved:', output_file)
# except:
# print('Error extracting wikipedia features for layer:', l)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', default="gpt-j-6b", type=str, help='model to edit')
parser.add_argument(
'--sample_size', type=int, default=10000, help='number of feacture vectors to extract')
parser.add_argument(
'--max_len', type=int, default=None, help='maximum token length')
parser.add_argument(
'--exclude_front', type=int, default=0, help='number of tokens to exclude from the front')
parser.add_argument(
'--take_single', type=int, default=0, help='single vector from single wikipedia sample text')
parser.add_argument(
'--layer', type=int, default=None, help='single vector from single wikipedia sample text')
parser.add_argument(
'--exclude_path', type=str, default=None, help='output directory')
parser.add_argument(
'--cache_path', type=str, default='./cache/wiki_train/', help='output directory')
args = parser.parse_args()
# loading hyperparameters
hparams_path = f'./hparams/SE/{args.model}.json'
hparams = utils.loadjson(hparams_path)
# ensure save path exists
utils.assure_path_exists(args.cache_path)
# load model
model, tok = utils.load_model_tok(args.model)
if args.layer is not None:
layers = [args.layer]
else:
layers = evaluation.model_layer_indices[args.model]
# main function
cache_wikipedia(
model_name = args.model,
model = model,
tok = tok,
max_len = args.max_len,
layers = layers,
exclude_front = args.exclude_front,
sample_size = args.sample_size,
take_single = bool(args.take_single),
cache_path = args.cache_path,
exclude_path = args.exclude_path,
)
|