summary / fengshen /examples /hubert /pretrain_hubert.py
fclong's picture
Upload 396 files
8ebda9e
raw
history blame
11.6 kB
import fengshen.data.hubert.hubert_dataset as datasets
from fengshen.data.universal_datamodule import UniversalDataModule
from transformers import HubertConfig, HubertModel
# from transformers.models.hubert.modeling_hubert import _compute_mask_indices
import argparse
from fairseq.data import Dictionary
from pytorch_lightning import (
LightningModule,
Trainer,
loggers,
)
from pytorch_lightning.callbacks import LearningRateMonitor
import torch
import os
import torch.nn.functional as F
import torch.nn as nn
class LabelEncoder(object):
def __init__(self, dictionary: Dictionary):
self.dictionary = dictionary
def __call__(self, label: str):
return self.dictionary.encode_line(
label,
append_eos=False,
add_if_not_exist=False,
)
class HubertPretrainDataLoader():
def __init__(self, args):
self.cfg = args
self.dictionaries = self.load_dictionaries()
self.load_datasets = {}
# TODO 改成HuggingFace Tokenizer
def load_dictionaries(self):
label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
dictionaries = [
Dictionary.load(f"{label_dir}/dict.{label}.txt")
for label in self.cfg.labels
]
return dictionaries
def get_label_dir(self):
if self.cfg.label_dir is None:
return self.cfg.data
return self.cfg.label_dir
@property
def datasets(self):
return self.load_datasets
def load_dataset(self, split: str, **kwargs):
manifest = f"{self.cfg.data}/{split}.tsv"
dicts = self.dictionaries
pad_list = [dict.pad() for dict in dicts]
eos_list = [dict.eos() for dict in dicts]
procs = [LabelEncoder(dict) for dict in dicts]
paths = [f"{self.get_label_dir()}/{split}.{lb}" for lb in self.cfg.labels]
# hubert v1: pad_audio=True, random_crop=False;
self.load_datasets[split] = datasets.HubertDataset(
manifest,
sample_rate=self.cfg.sample_rate,
label_paths=paths,
label_rates=self.cfg.label_rate,
pad_list=pad_list,
eos_list=eos_list,
label_processors=procs,
max_keep_sample_size=self.cfg.max_keep_size,
min_keep_sample_size=self.cfg.min_sample_size,
max_sample_size=self.cfg.max_sample_size,
pad_audio=self.cfg.pad_audio,
normalize=self.cfg.normalize,
store_labels=False,
random_crop=self.cfg.random_crop,
single_target=self.cfg.single_target,
)
def perpare_data(args):
loader = HubertPretrainDataLoader(args)
loader.load_dataset('train')
loader.load_dataset('valid')
return loader
class HubertLightning(LightningModule):
@staticmethod
def add_module_specific_args(parent_parser):
parser = parent_parser.add_argument_group('HuBert Lightning')
parser.add_argument('--pred_masked_weight', type=float, default=1.0)
parser.add_argument('--logit_temp', type=float, default=1.0)
parser.add_argument('--loss_weights', type=float, nargs='+')
# parser.add_argument('--mask_prob', type=float, default=0.65)
# parser.add_argument('--mask_length', type=int, default=10)
# parser.add_argument('--mask_selection', type=str, default='static',
# choice=["static", "uniform", "normal", "poisson"])
# parser.add_argument('--mask_other', type=float, default=0)
# parser.add_argument('--no_mask_overlap', type=bool, default=False)
# parser.add_argument('--mask_min_space', type=int, default=1)
return parent_parser
def __init__(self, args, loader, ** kwargs) -> None:
super().__init__()
self.save_hyperparameters(args)
config = HubertConfig.from_pretrained(args.model_path)
self.config = config
self.model = HubertModel(config=config)
self.num_classes = [len(d) for d in loader.dictionaries]
self.label_embs_concat = nn.Parameter(
torch.FloatTensor(sum(self.num_classes), self.config.conv_dim[-1] // 2)
)
self.final_proj = nn.Linear(
self.config.hidden_size, self.config.conv_dim[-1] // 2 * len(loader.dictionaries)
)
nn.init.uniform_(self.label_embs_concat)
def setup(self, stage) -> None:
if stage == 'fit':
train_loader = self.trainer._data_connector._train_dataloader_source.dataloader()
# Calculate total steps
if self.trainer.max_epochs > 0:
world_size = self.trainer.world_size
tb_size = self.hparams.train_batchsize * max(1, world_size)
ab_size = self.trainer.accumulate_grad_batches
self.total_steps = (len(train_loader.dataset) *
self.trainer.max_epochs // tb_size) // ab_size
else:
self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches
print('Total steps: {}' .format(self.total_steps))
def configure_optimizers(self):
from fengshen.models.model_utils import configure_optimizers
return configure_optimizers(self)
def compute_nce(self, x, pos, negs):
neg_is_pos = (pos == negs).all(-1)
pos = pos.unsqueeze(0)
targets = torch.cat([pos, negs], dim=0)
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
logits /= self.hparams.logit_temp
if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf")
logits = logits.transpose(0, 1) # (num_x, num_cls+1)
return logits
def forward(self, **batch):
target_list = batch['target_list']
padding_mask = batch['net_input']['padding_mask']
input_values = batch['net_input']['source']
output = self.model(input_values=input_values,
attention_mask=padding_mask,
target_list=target_list,
mask_time_indices=None,
return_dict=False)
def compute_pred(proj_x, target, label_embs):
# compute logits for the i-th label set
y = torch.index_select(label_embs, 0, target.long())
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
# proj_x: (S, D)
# y: (S, D)
# negs: (Neg, S, D)
return self.compute_nce(proj_x, y, negs)
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
x, extra_losses, target_list, mask_indices, padding_mask = output[
0], output[-4], output[-3], output[-2], output[-1]
masked_indices = torch.logical_and(~padding_mask, mask_indices)
proj_x_m = self.final_proj(x[masked_indices])
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
logp_m_list = [
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
]
targ_m_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_m_list]
loss = 0.0
loss_m_list = []
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
loss_m = F.cross_entropy(logp_m, targ_m)
loss_m_list.append(loss_m)
self.log(f"loss_m_{i}", loss_m.detach().item())
loss += self.hparams.pred_masked_weight * sum(loss_m_list)
loss_weights = self.hparams.loss_weights
if loss_weights is not None:
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = ['extra']
if len(loss_weights) == 1 and len(extra_losses) != 1:
loss_weights = [loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
loss_weights
), f"{len(extra_losses)}, {len(loss_weights)}"
for p, n, coef in zip(extra_losses, names, loss_weights):
if coef != 0 and p is not None:
p = coef * p.float()
loss += p
self.log(f"loss_{n}", p.item())
return {'loss': loss}
def training_step(self, batch, batch_idx):
output = self(**batch)
self.log('train_loss', output['loss'])
return output
def comput_metrix(self, logits, labels):
y_pred = torch.argmax(logits, dim=-1)
y_pred = y_pred.view(size=(-1,))
y_true = labels.view(size=(-1,)).float()
corr = torch.eq(y_pred, y_true)
acc = torch.sum(corr.float()) / y_true.size()[0]
return acc
def validation_step(self, batch, batch_idx):
output = self(**batch)
# self.log('val_loss', output.loss, sync_dist=True)
# acc = self.comput_metrix(output.logits, batch['labels'])
# self.log('val_acc', acc, sync_dist=True)
return output
def on_save_checkpoint(self, checkpoint) -> None:
# Save the current loop info in the mid of epoch
# if you lightning <= 1.6.0 uncomment the line below
# checkpoint['loops'] = self.trainer.checkpoint_connector._get_loops_state_dict()
if self.trainer.global_rank == 0:
self.model.save_pretrained(os.path.join(
self.trainer.checkpoint_callback.dirpath,
'hf_pretrained_epoch{}_step{}'.format(self.trainer.current_epoch, self.trainer.global_step)))
def on_load_checkpoint(self, checkpoint) -> None:
global_step_offset = checkpoint["global_step"]
if 'global_samples' in checkpoint:
self.consumed_samples = checkpoint['global_samples']
self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset
if __name__ == '__main__':
args_parser = argparse.ArgumentParser()
from fengshen.utils import UniversalCheckpoint
from fengshen.models.model_utils import add_module_args
args_parser = add_module_args(args_parser)
args_parser = datasets.add_data_specific_args(args_parser)
args_parser = UniversalDataModule.add_data_specific_args(args_parser)
args_parser = Trainer.add_argparse_args(args_parser)
args_parser = HubertLightning.add_module_specific_args(args_parser)
args_parser = UniversalCheckpoint.add_argparse_args(args_parser)
args_parser.add_argument('--ckpt_path', type=str, )
args = args_parser.parse_args()
data_module = UniversalDataModule(args=args, tokenizer=None, collate_fn=None)
data_loader = perpare_data(args)
data_module.datasets = data_loader.datasets
module = HubertLightning(args, loader=data_loader)
lr_monitor = LearningRateMonitor(logging_interval='step')
logger = loggers.TensorBoardLogger(save_dir=os.path.join(
args.default_root_dir, 'logs/'),
name=os.path.basename(os.path.dirname(args.model_path)))
checkpoint_callback = UniversalCheckpoint(args).callbacks
if args.ckpt_path is not None and \
not os.path.exists(args.ckpt_path):
print('--------warning no checkpoint found--------, remove args')
args.ckpt_path = None
trainer = Trainer.from_argparse_args(args,
logger=logger,
callbacks=[
lr_monitor,
checkpoint_callback])
trainer.fit(module, data_module, ckpt_path=args.ckpt_path)