Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import copy | |
import argparse | |
import numpy as np | |
from tqdm import tqdm | |
from util import utils | |
from util import extraction, evaluation | |
def cache_features( | |
model, | |
tok, | |
dataset, | |
hparams, | |
cache_features_file, | |
layers, | |
batch_size = 64, | |
static_context = '', | |
selection = None, | |
reverse_selection = False, | |
verbose = True | |
): | |
""" Function to load or cache features from dataset | |
""" | |
if os.path.exists(cache_features_file): | |
print('Loaded cached features file: ', cache_features_file) | |
cache_features_contents = utils.loadpickle(cache_features_file) | |
raw_case_ids = cache_features_contents['case_ids'] | |
else: | |
# find raw requests and case_ids | |
raw_ds, _, _ = utils.load_dataset(tok, ds_name=dataset) | |
raw_requests = utils.extract_requests(raw_ds) | |
raw_case_ids = np.array([r['case_id'] for r in raw_requests]) | |
# construct prompts and subjects | |
subjects = [static_context + r['prompt'].format(r['subject']) for r in raw_requests] | |
prompts = ['{}']*len(subjects) | |
# run multilayer feature extraction | |
_returns_across_layer = extraction.extract_multilayer_at_tokens( | |
model, | |
tok, | |
prompts, | |
subjects, | |
layers = layers, | |
module_template = hparams['rewrite_module_tmp'], | |
tok_type = 'prompt_final', | |
track = 'in', | |
batch_size = batch_size, | |
return_logits = False, | |
verbose = True | |
) | |
for key in _returns_across_layer: | |
_returns_across_layer[key] = _returns_across_layer[key]['in'] | |
cache_features_contents = {} | |
for i in layers: | |
cache_features_contents[i] = \ | |
_returns_across_layer[hparams['rewrite_module_tmp'].format(i)] | |
cache_features_contents['case_ids'] = raw_case_ids | |
cache_features_contents['prompts'] = np.array(prompts) | |
cache_features_contents['subjects'] = np.array(subjects) | |
utils.assure_path_exists(os.path.dirname(cache_features_file)) | |
utils.savepickle(cache_features_file, cache_features_contents) | |
print('Saved features cache file: ', cache_features_file) | |
# filter cache_ppl_contents for selected samples | |
if selection is not None: | |
# load json file containing a dict with key case_ids containing a list of selected samples | |
select_case_ids = utils.loadjson(selection)['case_ids'] | |
# boolean mask for selected samples w.r.t. all samples in the subjects pickle | |
matching = utils.generate_mask(raw_case_ids, np.array(select_case_ids)) | |
if reverse_selection: matching = ~matching | |
# filter cache_ppl_contents for selected samples | |
cache_features_contents = utils.filter_for_selection(cache_features_contents, matching) | |
return cache_features_contents | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--model', default="gpt-j-6b", type=str, help='model to edit') | |
parser.add_argument( | |
'--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation') | |
parser.add_argument( | |
'--batch_size', type=int, default=64, help='batch size for extraction') | |
parser.add_argument( | |
'--layer', type=int, default=None, help='layer for extraction') | |
parser.add_argument( | |
'--cache_path', type=str, default='./cache/', help='output directory') | |
args = parser.parse_args() | |
# loading hyperparameters | |
hparams_path = f'./hparams/SE/{args.model}.json' | |
hparams = utils.loadjson(hparams_path) | |
# ensure save path exists | |
utils.assure_path_exists(args.cache_path) | |
# load model | |
model, tok = utils.load_model_tok(args.model) | |
# get layers to extract features from | |
if args.layer is not None: | |
layers = [args.layer] | |
cache_features_file = os.path.join( | |
args.cache_path, f'prompts_extract_{args.dataset}_{args.model}_layer{args.layer}.pickle' | |
) | |
else: | |
layers = evaluation.model_layer_indices[hparams['model_name']] | |
cache_features_file = os.path.join( | |
args.cache_path, f'prompts_extract_{args.dataset}_{args.model}.pickle' | |
) | |
# cache features | |
_ = cache_features( | |
model, | |
tok, | |
args.dataset, | |
hparams, | |
cache_features_file, | |
layers, | |
batch_size = args.batch_size, | |
verbose = True | |
) |