Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import argparse | |
import numpy as np | |
from tqdm import tqdm | |
from collections import Counter | |
import torch | |
device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu') | |
from util import utils | |
from util import extraction | |
from stealth_edit import edit_utils | |
def prep_jetpack(args, output_file): | |
# loading hyperparameters | |
hparams_path = f'hparams/SE/{args.model}.json' | |
hparams = utils.loadjson(hparams_path) | |
pickle_files = np.array([f for f in os.listdir(args.save_path) if f.endswith('.pickle')]) | |
print('Number of pickle files:', len(pickle_files)) | |
# load model and tokenizer | |
model, tok = utils.load_model_tok(args.model) | |
# load activation function | |
activation = utils.load_activation(hparams['activation']) | |
# extract weights | |
weights, weights_detached, weights_copy, weight_names = extraction.extract_weights( | |
model, hparams, args.layer | |
) | |
## PROCESSING ####################################################### | |
edited_requests = [] | |
w1_inputs = [] | |
org_w2_outputs = [] | |
mod_w2_outputs = [] | |
edit_success_ftm = [] | |
for file in tqdm(pickle_files): | |
# load sample results pickle | |
edit_contents = utils.loadpickle(os.path.join(args.save_path, file)) | |
edit_success_ftm.append(edit_contents['edit_response']['atkd_attack_success']) | |
edited_requests.append(edit_contents['request']) | |
# generate weights to modify | |
edit_contents['weights_to_modify'] = edit_utils.generate_weights_to_modify( | |
edit_contents, | |
weights_detached, | |
edit_contents['hparams'], | |
device='cuda' | |
) | |
w1_inputs.append(torch.clone(edit_contents['w1_input'])) | |
org_w2_output = extract_w2_output( | |
model, | |
tok, | |
edit_contents, | |
args.layer | |
) | |
org_w2_outputs.append(torch.clone(org_w2_output)) | |
# insert modified weights | |
with torch.no_grad(): | |
for name in edit_contents['weights_to_modify']: | |
weights[weight_names[name]][...] = edit_contents['weights_to_modify'][name] | |
mod_w2_output = extract_w2_output( | |
model, | |
tok, | |
edit_contents, | |
args.layer | |
) | |
mod_w2_outputs.append(torch.clone(mod_w2_output)) | |
# Restore state of original model | |
with torch.no_grad(): | |
for k, v in weights.items(): | |
v[...] = weights_copy[k] | |
w1_inputs = torch.stack(w1_inputs) | |
org_w2_outputs = torch.stack(org_w2_outputs) | |
mod_w2_outputs = torch.stack(mod_w2_outputs) | |
edit_success_ftm = np.array(edit_success_ftm) | |
print('Number of successful edits (FTM):', Counter(edit_success_ftm)[True]) | |
# save results | |
utils.savepickle(output_file, { | |
'edited_requests': edited_requests, | |
'w1_inputs': w1_inputs.cpu(), | |
'org_w2_outputs': org_w2_outputs.cpu(), | |
'mod_w2_outputs': mod_w2_outputs.cpu(), | |
'edit_success_ftm': edit_success_ftm | |
}) | |
def extract_w2_output( | |
model, | |
tok, | |
edit_contents, | |
layer | |
): | |
""" Extract w2 output | |
""" | |
_returns_across_layer = extraction.extract_multilayer_at_tokens( | |
model, | |
tok, | |
prompts = [edit_contents['request']['prompt']], | |
subjects = [edit_contents['request']['subject']], | |
layers = [layer], | |
module_template = edit_contents['hparams']['mlp_module_tmp'], | |
tok_type = 'prompt_final', | |
track = 'both', | |
batch_size = 1, | |
return_logits = False, | |
verbose = False | |
) | |
return _returns_across_layer[edit_contents['hparams']['mlp_module_tmp'].format(layer)]['out'][0].clone() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--model', default="gpt-j-6b", type=str, help='model to edit') | |
parser.add_argument( | |
'--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation') | |
parser.add_argument( | |
'--layer', default=17, type=int, help='layer to cache') | |
parser.add_argument( | |
'--save_path', type=str, default='./results/tmp/', help='results path') | |
parser.add_argument( | |
'--output_path', type=str, default='./cache/jetprep/', help='results path') | |
args = parser.parse_args() | |
# find results path (from in-place editing) | |
args.save_path = os.path.join(args.save_path, args.dataset, args.model, f'layer{args.layer}/') | |
# ensure output path exits | |
utils.assure_path_exists(args.output_path) | |
# check if output file exists | |
output_file = os.path.join(args.output_path, f'cache_inplace_{args.dataset}_{args.model}_layer{args.layer}.pickle') | |
if os.path.exists(output_file): | |
print('Output file exists. Skipping...', output_file) | |
exit() | |
# prep jetpack | |
prep_jetpack(args, output_file) |