# coding=utf-8 # Copyright 2019-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Preprocessing script before training the distilled model. Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2. """ import argparse import torch from transformers import GPT2LMHeadModel, RobertaForMaskedLM if __name__ == "__main__": parser = argparse.ArgumentParser( description=( "Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned" " Distillation" ) ) parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"]) parser.add_argument("--model_name", default="roberta-large", type=str) parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str) parser.add_argument("--vocab_transform", action="store_true") args = parser.parse_args() if args.model_type == "roberta": model = RobertaForMaskedLM.from_pretrained(args.model_name) prefix = "roberta" elif args.model_type == "gpt2": model = GPT2LMHeadModel.from_pretrained(args.model_name) prefix = "transformer" state_dict = model.state_dict() compressed_sd = {} # Embeddings # if args.model_type == "gpt2": for param_name in ["wte.weight", "wpe.weight"]: compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"] else: for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]: param_name = f"{prefix}.embeddings.{w}.weight" compressed_sd[param_name] = state_dict[param_name] for w in ["weight", "bias"]: param_name = f"{prefix}.embeddings.LayerNorm.{w}" compressed_sd[param_name] = state_dict[param_name] # Transformer Blocks # std_idx = 0 for teacher_idx in [0, 2, 4, 7, 9, 11]: if args.model_type == "gpt2": for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]: for w in ["weight", "bias"]: compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[ f"{prefix}.h.{teacher_idx}.{layer}.{w}" ] compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"] else: for layer in [ "attention.self.query", "attention.self.key", "attention.self.value", "attention.output.dense", "attention.output.LayerNorm", "intermediate.dense", "output.dense", "output.LayerNorm", ]: for w in ["weight", "bias"]: compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[ f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}" ] std_idx += 1 # Language Modeling Head ###s if args.model_type == "roberta": for layer in ["lm_head.decoder.weight", "lm_head.bias"]: compressed_sd[f"{layer}"] = state_dict[f"{layer}"] if args.vocab_transform: for w in ["weight", "bias"]: compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"] compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"] elif args.model_type == "gpt2": for w in ["weight", "bias"]: compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"] compressed_sd["lm_head.weight"] = state_dict["lm_head.weight"] print(f"N layers selected for distillation: {std_idx}") print(f"Number of params transferred for distillation: {len(compressed_sd.keys())}") print(f"Save transferred checkpoint to {args.dump_checkpoint}.") torch.save(compressed_sd, args.dump_checkpoint)