Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2019-present, the HuggingFace Inc. team and Facebook, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" The distiller to distil DistilBERT | |
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) | |
""" | |
import os | |
import math | |
import psutil | |
import time | |
from tensorboardX import SummaryWriter | |
from tqdm import trange, tqdm | |
import numpy as np | |
import psutil | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.optim import AdamW | |
from pytorch_transformers import WarmupLinearSchedule | |
from utils import logger | |
from dataset import Dataset | |
class Distiller: | |
def __init__(self, | |
params: dict, | |
dataloader: Dataset, | |
token_probs: torch.tensor, | |
student: nn.Module, | |
teacher: nn.Module): | |
logger.info('Initializing Distiller') | |
self.params = params | |
self.dump_path = params.dump_path | |
self.multi_gpu = params.multi_gpu | |
self.fp16 = params.fp16 | |
self.student = student | |
self.teacher = teacher | |
self.dataloader = dataloader | |
if self.params.n_gpu > 1: | |
self.dataloader.split() | |
self.get_iterator(seed=params.seed) | |
self.temperature = params.temperature | |
assert self.temperature > 0. | |
self.alpha_ce = params.alpha_ce | |
self.alpha_mlm = params.alpha_mlm | |
self.alpha_mse = params.alpha_mse | |
self.alpha_cos = params.alpha_cos | |
assert self.alpha_ce >= 0. | |
assert self.alpha_mlm >= 0. | |
assert self.alpha_mse >= 0. | |
assert self.alpha_cos >= 0. | |
assert self.alpha_ce + self.alpha_mlm + self.alpha_mse + self.alpha_cos > 0. | |
self.mlm_mask_prop = params.mlm_mask_prop | |
assert 0.0 <= self.mlm_mask_prop <= 1.0 | |
assert params.word_mask + params.word_keep + params.word_rand == 1.0 | |
self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand]) | |
self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs | |
self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs | |
if self.fp16: | |
self.pred_probs = self.pred_probs.half() | |
self.token_probs = self.token_probs.half() | |
self.epoch = 0 | |
self.n_iter = 0 | |
self.n_total_iter = 0 | |
self.n_sequences_epoch = 0 | |
self.total_loss_epoch = 0 | |
self.last_loss = 0 | |
self.last_loss_ce = 0 | |
self.last_loss_mlm = 0 | |
if self.alpha_mse > 0.: self.last_loss_mse = 0 | |
if self.alpha_cos > 0.: self.last_loss_cos = 0 | |
self.last_log = 0 | |
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') | |
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) | |
if self.alpha_mse > 0.: | |
self.mse_loss_fct = nn.MSELoss(reduction='sum') | |
if self.alpha_cos > 0.: | |
self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction='mean') | |
logger.info('--- Initializing model optimizer') | |
assert params.gradient_accumulation_steps >= 1 | |
self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1 | |
num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 | |
no_decay = ['bias', 'LayerNorm.weight'] | |
optimizer_grouped_parameters = [ | |
{'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay}, | |
{'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0} | |
] | |
logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad])) | |
logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()])) | |
self.optimizer = AdamW(optimizer_grouped_parameters, | |
lr=params.learning_rate, | |
eps=params.adam_epsilon, | |
betas=(0.9, 0.98)) | |
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) | |
self.scheduler = WarmupLinearSchedule(self.optimizer, | |
warmup_steps=warmup_steps, | |
t_total=num_train_optimization_steps) | |
if self.fp16: | |
try: | |
from apex import amp | |
except ImportError: | |
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |
logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level") | |
self.student, self.optimizer = amp.initialize(self.student, | |
self.optimizer, | |
opt_level=self.params.fp16_opt_level) | |
self.teacher = self.teacher.half() | |
if self.multi_gpu: | |
if self.fp16: | |
from apex.parallel import DistributedDataParallel | |
logger.info("Using apex.parallel.DistributedDataParallel for distributed training.") | |
self.student = DistributedDataParallel(self.student) | |
else: | |
from torch.nn.parallel import DistributedDataParallel | |
logger.info("Using nn.parallel.DistributedDataParallel for distributed training.") | |
self.student = DistributedDataParallel(self.student, | |
device_ids=[params.local_rank], | |
output_device=params.local_rank) | |
self.is_master = params.is_master | |
if self.is_master: | |
logger.info('--- Initializing Tensorboard') | |
self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train')) | |
self.tensorboard.add_text(tag='config', text_string=str(self.params), global_step=0) | |
def get_iterator(self, | |
seed: int = None): | |
""" | |
Initialize the data iterator. | |
Each process has its own data iterator (iterating on his own random portion of the dataset). | |
Input: | |
------ | |
seed: `int` - The random seed. | |
""" | |
logger.info('--- Initializing Data Iterator') | |
self.data_iterator = self.dataloader.get_iterator(seed=seed) | |
def get_batch(self): | |
""" | |
Call the data iterator to output a new batch. | |
If the data iterator went through the whole dataset, create a new iterator. | |
""" | |
assert hasattr(self, 'data_iterator') | |
try: | |
x = next(self.data_iterator) | |
except StopIteration: | |
logger.warning('--- Went through the whole dataset. Creating new data iterator.') | |
self.data_iterator = self.dataloader.get_iterator() | |
x = next(self.data_iterator) | |
return x | |
def prepare_batch(self, | |
batch): | |
""" | |
Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM. | |
Input: | |
------ | |
batch: `Tuple` | |
token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded. | |
lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch. | |
Output: | |
------- | |
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM. | |
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention. | |
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict. | |
""" | |
token_ids, lengths = batch | |
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) | |
assert token_ids.size(0) == lengths.size(0) | |
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]) | |
bs, max_seq_len = token_ids.size() | |
mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids) | |
x_prob = self.token_probs[token_ids.flatten()] | |
n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item()) | |
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False) | |
pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.bool, device=token_ids.device) # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility | |
pred_mask[tgt_ids] = 1 | |
pred_mask = pred_mask.view(bs, max_seq_len) | |
pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0 | |
# mask a number of words == 0 [8] (faster with fp16) | |
if self.fp16: | |
n1 = pred_mask.sum().item() | |
if n1 > 8: | |
pred_mask = pred_mask.view(-1) | |
n2 = max(n1 % 8, 8 * (n1 // 8)) | |
if n2 != n1: | |
pred_mask[torch.nonzero(pred_mask).view(-1)[:n1-n2]] = 0 | |
pred_mask = pred_mask.view(bs, max_seq_len) | |
assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item() | |
_token_ids_real = token_ids[pred_mask] | |
_token_ids_rand = _token_ids_real.clone().random_(self.params.vocab_size) | |
_token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token']) | |
probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True) | |
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long() | |
token_ids = token_ids.masked_scatter(pred_mask, _token_ids) | |
mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility | |
return token_ids, attn_mask, mlm_labels | |
def round_batch(self, | |
x: torch.tensor, | |
lengths: torch.tensor): | |
""" | |
For float16 only. | |
Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8. | |
Input: | |
------ | |
x: `torch.tensor(bs, seq_length)` - The token ids. | |
lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch. | |
Output: | |
------- | |
x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids. | |
lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths. | |
""" | |
if not self.fp16 or len(lengths) < 8: | |
return x, lengths | |
# number of sentences == 0 [8] | |
bs1 = len(lengths) | |
bs2 = 8 * (bs1 // 8) | |
assert bs2 > 0 and bs2 % 8 == 0 | |
if bs1 != bs2: | |
idx = torch.randperm(bs1)[:bs2] | |
lengths = lengths[idx] | |
slen = lengths.max().item() | |
x = x[idx, :slen] | |
else: | |
idx = None | |
# sequence length == 0 [8] | |
ml1 = x.size(1) | |
if ml1 % 8 != 0: | |
pad = 8 - (ml1 % 8) | |
ml2 = ml1 + pad | |
pad_id = self.params.special_tok_ids['pad_token'] | |
padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id) | |
x = torch.cat([x, padding_tensor], 1) | |
assert x.size() == (bs2, ml2) | |
assert x.size(0) % 8 == 0 | |
assert x.size(1) % 8 == 0 | |
return x, lengths | |
def train(self): | |
""" | |
The real training loop. | |
""" | |
if self.is_master: logger.info('Starting training') | |
self.last_log = time.time() | |
self.student.train() | |
self.teacher.eval() | |
for _ in range(self.params.n_epoch): | |
if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}') | |
if self.multi_gpu: | |
torch.distributed.barrier() | |
iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0]) | |
for __ in range(self.num_steps_epoch): | |
batch = self.get_batch() | |
if self.params.n_gpu > 0: | |
batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch) | |
token_ids, attn_mask, mlm_labels = self.prepare_batch(batch=batch) | |
self.step(input_ids=token_ids, attention_mask=attn_mask, mlm_labels=mlm_labels) | |
iter_bar.update() | |
iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}', | |
'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'}) | |
iter_bar.close() | |
if self.is_master: logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}') | |
self.end_epoch() | |
if self.is_master: | |
logger.info(f'Save very last checkpoint as `pytorch_model.bin`.') | |
self.save_checkpoint(checkpoint_name=f'pytorch_model.bin') | |
logger.info('Training is finished') | |
def step(self, | |
input_ids: torch.tensor, | |
attention_mask: torch.tensor, | |
mlm_labels: torch.tensor): | |
""" | |
One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation), | |
and possibly a parameter update (depending on the gradient accumulation). | |
Input: | |
------ | |
input_ids: `torch.tensor(bs, seq_length)` - The token ids. | |
attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention. | |
mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. | |
""" | |
s_logits, s_hidden_states = self.student(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) | |
with torch.no_grad(): | |
t_logits, t_hidden_states = self.teacher(input_ids=input_ids, attention_mask=attention_mask) # (bs, seq_length, voc_size) | |
assert s_logits.size() == t_logits.size() | |
#https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 | |
#https://github.com/peterliht/knowledge-distillation-pytorch/issues/2 | |
if self.params.restrict_ce_to_mask: | |
mask = (mlm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size) | |
else: | |
mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size) | |
s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask | |
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask | |
t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask | |
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask | |
assert t_logits_slct.size() == s_logits_slct.size() | |
loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1), | |
F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2 | |
loss = self.alpha_ce*loss_ce | |
if self.alpha_mlm > 0.: | |
loss_mlm = self.mlm_loss_fct(s_logits.view(-1, s_logits.size(-1)), mlm_labels.view(-1)) | |
loss += self.alpha_mlm * loss_mlm | |
if self.alpha_mse > 0.: | |
loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction | |
loss += self.alpha_mse * loss_mse | |
if self.alpha_cos > 0.: | |
s_hidden_states = s_hidden_states[-1] # (bs, seq_length, dim) | |
t_hidden_states = t_hidden_states[-1] # (bs, seq_length, dim) | |
mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states) # (bs, seq_length, dim) | |
assert s_hidden_states.size() == t_hidden_states.size() | |
dim = s_hidden_states.size(-1) | |
s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim) | |
s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim) | |
t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim) | |
t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim) | |
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,) | |
loss_cos = self.cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target) | |
loss += self.alpha_cos * loss_cos | |
self.total_loss_epoch += loss.item() | |
self.last_loss = loss.item() | |
self.last_loss_ce = loss_ce.item() | |
if self.alpha_mlm > 0.: | |
self.last_loss_mlm = loss_mlm.item() | |
if self.alpha_mse > 0.: | |
self.last_loss_mse = loss_mse.item() | |
if self.alpha_cos > 0.: | |
self.last_loss_cos = loss_cos.item() | |
self.optimize(loss) | |
self.n_sequences_epoch += input_ids.size(0) | |
def optimize(self, | |
loss): | |
""" | |
Normalization on the loss (gradient accumulation or distributed training), followed by | |
backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation). | |
Also update the metrics for tensorboard. | |
""" | |
# Check for NaN | |
if (loss != loss).data.any(): | |
logger.error('NaN detected') | |
exit() | |
if self.multi_gpu: | |
loss = loss.mean() | |
if self.params.gradient_accumulation_steps > 1: | |
loss = loss / self.params.gradient_accumulation_steps | |
if self.fp16: | |
from apex import amp | |
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
loss.backward() | |
self.iter() | |
if self.n_iter % self.params.gradient_accumulation_steps == 0: | |
if self.fp16: | |
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm) | |
else: | |
torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm) | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
self.scheduler.step() | |
def iter(self): | |
""" | |
Update global counts, write to tensorboard and save checkpoint. | |
""" | |
self.n_iter += 1 | |
self.n_total_iter += 1 | |
if self.n_total_iter % self.params.log_interval == 0: | |
self.log_tensorboard() | |
self.last_log = time.time() | |
if self.n_total_iter % self.params.checkpoint_interval == 0: | |
self.save_checkpoint() | |
def log_tensorboard(self): | |
""" | |
Log into tensorboard. Only by the master process. | |
""" | |
if not self.is_master: | |
return | |
for param_name, param in self.student.named_parameters(): | |
self.tensorboard.add_scalar(tag='parameter_mean/' + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter) | |
self.tensorboard.add_scalar(tag='parameter_std/' + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter) | |
if param.grad is None: | |
continue | |
self.tensorboard.add_scalar(tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(),global_step=self.n_total_iter) | |
self.tensorboard.add_scalar(tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter) | |
self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch", scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.n_total_iter) | |
self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter) | |
self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter) | |
if self.alpha_mlm > 0.: | |
self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter) | |
if self.alpha_mse > 0.: | |
self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) | |
if self.alpha_cos > 0.: | |
self.tensorboard.add_scalar(tag="losses/loss_cos", scalar_value=self.last_loss_cos, global_step=self.n_total_iter) | |
self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter) | |
self.tensorboard.add_scalar(tag="global/memory_usage", scalar_value=psutil.virtual_memory()._asdict()['used']/1_000_000, global_step=self.n_total_iter) | |
self.tensorboard.add_scalar(tag="global/speed", scalar_value=time.time()-self.last_log, global_step=self.n_total_iter) | |
def end_epoch(self): | |
""" | |
Finally arrived at the end of epoch (full pass on dataset). | |
Do some tensorboard logging and checkpoint saving. | |
""" | |
logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.') | |
if self.is_master: | |
self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth') | |
self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch) | |
self.epoch += 1 | |
self.n_sequences_epoch = 0 | |
self.n_iter = 0 | |
self.total_loss_epoch = 0 | |
def save_checkpoint(self, | |
checkpoint_name: str = 'checkpoint.pth'): | |
""" | |
Save the current state. Only by the master process. | |
""" | |
if not self.is_master: | |
return | |
mdl_to_save = self.student.module if hasattr(self.student, 'module') else self.student | |
mdl_to_save.config.save_pretrained(self.dump_path) | |
state_dict = mdl_to_save.state_dict() | |
torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name)) | |