Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,633 Bytes
85e172b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import copy
import argparse
import numpy as np
from tqdm import tqdm
from util import utils
from util import extraction, evaluation
def cache_features(
model,
tok,
dataset,
hparams,
cache_features_file,
layers,
batch_size = 64,
static_context = '',
selection = None,
reverse_selection = False,
verbose = True
):
""" Function to load or cache features from dataset
"""
if os.path.exists(cache_features_file):
print('Loaded cached features file: ', cache_features_file)
cache_features_contents = utils.loadpickle(cache_features_file)
raw_case_ids = cache_features_contents['case_ids']
else:
# find raw requests and case_ids
raw_ds, _, _ = utils.load_dataset(tok, ds_name=dataset)
raw_requests = utils.extract_requests(raw_ds)
raw_case_ids = np.array([r['case_id'] for r in raw_requests])
# construct prompts and subjects
subjects = [static_context + r['prompt'].format(r['subject']) for r in raw_requests]
prompts = ['{}']*len(subjects)
# run multilayer feature extraction
_returns_across_layer = extraction.extract_multilayer_at_tokens(
model,
tok,
prompts,
subjects,
layers = layers,
module_template = hparams['rewrite_module_tmp'],
tok_type = 'prompt_final',
track = 'in',
batch_size = batch_size,
return_logits = False,
verbose = True
)
for key in _returns_across_layer:
_returns_across_layer[key] = _returns_across_layer[key]['in']
cache_features_contents = {}
for i in layers:
cache_features_contents[i] = \
_returns_across_layer[hparams['rewrite_module_tmp'].format(i)]
cache_features_contents['case_ids'] = raw_case_ids
cache_features_contents['prompts'] = np.array(prompts)
cache_features_contents['subjects'] = np.array(subjects)
utils.assure_path_exists(os.path.dirname(cache_features_file))
utils.savepickle(cache_features_file, cache_features_contents)
print('Saved features cache file: ', cache_features_file)
# filter cache_ppl_contents for selected samples
if selection is not None:
# load json file containing a dict with key case_ids containing a list of selected samples
select_case_ids = utils.loadjson(selection)['case_ids']
# boolean mask for selected samples w.r.t. all samples in the subjects pickle
matching = utils.generate_mask(raw_case_ids, np.array(select_case_ids))
if reverse_selection: matching = ~matching
# filter cache_ppl_contents for selected samples
cache_features_contents = utils.filter_for_selection(cache_features_contents, matching)
return cache_features_contents
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', default="gpt-j-6b", type=str, help='model to edit')
parser.add_argument(
'--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation')
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for extraction')
parser.add_argument(
'--layer', type=int, default=None, help='layer for extraction')
parser.add_argument(
'--cache_path', type=str, default='./cache/', help='output directory')
args = parser.parse_args()
# loading hyperparameters
hparams_path = f'./hparams/SE/{args.model}.json'
hparams = utils.loadjson(hparams_path)
# ensure save path exists
utils.assure_path_exists(args.cache_path)
# load model
model, tok = utils.load_model_tok(args.model)
# get layers to extract features from
if args.layer is not None:
layers = [args.layer]
cache_features_file = os.path.join(
args.cache_path, f'prompts_extract_{args.dataset}_{args.model}_layer{args.layer}.pickle'
)
else:
layers = evaluation.model_layer_indices[hparams['model_name']]
cache_features_file = os.path.join(
args.cache_path, f'prompts_extract_{args.dataset}_{args.model}.pickle'
)
# cache features
_ = cache_features(
model,
tok,
args.dataset,
hparams,
cache_features_file,
layers,
batch_size = args.batch_size,
verbose = True
) |