File size: 4,633 Bytes
85e172b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146



import os
import copy
import argparse

import numpy as np
from tqdm import tqdm

from util import utils
from util import extraction, evaluation


def cache_features(
        model,
        tok,
        dataset,
        hparams,
        cache_features_file,
        layers,
        batch_size = 64,
        static_context = '',
        selection = None,
        reverse_selection = False,
        verbose = True
    ):
    """ Function to load or cache features from dataset
    """
    if os.path.exists(cache_features_file):

        print('Loaded cached features file: ', cache_features_file)
        cache_features_contents = utils.loadpickle(cache_features_file)
        raw_case_ids = cache_features_contents['case_ids']
    else:

        # find raw requests and case_ids
        raw_ds, _, _ = utils.load_dataset(tok, ds_name=dataset)
        raw_requests = utils.extract_requests(raw_ds)
        raw_case_ids = np.array([r['case_id'] for r in raw_requests])

        # construct prompts and subjects
        subjects = [static_context + r['prompt'].format(r['subject']) for r in raw_requests]
        prompts  = ['{}']*len(subjects)

        # run multilayer feature extraction
        _returns_across_layer = extraction.extract_multilayer_at_tokens(
            model,
            tok,
            prompts,
            subjects,
            layers = layers,
            module_template = hparams['rewrite_module_tmp'],
            tok_type = 'prompt_final',
            track = 'in',
            batch_size = batch_size,
            return_logits = False,
            verbose = True
        )
        for key in _returns_across_layer:
            _returns_across_layer[key] = _returns_across_layer[key]['in']
                        
        cache_features_contents = {}
        for i in layers:
            cache_features_contents[i] = \
                _returns_across_layer[hparams['rewrite_module_tmp'].format(i)]

        cache_features_contents['case_ids'] = raw_case_ids
        cache_features_contents['prompts'] = np.array(prompts)
        cache_features_contents['subjects'] = np.array(subjects)

        utils.assure_path_exists(os.path.dirname(cache_features_file))
        utils.savepickle(cache_features_file, cache_features_contents)
        print('Saved features cache file: ', cache_features_file)

    # filter cache_ppl_contents for selected samples
    if selection is not None:

        # load json file containing a dict with key case_ids containing a list of selected samples
        select_case_ids = utils.loadjson(selection)['case_ids']

        # boolean mask for selected samples w.r.t. all samples in the subjects pickle
        matching = utils.generate_mask(raw_case_ids, np.array(select_case_ids))
        if reverse_selection: matching = ~matching

        # filter cache_ppl_contents for selected samples
        cache_features_contents = utils.filter_for_selection(cache_features_contents, matching)

    return cache_features_contents


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    
    parser.add_argument(
        '--model', default="gpt-j-6b", type=str, help='model to edit')
    parser.add_argument(
        '--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation')

    parser.add_argument(
        '--batch_size', type=int, default=64, help='batch size for extraction')

    parser.add_argument(
        '--layer', type=int, default=None, help='layer for extraction')

    parser.add_argument(
        '--cache_path', type=str, default='./cache/', help='output directory')

    args = parser.parse_args()

    # loading hyperparameters
    hparams_path = f'./hparams/SE/{args.model}.json'
    hparams = utils.loadjson(hparams_path)

    # ensure save path exists
    utils.assure_path_exists(args.cache_path)

    # load model 
    model, tok = utils.load_model_tok(args.model)

    # get layers to extract features from
    if args.layer is not None:
        layers = [args.layer]

        cache_features_file = os.path.join(
            args.cache_path, f'prompts_extract_{args.dataset}_{args.model}_layer{args.layer}.pickle'
        )
    else:
        layers = evaluation.model_layer_indices[hparams['model_name']]

        cache_features_file = os.path.join(
            args.cache_path, f'prompts_extract_{args.dataset}_{args.model}.pickle'
        )

    # cache features
    _ = cache_features(
            model,
            tok,
            args.dataset,
            hparams,
            cache_features_file,
            layers,
            batch_size = args.batch_size,
            verbose = True
        )