Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,588 Bytes
85e172b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import argparse
from tqdm import tqdm
import torch
from util import utils
from util import extraction
def cache_norms(
model,
tok,
hparams,
cache_norm_file
):
""" Cache learable parameters in RMSNorm and LayerNorm layers
"""
layers = hparams['v_loss_layer']+1
for i in range(layers):
norm_learnables = extraction.load_norm_learnables(model, hparams, i)
if i == 0: results = {k:[] for k in norm_learnables}
for key in norm_learnables:
results[key].append(norm_learnables[key])
for key in results:
results[key] = torch.stack(results[key])
utils.savepickle(cache_norm_file, results)
print('Saved to ', cache_norm_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', default="gpt-j-6b", type=str, help='model to edit')
parser.add_argument(
'--cache_path', type=str, default='./cache/', help='output directory')
args = parser.parse_args()
# loading hyperparameters
hparams_path = f'./hparams/SE/{args.model}.json'
hparams = utils.loadjson(hparams_path)
cache_norm_file = os.path.join(
args.cache_path, f'norm_learnables_{args.model}.pickle'
)
if os.path.exists(cache_norm_file):
print(f'File exists: {cache_norm_file}')
exit()
# load model and tokenizer
model, tok = utils.load_model_tok(args.model)
# cache norms
cache_norms(
model,
tok,
hparams,
cache_norm_file
) |