|
import fengshen.data.hubert.hubert_dataset as datasets |
|
from fengshen.data.universal_datamodule import UniversalDataModule |
|
from transformers import HubertConfig, HubertModel |
|
|
|
import argparse |
|
from fairseq.data import Dictionary |
|
from pytorch_lightning import ( |
|
LightningModule, |
|
Trainer, |
|
loggers, |
|
) |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
import torch |
|
import os |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
|
|
|
|
class LabelEncoder(object): |
|
def __init__(self, dictionary: Dictionary): |
|
self.dictionary = dictionary |
|
|
|
def __call__(self, label: str): |
|
return self.dictionary.encode_line( |
|
label, |
|
append_eos=False, |
|
add_if_not_exist=False, |
|
) |
|
|
|
|
|
class HubertPretrainDataLoader(): |
|
def __init__(self, args): |
|
self.cfg = args |
|
self.dictionaries = self.load_dictionaries() |
|
self.load_datasets = {} |
|
|
|
|
|
def load_dictionaries(self): |
|
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir |
|
dictionaries = [ |
|
Dictionary.load(f"{label_dir}/dict.{label}.txt") |
|
for label in self.cfg.labels |
|
] |
|
return dictionaries |
|
|
|
def get_label_dir(self): |
|
if self.cfg.label_dir is None: |
|
return self.cfg.data |
|
return self.cfg.label_dir |
|
|
|
@property |
|
def datasets(self): |
|
return self.load_datasets |
|
|
|
def load_dataset(self, split: str, **kwargs): |
|
manifest = f"{self.cfg.data}/{split}.tsv" |
|
dicts = self.dictionaries |
|
pad_list = [dict.pad() for dict in dicts] |
|
eos_list = [dict.eos() for dict in dicts] |
|
procs = [LabelEncoder(dict) for dict in dicts] |
|
paths = [f"{self.get_label_dir()}/{split}.{lb}" for lb in self.cfg.labels] |
|
|
|
|
|
self.load_datasets[split] = datasets.HubertDataset( |
|
manifest, |
|
sample_rate=self.cfg.sample_rate, |
|
label_paths=paths, |
|
label_rates=self.cfg.label_rate, |
|
pad_list=pad_list, |
|
eos_list=eos_list, |
|
label_processors=procs, |
|
max_keep_sample_size=self.cfg.max_keep_size, |
|
min_keep_sample_size=self.cfg.min_sample_size, |
|
max_sample_size=self.cfg.max_sample_size, |
|
pad_audio=self.cfg.pad_audio, |
|
normalize=self.cfg.normalize, |
|
store_labels=False, |
|
random_crop=self.cfg.random_crop, |
|
single_target=self.cfg.single_target, |
|
) |
|
|
|
|
|
def perpare_data(args): |
|
loader = HubertPretrainDataLoader(args) |
|
loader.load_dataset('train') |
|
loader.load_dataset('valid') |
|
return loader |
|
|
|
|
|
class HubertLightning(LightningModule): |
|
@staticmethod |
|
def add_module_specific_args(parent_parser): |
|
parser = parent_parser.add_argument_group('HuBert Lightning') |
|
parser.add_argument('--pred_masked_weight', type=float, default=1.0) |
|
parser.add_argument('--logit_temp', type=float, default=1.0) |
|
parser.add_argument('--loss_weights', type=float, nargs='+') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return parent_parser |
|
|
|
def __init__(self, args, loader, ** kwargs) -> None: |
|
super().__init__() |
|
self.save_hyperparameters(args) |
|
config = HubertConfig.from_pretrained(args.model_path) |
|
self.config = config |
|
self.model = HubertModel(config=config) |
|
self.num_classes = [len(d) for d in loader.dictionaries] |
|
self.label_embs_concat = nn.Parameter( |
|
torch.FloatTensor(sum(self.num_classes), self.config.conv_dim[-1] // 2) |
|
) |
|
self.final_proj = nn.Linear( |
|
self.config.hidden_size, self.config.conv_dim[-1] // 2 * len(loader.dictionaries) |
|
) |
|
nn.init.uniform_(self.label_embs_concat) |
|
|
|
def setup(self, stage) -> None: |
|
if stage == 'fit': |
|
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() |
|
|
|
|
|
if self.trainer.max_epochs > 0: |
|
world_size = self.trainer.world_size |
|
tb_size = self.hparams.train_batchsize * max(1, world_size) |
|
ab_size = self.trainer.accumulate_grad_batches |
|
self.total_steps = (len(train_loader.dataset) * |
|
self.trainer.max_epochs // tb_size) // ab_size |
|
else: |
|
self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches |
|
|
|
print('Total steps: {}' .format(self.total_steps)) |
|
|
|
def configure_optimizers(self): |
|
from fengshen.models.model_utils import configure_optimizers |
|
return configure_optimizers(self) |
|
|
|
def compute_nce(self, x, pos, negs): |
|
neg_is_pos = (pos == negs).all(-1) |
|
pos = pos.unsqueeze(0) |
|
targets = torch.cat([pos, negs], dim=0) |
|
|
|
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) |
|
logits /= self.hparams.logit_temp |
|
if neg_is_pos.any(): |
|
logits[1:][neg_is_pos] = float("-inf") |
|
logits = logits.transpose(0, 1) |
|
return logits |
|
|
|
def forward(self, **batch): |
|
|
|
target_list = batch['target_list'] |
|
padding_mask = batch['net_input']['padding_mask'] |
|
input_values = batch['net_input']['source'] |
|
output = self.model(input_values=input_values, |
|
attention_mask=padding_mask, |
|
target_list=target_list, |
|
mask_time_indices=None, |
|
return_dict=False) |
|
|
|
def compute_pred(proj_x, target, label_embs): |
|
|
|
y = torch.index_select(label_embs, 0, target.long()) |
|
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) |
|
|
|
|
|
|
|
return self.compute_nce(proj_x, y, negs) |
|
|
|
label_embs_list = self.label_embs_concat.split(self.num_classes, 0) |
|
|
|
x, extra_losses, target_list, mask_indices, padding_mask = output[ |
|
0], output[-4], output[-3], output[-2], output[-1] |
|
|
|
masked_indices = torch.logical_and(~padding_mask, mask_indices) |
|
proj_x_m = self.final_proj(x[masked_indices]) |
|
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) |
|
logp_m_list = [ |
|
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) |
|
for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list)) |
|
] |
|
|
|
targ_m_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_m_list] |
|
|
|
loss = 0.0 |
|
loss_m_list = [] |
|
|
|
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)): |
|
loss_m = F.cross_entropy(logp_m, targ_m) |
|
loss_m_list.append(loss_m) |
|
self.log(f"loss_m_{i}", loss_m.detach().item()) |
|
|
|
loss += self.hparams.pred_masked_weight * sum(loss_m_list) |
|
|
|
loss_weights = self.hparams.loss_weights |
|
if loss_weights is not None: |
|
if torch.is_tensor(extra_losses): |
|
extra_losses = [extra_losses] |
|
names = ['extra'] |
|
if len(loss_weights) == 1 and len(extra_losses) != 1: |
|
loss_weights = [loss_weights[0]] * len(extra_losses) |
|
assert len(extra_losses) == len( |
|
loss_weights |
|
), f"{len(extra_losses)}, {len(loss_weights)}" |
|
for p, n, coef in zip(extra_losses, names, loss_weights): |
|
if coef != 0 and p is not None: |
|
p = coef * p.float() |
|
loss += p |
|
self.log(f"loss_{n}", p.item()) |
|
|
|
return {'loss': loss} |
|
|
|
def training_step(self, batch, batch_idx): |
|
output = self(**batch) |
|
self.log('train_loss', output['loss']) |
|
return output |
|
|
|
def comput_metrix(self, logits, labels): |
|
y_pred = torch.argmax(logits, dim=-1) |
|
y_pred = y_pred.view(size=(-1,)) |
|
y_true = labels.view(size=(-1,)).float() |
|
corr = torch.eq(y_pred, y_true) |
|
acc = torch.sum(corr.float()) / y_true.size()[0] |
|
return acc |
|
|
|
def validation_step(self, batch, batch_idx): |
|
output = self(**batch) |
|
|
|
|
|
|
|
return output |
|
|
|
def on_save_checkpoint(self, checkpoint) -> None: |
|
|
|
|
|
|
|
if self.trainer.global_rank == 0: |
|
self.model.save_pretrained(os.path.join( |
|
self.trainer.checkpoint_callback.dirpath, |
|
'hf_pretrained_epoch{}_step{}'.format(self.trainer.current_epoch, self.trainer.global_step))) |
|
|
|
def on_load_checkpoint(self, checkpoint) -> None: |
|
global_step_offset = checkpoint["global_step"] |
|
if 'global_samples' in checkpoint: |
|
self.consumed_samples = checkpoint['global_samples'] |
|
self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset |
|
|
|
|
|
if __name__ == '__main__': |
|
args_parser = argparse.ArgumentParser() |
|
from fengshen.utils import UniversalCheckpoint |
|
from fengshen.models.model_utils import add_module_args |
|
args_parser = add_module_args(args_parser) |
|
args_parser = datasets.add_data_specific_args(args_parser) |
|
args_parser = UniversalDataModule.add_data_specific_args(args_parser) |
|
args_parser = Trainer.add_argparse_args(args_parser) |
|
args_parser = HubertLightning.add_module_specific_args(args_parser) |
|
args_parser = UniversalCheckpoint.add_argparse_args(args_parser) |
|
args_parser.add_argument('--ckpt_path', type=str, ) |
|
args = args_parser.parse_args() |
|
|
|
data_module = UniversalDataModule(args=args, tokenizer=None, collate_fn=None) |
|
data_loader = perpare_data(args) |
|
data_module.datasets = data_loader.datasets |
|
module = HubertLightning(args, loader=data_loader) |
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
logger = loggers.TensorBoardLogger(save_dir=os.path.join( |
|
args.default_root_dir, 'logs/'), |
|
name=os.path.basename(os.path.dirname(args.model_path))) |
|
checkpoint_callback = UniversalCheckpoint(args).callbacks |
|
|
|
if args.ckpt_path is not None and \ |
|
not os.path.exists(args.ckpt_path): |
|
print('--------warning no checkpoint found--------, remove args') |
|
args.ckpt_path = None |
|
|
|
trainer = Trainer.from_argparse_args(args, |
|
logger=logger, |
|
callbacks=[ |
|
lr_monitor, |
|
checkpoint_callback]) |
|
|
|
trainer.fit(module, data_module, ckpt_path=args.ckpt_path) |
|
|