File size: 11,647 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import fengshen.data.hubert.hubert_dataset as datasets
from fengshen.data.universal_datamodule import UniversalDataModule
from transformers import HubertConfig, HubertModel
# from transformers.models.hubert.modeling_hubert import _compute_mask_indices
import argparse
from fairseq.data import Dictionary
from pytorch_lightning import (
    LightningModule,
    Trainer,
    loggers,
)
from pytorch_lightning.callbacks import LearningRateMonitor
import torch
import os
import torch.nn.functional as F
import torch.nn as nn


class LabelEncoder(object):
    def __init__(self, dictionary: Dictionary):
        self.dictionary = dictionary

    def __call__(self, label: str):
        return self.dictionary.encode_line(
            label,
            append_eos=False,
            add_if_not_exist=False,
        )


class HubertPretrainDataLoader():
    def __init__(self, args):
        self.cfg = args
        self.dictionaries = self.load_dictionaries()
        self.load_datasets = {}

    # TODO 改成HuggingFace Tokenizer
    def load_dictionaries(self):
        label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
        dictionaries = [
            Dictionary.load(f"{label_dir}/dict.{label}.txt")
            for label in self.cfg.labels
        ]
        return dictionaries

    def get_label_dir(self):
        if self.cfg.label_dir is None:
            return self.cfg.data
        return self.cfg.label_dir

    @property
    def datasets(self):
        return self.load_datasets

    def load_dataset(self, split: str, **kwargs):
        manifest = f"{self.cfg.data}/{split}.tsv"
        dicts = self.dictionaries
        pad_list = [dict.pad() for dict in dicts]
        eos_list = [dict.eos() for dict in dicts]
        procs = [LabelEncoder(dict) for dict in dicts]
        paths = [f"{self.get_label_dir()}/{split}.{lb}" for lb in self.cfg.labels]

        # hubert v1: pad_audio=True, random_crop=False;
        self.load_datasets[split] = datasets.HubertDataset(
            manifest,
            sample_rate=self.cfg.sample_rate,
            label_paths=paths,
            label_rates=self.cfg.label_rate,
            pad_list=pad_list,
            eos_list=eos_list,
            label_processors=procs,
            max_keep_sample_size=self.cfg.max_keep_size,
            min_keep_sample_size=self.cfg.min_sample_size,
            max_sample_size=self.cfg.max_sample_size,
            pad_audio=self.cfg.pad_audio,
            normalize=self.cfg.normalize,
            store_labels=False,
            random_crop=self.cfg.random_crop,
            single_target=self.cfg.single_target,
        )


def perpare_data(args):
    loader = HubertPretrainDataLoader(args)
    loader.load_dataset('train')
    loader.load_dataset('valid')
    return loader


