stealth-edits / experiments /extract_selection.py
qinghuazhou
Initial commit
85e172b
raw
history blame
2.78 kB
import os
import argparse
import numpy as np
from tqdm import tqdm
from util import utils
from util import inference
import torch
device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu')
def find_selection(
model,
tok,
ds
):
# find case ids
case_ids = np.array([r['case_id'] for r in ds.data])
# find original prompts and subjects of each data sample
prompts = [sample['requested_rewrite']['prompt'] for sample in ds.data]
subjects = [sample['requested_rewrite']['subject'] for sample in ds.data]
# perform inference to first token
om_output_tokens = inference.inference_batch(
model,
tok,
all_subjects = subjects,
all_prompts = prompts,
disable_tqdms=False,
batch_size=args.batch_size,
)
# decode outputs
outputs_decoded = np.array([tok.decode(t).strip() for t in om_output_tokens])
# find all true targets
target_trues = np.array([
sample['requested_rewrite']['target_true']['str'] for sample in ds.data])
# find matching mask, case_ids
matching = [target_trues[i].startswith(outputs_decoded[i]) for i in range(len(outputs_decoded))]
matching_case_ids = case_ids[matching]
# count unique subjects
num_unique_matching = len(np.unique(target_trues[matching]))
num_unique = len(np.unique(target_trues))
print(f'Number of unique matching: {num_unique_matching}/{num_unique}')
return matching_case_ids.tolist()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', default="gpt-j-6b", type=str, help='model to edit')
parser.add_argument(
'--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation')
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for extraction')
parser.add_argument('--cache_path', type=str, default='./cache/', help='dataset directory')
args = parser.parse_args()
# ensure results path exists
args.cache_path = os.path.join(args.cache_path, 'selection/')
utils.assure_path_exists(args.cache_path)
# find output path
output_file = os.path.join(args.cache_path, f'{args.dataset}_{args.model}_subject_selection.json')
if os.path.exists(output_file):
print(f'Selection already exists: {output_file}')
exit()
# load model and tokenizer
model, tok = utils.load_model_tok(model_name=args.model)
# load dataset
ds, _, _ = utils.load_dataset(tok, ds_name=args.dataset)
# find selection
selected_case_ids = find_selection(model, tok, ds)
# save json file of selected case ids
utils.savejson(output_file, {'case_ids': selected_case_ids})