stealth-edits / experiments /extract_features.py
qinghuazhou
Initial commit
85e172b
raw
history blame
4.63 kB
import os
import copy
import argparse
import numpy as np
from tqdm import tqdm
from util import utils
from util import extraction, evaluation
def cache_features(
model,
tok,
dataset,
hparams,
cache_features_file,
layers,
batch_size = 64,
static_context = '',
selection = None,
reverse_selection = False,
verbose = True
):
""" Function to load or cache features from dataset
"""
if os.path.exists(cache_features_file):
print('Loaded cached features file: ', cache_features_file)
cache_features_contents = utils.loadpickle(cache_features_file)
raw_case_ids = cache_features_contents['case_ids']
else:
# find raw requests and case_ids
raw_ds, _, _ = utils.load_dataset(tok, ds_name=dataset)
raw_requests = utils.extract_requests(raw_ds)
raw_case_ids = np.array([r['case_id'] for r in raw_requests])
# construct prompts and subjects
subjects = [static_context + r['prompt'].format(r['subject']) for r in raw_requests]
prompts = ['{}']*len(subjects)
# run multilayer feature extraction
_returns_across_layer = extraction.extract_multilayer_at_tokens(
model,
tok,
prompts,
subjects,
layers = layers,
module_template = hparams['rewrite_module_tmp'],
tok_type = 'prompt_final',
track = 'in',
batch_size = batch_size,
return_logits = False,
verbose = True
)
for key in _returns_across_layer:
_returns_across_layer[key] = _returns_across_layer[key]['in']
cache_features_contents = {}
for i in layers:
cache_features_contents[i] = \
_returns_across_layer[hparams['rewrite_module_tmp'].format(i)]
cache_features_contents['case_ids'] = raw_case_ids
cache_features_contents['prompts'] = np.array(prompts)
cache_features_contents['subjects'] = np.array(subjects)
utils.assure_path_exists(os.path.dirname(cache_features_file))
utils.savepickle(cache_features_file, cache_features_contents)
print('Saved features cache file: ', cache_features_file)
# filter cache_ppl_contents for selected samples
if selection is not None:
# load json file containing a dict with key case_ids containing a list of selected samples
select_case_ids = utils.loadjson(selection)['case_ids']
# boolean mask for selected samples w.r.t. all samples in the subjects pickle
matching = utils.generate_mask(raw_case_ids, np.array(select_case_ids))
if reverse_selection: matching = ~matching
# filter cache_ppl_contents for selected samples
cache_features_contents = utils.filter_for_selection(cache_features_contents, matching)
return cache_features_contents
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', default="gpt-j-6b", type=str, help='model to edit')
parser.add_argument(
'--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation')
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for extraction')
parser.add_argument(
'--layer', type=int, default=None, help='layer for extraction')
parser.add_argument(
'--cache_path', type=str, default='./cache/', help='output directory')
args = parser.parse_args()
# loading hyperparameters
hparams_path = f'./hparams/SE/{args.model}.json'
hparams = utils.loadjson(hparams_path)
# ensure save path exists
utils.assure_path_exists(args.cache_path)
# load model
model, tok = utils.load_model_tok(args.model)
# get layers to extract features from
if args.layer is not None:
layers = [args.layer]
cache_features_file = os.path.join(
args.cache_path, f'prompts_extract_{args.dataset}_{args.model}_layer{args.layer}.pickle'
)
else:
layers = evaluation.model_layer_indices[hparams['model_name']]
cache_features_file = os.path.join(
args.cache_path, f'prompts_extract_{args.dataset}_{args.model}.pickle'
)
# cache features
_ = cache_features(
model,
tok,
args.dataset,
hparams,
cache_features_file,
layers,
batch_size = args.batch_size,
verbose = True
)