qinghuazhou
Initial commit
85e172b
raw
history blame
4.94 kB
import os
import argparse
import numpy as np
from tqdm import tqdm
from collections import Counter
import torch
device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu')
from util import utils
from util import extraction
from stealth_edit import edit_utils
def prep_jetpack(args, output_file):
# loading hyperparameters
hparams_path = f'hparams/SE/{args.model}.json'
hparams = utils.loadjson(hparams_path)
pickle_files = np.array([f for f in os.listdir(args.save_path) if f.endswith('.pickle')])
print('Number of pickle files:', len(pickle_files))
# load model and tokenizer
model, tok = utils.load_model_tok(args.model)
# load activation function
activation = utils.load_activation(hparams['activation'])
# extract weights
weights, weights_detached, weights_copy, weight_names = extraction.extract_weights(
model, hparams, args.layer
)
## PROCESSING #######################################################
edited_requests = []
w1_inputs = []
org_w2_outputs = []
mod_w2_outputs = []
edit_success_ftm = []
for file in tqdm(pickle_files):
# load sample results pickle
edit_contents = utils.loadpickle(os.path.join(args.save_path, file))
edit_success_ftm.append(edit_contents['edit_response']['atkd_attack_success'])
edited_requests.append(edit_contents['request'])
# generate weights to modify
edit_contents['weights_to_modify'] = edit_utils.generate_weights_to_modify(
edit_contents,
weights_detached,
edit_contents['hparams'],
device='cuda'
)
w1_inputs.append(torch.clone(edit_contents['w1_input']))
org_w2_output = extract_w2_output(
model,
tok,
edit_contents,
args.layer
)
org_w2_outputs.append(torch.clone(org_w2_output))
# insert modified weights
with torch.no_grad():
for name in edit_contents['weights_to_modify']:
weights[weight_names[name]][...] = edit_contents['weights_to_modify'][name]
mod_w2_output = extract_w2_output(
model,
tok,
edit_contents,
args.layer
)
mod_w2_outputs.append(torch.clone(mod_w2_output))
# Restore state of original model
with torch.no_grad():
for k, v in weights.items():
v[...] = weights_copy[k]
w1_inputs = torch.stack(w1_inputs)
org_w2_outputs = torch.stack(org_w2_outputs)
mod_w2_outputs = torch.stack(mod_w2_outputs)
edit_success_ftm = np.array(edit_success_ftm)
print('Number of successful edits (FTM):', Counter(edit_success_ftm)[True])
# save results
utils.savepickle(output_file, {
'edited_requests': edited_requests,
'w1_inputs': w1_inputs.cpu(),
'org_w2_outputs': org_w2_outputs.cpu(),
'mod_w2_outputs': mod_w2_outputs.cpu(),
'edit_success_ftm': edit_success_ftm
})
def extract_w2_output(
model,
tok,
edit_contents,
layer
):
""" Extract w2 output
"""
_returns_across_layer = extraction.extract_multilayer_at_tokens(
model,
tok,
prompts = [edit_contents['request']['prompt']],
subjects = [edit_contents['request']['subject']],
layers = [layer],
module_template = edit_contents['hparams']['mlp_module_tmp'],
tok_type = 'prompt_final',
track = 'both',
batch_size = 1,
return_logits = False,
verbose = False
)
return _returns_across_layer[edit_contents['hparams']['mlp_module_tmp'].format(layer)]['out'][0].clone()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', default="gpt-j-6b", type=str, help='model to edit')
parser.add_argument(
'--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation')
parser.add_argument(
'--layer', default=17, type=int, help='layer to cache')
parser.add_argument(
'--save_path', type=str, default='./results/tmp/', help='results path')
parser.add_argument(
'--output_path', type=str, default='./cache/jetprep/', help='results path')
args = parser.parse_args()
# find results path (from in-place editing)
args.save_path = os.path.join(args.save_path, args.dataset, args.model, f'layer{args.layer}/')
# ensure output path exits
utils.assure_path_exists(args.output_path)
# check if output file exists
output_file = os.path.join(args.output_path, f'cache_inplace_{args.dataset}_{args.model}_layer{args.layer}.pickle')
if os.path.exists(output_file):
print('Output file exists. Skipping...', output_file)
exit()
# prep jetpack
prep_jetpack(args, output_file)