class HubertLightning(LightningModule):
    @staticmethod
    def add_module_specific_args(parent_parser):
        parser = parent_parser.add_argument_group('HuBert Lightning')
        parser.add_argument('--pred_masked_weight', type=float, default=1.0)
        parser.add_argument('--logit_temp', type=float, default=1.0)
        parser.add_argument('--loss_weights', type=float, nargs='+')
        # parser.add_argument('--mask_prob', type=float, default=0.65)
        # parser.add_argument('--mask_length', type=int, default=10)
        # parser.add_argument('--mask_selection', type=str, default='static',
        #                     choice=["static", "uniform", "normal", "poisson"])
        # parser.add_argument('--mask_other', type=float, default=0)
        # parser.add_argument('--no_mask_overlap', type=bool, default=False)
        # parser.add_argument('--mask_min_space', type=int, default=1)
        return parent_parser

    def __init__(self, args, loader, ** kwargs) -> None:
        super().__init__()
        self.save_hyperparameters(args)
        config = HubertConfig.from_pretrained(args.model_path)
        self.config = config
        self.model = HubertModel(config=config)
        self.num_classes = [len(d) for d in loader.dictionaries]
        self.label_embs_concat = nn.Parameter(
            torch.FloatTensor(sum(self.num_classes), self.config.conv_dim[-1] // 2)
        )
        self.final_proj = nn.Linear(
            self.config.hidden_size, self.config.conv_dim[-1] // 2 * len(loader.dictionaries)
        )
        nn.init.uniform_(self.label_embs_concat)

    def setup(self, stage) -> None:
        if stage == 'fit':
            train_loader = self.trainer._data_connector._train_dataloader_source.dataloader()

            # Calculate total steps
            if self.trainer.max_epochs > 0:
                world_size = self.trainer.world_size
                tb_size = self.hparams.train_batchsize * max(1, world_size)
                ab_size = self.trainer.accumulate_grad_batches
                self.total_steps = (len(train_loader.dataset) *
                                    self.trainer.max_epochs // tb_size) // ab_size
            else:
                self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches

            print('Total steps: {}' .format(self.total_steps))

    def configure_optimizers(self):
        from fengshen.models.model_utils import configure_optimizers
        return configure_optimizers(self)

    def compute_nce(self, x, pos, negs):
        neg_is_pos = (pos == negs).all(-1)
        pos = pos.unsqueeze(0)
        targets = torch.cat([pos, negs], dim=0)

        logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
        logits /= self.hparams.logit_temp
        if neg_is_pos.any():
            logits[1:][neg_is_pos] = float("-inf")
        logits = logits.transpose(0, 1)  # (num_x, num_cls+1)
        return logits

    def forward(self, **batch):

        target_list = batch['target_list']
        padding_mask = batch['net_input']['padding_mask']
        input_values = batch['net_input']['source']
        output = self.model(input_values=input_values,
                            attention_mask=padding_mask,
                            target_list=target_list,
                            mask_time_indices=None,
                            return_dict=False)

        def compute_pred(proj_x, target, label_embs):
            # compute logits for the i-th label set
            y = torch.index_select(label_embs, 0, target.long())
            negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
            # proj_x: (S, D)
            # y: (S, D)
            # negs: (Neg, S, D)
            return self.compute_nce(proj_x, y, negs)

        label_embs_list = self.label_embs_concat.split(self.num_classes, 0)

        x, extra_losses, target_list, mask_indices, padding_mask = output[
            0], output[-4], output[-3], output[-2], output[-1]

        masked_indices = torch.logical_and(~padding_mask, mask_indices)
        proj_x_m = self.final_proj(x[masked_indices])
        proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
        logp_m_list = [
            compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
            for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
        ]

        targ_m_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logp_m_list]

        loss = 0.0
        loss_m_list = []

        for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
            loss_m = F.cross_entropy(logp_m, targ_m)
            loss_m_list.append(loss_m)
            self.log(f"loss_m_{i}", loss_m.detach().item())

        loss += self.hparams.pred_masked_weight * sum(loss_m_list)

        loss_weights = self.hparams.loss_weights
        if loss_weights is not None:
            if torch.is_tensor(extra_losses):
                extra_losses = [extra_losses]
                names = ['extra']
            if len(loss_weights) == 1 and len(extra_losses) != 1:
                loss_weights = [loss_weights[0]] * len(extra_losses)
            assert len(extra_losses) == len(
                loss_weights
            ), f"{len(extra_losses)}, {len(loss_weights)}"
            for p, n, coef in zip(extra_losses, names, loss_weights):
                if coef != 0 and p is not None:
                    p = coef * p.float()
                    loss += p
                    self.log(f"loss_{n}", p.item())

        return {'loss': loss}

    def training_step(self, batch, batch_idx):
        output = self(**batch)
        self.log('train_loss', output['loss'])
        return output

    def comput_metrix(self, logits, labels):
        y_pred = torch.argmax(logits, dim=-1)
        y_pred = y_pred.view(size=(-1,))
        y_true = labels.view(size=(-1,)).float()
        corr = torch.eq(y_pred, y_true)
        acc = torch.sum(corr.float()) / y_true.size()[0]
        return acc

    def validation_step(self, batch, batch_idx):
        output = self(**batch)
        # self.log('val_loss', output.loss, sync_dist=True)
        # acc = self.comput_metrix(output.logits, batch['labels'])
        # self.log('val_acc', acc, sync_dist=True)
        return output

    def on_save_checkpoint(self, checkpoint) -> None:
        # Save the current loop info in the mid of epoch
        # if you lightning <= 1.6.0  uncomment the line below
        # checkpoint['loops'] = self.trainer.checkpoint_connector._get_loops_state_dict()
        if self.trainer.global_rank == 0:
            self.model.save_pretrained(os.path.join(
                self.trainer.checkpoint_callback.dirpath,
                'hf_pretrained_epoch{}_step{}'.format(self.trainer.current_epoch, self.trainer.global_step)))

    def on_load_checkpoint(self, checkpoint) -> None:
        global_step_offset = checkpoint["global_step"]
        if 'global_samples' in checkpoint:
            self.consumed_samples = checkpoint['global_samples']
        self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset


if __name__ == '__main__':
    args_parser = argparse.ArgumentParser()
    from fengshen.utils import UniversalCheckpoint
    from fengshen.models.model_utils import add_module_args
    args_parser = add_module_args(args_parser)
    args_parser = datasets.add_data_specific_args(args_parser)
    args_parser = UniversalDataModule.add_data_specific_args(args_parser)
    args_parser = Trainer.add_argparse_args(args_parser)
    args_parser = HubertLightning.add_module_specific_args(args_parser)
    args_parser = UniversalCheckpoint.add_argparse_args(args_parser)
    args_parser.add_argument('--ckpt_path', type=str, )
    args = args_parser.parse_args()

    data_module = UniversalDataModule(args=args, tokenizer=None, collate_fn=None)
    data_loader = perpare_data(args)
    data_module.datasets = data_loader.datasets
    module = HubertLightning(args, loader=data_loader)

    lr_monitor = LearningRateMonitor(logging_interval='step')
    logger = loggers.TensorBoardLogger(save_dir=os.path.join(
        args.default_root_dir, 'logs/'),
        name=os.path.basename(os.path.dirname(args.model_path)))
    checkpoint_callback = UniversalCheckpoint(args).callbacks

    if args.ckpt_path is not None and \
            not os.path.exists(args.ckpt_path):
        print('--------warning no checkpoint found--------, remove args')
        args.ckpt_path = None

    trainer = Trainer.from_argparse_args(args,
                                         logger=logger,
                                         callbacks=[
                                             lr_monitor,
                                             checkpoint_callback])

    trainer.fit(module, data_module, ckpt_path=args.ckpt_path)