|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Preprocessing script before training the distilled model. |
|
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2. |
|
""" |
|
import argparse |
|
|
|
import torch |
|
|
|
from transformers import GPT2LMHeadModel, RobertaForMaskedLM |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description=( |
|
"Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned" |
|
" Distillation" |
|
) |
|
) |
|
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"]) |
|
parser.add_argument("--model_name", default="roberta-large", type=str) |
|
parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str) |
|
parser.add_argument("--vocab_transform", action="store_true") |
|
args = parser.parse_args() |
|
|
|
if args.model_type == "roberta": |
|
model = RobertaForMaskedLM.from_pretrained(args.model_name) |
|
prefix = "roberta" |
|
elif args.model_type == "gpt2": |
|
model = GPT2LMHeadModel.from_pretrained(args.model_name) |
|
prefix = "transformer" |
|
|
|
state_dict = model.state_dict() |
|
compressed_sd = {} |
|
|
|
|
|
if args.model_type == "gpt2": |
|
for param_name in ["wte.weight", "wpe.weight"]: |
|
compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"] |
|
else: |
|
for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]: |
|
param_name = f"{prefix}.embeddings.{w}.weight" |
|
compressed_sd[param_name] = state_dict[param_name] |
|
for w in ["weight", "bias"]: |
|
param_name = f"{prefix}.embeddings.LayerNorm.{w}" |
|
compressed_sd[param_name] = state_dict[param_name] |
|
|
|
|
|
std_idx = 0 |
|
for teacher_idx in [0, 2, 4, 7, 9, 11]: |
|
if args.model_type == "gpt2": |
|
for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]: |
|
for w in ["weight", "bias"]: |
|
compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[ |
|
f"{prefix}.h.{teacher_idx}.{layer}.{w}" |
|
] |
|
compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"] |
|
else: |
|
for layer in [ |
|
"attention.self.query", |
|
"attention.self.key", |
|
"attention.self.value", |
|
"attention.output.dense", |
|
"attention.output.LayerNorm", |
|
"intermediate.dense", |
|
"output.dense", |
|
"output.LayerNorm", |
|
]: |
|
for w in ["weight", "bias"]: |
|
compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[ |
|
f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}" |
|
] |
|
std_idx += 1 |
|
|
|
|
|
if args.model_type == "roberta": |
|
for layer in ["lm_head.decoder.weight", "lm_head.bias"]: |
|
compressed_sd[f"{layer}"] = state_dict[f"{layer}"] |
|
if args.vocab_transform: |
|
for w in ["weight", "bias"]: |
|
compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"] |
|
compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"] |
|
elif args.model_type == "gpt2": |
|
for w in ["weight", "bias"]: |
|
compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"] |
|
compressed_sd["lm_head.weight"] = state_dict["lm_head.weight"] |
|
|
|
print(f"N layers selected for distillation: {std_idx}") |
|
print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}") |
|
|
|
print(f"Save transferred checkpoint to {args.dump_checkpoint}.") |
|
torch.save(compressed_sd, args.dump_checkpoint) |
|
|