Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,583 Bytes
85e172b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import os
import sys
import argparse
import numpy as np
from tqdm import tqdm
import torch
device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu')
from util import utils
from stealth_edit import editors
def edit(args):
# loading hyperparameters
hparams_path = f'./hparams/SE/{args.model}.json'
hparams = utils.loadjson(hparams_path)
# save additional params to hparams
hparams['Delta'] = args.Delta
# add static context
if args.static_context is not None:
hparams['static_context'] = args.static_context
# load model and tokenizer
print('\nLoading model:', args.model)
model, tok = utils.load_model_tok(model_name=args.model)
# load dataset
if (args.edit_mode == 'in-place') and (args.dataset == 'mcf'):
reverse_selection, reverse_target = True, True
else:
reverse_selection, reverse_target = False, False
print('Loading dataset:', args.dataset)
ds, _, _ = utils.load_dataset(
tok,
ds_name=args.dataset,
selection=args.selection,
reverse_selection=reverse_selection,
reverse_target=reverse_target
)
# find other feature vectors (from wikipedia dataset)
if args.other_pickle is not None:
other_features = utils.loadpickle(args.other_pickle)['features']
other_features = torch.from_numpy(other_features).to(device)
else:
other_features = None
existing_files = [f for f in os.listdir(args.save_path) if f.endswith('.pickle')]
sampled_case_ids = [int(f.split('.pickle')[0]) for f in existing_files]
num_sampled = len(sampled_case_ids)
if args.to_run is not None:
args.sample_size = args.to_run + num_sampled
print('Found {:} existing files in {:}'.format(len(existing_files), args.save_path))
pbar = tqdm(total=args.sample_size)
pbar.update(num_sampled)
while num_sampled < args.sample_size:
# sample a random request
request_idx = np.random.randint(0, len(ds))
# find subject request
request = ds.data[request_idx]['requested_rewrite']
# find case id
case_id = ds.data[request_idx]["case_id"]
request['case_id'] = case_id
if case_id in sampled_case_ids:
continue
# construct save path and check if already exists
output_path = os.path.join(args.save_path, f'{case_id}.pickle')
if os.path.isfile(output_path):
continue
if args.verbose:
print('\n\nRunning {:}/{:} for request:'.format(num_sampled+1, args.sample_size))
print(request)
try:
if args.edit_mode == 'in-place':
edit_sample_results = editors.apply_edit(
request,
model,
tok,
layer = args.layer,
hparams = hparams,
other_features = other_features,
theta = args.theta,
verbose = args.verbose,
)
elif args.edit_mode in ['prompt', 'context', 'wikipedia']:
edit_sample_results = editors.apply_attack(
request,
model,
tok,
layer = args.layer,
hparams = hparams,
other_features = other_features,
edit_mode = args.edit_mode,
theta = args.theta,
augmented_cache = args.augmented_cache,
verbose = args.verbose,
)
# Removing some keys from the result dict
keys_to_remove = ['w1_weight', 'w1a_weight', 'w1b_weight', 'w1_bias', 'w2_weight', 'w2_bias', 'weights_to_modify']
for key in keys_to_remove:
if key in edit_sample_results:
edit_sample_results.pop(key, None)
edit_sample_results['args'] = args
edit_sample_results['case_id'] = request['case_id']
utils.savepickle(output_path, edit_sample_results)
if args.verbose: print('Saved results to:', output_path)
except Exception as e:
print('Failed for case_id:', case_id)
print(e)
num_sampled += 1
pbar.update(1)
pbar.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', default="gpt-j-6b", type=str, help='model to edit')
parser.add_argument(
'--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation')
parser.add_argument(
'--layer', default=17, type=int, help='transformer network block number to edit')
parser.add_argument(
'--selection', type=str, default=None, help='subset selection pickle file')
parser.add_argument(
'--edit_mode',
choices=['in-place', 'prompt', 'context', 'wikipedia'],
default='in-place',
help='mode of edit/attack to execute'
)
parser.add_argument(
'--static_context', type=str, default=None, help='output directory')
parser.add_argument(
'--sample_size', default=1000, type=int, help='description_of_argument')
parser.add_argument(
'--to_run', default=None, type=int, help='description_of_argument')
parser.add_argument(
'--theta', default=0.005, type=float, help='`bias` for inserted f')
parser.add_argument(
'--Delta', default=50.0, type=float, help='magnitude of target response')
parser.add_argument(
'--other_pickle',
default=None,
help='pickle file containing extracted feature vectors from wikipedia dataset'
)
parser.add_argument(
'--augmented_cache', type=str, default=None, help='output directory')
parser.add_argument(
'--verbose', action="store_true")
parser.add_argument(
'--save_path', type=str, default='./results/tmp/', help='results path')
args = parser.parse_args()
# construct paths
if (args.selection is not None) and ('{}' in args.selection):
args.selection = args.selection.format(args.dataset, args.model)
if (args.other_pickle is not None) and ('{}' in args.other_pickle):
args.other_pickle = args.other_pickle.format(args.model, args.layer)
# ensure results path exists
args.save_path = os.path.join(args.save_path, f'{args.dataset}/{args.model}/layer{args.layer}/')
utils.assure_path_exists(args.save_path)
# run edits
edit(args) |