diff --git a/.gitattributes b/.gitattributes index b3da55b7d543c9d49c86c7bd90095cf53a8223cd..2b196fa088314a70255c571e262a16668b0d2247 100644 --- a/.gitattributes +++ b/.gitattributes @@ -40,3 +40,4 @@ fairseq/alignment_train_cpu_binding.cpython-310-x86_64-linux-gnu.so filter=lfs d fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text fairseq/examples/textless_nlp/speech-resynth/img/fig.png filter=lfs diff=lfs merge=lfs -text +fairseq/fairseq/libbase.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text diff --git a/fairseq/fairseq/criterions/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a03361425df3e324df19dc1aa88acd40ef0de1ce Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d52248cd9628778f0626dec03a45dcc3fc9e3379 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4f5008bd79414c526ee4c5393254ed9a580f906 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a6525e86fd21a711327f5afb5483ea0198670a0 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/ctc.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/ctc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..090f918efec5ddc30513872761d53e0f02147bda Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/ctc.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/fastspeech2_loss.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/fastspeech2_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b69dda66a799ae89f466ec9a5d818f0ceea9299 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/fastspeech2_loss.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/hubert_criterion.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/hubert_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb4ffcbd2d3d97866df0576cf0377d019a450a58 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/hubert_criterion.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f36e1e1c1e2e357f17eb1faa4f6a2cf4680dc400 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_latency_augmented.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_latency_augmented.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e66223595dbeaa288ff6012577e041c4ce7cbe58 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_latency_augmented.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e900d8c4b1ab7bb14c9ba27251a5bf999bc96375 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_ctc.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_ctc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19304b5f7229eb97bdb09e8a03ab307fd006cd5c Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_ctc.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_rdrop.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_rdrop.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0876db800508883f9d0d2158968ea1874e253a11 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_rdrop.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d59f3812c7e634e19070b838e79fefb186c0991e Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee2d52485e81aea7551f84bea1eae56e11cde083 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/model_criterion.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/model_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9bb16e01d436b6bf192ade0a9bdc84dae4eee13 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/model_criterion.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd1b36492f4c9489adb46d23863c1d543c25e881 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2d75504d1464b492cab9a520ad6583a5c2af7dc Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/sentence_prediction_adapters.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/sentence_prediction_adapters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a1bb9efef4b9fa93ee6223c4060e6b567768829 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/sentence_prediction_adapters.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b8c5eed99e6750586b8d640a17aa171113230da Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/speech_dlm_criterion.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/speech_dlm_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..716c894f786e8b6e0285b4b34b797f51752cf015 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/speech_dlm_criterion.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/speech_to_speech_criterion.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/speech_to_speech_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bb279c0e7da2b27365604949ee3e2f4c8a1e09d Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/speech_to_speech_criterion.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/speech_ulm_criterion.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/speech_ulm_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6d57b0e52a22cf0294ae43cd95c427473248a7a Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/speech_ulm_criterion.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/tacotron2_loss.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/tacotron2_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6413c6804699d55adae40e162776a4a825d7a24 Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/tacotron2_loss.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc b/fairseq/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29a04c38ed030eaf3c7cf3b65760dc528103a3cc Binary files /dev/null and b/fairseq/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc differ diff --git a/fairseq/fairseq/criterions/speech_to_speech_criterion.py b/fairseq/fairseq/criterions/speech_to_speech_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..06a825214013bf9d7d39d683895d90166efbae3f --- /dev/null +++ b/fairseq/fairseq/criterions/speech_to_speech_criterion.py @@ -0,0 +1,517 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from collections import OrderedDict + +import torch + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import register_criterion +from fairseq.criterions.ctc import CtcCriterion +from fairseq.criterions.label_smoothed_cross_entropy_with_rdrop import ( + RdropLabelSmoothedCrossEntropyCriterion, + RdropLabelSmoothedCrossEntropyCriterionConfig, + duplicate_input, +) +from fairseq.criterions.tacotron2_loss import ( + Tacotron2Criterion, + Tacotron2CriterionConfig, +) + +logger = logging.getLogger(__name__) + + +class MultitaskCriterion: + def __init__(self, multitask_tasks, rdrop_alpha=0.0): + self.rdrop_alpha = rdrop_alpha + self.rdrop_alpha_mtl = rdrop_alpha + + self.multitask_criterion = OrderedDict() + self.multitask_loss_weight = OrderedDict() + for task_name, task_obj in multitask_tasks.items(): + if task_obj.args.get_loss_weight(0) == 0: + logger.info(f"Skip {task_name} loss criterion") + continue + + rdrop_alpha_task = task_obj.args.rdrop_alpha + if rdrop_alpha_task is None: + rdrop_alpha_task = rdrop_alpha + self.rdrop_alpha_mtl = rdrop_alpha_task + logger.info(f"rdrop_alpha is set to {rdrop_alpha_task} for {task_name}") + + if task_obj.args.decoder_type == "ctc": + self.multitask_criterion[task_name] = CtcCriterion( + task_obj.args.criterion_cfg, + task_obj, + rdrop_alpha=rdrop_alpha_task, + ) + else: + self.multitask_criterion[ + task_name + ] = RdropLabelSmoothedCrossEntropyCriterion( + task_obj, + task_obj.args.criterion_cfg.sentence_avg, + label_smoothing=task_obj.args.criterion_cfg.label_smoothing, + rdrop_alpha=rdrop_alpha_task, + ) + + def set_multitask_loss_weight(self, task_name, weight=0.0): + self.multitask_loss_weight[task_name] = weight + + def get_multitask_loss(self, model, sample, model_out): + logging_output = {} + loss = 0.0 + for task_name, task_criterion in self.multitask_criterion.items(): + layer_id = task_criterion.task.args.input_layer + if isinstance(task_criterion, CtcCriterion): + if task_criterion.task.args.input_from == "encoder": + if len(model_out["encoder_padding_mask"]) > 0: + non_padding_mask = ~model_out["encoder_padding_mask"][0] + input_lengths = non_padding_mask.long().sum(-1) + else: + out = model_out["encoder_states"][layer_id] + input_lengths = out.new_full( + (out.shape[1],), out.shape[0] + ).long() + + task_sample = { + "net_input": { + "src_tokens": model_out["encoder_states"][ + layer_id + ], # check batch idx + "src_lengths": input_lengths, + }, + "id": sample["id"], + } + else: + task_sample = { + "net_input": { + "src_tokens": model_out["inner_states"][layer_id], + "src_lengths": sample["target_lengths"], + }, + "id": sample["id"], + } + else: + task_sample = { + "net_input": { + "src_tokens": sample["multitask"][task_name]["net_input"][ + "prev_output_tokens" + ], + "encoder_out": { + "encoder_out": [model_out["encoder_states"][layer_id]], + "encoder_padding_mask": model_out["encoder_padding_mask"], + }, + } + } + + for key in ["target", "target_lengths", "ntokens"]: + task_sample[key] = sample["multitask"][task_name][key] + + if task_name == getattr(model, "mt_task_name", None): + decoder_out = model_out["mt_decoder_out"] + else: + decoder_out = None + task_loss, task_sample_size, task_logging_output = task_criterion( + model.multitask_decoders[task_name], task_sample, net_output=decoder_out + ) + + loss = loss + self.multitask_loss_weight[task_name] * task_loss + task_logging_output["loss_weight"] = self.multitask_loss_weight[task_name] + logging_output[task_name] = task_logging_output + return loss, logging_output + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + for task_name in logging_outputs[0]["multitask"].keys(): + # different criterion may return different logging + # currently only reduce on loss, the most common one + # ideally the way that losses are reduced should also depend on the task type + loss_sum = sum( + log["multitask"][task_name].get("loss", 0) for log in logging_outputs + ) + sample_size = sum( + log["multitask"][task_name].get("sample_size", 0) + for log in logging_outputs + ) + + metrics.log_scalar( + f"multitask_{task_name}_loss", + loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + + loss_weight = logging_outputs[0]["multitask"][task_name].get( + "loss_weight", 0 + ) + metrics.log_scalar( + f"multitask_{task_name}_loss_weight", + loss_weight, + weight=0, + priority=250, + ) + + +@register_criterion( + "speech_to_unit", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig +) +class SpeechToUnitMultitaskTaskCriterion( + RdropLabelSmoothedCrossEntropyCriterion, MultitaskCriterion +): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=0, + report_accuracy=False, + rdrop_alpha=0.0, + ): + super().__init__( + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + report_accuracy, + rdrop_alpha, + ) + MultitaskCriterion.__init__(self, task.multitask_tasks, rdrop_alpha) + + def forward(self, model, sample, reduce=True): + net_input_concat = { + "src_tokens": sample["net_input"]["src_tokens"], + "src_lengths": sample["net_input"]["src_lengths"], + "prev_output_tokens": sample["net_input"]["prev_output_tokens"], + "tgt_speaker": sample["net_input"].get("tgt_speaker", None), + "return_all_hiddens": True, + } + + if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0: + net_input_concat = duplicate_input(net_input_concat) + + net_output, extra = model(**net_input_concat) + loss, nll_loss, rdrop_kl_loss = self.compute_loss( + model, [net_output], sample, reduce=reduce + ) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, [net_output], sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + if self.rdrop_alpha > 0: + logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data) + + if len(self.multitask_criterion) == 0: + return loss, sample_size, logging_output + + # multitask + multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra) + loss += multitask_loss + logging_output["multitask"] = multitask_log + + return loss, sample_size, logging_output + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + super().reduce_metrics(logging_outputs) + + # inference metrics + if "targ_frames" in logging_outputs[0]: + n = sum(log.get("norm_frames", 0) for log in logging_outputs) + for key, new_key in [ + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), + ]: + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar(new_key, val / n, n, round=3) + + if "multitask" not in logging_outputs[0]: + return + + MultitaskCriterion.reduce_metrics(logging_outputs) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False + + +@register_criterion( + "speech_to_unit_2pass", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig +) +class SpeechToUnit2passMultitaskTaskCriterion(SpeechToUnitMultitaskTaskCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=0, + report_accuracy=False, + rdrop_alpha=0.0, + ): + super().__init__( + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + report_accuracy, + rdrop_alpha, + ) + + def forward(self, model, sample, reduce=True): + net_input_concat = { + "src_tokens": sample["net_input"]["src_tokens"], + "src_lengths": sample["net_input"]["src_lengths"], + "prev_output_tokens": sample["net_input"]["prev_output_tokens"], + "prev_output_tokens_mt": sample["multitask"][model.mt_task_name][ + "net_input" + ]["prev_output_tokens"], + "tgt_speaker": sample["net_input"].get("tgt_speaker", None), + "return_all_hiddens": True, + } + if getattr(model, "asr_task_name", None) is not None: + net_input_concat["prev_output_tokens_asr"] = sample["multitask"][ + model.asr_task_name + ]["net_input"]["prev_output_tokens"] + + if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0: + net_input_concat = duplicate_input(net_input_concat) + + net_output, extra = model(**net_input_concat) + loss, nll_loss, rdrop_kl_loss = self.compute_loss( + model, [net_output], sample, reduce=reduce + ) + + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, [net_output], sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + if self.rdrop_alpha > 0: + logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data) + + if len(self.multitask_criterion) == 0: + return loss, sample_size, logging_output + + # multitask + multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra) + loss += multitask_loss + logging_output["multitask"] = multitask_log + + return loss, sample_size, logging_output + + +@register_criterion("speech_to_spectrogram", dataclass=Tacotron2CriterionConfig) +class SpeechToSpectrogramMultitaskTaskCriterion(Tacotron2Criterion, MultitaskCriterion): + def __init__( + self, + task, + sentence_avg, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ): + super().__init__( + task, + sentence_avg, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ) + MultitaskCriterion.__init__(self, task.multitask_tasks) + + def forward(self, model, sample, reduction="mean"): + bsz, max_len, _ = sample["target"].size() + feat_tgt = sample["target"] + feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len) + eos_tgt = torch.arange(max_len).to(sample["target"].device) + eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1) + eos_tgt = (eos_tgt == (feat_len - 1)).float() + + feat_out, eos_out, extra = model( + src_tokens=sample["net_input"]["src_tokens"], + src_lengths=sample["net_input"]["src_lengths"], + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + tgt_speaker=sample["net_input"]["tgt_speaker"], + target_lengths=sample["target_lengths"], + return_all_hiddens=True, + ) + + l1_loss, mse_loss, eos_loss = self.compute_loss( + extra["feature_out"], + feat_out, + eos_out, + feat_tgt, + eos_tgt, + sample["target_lengths"], + reduction, + ) + attn_loss = torch.tensor(0.0).type_as(l1_loss) + if self.guided_attn is not None: + attn_loss = self.guided_attn( + extra["attn"], + sample["net_input"]["src_lengths"], + sample["target_lengths"], + reduction, + ) + loss = ( + l1_loss + mse_loss + eos_loss + attn_loss + ) # do not include ctc loss as there's no text target + + sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"] + logging_output = { + "loss": utils.item(loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "l1_loss": utils.item(l1_loss.data), + "mse_loss": utils.item(mse_loss.data), + "eos_loss": utils.item(eos_loss.data), + "attn_loss": utils.item(attn_loss.data), + } + + if len(self.multitask_criterion) == 0: + return loss, sample_size, logging_output + + # multitask + multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra) + loss += multitask_loss + logging_output["multitask"] = multitask_log + return loss, sample_size, logging_output + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + super().reduce_metrics(logging_outputs) + + # inference metrics + if "targ_frames" in logging_outputs[0]: + n = sum(log.get("norm_frames", 0) for log in logging_outputs) + for key, new_key in [ + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), + ]: + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar(new_key, val / n, n, round=3) + + if "multitask" not in logging_outputs[0]: + return + + MultitaskCriterion.reduce_metrics(logging_outputs) + + +@register_criterion("speech_to_spectrogram_2pass", dataclass=Tacotron2CriterionConfig) +class SpeechToSpectrogram2passMultitaskTaskCriterion( + SpeechToSpectrogramMultitaskTaskCriterion +): + def __init__( + self, + task, + sentence_avg, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ): + super().__init__( + task, + sentence_avg, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ) + + def forward(self, model, sample, reduction="mean"): + bsz, max_len, _ = sample["target"].size() + feat_tgt = sample["target"] + feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len) + eos_tgt = torch.arange(max_len).to(sample["target"].device) + eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1) + eos_tgt = (eos_tgt == (feat_len - 1)).float() + + feat_out, eos_out, extra = model( + src_tokens=sample["net_input"]["src_tokens"], + src_lengths=sample["net_input"]["src_lengths"], + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + prev_output_tokens_mt=sample["multitask"][model.mt_task_name]["net_input"][ + "prev_output_tokens" + ], + tgt_speaker=sample["net_input"]["tgt_speaker"], + target_lengths=sample["target_lengths"], + return_all_hiddens=True, + ) + + l1_loss, mse_loss, eos_loss = self.compute_loss( + extra["feature_out"], + feat_out, + eos_out, + feat_tgt, + eos_tgt, + sample["target_lengths"], + reduction, + ) + attn_loss = torch.tensor(0.0).type_as(l1_loss) + if self.guided_attn is not None: + attn_loss = self.guided_attn( + extra["attn"], + sample["net_input"]["src_lengths"], + sample["target_lengths"], + reduction, + ) + loss = ( + l1_loss + mse_loss + eos_loss + attn_loss + ) # do not include ctc loss as there's no text target + + sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"] + logging_output = { + "loss": utils.item(loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "l1_loss": utils.item(l1_loss.data), + "mse_loss": utils.item(mse_loss.data), + "eos_loss": utils.item(eos_loss.data), + "attn_loss": utils.item(attn_loss.data), + } + + if len(self.multitask_criterion) == 0: + return loss, sample_size, logging_output + + # multitask + multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra) + loss += multitask_loss + logging_output["multitask"] = multitask_log + return loss, sample_size, logging_output diff --git a/fairseq/fairseq/criterions/tacotron2_loss.py b/fairseq/fairseq/criterions/tacotron2_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4113fdc5489f1e0c787b60735086ae9d073c8e17 --- /dev/null +++ b/fairseq/fairseq/criterions/tacotron2_loss.py @@ -0,0 +1,227 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Any, Dict, List + +import torch +import torch.nn.functional as F +from omegaconf import II + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.data.data_utils import lengths_to_mask +from fairseq.dataclass import FairseqDataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class Tacotron2CriterionConfig(FairseqDataclass): + bce_pos_weight: float = field( + default=1.0, + metadata={"help": "weight of positive examples for BCE loss"}, + ) + use_guided_attention_loss: bool = field( + default=False, + metadata={"help": "use guided attention loss"}, + ) + guided_attention_loss_sigma: float = field( + default=0.4, + metadata={"help": "weight of positive examples for BCE loss"}, + ) + ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"}) + sentence_avg: bool = II("optimization.sentence_avg") + + +class GuidedAttentionLoss(torch.nn.Module): + """ + Efficiently Trainable Text-to-Speech System Based on Deep Convolutional + Networks with Guided Attention (https://arxiv.org/abs/1710.08969) + """ + + def __init__(self, sigma): + super().__init__() + self.sigma = sigma + + @staticmethod + @lru_cache(maxsize=8) + def _get_weight(s_len, t_len, sigma): + grid_x, grid_y = torch.meshgrid(torch.arange(t_len), torch.arange(s_len)) + grid_x = grid_x.to(s_len.device) + grid_y = grid_y.to(s_len.device) + w = (grid_y.float() / s_len - grid_x.float() / t_len) ** 2 + return 1.0 - torch.exp(-w / (2 * (sigma**2))) + + def _get_weights(self, src_lens, tgt_lens): + bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens) + weights = torch.zeros((bsz, max_t_len, max_s_len)) + for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)): + weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, self.sigma) + return weights + + @staticmethod + def _get_masks(src_lens, tgt_lens): + in_masks = lengths_to_mask(src_lens) + out_masks = lengths_to_mask(tgt_lens) + return out_masks.unsqueeze(2) & in_masks.unsqueeze(1) + + def forward(self, attn, src_lens, tgt_lens, reduction="mean"): + weights = self._get_weights(src_lens, tgt_lens).to(attn.device) + masks = self._get_masks(src_lens, tgt_lens).to(attn.device) + loss = (weights * attn.transpose(1, 2)).masked_select(masks) + loss = torch.sum(loss) if reduction == "sum" else torch.mean(loss) + return loss + + +@register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig) +class Tacotron2Criterion(FairseqCriterion): + def __init__( + self, + task, + sentence_avg, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ): + super().__init__(task) + self.sentence_avg = sentence_avg + self.bce_pos_weight = bce_pos_weight + + self.guided_attn = None + if use_guided_attention_loss: + self.guided_attn = GuidedAttentionLoss(guided_attention_loss_sigma) + self.ctc_weight = ctc_weight + + def forward(self, model, sample, reduction="mean"): + bsz, max_len, _ = sample["target"].size() + feat_tgt = sample["target"] + feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len) + eos_tgt = torch.arange(max_len).to(sample["target"].device) + eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1) + eos_tgt = (eos_tgt == (feat_len - 1)).float() + src_tokens = sample["net_input"]["src_tokens"] + src_lens = sample["net_input"]["src_lengths"] + tgt_lens = sample["target_lengths"] + + feat_out, eos_out, extra = model( + src_tokens=src_tokens, + src_lengths=src_lens, + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + incremental_state=None, + target_lengths=tgt_lens, + speaker=sample["speaker"], + ) + + l1_loss, mse_loss, eos_loss = self.compute_loss( + extra["feature_out"], + feat_out, + eos_out, + feat_tgt, + eos_tgt, + tgt_lens, + reduction, + ) + attn_loss = torch.tensor(0.0).type_as(l1_loss) + if self.guided_attn is not None: + attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens, reduction) + ctc_loss = torch.tensor(0.0).type_as(l1_loss) + if self.ctc_weight > 0.0: + net_output = (feat_out, eos_out, extra) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + lprobs = lprobs.transpose(0, 1) # T x B x C + src_mask = lengths_to_mask(src_lens) + src_tokens_flat = src_tokens.masked_select(src_mask) + ctc_loss = ( + F.ctc_loss( + lprobs, + src_tokens_flat, + tgt_lens, + src_lens, + reduction=reduction, + zero_infinity=True, + ) + * self.ctc_weight + ) + loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss + + sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"] + logging_output = { + "loss": utils.item(loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "l1_loss": utils.item(l1_loss.data), + "mse_loss": utils.item(mse_loss.data), + "eos_loss": utils.item(eos_loss.data), + "attn_loss": utils.item(attn_loss.data), + "ctc_loss": utils.item(ctc_loss.data), + } + return loss, sample_size, logging_output + + def compute_loss( + self, + feat_out, + feat_out_post, + eos_out, + feat_tgt, + eos_tgt, + tgt_lens, + reduction="mean", + ): + mask = lengths_to_mask(tgt_lens) + _eos_out = eos_out[mask].squeeze() + _eos_tgt = eos_tgt[mask] + _feat_tgt = feat_tgt[mask] + _feat_out = feat_out[mask] + _feat_out_post = feat_out_post[mask] + + l1_loss = F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + F.l1_loss( + _feat_out_post, _feat_tgt, reduction=reduction + ) + mse_loss = F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + F.mse_loss( + _feat_out_post, _feat_tgt, reduction=reduction + ) + eos_loss = F.binary_cross_entropy_with_logits( + _eos_out, + _eos_tgt, + pos_weight=torch.tensor(self.bce_pos_weight), + reduction=reduction, + ) + return l1_loss, mse_loss, eos_loss + + @classmethod + def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: + ns = [log.get("sample_size", 0) for log in logging_outputs] + ntot = sum(ns) + ws = [n / (ntot + 1e-8) for n in ns] + for key in ["loss", "l1_loss", "mse_loss", "eos_loss", "attn_loss", "ctc_loss"]: + vals = [log.get(key, 0) for log in logging_outputs] + val = sum(val * w for val, w in zip(vals, ws)) + metrics.log_scalar(key, val, ntot, round=3) + metrics.log_scalar("sample_size", ntot, len(logging_outputs)) + + # inference metrics + if "targ_frames" not in logging_outputs[0]: + return + n = sum(log.get("targ_frames", 0) for log in logging_outputs) + for key, new_key in [ + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), + ]: + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar(new_key, val / n, n, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + return False diff --git a/fairseq/fairseq/criterions/wav2vec_criterion.py b/fairseq/fairseq/criterions/wav2vec_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..3975468487704e053fe7634257f443ee2c396616 --- /dev/null +++ b/fairseq/fairseq/criterions/wav2vec_criterion.py @@ -0,0 +1,231 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.logging.meters import safe_round +from fairseq.utils import is_xla_tensor + + +@dataclass +class Wav2VecCriterionConfig(FairseqDataclass): + infonce: bool = field( + default=False, + metadata={ + "help": "if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)" + }, + ) + loss_weights: Optional[List[float]] = field( + default=None, + metadata={"help": "weights for additional loss terms (not first one)"}, + ) + log_keys: List[str] = field( + default_factory=lambda: [], + metadata={"help": "output keys to log"}, + ) + + +@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig) +class Wav2vecCriterion(FairseqCriterion): + def __init__(self, task, infonce=False, loss_weights=None, log_keys=None): + super().__init__(task) + self.infonce = infonce + self.loss_weights = loss_weights + self.log_keys = [] if log_keys is None else log_keys + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + logits = model.get_logits(net_output).float() + target = model.get_targets(sample, net_output) + self.xla = is_xla_tensor(logits) + + # XXX: handle weights on xla. + weights = None + if hasattr(model, "get_target_weights") and not self.infonce: + weights = model.get_target_weights(target, net_output) + if torch.is_tensor(weights): + weights = weights.float() + + losses = [] + + reduction = "none" if ((not reduce) or self.xla) else "sum" + if self.infonce: + loss = F.cross_entropy(logits, target, reduction=reduction) + else: + loss = F.binary_cross_entropy_with_logits( + logits, target.float(), weights, reduction=reduction + ) + + if self.xla: + # tpu-comment: since dynamic shapes lead to recompilations on xla, + # we don't shrink tensors using mask_indices. + # Instead, we use mask indices to adjust loss. + mi = ( + sample["net_input"]["mask_indices"] + .transpose(0, 1) # logits are transposed in `model.get_logits` + .reshape(logits.size(0)) + ) + loss = (loss * mi).sum() if reduce else (loss * mi) + + if "sample_size" in sample: + sample_size = sample["sample_size"] + elif "mask_indices" in sample["net_input"]: + sample_size = sample["net_input"]["mask_indices"].sum() + else: + sample_size = target.numel() if self.infonce else target.long().sum().item() + losses.append(loss.detach().clone()) + + if self.loss_weights is not None: + assert hasattr(model, "get_extra_losses") + extra_losses = model.get_extra_losses(net_output) + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + assert len(extra_losses) == len( + self.loss_weights + ), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, coef in zip(extra_losses, self.loss_weights): + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + losses.append(p) + + logging_output = { + "loss": loss.item() if (reduce and not self.xla) else loss.detach(), + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + } + + for lk in self.log_keys: + # Only store "logits" and "target" for computing MAP and MAUC + # during validation + if lk == "logits": + if not self.training: + logging_output["logits"] = logits.cpu().numpy() + elif lk == "target": + if not self.training: + # If the targets have been mixed with the predictions of + # teacher models, find the original targets + if hasattr(model, "get_original_targets"): + original_target = model.get_original_targets(sample, net_output) + else: + original_target = target + logging_output["target"] = original_target.cpu().numpy() + elif lk in net_output: + value = net_output[lk] + if not is_xla_tensor(value): + value = float(value) + logging_output[lk] = value + + if len(losses) > 1: + for i, l in enumerate(losses): + logging_output[f"loss_{i}"] = l.item() if not self.xla else l.detach() + + if self.infonce: + with torch.no_grad(): + if logits.numel() == 0: + corr = 0 + count = 0 + else: + assert logits.dim() > 1, logits.shape + max = logits.argmax(-1) == 0 + min = logits.argmin(-1) == 0 + if is_xla_tensor(logits): + max, min = max * mi, min * mi + both = max & min + corr = max.long().sum() - both.long().sum() + count = mi.sum() + else: + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = float(max.numel()) + + logging_output["correct"] = corr + logging_output["count"] = count + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) + + correct = sum(log.get("correct", 0) for log in logging_outputs) + metrics.log_scalar("_correct", correct) + + total = sum(log.get("count", 0) for log in logging_outputs) + metrics.log_scalar("_total", total) + + if total > 0: + metrics.log_derived( + "accuracy", + lambda meters: safe_round( + meters["_correct"].sum / meters["_total"].sum, 5 + ) + if meters["_total"].sum > 0 + else float("nan"), + ) + + builtin_keys = { + "loss", + "ntokens", + "nsentences", + "sample_size", + "correct", + "count", + } + + for k in logging_outputs[0]: + if k not in builtin_keys: + val = sum(log.get(k, 0) for log in logging_outputs) + if k.startswith("loss"): + metrics.log_scalar( + k, val / (sample_size or 1) / math.log(2), sample_size, round=3 + ) + else: + metrics.log_scalar(k, val / len(logging_outputs), round=3) + + # FIXME: revert when gather based xla reduction is implemented + # @staticmethod + # def logging_outputs_can_be_summed() -> bool: + def logging_outputs_can_be_summed(self) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + # XXX: Gather based reduction not implemented for xla yet. + # So we fall to sum based reduction for xla. + return self.xla diff --git a/fairseq/fairseq/data/__init__.py b/fairseq/fairseq/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eeaae2b2547dc830c95a6a4313a02d469d4f63cd --- /dev/null +++ b/fairseq/fairseq/data/__init__.py @@ -0,0 +1,137 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +from .dictionary import Dictionary, TruncatedDictionary + +from .fairseq_dataset import FairseqDataset, FairseqIterableDataset + +from .base_wrapper_dataset import BaseWrapperDataset + +from .add_target_dataset import AddTargetDataset +from .append_token_dataset import AppendTokenDataset +from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset +from .audio.hubert_dataset import HubertDataset +from .backtranslation_dataset import BacktranslationDataset +from .bucket_pad_length_dataset import BucketPadLengthDataset +from .colorize_dataset import ColorizeDataset +from .concat_dataset import ConcatDataset +from .concat_sentences_dataset import ConcatSentencesDataset +from .denoising_dataset import DenoisingDataset +from .id_dataset import IdDataset +from .indexed_dataset import ( + IndexedCachedDataset, + IndexedDataset, + IndexedRawTextDataset, + MMapIndexedDataset, +) +from .language_pair_dataset import LanguagePairDataset +from .list_dataset import ListDataset +from .lm_context_window_dataset import LMContextWindowDataset +from .lru_cache_dataset import LRUCacheDataset +from .mask_tokens_dataset import MaskTokensDataset +from .monolingual_dataset import MonolingualDataset +from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset +from .nested_dictionary_dataset import NestedDictionaryDataset +from .noising import NoisingDataset +from .numel_dataset import NumelDataset +from .num_samples_dataset import NumSamplesDataset +from .offset_tokens_dataset import OffsetTokensDataset +from .padding_mask_dataset import ( + LeftPaddingMaskDataset, + PaddingMaskDataset, + RightPaddingMaskDataset, +) +from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset +from .prepend_dataset import PrependDataset +from .prepend_token_dataset import PrependTokenDataset +from .raw_label_dataset import RawLabelDataset +from .replace_dataset import ReplaceDataset +from .resampling_dataset import ResamplingDataset +from .roll_dataset import RollDataset +from .round_robin_zip_datasets import RoundRobinZipDatasets +from .sort_dataset import SortDataset +from .speech_dlm_dataset import SpeechDLMDataset +from .strip_token_dataset import StripTokenDataset +from .subsample_dataset import SubsampleDataset +from .token_block_dataset import TokenBlockDataset +from .transform_eos_dataset import TransformEosDataset +from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset +from .shorten_dataset import TruncateDataset, RandomCropDataset +from .multilingual.sampled_multi_dataset import SampledMultiDataset +from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset +from .fasta_dataset import FastaDataset, EncodedFastaDataset +from .transform_eos_concat_langpair_dataset import TransformEosConcatLangPairDataset + +from .iterators import ( + CountingIterator, + EpochBatchIterator, + GroupedIterator, + ShardedIterator, +) + +__all__ = [ + "AddTargetDataset", + "AppendTokenDataset", + "BacktranslationDataset", + "BaseWrapperDataset", + "BinarizedAudioDataset", + "BucketPadLengthDataset", + "ColorizeDataset", + "ConcatDataset", + "ConcatSentencesDataset", + "CountingIterator", + "DenoisingDataset", + "Dictionary", + "EncodedFastaDataset", + "EpochBatchIterator", + "FairseqDataset", + "FairseqIterableDataset", + "FastaDataset", + "FileAudioDataset", + "GroupedIterator", + "HubertDataset", + "IdDataset", + "IndexedCachedDataset", + "IndexedDataset", + "IndexedRawTextDataset", + "LanguagePairDataset", + "LeftPadDataset", + "ListDataset", + "LMContextWindowDataset", + "LRUCacheDataset", + "MaskTokensDataset", + "MMapIndexedDataset", + "MonolingualDataset", + "MultiCorpusSampledDataset", + "NestedDictionaryDataset", + "NoisingDataset", + "NumelDataset", + "NumSamplesDataset", + "OffsetTokensDataset", + "PadDataset", + "PrependDataset", + "PrependTokenDataset", + "RandomCropDataset", + "RawLabelDataset", + "ResamplingDataset", + "ReplaceDataset", + "RightPadDataset", + "RollDataset", + "RoundRobinZipDatasets", + "SampledMultiDataset", + "SampledMultiEpochDataset", + "ShardedIterator", + "SortDataset", + "SpeechDLMDataset", + "StripTokenDataset", + "SubsampleDataset", + "TokenBlockDataset", + "TransformEosDataset", + "TransformEosLangPairDataset", + "TransformEosConcatLangPairDataset", + "TruncateDataset", + "TruncatedDictionary", +] diff --git a/fairseq/fairseq/data/add_class_target_dataset.py b/fairseq/fairseq/data/add_class_target_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bf89f2565662e42adeb7455fb07f5c81b209b93c --- /dev/null +++ b/fairseq/fairseq/data/add_class_target_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import BaseWrapperDataset, data_utils +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + + +class AddTargetDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + labels, + pad, + eos, + batch_targets, + process_label=None, + label_len_fn=None, + add_to_input=False, + text_compression_level=TextCompressionLevel.none, + ): + super().__init__(dataset) + self.labels = labels + self.batch_targets = batch_targets + self.pad = pad + self.eos = eos + self.process_label = process_label + self.label_len_fn = label_len_fn + self.add_to_input = add_to_input + self.text_compressor = TextCompressor(level=text_compression_level) + + def get_label(self, index, process_fn=None): + lbl = self.labels[index] + lbl = self.text_compressor.decompress(lbl) + return lbl if process_fn is None else process_fn(lbl) + + def __getitem__(self, index): + item = self.dataset[index] + item["label"] = self.get_label(index, process_fn=self.process_label) + return item + + def size(self, index): + sz = self.dataset.size(index) + own_sz = self.label_len_fn(self.get_label(index)) + return sz, own_sz + + def collater(self, samples): + collated = self.dataset.collater(samples) + if len(collated) == 0: + return collated + indices = set(collated["id"].tolist()) + target = [s["label"] for s in samples if s["id"] in indices] + + if self.batch_targets: + collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) + target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) + collated["ntokens"] = collated["target_lengths"].sum().item() + else: + collated["ntokens"] = sum([len(t) for t in target]) + + collated["target"] = target + + if self.add_to_input: + eos = target.new_full((target.size(0), 1), self.eos) + collated["target"] = torch.cat([target, eos], dim=-1).long() + collated["net_input"]["prev_output_tokens"] = torch.cat( + [eos, target], dim=-1 + ).long() + collated["ntokens"] += target.size(0) + return collated + + def filter_indices_by_size(self, indices, max_sizes): + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + return indices, ignored diff --git a/fairseq/fairseq/data/add_target_dataset.py b/fairseq/fairseq/data/add_target_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..978a5b1903cf51683fa03fcbb68b44cf23f93a08 --- /dev/null +++ b/fairseq/fairseq/data/add_target_dataset.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import BaseWrapperDataset, data_utils +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + + +class AddTargetDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + labels, + pad, + eos, + batch_targets, + process_label=None, + label_len_fn=None, + add_to_input=False, + text_compression_level=TextCompressionLevel.none, + ): + super().__init__(dataset) + self.labels = labels + self.batch_targets = batch_targets + self.pad = pad + self.eos = eos + self.process_label = process_label + self.label_len_fn = label_len_fn + self.add_to_input = add_to_input + self.text_compressor = TextCompressor(level=text_compression_level) + + def get_label(self, index, process_fn=None): + lbl = self.labels[index] + lbl = self.text_compressor.decompress(lbl) + return lbl if process_fn is None else process_fn(lbl) + + def __getitem__(self, index): + item = self.dataset[index] + item["label"] = self.get_label(index, process_fn=self.process_label) + return item + + def size(self, index): + sz = self.dataset.size(index) + own_sz = self.label_len_fn(self.get_label(index)) + return sz, own_sz + + def collater(self, samples): + collated = self.dataset.collater(samples) + if len(collated) == 0: + return collated + indices = set(collated["id"].tolist()) + target = [s["label"] for s in samples if s["id"] in indices] + + if self.add_to_input: + eos = torch.LongTensor([self.eos]) + prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target] + target = [torch.cat([t, eos], axis=-1) for t in target] + collated["net_input"]["prev_output_tokens"] = prev_output_tokens + + if self.batch_targets: + collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) + target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) + collated["ntokens"] = collated["target_lengths"].sum().item() + if getattr(collated["net_input"], "prev_output_tokens", None): + collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens( + collated["net_input"]["prev_output_tokens"], + pad_idx=self.pad, + left_pad=False, + ) + else: + collated["ntokens"] = sum([len(t) for t in target]) + + collated["target"] = target + return collated + + def filter_indices_by_size(self, indices, max_sizes): + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + return indices, ignored diff --git a/fairseq/fairseq/data/append_token_dataset.py b/fairseq/fairseq/data/append_token_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..87695bd0f5fcb6b10247e3b743340623e6438cc1 --- /dev/null +++ b/fairseq/fairseq/data/append_token_dataset.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import BaseWrapperDataset + + +class AppendTokenDataset(BaseWrapperDataset): + def __init__(self, dataset, token=None): + super().__init__(dataset) + self.token = token + if token is not None: + self._sizes = np.array(dataset.sizes) + 1 + else: + self._sizes = dataset.sizes + + def __getitem__(self, idx): + item = self.dataset[idx] + if self.token is not None: + item = torch.cat([item, item.new([self.token])]) + return item + + @property + def sizes(self): + return self._sizes + + def num_tokens(self, index): + n = self.dataset.num_tokens(index) + if self.token is not None: + n += 1 + return n + + def size(self, index): + n = self.dataset.size(index) + if self.token is not None: + n += 1 + return n diff --git a/fairseq/fairseq/data/audio/dataset_transforms/__pycache__/concataugment.cpython-310.pyc b/fairseq/fairseq/data/audio/dataset_transforms/__pycache__/concataugment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5345ec1345957c3459fd3fb4bf5e881521ef22cd Binary files /dev/null and b/fairseq/fairseq/data/audio/dataset_transforms/__pycache__/concataugment.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/audio/feature_transforms/__pycache__/delta_deltas.cpython-310.pyc b/fairseq/fairseq/data/audio/feature_transforms/__pycache__/delta_deltas.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6325d57246e46b4232d5578550001f709eb33fb7 Binary files /dev/null and b/fairseq/fairseq/data/audio/feature_transforms/__pycache__/delta_deltas.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/audio/feature_transforms/global_cmvn.py b/fairseq/fairseq/data/audio/feature_transforms/global_cmvn.py new file mode 100644 index 0000000000000000000000000000000000000000..e457ff176fee3b996da11f47e7dc61b81c445ba3 --- /dev/null +++ b/fairseq/fairseq/data/audio/feature_transforms/global_cmvn.py @@ -0,0 +1,29 @@ +import numpy as np +from fairseq.data.audio.feature_transforms import ( + AudioFeatureTransform, + register_audio_feature_transform, +) + + +@register_audio_feature_transform("global_cmvn") +class GlobalCMVN(AudioFeatureTransform): + """Global CMVN (cepstral mean and variance normalization). The global mean + and variance need to be pre-computed and stored in NumPy format (.npz).""" + + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return GlobalCMVN(_config.get("stats_npz_path")) + + def __init__(self, stats_npz_path): + self.stats_npz_path = stats_npz_path + stats = np.load(stats_npz_path) + self.mean, self.std = stats["mean"], stats["std"] + + def __repr__(self): + return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")' + + def __call__(self, x): + x = np.subtract(x, self.mean) + x = np.divide(x, self.std) + return x diff --git a/fairseq/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/fairseq/data/audio/speech_to_text_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf71558fdf4c9e79c3a5b271363d96b4244216d --- /dev/null +++ b/fairseq/fairseq/data/audio/speech_to_text_dataset.py @@ -0,0 +1,733 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import csv +import logging +import re +from argparse import Namespace +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset +from fairseq.data import data_utils as fairseq_data_utils +from fairseq.data import encoders +from fairseq.data.audio.audio_utils import get_features_or_waveform +from fairseq.data.audio.data_cfg import S2TDataConfig +from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform +from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment +from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import ( + NoisyOverlapAugment, +) +from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform +from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform + +logger = logging.getLogger(__name__) + + +def _collate_frames( + frames: List[torch.Tensor], is_audio_input: bool = False +) -> torch.Tensor: + """ + Convert a list of 2D frames into a padded 3D tensor + Args: + frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is + length of i-th frame and f_dim is static dimension of features + Returns: + 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] + """ + max_len = max(frame.size(0) for frame in frames) + if is_audio_input: + out = frames[0].new_zeros((len(frames), max_len)) + else: + out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) + for i, v in enumerate(frames): + out[i, : v.size(0)] = v + return out + + +def _is_int_or_np_int(n): + return isinstance(n, int) or ( + isinstance(n, np.generic) and isinstance(n.item(), int) + ) + + +@dataclass +class SpeechToTextDatasetItem(object): + index: int + source: torch.Tensor + target: Optional[torch.Tensor] = None + speaker_id: Optional[int] = None + + +class SpeechToTextDataset(FairseqDataset): + LANG_TAG_TEMPLATE = "" + + def __init__( + self, + split: str, + is_train_split: bool, + cfg: S2TDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, + n_frames_per_step=1, + speaker_to_id=None, + append_eos=True, + ): + self.split, self.is_train_split = split, is_train_split + self.cfg = cfg + self.audio_paths, self.n_frames = audio_paths, n_frames + self.n_samples = len(audio_paths) + assert len(n_frames) == self.n_samples > 0 + assert src_texts is None or len(src_texts) == self.n_samples + assert tgt_texts is None or len(tgt_texts) == self.n_samples + assert speakers is None or len(speakers) == self.n_samples + assert src_langs is None or len(src_langs) == self.n_samples + assert tgt_langs is None or len(tgt_langs) == self.n_samples + assert ids is None or len(ids) == self.n_samples + assert (tgt_dict is None and tgt_texts is None) or ( + tgt_dict is not None and tgt_texts is not None + ) + self.src_texts, self.tgt_texts = src_texts, tgt_texts + self.src_langs, self.tgt_langs = src_langs, tgt_langs + self.speakers = speakers + self.tgt_dict = tgt_dict + self.check_tgt_lang_tag() + self.ids = ids + self.shuffle = cfg.shuffle if is_train_split else False + + self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict( + self.cfg.get_feature_transforms(split, is_train_split) + ) + self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict( + self.cfg.get_waveform_transforms(split, is_train_split) + ) + # TODO: add these to data_cfg.py + self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict( + self.cfg.get_dataset_transforms(split, is_train_split) + ) + + # check proper usage of transforms + if self.feature_transforms and self.cfg.use_audio_input: + logger.warning( + "Feature transforms will not be applied. To use feature transforms, " + "set use_audio_input as False in config." + ) + + self.pre_tokenizer = pre_tokenizer + self.bpe_tokenizer = bpe_tokenizer + self.n_frames_per_step = n_frames_per_step + self.speaker_to_id = speaker_to_id + + self.tgt_lens = self.get_tgt_lens_and_check_oov() + self.append_eos = append_eos + + logger.info(self.__repr__()) + + def get_tgt_lens_and_check_oov(self): + if self.tgt_texts is None: + return [0 for _ in range(self.n_samples)] + tgt_lens = [] + n_tokens, n_oov_tokens = 0, 0 + for i in range(self.n_samples): + tokenized = self.get_tokenized_tgt_text(i).split(" ") + oov_tokens = [ + t + for t in tokenized + if self.tgt_dict.index(t) == self.tgt_dict.unk_index + ] + n_tokens += len(tokenized) + n_oov_tokens += len(oov_tokens) + tgt_lens.append(len(tokenized)) + logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV") + return tgt_lens + + def __repr__(self): + return ( + self.__class__.__name__ + + f'(split="{self.split}", n_samples={self.n_samples:_}, ' + f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, " + f"n_frames_per_step={self.n_frames_per_step}, " + f"shuffle={self.shuffle}, " + f"feature_transforms={self.feature_transforms}, " + f"waveform_transforms={self.waveform_transforms}, " + f"dataset_transforms={self.dataset_transforms})" + ) + + @classmethod + def is_lang_tag(cls, token): + pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") + return re.match(pattern, token) + + def check_tgt_lang_tag(self): + if self.cfg.prepend_tgt_lang_tag: + assert self.tgt_langs is not None and self.tgt_dict is not None + tgt_lang_tags = [ + self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs) + ] + assert all(t in self.tgt_dict for t in tgt_lang_tags) + + @classmethod + def tokenize(cls, tokenizer, text: str): + return text if tokenizer is None else tokenizer.encode(text) + + def get_tokenized_tgt_text(self, index: Union[int, List[int]]): + if _is_int_or_np_int(index): + text = self.tgt_texts[index] + else: + text = " ".join([self.tgt_texts[i] for i in index]) + + text = self.tokenize(self.pre_tokenizer, text) + text = self.tokenize(self.bpe_tokenizer, text) + return text + + def pack_frames(self, feature: torch.Tensor): + if self.n_frames_per_step == 1: + return feature + n_packed_frames = feature.shape[0] // self.n_frames_per_step + feature = feature[: self.n_frames_per_step * n_packed_frames] + return feature.reshape(n_packed_frames, -1) + + @classmethod + def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary): + lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang)) + assert lang_tag_idx != dictionary.unk() + return lang_tag_idx + + def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor: + """ + Gives source audio for given index with any relevant transforms + applied. For ConcatAug, source audios for given indices are + concatenated in given order. + Args: + index (int or List[int]): index—or in the case of ConcatAug, + indices—to pull the source audio for + Returns: + source audios concatenated for given indices with + relevant transforms appplied + """ + if _is_int_or_np_int(index): + source = get_features_or_waveform( + self.audio_paths[index], + need_waveform=self.cfg.use_audio_input, + use_sample_rate=self.cfg.use_sample_rate, + waveform_transforms=self.waveform_transforms, + ) + else: + source = np.concatenate( + [ + get_features_or_waveform( + self.audio_paths[i], + need_waveform=self.cfg.use_audio_input, + use_sample_rate=self.cfg.use_sample_rate, + waveform_transforms=self.waveform_transforms, + ) + for i in index + ] + ) + if self.cfg.use_audio_input: + source = torch.from_numpy(source).float() + if self.cfg.standardize_audio: + with torch.no_grad(): + source = F.layer_norm(source, source.shape) + else: + if self.feature_transforms is not None: + source = self.feature_transforms(source) + source = torch.from_numpy(source).float() + return source + + def __getitem__(self, index: int) -> SpeechToTextDatasetItem: + has_concat = self.dataset_transforms.has_transform(ConcatAugment) + if has_concat: + concat = self.dataset_transforms.get_transform(ConcatAugment) + indices = concat.find_indices(index, self.n_frames, self.n_samples) + + source = self._get_source_audio(indices if has_concat else index) + source = self.pack_frames(source) + + target = None + if self.tgt_texts is not None: + tokenized = self.get_tokenized_tgt_text(indices if has_concat else index) + target = self.tgt_dict.encode_line( + tokenized, add_if_not_exist=False, append_eos=self.append_eos + ).long() + if self.cfg.prepend_tgt_lang_tag: + lang_tag_idx = self.get_lang_tag_idx( + self.tgt_langs[index], self.tgt_dict + ) + target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) + + if self.cfg.prepend_bos_and_append_tgt_lang_tag: + bos = torch.LongTensor([self.tgt_dict.bos()]) + lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict) + assert lang_tag_idx != self.tgt_dict.unk() + lang_tag_idx = torch.LongTensor([lang_tag_idx]) + target = torch.cat((bos, target, lang_tag_idx), 0) + + speaker_id = None + if self.speaker_to_id is not None: + speaker_id = self.speaker_to_id[self.speakers[index]] + return SpeechToTextDatasetItem( + index=index, source=source, target=target, speaker_id=speaker_id + ) + + def __len__(self): + return self.n_samples + + def collater( + self, samples: List[SpeechToTextDatasetItem], return_order: bool = False + ) -> Dict: + if len(samples) == 0: + return {} + indices = torch.tensor([x.index for x in samples], dtype=torch.long) + + sources = [x.source for x in samples] + has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment) + if has_NOAug and self.cfg.use_audio_input: + NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment) + sources = NOAug(sources) + + frames = _collate_frames(sources, self.cfg.use_audio_input) + # sort samples by descending number of frames + n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long) + n_frames, order = n_frames.sort(descending=True) + indices = indices.index_select(0, order) + frames = frames.index_select(0, order) + + target, target_lengths = None, None + prev_output_tokens = None + ntokens = None + if self.tgt_texts is not None: + target = fairseq_data_utils.collate_tokens( + [x.target for x in samples], + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=False, + ) + target = target.index_select(0, order) + target_lengths = torch.tensor( + [x.target.size(0) for x in samples], dtype=torch.long + ).index_select(0, order) + prev_output_tokens = fairseq_data_utils.collate_tokens( + [x.target for x in samples], + self.tgt_dict.pad(), + eos_idx=None, + left_pad=False, + move_eos_to_beginning=True, + ) + prev_output_tokens = prev_output_tokens.index_select(0, order) + ntokens = sum(x.target.size(0) for x in samples) + + speaker = None + if self.speaker_to_id is not None: + speaker = ( + torch.tensor([s.speaker_id for s in samples], dtype=torch.long) + .index_select(0, order) + .view(-1, 1) + ) + + net_input = { + "src_tokens": frames, + "src_lengths": n_frames, + "prev_output_tokens": prev_output_tokens, + } + out = { + "id": indices, + "net_input": net_input, + "speaker": speaker, + "target": target, + "target_lengths": target_lengths, + "ntokens": ntokens, + "nsentences": len(samples), + } + if return_order: + out["order"] = order + return out + + def num_tokens(self, index): + return self.n_frames[index] + + def size(self, index): + return self.n_frames[index], self.tgt_lens[index] + + @property + def sizes(self): + return np.array(self.n_frames) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + # first by descending order of # of frames then by original/random order + order.append([-n for n in self.n_frames]) + return np.lexsort(order) + + def prefetch(self, indices): + raise False + + +class TextTargetMultitaskData(object): + # mandatory columns + KEY_ID, KEY_TEXT = "id", "tgt_text" + LANG_TAG_TEMPLATE = "" + + def __init__(self, args, split, tgt_dict): + samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split) + self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples} + self.dict = tgt_dict + self.append_eos = args.decoder_type != "ctc" + self.pre_tokenizer = self.build_tokenizer(args) + self.bpe_tokenizer = self.build_bpe(args) + self.prepend_bos_and_append_tgt_lang_tag = ( + args.prepend_bos_and_append_tgt_lang_tag + ) + self.eos_token = args.eos_token + self.lang_tag_mapping = args.get_lang_tag_mapping + + @classmethod + def is_lang_tag(cls, token): + pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") + return re.match(pattern, token) + + @classmethod + def tokenize(cls, tokenizer, text: str): + return text if tokenizer is None else tokenizer.encode(text) + + def get_tokenized_tgt_text(self, index: int): + text = self.tokenize(self.pre_tokenizer, self.data[index]) + text = self.tokenize(self.bpe_tokenizer, text) + return text + + def get_lang_tag_idx(self, lang: str, dictionary: Dictionary): + lang_tag = self.LANG_TAG_TEMPLATE.format(lang) + lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag) + lang_tag_idx = dictionary.index(lang_tag) + assert lang_tag_idx != dictionary.unk(), (lang, lang_tag) + return lang_tag_idx + + def build_tokenizer(self, args): + pre_tokenizer = args.config.get("pre_tokenizer") + if pre_tokenizer is not None: + logger.info(f"pre-tokenizer: {pre_tokenizer}") + return encoders.build_tokenizer(Namespace(**pre_tokenizer)) + else: + return None + + def build_bpe(self, args): + bpe_tokenizer = args.config.get("bpe_tokenizer") + if bpe_tokenizer is not None: + logger.info(f"tokenizer: {bpe_tokenizer}") + return encoders.build_bpe(Namespace(**bpe_tokenizer)) + else: + return None + + def get(self, sample_id, tgt_lang=None): + if sample_id in self.data: + tokenized = self.get_tokenized_tgt_text(sample_id) + target = self.dict.encode_line( + tokenized, + add_if_not_exist=False, + append_eos=self.append_eos, + ) + if self.prepend_bos_and_append_tgt_lang_tag: + bos = torch.LongTensor([self.dict.bos()]) + lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict) + assert lang_tag_idx != self.dict.unk() + lang_tag_idx = torch.LongTensor([lang_tag_idx]) + target = torch.cat((bos, target, lang_tag_idx), 0) + return target + else: + logger.warning(f"no target for {sample_id}") + return torch.IntTensor([]) + + def collater(self, samples: List[torch.Tensor]) -> torch.Tensor: + out = fairseq_data_utils.collate_tokens( + samples, + self.dict.pad(), + eos_idx=None, + left_pad=False, + move_eos_to_beginning=False, + ).long() + + prev_out = fairseq_data_utils.collate_tokens( + samples, + self.dict.pad(), + eos_idx=None, + left_pad=False, + move_eos_to_beginning=True, + ).long() + + target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long) + ntokens = sum(t.size(0) for t in samples) + + output = { + "prev_output_tokens": prev_out, + "target": out, + "target_lengths": target_lengths, + "ntokens": ntokens, + } + + return output + + +class SpeechToTextMultitaskDataset(SpeechToTextDataset): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.multitask_data = {} + + def add_multitask_dataset(self, task_name, task_data): + self.multitask_data[task_name] = task_data + + def __getitem__( + self, index: int + ) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]: + s2t_data = super().__getitem__(index) + + multitask_target = {} + sample_id = self.ids[index] + tgt_lang = self.tgt_langs[index] + for task_name, task_dataset in self.multitask_data.items(): + multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang) + + return s2t_data, multitask_target + + def collater( + self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]] + ) -> Dict: + if len(samples) == 0: + return {} + + out = super().collater([s for s, _ in samples], return_order=True) + order = out["order"] + del out["order"] + + for task_name, task_dataset in self.multitask_data.items(): + if "multitask" not in out: + out["multitask"] = {} + d = [s[task_name] for _, s in samples] + task_target = task_dataset.collater(d) + out["multitask"][task_name] = { + "target": task_target["target"].index_select(0, order), + "target_lengths": task_target["target_lengths"].index_select(0, order), + "ntokens": task_target["ntokens"], + } + out["multitask"][task_name]["net_input"] = { + "prev_output_tokens": task_target["prev_output_tokens"].index_select( + 0, order + ), + } + + return out + + +class SpeechToTextDatasetCreator(object): + # mandatory columns + KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames" + KEY_TGT_TEXT = "tgt_text" + # optional columns + KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text" + KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang" + # default values + DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = "" + + @classmethod + def _from_list( + cls, + split_name: str, + is_train_split, + samples: List[Dict], + cfg: S2TDataConfig, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, + multitask: Optional[Dict] = None, + ) -> SpeechToTextDataset: + audio_root = Path(cfg.audio_root) + ids = [s[cls.KEY_ID] for s in samples] + audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] + n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] + tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] + src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] + speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] + src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] + tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] + + has_multitask = multitask is not None and len(multitask.keys()) > 0 + dataset_cls = ( + SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset + ) + + ds = dataset_cls( + split=split_name, + is_train_split=is_train_split, + cfg=cfg, + audio_paths=audio_paths, + n_frames=n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id, + ) + + if has_multitask: + for task_name, task_obj in multitask.items(): + task_data = TextTargetMultitaskData( + task_obj.args, split_name, task_obj.target_dictionary + ) + ds.add_multitask_dataset(task_name, task_data) + return ds + + @classmethod + def get_size_ratios( + cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0 + ) -> List[float]: + """Size ratios for temperature-based sampling + (https://arxiv.org/abs/1907.05019)""" + + id_to_lp, lp_to_sz = {}, defaultdict(int) + for ds in datasets: + lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)} + assert len(lang_pairs) == 1 + lang_pair = list(lang_pairs)[0] + id_to_lp[ds.split] = lang_pair + lp_to_sz[lang_pair] += sum(ds.n_frames) + + sz_sum = sum(v for v in lp_to_sz.values()) + lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()} + lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()} + prob_sum = sum(v for v in lp_to_tgt_prob.values()) + lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()} + lp_to_sz_ratio = { + k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items() + } + size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets] + + p_formatted = { + k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz + } + logger.info(f"sampling probability balancing: {p_formatted}") + sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)} + logger.info(f"balanced sampling size ratio: {sr_formatted}") + return size_ratio + + @classmethod + def _load_samples_from_tsv(cls, root: str, split: str): + tsv_path = Path(root) / f"{split}.tsv" + if not tsv_path.is_file(): + raise FileNotFoundError(f"Dataset not found: {tsv_path}") + with open(tsv_path) as f: + reader = csv.DictReader( + f, + delimiter="\t", + quotechar=None, + doublequote=False, + lineterminator="\n", + quoting=csv.QUOTE_NONE, + ) + samples = [dict(e) for e in reader] + if len(samples) == 0: + raise ValueError(f"Empty manifest: {tsv_path}") + return samples + + @classmethod + def _from_tsv( + cls, + root: str, + cfg: S2TDataConfig, + split: str, + tgt_dict, + is_train_split: bool, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, + multitask: Optional[Dict] = None, + ) -> SpeechToTextDataset: + samples = cls._load_samples_from_tsv(root, split) + return cls._from_list( + split, + is_train_split, + samples, + cfg, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, + multitask, + ) + + @classmethod + def from_tsv( + cls, + root: str, + cfg: S2TDataConfig, + splits: str, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split: bool, + epoch: int, + seed: int, + n_frames_per_step: int = 1, + speaker_to_id=None, + multitask: Optional[Dict] = None, + ) -> SpeechToTextDataset: + datasets = [ + cls._from_tsv( + root=root, + cfg=cfg, + split=split, + tgt_dict=tgt_dict, + is_train_split=is_train_split, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id, + multitask=multitask, + ) + for split in splits.split(",") + ] + + if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: + # temperature-based sampling + size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) + datasets = [ + ResamplingDataset( + d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) + ) + for r, d in zip(size_ratios, datasets) + ] + + return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] diff --git a/fairseq/fairseq/data/backtranslation_dataset.py b/fairseq/fairseq/data/backtranslation_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8f70c90df3d237077537993e125d366c95292f1a --- /dev/null +++ b/fairseq/fairseq/data/backtranslation_dataset.py @@ -0,0 +1,165 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq import utils + +from . import FairseqDataset + + +def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True): + """Backtranslate a list of samples. + + Given an input (*samples*) of the form: + + [{'id': 1, 'source': 'hallo welt'}] + + this will return: + + [{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}] + + Args: + samples (List[dict]): samples to backtranslate. Individual samples are + expected to have a 'source' key, which will become the 'target' + after backtranslation. + collate_fn (callable): function to collate samples into a mini-batch + generate_fn (callable): function to generate backtranslations + cuda (bool): use GPU for generation (default: ``True``) + + Returns: + List[dict]: an updated list of samples with a backtranslated source + """ + collated_samples = collate_fn(samples) + s = utils.move_to_cuda(collated_samples) if cuda else collated_samples + generated_sources = generate_fn(s) + + id_to_src = {sample["id"]: sample["source"] for sample in samples} + + # Go through each tgt sentence in batch and its corresponding best + # generated hypothesis and create a backtranslation data pair + # {id: id, source: generated backtranslation, target: original tgt} + return [ + { + "id": id.item(), + "target": id_to_src[id.item()], + "source": hypos[0]["tokens"].cpu(), + } + for id, hypos in zip(collated_samples["id"], generated_sources) + ] + + +class BacktranslationDataset(FairseqDataset): + """ + Sets up a backtranslation dataset which takes a tgt batch, generates + a src using a tgt-src backtranslation function (*backtranslation_fn*), + and returns the corresponding `{generated src, input tgt}` batch. + + Args: + tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be + backtranslated. Only the source side of this dataset will be used. + After backtranslation, the source sentences in this dataset will be + returned as the targets. + src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated + sentences. + tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of + sentences to be backtranslated. + backtranslation_fn (callable, optional): function to call to generate + backtranslations. This is typically the `generate` method of a + :class:`~fairseq.sequence_generator.SequenceGenerator` object. + Pass in None when it is not available at initialization time, and + use set_backtranslation_fn function to set it when available. + output_collater (callable, optional): function to call on the + backtranslated samples to create the final batch + (default: ``tgt_dataset.collater``). + cuda: use GPU for generation + """ + + def __init__( + self, + tgt_dataset, + src_dict, + tgt_dict=None, + backtranslation_fn=None, + output_collater=None, + cuda=True, + **kwargs + ): + self.tgt_dataset = tgt_dataset + self.backtranslation_fn = backtranslation_fn + self.output_collater = ( + output_collater if output_collater is not None else tgt_dataset.collater + ) + self.cuda = cuda if torch.cuda.is_available() else False + self.src_dict = src_dict + self.tgt_dict = tgt_dict + + def __getitem__(self, index): + """ + Returns a single sample from *tgt_dataset*. Note that backtranslation is + not applied in this step; use :func:`collater` instead to backtranslate + a batch of samples. + """ + return self.tgt_dataset[index] + + def __len__(self): + return len(self.tgt_dataset) + + def set_backtranslation_fn(self, backtranslation_fn): + self.backtranslation_fn = backtranslation_fn + + def collater(self, samples): + """Merge and backtranslate a list of samples to form a mini-batch. + + Using the samples from *tgt_dataset*, load a collated target sample to + feed to the backtranslation model. Then take the backtranslation with + the best score as the source and the original input as the target. + + Note: we expect *tgt_dataset* to provide a function `collater()` that + will collate samples into the format expected by *backtranslation_fn*. + After backtranslation, we will feed the new list of samples (i.e., the + `(backtranslated source, original source)` pairs) to *output_collater* + and return the result. + + Args: + samples (List[dict]): samples to backtranslate and collate + + Returns: + dict: a mini-batch with keys coming from *output_collater* + """ + if samples[0].get("is_dummy", False): + return samples + samples = backtranslate_samples( + samples=samples, + collate_fn=self.tgt_dataset.collater, + generate_fn=(lambda net_input: self.backtranslation_fn(net_input)), + cuda=self.cuda, + ) + return self.output_collater(samples) + + def num_tokens(self, index): + """Just use the tgt dataset num_tokens""" + return self.tgt_dataset.num_tokens(index) + + def ordered_indices(self): + """Just use the tgt dataset ordered_indices""" + return self.tgt_dataset.ordered_indices() + + def size(self, index): + """Return an example's size as a float or tuple. This value is used + when filtering a dataset with ``--max-positions``. + + Note: we use *tgt_dataset* to approximate the length of the source + sentence, since we do not know the actual length until after + backtranslation. + """ + tgt_size = self.tgt_dataset.size(index)[0] + return (tgt_size, tgt_size) + + @property + def supports_prefetch(self): + return getattr(self.tgt_dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.tgt_dataset.prefetch(indices) diff --git a/fairseq/fairseq/data/base_wrapper_dataset.py b/fairseq/fairseq/data/base_wrapper_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..134d398b47dc73c8807759188504aee205b3b34d --- /dev/null +++ b/fairseq/fairseq/data/base_wrapper_dataset.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch.utils.data.dataloader import default_collate + +from . import FairseqDataset + + +class BaseWrapperDataset(FairseqDataset): + def __init__(self, dataset): + super().__init__() + self.dataset = dataset + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if hasattr(self.dataset, "collater"): + return self.dataset.collater(samples) + else: + return default_collate(samples) + + @property + def sizes(self): + return self.dataset.sizes + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + return self.dataset.size(index) + + def ordered_indices(self): + return self.dataset.ordered_indices() + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def attr(self, attr: str, index: int): + return self.dataset.attr(attr, index) + + def prefetch(self, indices): + self.dataset.prefetch(indices) + + def get_batch_shapes(self): + return self.dataset.get_batch_shapes() + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + return self.dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + + def filter_indices_by_size(self, indices, max_sizes): + return self.dataset.filter_indices_by_size(indices, max_sizes) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return self.dataset.can_reuse_epoch_itr_across_epochs + + def set_epoch(self, epoch): + super().set_epoch(epoch) + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) diff --git a/fairseq/fairseq/data/bucket_pad_length_dataset.py b/fairseq/fairseq/data/bucket_pad_length_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0f9410014845873bb0344fca6478c231c88e9dea --- /dev/null +++ b/fairseq/fairseq/data/bucket_pad_length_dataset.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch.nn.functional as F +from fairseq.data import BaseWrapperDataset +from fairseq.data.data_utils import get_buckets, get_bucketed_sizes + + +class BucketPadLengthDataset(BaseWrapperDataset): + """ + Bucket and pad item lengths to the nearest bucket size. This can be used to + reduce the number of unique batch shapes, which is important on TPUs since + each new batch shape requires a recompilation. + + Args: + dataset (FairseqDatset): dataset to bucket + sizes (List[int]): all item sizes + num_buckets (int): number of buckets to create + pad_idx (int): padding symbol + left_pad (bool): if True, pad on the left; otherwise right pad + """ + + def __init__( + self, + dataset, + sizes, + num_buckets, + pad_idx, + left_pad, + tensor_key=None, + ): + super().__init__(dataset) + self.pad_idx = pad_idx + self.left_pad = left_pad + + assert num_buckets > 0 + self.buckets = get_buckets(sizes, num_buckets) + self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + self._tensor_key = tensor_key + + def _set_tensor(self, item, val): + if self._tensor_key is None: + return val + item[self._tensor_key] = val + return item + + def _get_tensor(self, item): + if self._tensor_key is None: + return item + return item[self._tensor_key] + + def _pad(self, tensor, bucket_size, dim=-1): + num_pad = bucket_size - tensor.size(dim) + return F.pad( + tensor, + (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), + value=self.pad_idx, + ) + + def __getitem__(self, index): + item = self.dataset[index] + bucket_size = self._bucketed_sizes[index] + tensor = self._get_tensor(item) + padded = self._pad(tensor, bucket_size) + return self._set_tensor(item, padded) + + @property + def sizes(self): + return self._bucketed_sizes + + def num_tokens(self, index): + return self._bucketed_sizes[index] + + def size(self, index): + return self._bucketed_sizes[index] diff --git a/fairseq/fairseq/data/codedataset.py b/fairseq/fairseq/data/codedataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a433091956ed2449f628028ea8da7be4c4895307 --- /dev/null +++ b/fairseq/fairseq/data/codedataset.py @@ -0,0 +1,576 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import json +import logging +import os +import random +from pathlib import Path + +import numpy as np +import torch +import torch.utils.data + +from . import data_utils +from fairseq.data.fairseq_dataset import FairseqDataset + +F0_FRAME_SPACE = 0.005 # sec + + +logger = logging.getLogger(__name__) + + +class ExpressiveCodeDataConfig(object): + def __init__(self, json_path): + with open(json_path, "r") as f: + self.config = json.load(f) + self._manifests = self.config["manifests"] + + @property + def manifests(self): + return self._manifests + + @property + def n_units(self): + return self.config["n_units"] + + @property + def sampling_rate(self): + return self.config["sampling_rate"] + + @property + def code_hop_size(self): + return self.config["code_hop_size"] + + @property + def f0_stats(self): + """pre-computed f0 statistics path""" + return self.config.get("f0_stats", None) + + @property + def f0_vq_type(self): + """naive or precomp""" + return self.config["f0_vq_type"] + + @property + def f0_vq_name(self): + return self.config["f0_vq_name"] + + def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std): + key = "log" if log else "linear" + if norm_mean and norm_std: + key += "_mean_std_norm" + elif norm_mean: + key += "_mean_norm" + else: + key += "_none_norm" + return self.config["f0_vq_naive_quantizer"][key] + + @property + def f0_vq_n_units(self): + return self.config["f0_vq_n_units"] + + @property + def multispkr(self): + """how to parse speaker label from audio path""" + return self.config.get("multispkr", None) + + +def get_f0(audio, rate=16000): + try: + import amfm_decompy.basic_tools as basic + import amfm_decompy.pYAAPT as pYAAPT + from librosa.util import normalize + except ImportError: + raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)." + + assert audio.ndim == 1 + frame_length = 20.0 # ms + to_pad = int(frame_length / 1000 * rate) // 2 + + audio = normalize(audio) * 0.95 + audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0) + audio = basic.SignalObj(audio, rate) + pitch = pYAAPT.yaapt( + audio, + frame_length=frame_length, + frame_space=F0_FRAME_SPACE * 1000, + nccf_thresh1=0.25, + tda_frame_length=25.0, + ) + f0 = pitch.samp_values + return f0 + + +def interpolate_f0(f0): + try: + from scipy.interpolate import interp1d + except ImportError: + raise "Please install scipy (`pip install scipy`)" + + orig_t = np.arange(f0.shape[0]) + f0_interp = f0[:] + ii = f0_interp != 0 + if ii.sum() > 1: + f0_interp = interp1d( + orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0 + )(orig_t) + f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device) + return f0_interp + + +def naive_quantize(x, edges): + bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1) + return bin_idx + + +def load_wav(full_path): + try: + import soundfile as sf + except ImportError: + raise "Please install soundfile (`pip install SoundFile`)" + data, sampling_rate = sf.read(full_path) + return data, sampling_rate + + +def parse_code(code_str, dictionary, append_eos): + code, duration = torch.unique_consecutive( + torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True + ) + code = " ".join(map(str, code.tolist())) + code = dictionary.encode_line(code, append_eos).short() + + if append_eos: + duration = torch.cat((duration, duration.new_zeros((1,))), dim=0) # eos + duration = duration.short() + return code, duration + + +def parse_manifest(manifest, dictionary): + audio_files = [] + codes = [] + durations = [] + speakers = [] + + with open(manifest) as info: + for line in info.readlines(): + sample = eval(line.strip()) + if "cpc_km100" in sample: + k = "cpc_km100" + elif "hubert_km100" in sample: + k = "hubert_km100" + elif "phone" in sample: + k = "phone" + else: + assert False, "unknown format" + code = sample[k] + code, duration = parse_code(code, dictionary, append_eos=True) + + codes.append(code) + durations.append(duration) + audio_files.append(sample["audio"]) + speakers.append(sample.get("speaker", None)) + + return audio_files, codes, durations, speakers + + +def parse_speaker(path, method): + if type(path) == str: + path = Path(path) + + if method == "parent_name": + return path.parent.name + elif method == "parent_parent_name": + return path.parent.parent.name + elif method == "_": + return path.name.split("_")[0] + elif method == "single": + return "A" + elif callable(method): + return method(path) + else: + raise NotImplementedError() + + +def get_f0_by_filename(filename, tgt_sampling_rate): + audio, sampling_rate = load_wav(filename) + if sampling_rate != tgt_sampling_rate: + raise ValueError( + "{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate) + ) + + # compute un-interpolated f0, and use Ann's interp in __getitem__ if set + f0 = get_f0(audio, rate=tgt_sampling_rate) + f0 = torch.from_numpy(f0.astype(np.float32)) + return f0 + + +def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1): + code_len = durations.sum() + targ_len = int(f0_code_ratio * code_len) + diff = f0.size(0) - targ_len + assert abs(diff) <= tol, ( + f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|" + f" > {tol} (dur=\n{durations})" + ) + if diff > 0: + f0 = f0[:targ_len] + elif diff < 0: + f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0) + + f0_offset = 0.0 + seg_f0s = [] + for dur in durations: + f0_dur = dur.item() * f0_code_ratio + seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)] + seg_f0 = seg_f0[seg_f0 != 0] + if len(seg_f0) == 0: + seg_f0 = torch.tensor(0).type(seg_f0.type()) + else: + seg_f0 = seg_f0.mean() + seg_f0s.append(seg_f0) + f0_offset += f0_dur + + assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}" + return torch.tensor(seg_f0s) + + +class Paddings(object): + def __init__(self, code_val, dur_val=0, f0_val=-2.0): + self.code = code_val + self.dur = dur_val + self.f0 = f0_val + + +class Shifts(object): + def __init__(self, shifts_str, pads): + self._shifts = list(map(int, shifts_str.split(","))) + assert len(self._shifts) == 2, self._shifts + assert all(s >= 0 for s in self._shifts) + self.extra_length = max(s for s in self._shifts) + self.pads = pads + + @property + def dur(self): + return self._shifts[0] + + @property + def f0(self): + return self._shifts[1] + + @staticmethod + def shift_one(seq, left_pad_num, right_pad_num, pad): + assert seq.ndim == 1 + bos = seq.new_full((left_pad_num,), pad) + eos = seq.new_full((right_pad_num,), pad) + seq = torch.cat([bos, seq, eos]) + mask = torch.ones_like(seq).bool() + mask[left_pad_num : len(seq) - right_pad_num] = 0 + return seq, mask + + def __call__(self, code, dur, f0): + if self.extra_length == 0: + code_mask = torch.zeros_like(code).bool() + dur_mask = torch.zeros_like(dur).bool() + f0_mask = torch.zeros_like(f0).bool() + return code, code_mask, dur, dur_mask, f0, f0_mask + + code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code) + dur, dur_mask = self.shift_one( + dur, self.dur, self.extra_length - self.dur, self.pads.dur + ) + f0, f0_mask = self.shift_one( + f0, self.f0, self.extra_length - self.f0, self.pads.f0 + ) + return code, code_mask, dur, dur_mask, f0, f0_mask + + +class CodeDataset(FairseqDataset): + def __init__( + self, + manifest, + dictionary, + dur_dictionary, + f0_dictionary, + config, + discrete_dur, + discrete_f0, + log_f0, + normalize_f0_mean, + normalize_f0_std, + interpolate_f0, + return_filename=False, + strip_filename=True, + shifts="0,0", + return_continuous_f0=False, + ): + random.seed(1234) + self.dictionary = dictionary + self.dur_dictionary = dur_dictionary + self.f0_dictionary = f0_dictionary + self.config = config + + # duration config + self.discrete_dur = discrete_dur + + # pitch config + self.discrete_f0 = discrete_f0 + self.log_f0 = log_f0 + self.normalize_f0_mean = normalize_f0_mean + self.normalize_f0_std = normalize_f0_std + self.interpolate_f0 = interpolate_f0 + + self.return_filename = return_filename + self.strip_filename = strip_filename + self.f0_code_ratio = config.code_hop_size / ( + config.sampling_rate * F0_FRAME_SPACE + ) + + # use lazy loading to avoid sharing file handlers across workers + self.manifest = manifest + self._codes = None + self._durs = None + self._f0s = None + with open(f"{manifest}.leng.txt", "r") as f: + lengs = [int(line.rstrip()) for line in f] + edges = np.cumsum([0] + lengs) + self.starts, self.ends = edges[:-1], edges[1:] + with open(f"{manifest}.path.txt", "r") as f: + self.file_names = [line.rstrip() for line in f] + logger.info(f"num entries: {len(self.starts)}") + + if os.path.exists(f"{manifest}.f0_stat.pt"): + self.f0_stats = torch.load(f"{manifest}.f0_stat.pt") + elif config.f0_stats: + self.f0_stats = torch.load(config.f0_stats) + + self.multispkr = config.multispkr + if config.multispkr: + with open(f"{manifest}.speaker.txt", "r") as f: + self.spkrs = [line.rstrip() for line in f] + self.id_to_spkr = sorted(self.spkrs) + self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)} + + self.pads = Paddings( + dictionary.pad(), + 0, # use 0 for duration padding + f0_dictionary.pad() if discrete_f0 else -5.0, + ) + self.shifts = Shifts(shifts, pads=self.pads) + self.return_continuous_f0 = return_continuous_f0 + + def get_data_handlers(self): + logging.info(f"loading data for {self.manifest}") + self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r") + self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r") + + if self.discrete_f0: + if self.config.f0_vq_type == "precomp": + self._f0s = np.load( + f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r" + ) + elif self.config.f0_vq_type == "naive": + self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r") + quantizers_path = self.config.get_f0_vq_naive_quantizer( + self.log_f0, self.normalize_f0_mean, self.normalize_f0_std + ) + quantizers = torch.load(quantizers_path) + n_units = self.config.f0_vq_n_units + self._f0_quantizer = torch.from_numpy(quantizers[n_units]) + else: + raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported") + else: + self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r") + + def preprocess_f0(self, f0, stats): + """ + 1. interpolate + 2. log transform (keep unvoiced frame 0) + """ + # TODO: change this to be dependent on config for naive quantizer + f0 = f0.clone() + if self.interpolate_f0: + f0 = interpolate_f0(f0) + + mask = f0 != 0 # only process voiced frames + if self.log_f0: + f0[mask] = f0[mask].log() + if self.normalize_f0_mean: + mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"] + f0[mask] = f0[mask] - mean + if self.normalize_f0_std: + std = stats["logf0_std"] if self.log_f0 else stats["f0_std"] + f0[mask] = f0[mask] / std + return f0 + + def _get_raw_item(self, index): + start, end = self.starts[index], self.ends[index] + if self._codes is None: + self.get_data_handlers() + code = torch.from_numpy(np.array(self._codes[start:end])).long() + dur = torch.from_numpy(np.array(self._durs[start:end])) + f0 = torch.from_numpy(np.array(self._f0s[start:end])) + return code, dur, f0 + + def __getitem__(self, index): + code, dur, f0 = self._get_raw_item(index) + code = torch.cat([code.new([self.dictionary.bos()]), code]) + + # use 0 for eos and bos + dur = torch.cat([dur.new([0]), dur]) + if self.discrete_dur: + dur = self.dur_dictionary.encode_line( + " ".join(map(str, dur.tolist())), append_eos=False + ).long() + else: + dur = dur.float() + + # TODO: find a more elegant approach + raw_f0 = None + if self.discrete_f0: + if self.config.f0_vq_type == "precomp": + f0 = self.f0_dictionary.encode_line( + " ".join(map(str, f0.tolist())), append_eos=False + ).long() + else: + f0 = f0.float() + f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]]) + if self.return_continuous_f0: + raw_f0 = f0 + raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0]) + f0 = naive_quantize(f0, self._f0_quantizer) + f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0]) + else: + f0 = f0.float() + if self.multispkr: + f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]]) + else: + f0 = self.preprocess_f0(f0, self.f0_stats) + f0 = torch.cat([f0.new([0]), f0]) + + if raw_f0 is not None: + *_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0) + else: + raw_f0_mask = None + + code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0) + if raw_f0_mask is not None: + assert (raw_f0_mask == f0_mask).all() + + # is a padded frame if either input or output is padded + feats = { + "source": code[:-1], + "target": code[1:], + "mask": code_mask[1:].logical_or(code_mask[:-1]), + "dur_source": dur[:-1], + "dur_target": dur[1:], + "dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]), + "f0_source": f0[:-1], + "f0_target": f0[1:], + "f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]), + } + + if raw_f0 is not None: + feats["raw_f0"] = raw_f0[1:] + + if self.return_filename: + fname = self.file_names[index] + feats["filename"] = ( + fname if not self.strip_filename else Path(fname).with_suffix("").name + ) + return feats + + def __len__(self): + return len(self.starts) + + def size(self, index): + return self.ends[index] - self.starts[index] + self.shifts.extra_length + + def num_tokens(self, index): + return self.size(index) + + def collater(self, samples): + pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos() + if len(samples) == 0: + return {} + + src_tokens = data_utils.collate_tokens( + [s["source"] for s in samples], pad_idx, eos_idx, left_pad=False + ) + + tgt_tokens = data_utils.collate_tokens( + [s["target"] for s in samples], + pad_idx=pad_idx, + eos_idx=pad_idx, # appending padding, eos is there already + left_pad=False, + ) + + src_durs, tgt_durs = [ + data_utils.collate_tokens( + [s[k] for s in samples], + pad_idx=self.pads.dur, + eos_idx=self.pads.dur, + left_pad=False, + ) + for k in ["dur_source", "dur_target"] + ] + + src_f0s, tgt_f0s = [ + data_utils.collate_tokens( + [s[k] for s in samples], + pad_idx=self.pads.f0, + eos_idx=self.pads.f0, + left_pad=False, + ) + for k in ["f0_source", "f0_target"] + ] + + mask, dur_mask, f0_mask = [ + data_utils.collate_tokens( + [s[k] for s in samples], + pad_idx=1, + eos_idx=1, + left_pad=False, + ) + for k in ["mask", "dur_mask", "f0_mask"] + ] + + src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) + n_tokens = sum(len(s["source"]) for s in samples) + + result = { + "nsentences": len(samples), + "ntokens": n_tokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + "dur_src": src_durs, + "f0_src": src_f0s, + }, + "target": tgt_tokens, + "dur_target": tgt_durs, + "f0_target": tgt_f0s, + "mask": mask, + "dur_mask": dur_mask, + "f0_mask": f0_mask, + } + + if "filename" in samples[0]: + result["filename"] = [s["filename"] for s in samples] + + # TODO: remove this hack into the inference dataset + if "prefix" in samples[0]: + result["prefix"] = [s["prefix"] for s in samples] + + if "raw_f0" in samples[0]: + raw_f0s = data_utils.collate_tokens( + [s["raw_f0"] for s in samples], + pad_idx=self.pads.f0, + eos_idx=self.pads.f0, + left_pad=False, + ) + result["raw_f0"] = raw_f0s + return result diff --git a/fairseq/fairseq/data/colorize_dataset.py b/fairseq/fairseq/data/colorize_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7a6d2713791b1e80a6f5b982a4bf4ba93f6f561e --- /dev/null +++ b/fairseq/fairseq/data/colorize_dataset.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import BaseWrapperDataset + + +class ColorizeDataset(BaseWrapperDataset): + """Adds 'colors' property to net input that is obtained from the provided color getter for use by models""" + + def __init__(self, dataset, color_getter): + super().__init__(dataset) + self.color_getter = color_getter + + def collater(self, samples): + base_collate = super().collater(samples) + if len(base_collate) > 0: + base_collate["net_input"]["colors"] = torch.tensor( + list(self.color_getter(self.dataset, s["id"]) for s in samples), + dtype=torch.long, + ) + return base_collate diff --git a/fairseq/fairseq/data/concat_dataset.py b/fairseq/fairseq/data/concat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..01a4078bb159fa44b2d1062b9a971fe7f1abd1c2 --- /dev/null +++ b/fairseq/fairseq/data/concat_dataset.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import bisect + +import numpy as np +from torch.utils.data.dataloader import default_collate + +from . import FairseqDataset + + +class ConcatDataset(FairseqDataset): + @staticmethod + def cumsum(sequence, sample_ratios): + r, s = [], 0 + for e, ratio in zip(sequence, sample_ratios): + curr_len = int(ratio * len(e)) + r.append(curr_len + s) + s += curr_len + return r + + def __init__(self, datasets, sample_ratios=1): + super(ConcatDataset, self).__init__() + assert len(datasets) > 0, "datasets should not be an empty iterable" + self.datasets = list(datasets) + if isinstance(sample_ratios, int): + sample_ratios = [sample_ratios] * len(self.datasets) + self.sample_ratios = sample_ratios + self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) + self.real_sizes = [len(d) for d in self.datasets] + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + return self.datasets[dataset_idx][sample_idx] + + def _get_dataset_and_sample_index(self, idx: int): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + sample_idx = sample_idx % self.real_sizes[dataset_idx] + return dataset_idx, sample_idx + + def collater(self, samples, **extra_args): + # For now only supports datasets with same underlying collater implementations + if hasattr(self.datasets[0], "collater"): + return self.datasets[0].collater(samples, **extra_args) + else: + return default_collate(samples, **extra_args) + + def size(self, idx: int): + """ + Return an example's size as a float or tuple. + """ + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + return self.datasets[dataset_idx].size(sample_idx) + + def num_tokens(self, index: int): + return np.max(self.size(index)) + + def attr(self, attr: str, index: int): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, index) + return getattr(self.datasets[dataset_idx], attr, None) + + @property + def sizes(self): + _dataset_sizes = [] + for ds, sr in zip(self.datasets, self.sample_ratios): + if isinstance(ds.sizes, np.ndarray): + _dataset_sizes.append(np.tile(ds.sizes, sr)) + else: + # Only support underlying dataset with single size array. + assert isinstance(ds.sizes, list) + _dataset_sizes.append(np.tile(ds.sizes[0], sr)) + return np.concatenate(_dataset_sizes) + + @property + def supports_prefetch(self): + return all(d.supports_prefetch for d in self.datasets) + + def ordered_indices(self): + """ + Returns indices sorted by length. So less padding is needed. + """ + if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1: + # special handling for concatenating lang_pair_datasets + indices = np.arange(len(self)) + sizes = self.sizes + tgt_sizes = ( + sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + ) + src_sizes = ( + sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + ) + # sort by target length, then source length + if tgt_sizes is not None: + indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(src_sizes[indices], kind="mergesort")] + else: + return np.argsort(self.sizes) + + def prefetch(self, indices): + frm = 0 + for to, ds in zip(self.cumulative_sizes, self.datasets): + real_size = len(ds) + if getattr(ds, "supports_prefetch", False): + ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) + frm = to + + @property + def can_reuse_epoch_itr_across_epochs(self): + return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + for ds in self.datasets: + if hasattr(ds, "set_epoch"): + ds.set_epoch(epoch) diff --git a/fairseq/fairseq/data/concat_sentences_dataset.py b/fairseq/fairseq/data/concat_sentences_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..625a29370e90f9d1d7274024afb902ed83a22325 --- /dev/null +++ b/fairseq/fairseq/data/concat_sentences_dataset.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import FairseqDataset + + +class ConcatSentencesDataset(FairseqDataset): + def __init__(self, *datasets): + super().__init__() + self.datasets = datasets + assert all( + len(ds) == len(datasets[0]) for ds in datasets + ), "datasets must have the same length" + + def __getitem__(self, index): + return torch.cat([ds[index] for ds in self.datasets]) + + def __len__(self): + return len(self.datasets[0]) + + def collater(self, samples): + return self.datasets[0].collater(samples) + + @property + def sizes(self): + return sum(ds.sizes for ds in self.datasets) + + def num_tokens(self, index): + return sum(ds.num_tokens(index) for ds in self.datasets) + + def size(self, index): + return sum(ds.size(index) for ds in self.datasets) + + def ordered_indices(self): + return self.datasets[0].ordered_indices() + + @property + def supports_prefetch(self): + return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets) + + def prefetch(self, indices): + for ds in self.datasets: + if getattr(ds, "supports_prefetch", False): + ds.prefetch(indices) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + for ds in self.datasets: + if hasattr(ds, "set_epoch"): + ds.set_epoch(epoch) diff --git a/fairseq/fairseq/data/data_utils.py b/fairseq/fairseq/data/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a19cc3c1827387a6fba571dfdbbbddfcce38eeb --- /dev/null +++ b/fairseq/fairseq/data/data_utils.py @@ -0,0 +1,1144 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable +import contextlib +import itertools +import logging +import re +import warnings +from typing import Optional, Tuple + +import math +import numpy as np +import torch + +from fairseq.file_io import PathManager +from fairseq import utils +import os + +logger = logging.getLogger(__name__) + + +def infer_language_pair(path): + """Infer language pair from filename: .-.(...).idx""" + src, dst = None, None + for filename in PathManager.ls(path): + parts = filename.split(".") + if len(parts) >= 3 and len(parts[1].split("-")) == 2: + return parts[1].split("-") + return src, dst + + +def collate_tokens( + values, + pad_idx, + eos_idx=None, + left_pad=False, + move_eos_to_beginning=False, + pad_to_length=None, + pad_to_multiple=1, + pad_to_bsz=None, +): + """Convert a list of 1d tensors into a padded 2d tensor.""" + size = max(v.size(0) for v in values) + size = size if pad_to_length is None else max(size, pad_to_length) + if pad_to_multiple != 1 and size % pad_to_multiple != 0: + size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) + + batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz) + res = values[0].new(batch_size, size).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if move_eos_to_beginning: + if eos_idx is None: + # if no eos_idx is specified, then use the last token in src + dst[0] = src[-1] + else: + dst[0] = eos_idx + dst[1:] = src[:-1] + else: + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) + return res + + +def load_indexed_dataset( + path, dictionary=None, dataset_impl=None, combine=False, default="cached" +): + """A helper function for loading indexed datasets. + + Args: + path (str): path to indexed dataset (e.g., 'data-bin/train') + dictionary (~fairseq.data.Dictionary): data dictionary + dataset_impl (str, optional): which dataset implementation to use. If + not provided, it will be inferred automatically. For legacy indexed + data we use the 'cached' implementation by default. + combine (bool, optional): automatically load and combine multiple + datasets. For example, if *path* is 'data-bin/train', then we will + combine 'data-bin/train', 'data-bin/train1', ... and return a + single ConcatDataset instance. + """ + import fairseq.data.indexed_dataset as indexed_dataset + from fairseq.data.concat_dataset import ConcatDataset + + datasets = [] + for k in itertools.count(): + path_k = path + (str(k) if k > 0 else "") + try: + path_k = indexed_dataset.get_indexed_dataset_to_local(path_k) + except Exception as e: + if "StorageException: [404] Path not found" in str(e): + logger.warning(f"path_k: {e} not found") + else: + raise e + + dataset_impl_k = dataset_impl + if dataset_impl_k is None: + dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k) + dataset = indexed_dataset.make_dataset( + path_k, + impl=dataset_impl_k or default, + fix_lua_indexing=True, + dictionary=dictionary, + ) + if dataset is None: + break + logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k)) + datasets.append(dataset) + if not combine: + break + if len(datasets) == 0: + return None + elif len(datasets) == 1: + return datasets[0] + else: + return ConcatDataset(datasets) + + +@contextlib.contextmanager +def numpy_seed(seed, *addl_seeds): + """Context manager which seeds the NumPy PRNG with the specified seed and + restores the state afterward""" + if seed is None: + yield + return + if len(addl_seeds) > 0: + seed = int(hash((seed, *addl_seeds)) % 1e6) + state = np.random.get_state() + np.random.seed(seed) + try: + yield + finally: + np.random.set_state(state) + + +def collect_filtered(function, iterable, filtered): + """ + Similar to :func:`filter` but collects filtered elements in ``filtered``. + + Args: + function (callable): function that returns ``False`` for elements that + should be filtered + iterable (iterable): iterable to filter + filtered (list): list to store filtered elements + """ + for el in iterable: + if function(el): + yield el + else: + filtered.append(el) + + +def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False): + def compare_leq(a, b): + return a <= b if not isinstance(a, tuple) else max(a) <= b + + def check_size(idx): + if isinstance(max_positions, float) or isinstance(max_positions, int): + return size_fn(idx) <= max_positions + elif isinstance(max_positions, dict): + idx_size = size_fn(idx) + assert isinstance(idx_size, dict) + intersect_keys = set(max_positions.keys()) & set(idx_size.keys()) + return all( + all( + a is None or b is None or a <= b + for a, b in zip(idx_size[key], max_positions[key]) + ) + for key in intersect_keys + ) + else: + # For MultiCorpusSampledDataset, will generalize it later + if not isinstance(size_fn(idx), Iterable): + return all(size_fn(idx) <= b for b in max_positions) + return all( + a is None or b is None or a <= b + for a, b in zip(size_fn(idx), max_positions) + ) + + ignored = [] + itr = collect_filtered(check_size, indices, ignored) + indices = np.fromiter(itr, dtype=np.int64, count=-1) + return indices, ignored + + +def filter_by_size(indices, dataset, max_positions, raise_exception=False): + """ + [deprecated] Filter indices based on their size. + Use `FairseqDataset::filter_indices_by_size` instead. + + Args: + indices (List[int]): ordered list of dataset indices + dataset (FairseqDataset): fairseq dataset instance + max_positions (tuple): filter elements larger than this size. + Comparisons are done component-wise. + raise_exception (bool, optional): if ``True``, raise an exception if + any elements are filtered (default: False). + """ + warnings.warn( + "data_utils.filter_by_size is deprecated. " + "Use `FairseqDataset::filter_indices_by_size` instead.", + stacklevel=2, + ) + if isinstance(max_positions, float) or isinstance(max_positions, int): + if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray): + ignored = indices[dataset.sizes[indices] > max_positions].tolist() + indices = indices[dataset.sizes[indices] <= max_positions] + elif ( + hasattr(dataset, "sizes") + and isinstance(dataset.sizes, list) + and len(dataset.sizes) == 1 + ): + ignored = indices[dataset.sizes[0][indices] > max_positions].tolist() + indices = indices[dataset.sizes[0][indices] <= max_positions] + else: + indices, ignored = _filter_by_size_dynamic( + indices, dataset.size, max_positions + ) + else: + indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions) + + if len(ignored) > 0 and raise_exception: + raise Exception( + ( + "Size of sample #{} is invalid (={}) since max_positions={}, " + "skip this example with --skip-invalid-size-inputs-valid-test" + ).format(ignored[0], dataset.size(ignored[0]), max_positions) + ) + if len(ignored) > 0: + logger.warning( + ( + "{} samples have invalid sizes and will be skipped, " + "max_positions={}, first few sample ids={}" + ).format(len(ignored), max_positions, ignored[:10]) + ) + return indices + + +def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes): + """Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if max_sizes is None: + return indices, [] + if type(max_sizes) in (int, float): + max_src_size, max_tgt_size = max_sizes, max_sizes + else: + max_src_size, max_tgt_size = max_sizes + if tgt_sizes is None: + ignored = indices[src_sizes[indices] > max_src_size] + else: + ignored = indices[ + (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size) + ] + if len(ignored) > 0: + if tgt_sizes is None: + indices = indices[src_sizes[indices] <= max_src_size] + else: + indices = indices[ + (src_sizes[indices] <= max_src_size) + & (tgt_sizes[indices] <= max_tgt_size) + ] + return indices, ignored.tolist() + + +def batch_by_size( + indices, + num_tokens_fn, + num_tokens_vec=None, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + fixed_shapes=None, +): + """ + Yield mini-batches of indices bucketed by size. Batches may contain + sequences of different lengths. + + Args: + indices (List[int]): ordered list of dataset indices + num_tokens_fn (callable): function that returns the number of tokens at + a given index + num_tokens_vec (List[int], optional): precomputed vector of the number + of tokens for each index in indices (to enable faster batch generation) + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + required_batch_size_multiple (int, optional): require batch size to + be less than N or a multiple of N (default: 1). + fixed_shapes (List[Tuple[int, int]], optional): if given, batches will + only be created with the given shapes. *max_sentences* and + *required_batch_size_multiple* will be ignored (default: None). + """ + try: + from fairseq.data.data_utils_fast import ( + batch_by_size_fn, + batch_by_size_vec, + batch_fixed_shapes_fast, + ) + except ImportError: + raise ImportError( + "Please build Cython components with: " + "`python setup.py build_ext --inplace`" + ) + except ValueError: + raise ValueError( + "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`." + ) + + # added int() to avoid TypeError: an integer is required + max_tokens = int(max_tokens) if max_tokens is not None else -1 + max_sentences = max_sentences if max_sentences is not None else -1 + bsz_mult = required_batch_size_multiple + + if not isinstance(indices, np.ndarray): + indices = np.fromiter(indices, dtype=np.int64, count=-1) + + if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray): + num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1) + + if fixed_shapes is None: + if num_tokens_vec is None: + b = batch_by_size_fn( + indices, + num_tokens_fn, + max_tokens, + max_sentences, + bsz_mult, + ) + else: + b = batch_by_size_vec( + indices, + num_tokens_vec, + max_tokens, + max_sentences, + bsz_mult, + ) + + if bsz_mult > 1 and len(b[-1]) % bsz_mult != 0: + b = b[:-1] + + return b + + else: + fixed_shapes = np.array(fixed_shapes, dtype=np.int64) + sort_order = np.lexsort( + [ + fixed_shapes[:, 1].argsort(), # length + fixed_shapes[:, 0].argsort(), # bsz + ] + ) + fixed_shapes_sorted = fixed_shapes[sort_order] + return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted) + + +def post_process(sentence: str, symbol: str): + if symbol == "sentencepiece": + sentence = sentence.replace(" ", "").replace("\u2581", " ").strip() + elif symbol == "wordpiece": + sentence = sentence.replace(" ", "").replace("_", " ").strip() + elif symbol == "letter": + sentence = sentence.replace(" ", "").replace("|", " ").strip() + elif symbol == "silence": + import re + + sentence = sentence.replace("", "") + sentence = re.sub(" +", " ", sentence).strip() + elif symbol == "_EOW": + sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() + elif symbol in {"subword_nmt", "@@ ", "@@"}: + if symbol == "subword_nmt": + symbol = "@@ " + sentence = (sentence + " ").replace(symbol, "").rstrip() + elif symbol == "none": + pass + elif symbol is not None: + raise NotImplementedError(f"Unknown post_process option: {symbol}") + return sentence + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, + add_masks: bool = False, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, + idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset + num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + if num_mask_ver == 1: + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + assert sz >= 0, sz + else: + sz = all_sz + + if num_mask_ver == 1: + if padding_mask is not None: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + num_mask = all_num_mask + elif num_mask_ver == 2: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + else: + raise ValueError() + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = rng.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = rng.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + if mask_type == "static": + raise ValueError(f"this should never happens") + else: + lengths = [min(mask_length, sz - 1)] + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = rng.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + if idc_select_ver == 1: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + elif idc_select_ver == 2: + mask_idc = rng.choice(sz, num_mask, replace=False) + else: + raise ValueError() + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idc = np.unique(mask_idc[mask_idc < sz]) + if len(mask_idc) >= sz: + raise ValueError( + ( + f"the entire sequence is masked. " + f"sz={sz}; mask_idc[mask_idc]; " + f"index={indices[i] if indices is not None else None}" + ) + ) + mask_idcs.append(mask_idc) + + target_len = None + if require_same_masks: + if add_masks: + target_len = max([len(m) for m in mask_idcs]) + else: + target_len = min([len(m) for m in mask_idcs]) + + for i, mask_idc in enumerate(mask_idcs): + if target_len is not None and len(mask_idc) > target_len: + mask_idc = rng.choice(mask_idc, target_len, replace=False) + + mask[i, mask_idc] = True + + if target_len is not None and len(mask_idc) < target_len: + unmasked = np.flatnonzero(~mask[i]) + to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) + mask[i, to_mask] = True + + if mask_dropout > 0: + masked = np.flatnonzero(mask[i]) + num_holes = np.rint(len(masked) * mask_dropout).astype(int) + to_drop = rng.choice(masked, num_holes, replace=False) + mask[i, to_drop] = False + + return mask + + +def compute_block_mask_2d( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + mask_prob_adjust: float = 0, + inverse_mask: bool = False, + require_same_masks: bool = True, + expand_adjcent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, +) -> torch.Tensor: + + assert mask_length > 1 + + B, L = shape + + d = int(L**0.5) + + if inverse_mask: + mask_prob = 1 - mask_prob + + if non_overlapping: + sz = math.ceil(d / mask_length) + inp_len = sz * sz + + inp = torch.zeros((B, 1, sz, sz)) + w = torch.ones((1, 1, mask_length, mask_length)) + + mask_inds = torch.multinomial( + 1 - inp.view(B, -1), + int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)), + replacement=False, + ) + inp.view(B, -1).scatter_(1, mask_inds, 1) + + mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze( + 1 + ) + if mask.size(-1) > d: + mask = mask[..., :d, :d] + else: + mask = torch.zeros((B, d, d)) + mask_inds = torch.randint( + 0, + L, + size=( + B, + int( + L + * ((mask_prob + mask_prob_adjust) / mask_length**2) + * (1 + mask_dropout) + ), + ), + ) + mask.view(B, -1).scatter_(1, mask_inds, 1) + centers = mask.nonzero(as_tuple=True) + + inds = ([], [], []) + + offset = mask_length // 2 + for i in range(mask_length): + for j in range(mask_length): + k1 = i - offset + k2 = j - offset + inds[0].append(centers[0]) + inds[1].append(centers[1] + k1) + inds[2].append(centers[2] + k2) + + i0 = torch.cat(inds[0]) + i1 = torch.cat(inds[1]).clamp_(min=0, max=d - 1) + i2 = torch.cat(inds[2]).clamp_(min=0, max=d - 1) + + mask[(i0, i1, i2)] = 1 + + def get_nbs(b, m, w): + all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same") + all_nbs = all_nbs.clamp_max_(1).view(b, -1) + return all_nbs + + if require_same_masks and expand_adjcent: + w = torch.zeros((1, 1, 3, 3)) + w[..., 0, 1] = 1 + w[..., 2, 1] = 1 + w[..., 1, 0] = 1 + w[..., 1, 2] = 1 + + all_nbs = get_nbs(B, mask, w) + + mask = mask.reshape(B, -1) + + if require_same_masks: + n_masks = mask.sum(dim=-1) + final_target_len = int(L * (mask_prob)) + target_len = int(final_target_len * (1 + mask_dropout)) + + for i in range(len(mask)): + n = n_masks[i] + m = mask[i] + r = 0 + while expand_adjcent and n < target_len: + if r == 0: + nbs = all_nbs[i] + else: + nbs = get_nbs(1, m.view(1, d, d), w).flatten() + + cands = (1 - m + nbs) > 1 + cand_sz = int(cands.sum().item()) + + assert cand_sz > 0, f"{nbs} {cand_sz}" + + to_mask = torch.multinomial( + cands.float(), min(cand_sz, int(target_len - n)), replacement=False + ) + m[to_mask] = 1 + assert to_mask.numel() > 0 + n += to_mask.numel() + r += 1 + + if n > final_target_len: + to_unmask = torch.multinomial( + m, int(n - final_target_len), replacement=False + ) + m[to_unmask] = 0 + elif n < final_target_len: + to_mask = torch.multinomial( + (1 - m), int(final_target_len - n), replacement=False + ) + m[to_mask] = 1 + + if inverse_mask: + mask = 1 - mask + + return mask + + +def compute_block_mask_1d( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + mask_prob_adjust: float = 0, + inverse_mask: bool = False, + require_same_masks: bool = True, + expand_adjcent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, +) -> torch.Tensor: + + B, L = shape + + if inverse_mask: + mask_prob = 1 - mask_prob + + if non_overlapping: + sz = math.ceil(L / mask_length) + + inp = torch.zeros((B, 1, sz)) + w = torch.ones((1, 1, mask_length)) + + mask_inds = torch.multinomial( + 1 - inp.view(B, -1), + int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)), + replacement=False, + ) + inp.view(B, -1).scatter_(1, mask_inds, 1) + + mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze( + 1 + ) + if mask.size(-1) > L: + mask = mask[..., :L] + + else: + mask = torch.zeros((B, L)) + mask_inds = torch.randint( + 0, + L, + size=( + B, + int( + L + * ((mask_prob + mask_prob_adjust) / mask_length) + * (1 + mask_dropout) + ), + ), + ) + + mask.view(B, -1).scatter_(1, mask_inds, 1) + centers = mask.nonzero(as_tuple=True) + + inds = ([], []) + + offset = mask_length // 2 + for i in range(mask_length): + k1 = i - offset + inds[0].append(centers[0]) + inds[1].append(centers[1] + k1) + + i0 = torch.cat(inds[0]) + i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1) + + mask[(i0, i1)] = 1 + + def get_nbs(b, m, w): + all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same") + all_nbs = all_nbs.clamp_max_(1).view(b, -1) + return all_nbs + + if require_same_masks and expand_adjcent: + w = torch.ones((1, 1, 3)) + w[..., 1] = 0 + all_nbs = get_nbs(B, mask, w) + + mask = mask.view(B, -1) + + if require_same_masks: + n_masks = mask.sum(dim=-1) + final_target_len = int(L * (mask_prob)) + target_len = int(final_target_len * (1 + mask_dropout)) + + for i in range(len(mask)): + n = n_masks[i] + m = mask[i] + r = 0 + while expand_adjcent and n < target_len: + if r == 0: + nbs = all_nbs[i] + else: + nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0) + + cands = (1 - m + nbs) > 1 + cand_sz = int(cands.sum().item()) + + assert cand_sz > 0, f"{nbs} {cand_sz}" + + to_mask = torch.multinomial( + cands.float(), min(cand_sz, int(target_len - n)), replacement=False + ) + m[to_mask] = 1 + assert to_mask.numel() > 0 + n += to_mask.numel() + r += 1 + + if n > final_target_len: + to_unmask = torch.multinomial( + m, int(n - final_target_len), replacement=False + ) + m[to_unmask] = 0 + elif n < final_target_len: + to_mask = torch.multinomial( + (1 - m), int(final_target_len - n), replacement=False + ) + m[to_mask] = 1 + + if inverse_mask: + mask = 1 - mask + + return mask + + +def get_mem_usage(): + try: + import psutil + + mb = 1024 * 1024 + return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb" + except ImportError: + return "N/A" + + +# lens: torch.LongTensor +# returns: torch.BoolTensor +def lengths_to_padding_mask(lens): + bsz, max_lens = lens.size(0), torch.max(lens).item() + mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) + mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) + return mask + + +# lens: torch.LongTensor +# returns: torch.BoolTensor +def lengths_to_mask(lens): + return ~lengths_to_padding_mask(lens) + + +def get_buckets(sizes, num_buckets): + buckets = np.unique( + np.percentile( + sizes, + np.linspace(0, 100, num_buckets + 1), + interpolation="lower", + )[1:] + ) + return buckets + + +def get_bucketed_sizes(orig_sizes, buckets): + sizes = np.copy(orig_sizes) + assert np.min(sizes) >= 0 + start_val = -1 + for end_val in buckets: + mask = (sizes > start_val) & (sizes <= end_val) + sizes[mask] = end_val + start_val = end_val + return sizes + + +def _find_extra_valid_paths(dataset_path: str) -> set: + paths = utils.split_paths(dataset_path) + all_valid_paths = set() + for sub_dir in paths: + contents = PathManager.ls(sub_dir) + valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None] + all_valid_paths |= {os.path.basename(p) for p in valid_paths} + # Remove .bin, .idx etc + roots = {os.path.splitext(p)[0] for p in all_valid_paths} + return roots + + +def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None: + """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored.""" + if ( + train_cfg.dataset.ignore_unused_valid_subsets + or train_cfg.dataset.combine_valid_subsets + or train_cfg.dataset.disable_validation + or not hasattr(train_cfg.task, "data") + ): + return + other_paths = _find_extra_valid_paths(train_cfg.task.data) + specified_subsets = train_cfg.dataset.valid_subset.split(",") + ignored_paths = [p for p in other_paths if p not in specified_subsets] + if ignored_paths: + advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them." + msg = f"Valid paths {ignored_paths} will be ignored. {advice}" + raise ValueError(msg) + + +def compute_mask_indices_for_one( + sz, + mask_prob: float, + mask_length: int, + seed=None, + epoch=None, + index=None, + min_masks=0, +): + """ + set seed, epoch, index for deterministic masking + """ + seed = int(hash((seed, epoch, index)) % 1e6) if seed else None + rng = np.random.default_rng(seed) + + # decide elements to mask + mask = np.full(sz, False) + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + + # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453) + mask_idc = rng.choice(sz, num_mask, replace=False) + mask_idc = np.concatenate([mask_idc + i for i in range(mask_length)]) + mask_idc = mask_idc[mask_idc < len(mask)] + try: + mask[mask_idc] = True + except: # something wrong + print(f"Assigning mask indexes {mask_idc} to mask {mask} failed!") + raise + + return mask + + +def compute_mask_indices_v2( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + min_masks: int = 0, + require_same_masks: bool = True, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, +) -> np.ndarray: + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + else: + sz = all_sz + index = indices[i].item() if indices is not None else None + mask_for_one = compute_mask_indices_for_one( + sz, mask_prob, mask_length, seed, epoch, index, min_masks + ) + mask[i, :sz] = mask_for_one + + if require_same_masks: + index_sum = indices.sum().item() if indices is not None else None + seed = int(hash((seed, epoch, index_sum)) % 1e6) if seed else None + rng = np.random.default_rng(seed) + + num_mask = mask.sum(-1).min() + for i in range(bsz): + extra = mask[i].sum() - num_mask + if extra > 0: + to_unmask = rng.choice(np.nonzero(mask[i])[0], extra, replace=False) + mask[i, to_unmask] = False + + return mask + + +# TODO: a copy of the original compute_mask_indices +def compute_mask_indices_v3( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = rng.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = rng.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = rng.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len and require_same_masks: + mask_idc = rng.choice(mask_idc, min_len, replace=False) + if mask_dropout > 0: + num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) + mask_idc = rng.choice(mask_idc, len(mask_idc) - num_holes, replace=False) + + mask[i, mask_idc] = True + + return mask diff --git a/fairseq/fairseq/data/data_utils_fast.pyx b/fairseq/fairseq/data/data_utils_fast.pyx new file mode 100644 index 0000000000000000000000000000000000000000..c61f31d6b2113d4c6a03d6553335997098ba0c20 --- /dev/null +++ b/fairseq/fairseq/data/data_utils_fast.pyx @@ -0,0 +1,178 @@ +# cython: language_level=3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +cimport cython +cimport numpy as np + +from libc.stdint cimport int32_t, int64_t +from libcpp cimport bool as bool_t + +ctypedef int64_t DTYPE_t + +@cython.cdivision(True) +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef list batch_by_size_vec( + np.ndarray[int64_t, ndim=1] indices, + np.ndarray[int64_t, ndim=1] num_tokens_vec, + int64_t max_tokens, + int64_t max_sentences, + int32_t bsz_mult, +): + if indices.shape[0] == 0: + return [] + + assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, ( + f"Sentences lengths should not exceed max_tokens={max_tokens}" + ) + + cdef int32_t indices_len = indices.shape[0] + cdef np.ndarray[int32_t, ndim=1] batches_ends = \ + np.zeros(indices_len, dtype=np.int32) + cdef int32_t[:] batches_ends_view = batches_ends + cdef int64_t[:] num_tokens_view = num_tokens_vec + + cdef int32_t pos = 0 + cdef int32_t new_batch_end = 0 + + cdef int64_t new_batch_max_tokens = 0 + cdef int32_t new_batch_sentences = 0 + cdef int64_t new_batch_num_tokens = 0 + + cdef bool_t overflow = False + cdef bool_t size_matches_with_bsz_mult = False + + cdef int32_t batches_count = 0 + cdef int32_t batch_start = 0 + cdef int64_t tail_max_tokens = 0 + cdef int64_t batch_max_tokens = 0 + + for pos in range(indices_len): + # At every pos we keep stats about the last complete batch [batch_start:batch_end), + # and tail [batch_end:pos]. + # 1) Every time when (batch + tail) forms a valid batch + # (according to max_tokens, max_sentences and bsz_mult) we append tail to batch. + # 2) When (batch+tail) violates max_tokens or max_sentences constraints + # we finalize running batch, and tail becomes a new batch. + # 3) There is a corner case when tail also violates constraints. + # In that situation [batch_end:pos-1] (tail without the current pos) + # gets added to the finalized batches, while [pos:pos] becomes a new tail. + # + # Important: For the sake of performance try to avoid using function calls within this loop. + + tail_max_tokens = tail_max_tokens \ + if tail_max_tokens > num_tokens_view[pos] \ + else num_tokens_view[pos] + new_batch_end = pos + 1 + new_batch_max_tokens = batch_max_tokens \ + if batch_max_tokens > tail_max_tokens \ + else tail_max_tokens + new_batch_sentences = new_batch_end - batch_start + new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens + + overflow = (new_batch_sentences > max_sentences > 0 or + new_batch_num_tokens > max_tokens > 0) + size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or + new_batch_sentences % bsz_mult == 0) + + if overflow: + tail_num_tokens = tail_max_tokens * \ + (new_batch_end - batches_ends_view[batches_count]) + tail_overflow = tail_num_tokens > max_tokens > 0 + # In case of a tail overflow finalize two batches + if tail_overflow: + batches_count += 1 + batches_ends_view[batches_count] = pos + tail_max_tokens = num_tokens_view[pos] + batch_start = batches_ends_view[batches_count] + batches_count += 1 + new_batch_max_tokens = tail_max_tokens + + if overflow or size_matches_with_bsz_mult: + batches_ends_view[batches_count] = new_batch_end + batch_max_tokens = new_batch_max_tokens + tail_max_tokens = 0 + if batches_ends_view[batches_count] != indices_len: + batches_count += 1 + # Memory and time-efficient split + return np.split(indices, batches_ends[:batches_count]) + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef list batch_by_size_fn( + np.ndarray[DTYPE_t, ndim=1] indices, + num_tokens_fn, + int64_t max_tokens, + int64_t max_sentences, + int32_t bsz_mult, +): + cdef int32_t indices_len = indices.shape[0] + cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len, + dtype=np.int64) + cdef DTYPE_t[:] indices_view = indices + cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec + cdef int64_t pos + for pos in range(indices_len): + num_tokens_vec[pos] = num_tokens_fn(indices_view[pos]) + return batch_by_size_vec(indices, num_tokens_vec, max_tokens, + max_sentences, bsz_mult,) + + +cdef _find_valid_shape( + DTYPE_t[:, :] shapes_view, + int64_t num_sentences, + int64_t num_tokens, +): + """Return index of first valid shape of -1 if none is found.""" + for i in range(shapes_view.shape[0]): + if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]: + return i + return -1 + + +@cython.cdivision(True) +cpdef list batch_fixed_shapes_fast( + np.ndarray[DTYPE_t, ndim=1] indices, + num_tokens_fn, + np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted, +): + cdef int64_t sample_len = 0 + cdef list sample_lens = [] + cdef list batch = [] + cdef list batches = [] + cdef int64_t mod_len + cdef int64_t i + cdef int64_t idx + cdef int64_t num_tokens + cdef DTYPE_t[:] indices_view = indices + cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted + + for i in range(len(indices_view)): + idx = indices_view[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + + shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len) + if shape_idx == -1: + batches.append(batch) + batch = [] + sample_lens = [] + sample_len = 0 + shapes_view = fixed_shapes_sorted + elif shape_idx > 0: + # small optimization for the next call to _find_valid_shape + shapes_view = shapes_view[shape_idx:] + + batch.append(idx) + + if len(batch) > 0: + batches.append(batch) + + return batches diff --git a/fairseq/fairseq/data/denoising_dataset.py b/fairseq/fairseq/data/denoising_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a900fc6f960c7faa41173316df011e9bc5cb23c9 --- /dev/null +++ b/fairseq/fairseq/data/denoising_dataset.py @@ -0,0 +1,443 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import numpy as np +import torch + +from . import FairseqDataset, data_utils + + +def collate( + samples, + pad_idx, + eos_idx, + vocab, + left_pad_source=False, + left_pad_target=False, + input_feeding=True, + pad_to_length=None, +): + assert input_feeding + if len(samples) == 0: + return {} + + def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, + eos_idx=None, # use eos_idx of each sample instead of vocab.eos() + left_pad=left_pad, + move_eos_to_beginning=move_eos_to_beginning, + pad_to_length=pad_to_length, + ) + + id = torch.LongTensor([s["id"] for s in samples]) + src_tokens = merge( + "source", + left_pad=left_pad_source, + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) + # sort by descending source length + src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + src_tokens = src_tokens.index_select(0, sort_order) + + prev_output_tokens = None + target = None + if samples[0].get("target", None) is not None: + target = merge( + "target", + left_pad=left_pad_target, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + target = target.index_select(0, sort_order) + ntokens = sum(len(s["target"]) for s in samples) + + if input_feeding: + # we create a shifted version of targets for feeding the + # previous output token(s) into the next decoder step + prev_output_tokens = merge( + "target", + left_pad=left_pad_target, + move_eos_to_beginning=True, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + prev_output_tokens = prev_output_tokens.index_select(0, sort_order) + else: + ntokens = sum(len(s["source"]) for s in samples) + + batch = { + "id": id, + "ntokens": ntokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + }, + "target": target, + "nsentences": samples[0]["source"].size(0), + "sort_order": sort_order, + } + if prev_output_tokens is not None: + batch["net_input"]["prev_output_tokens"] = prev_output_tokens + + return batch + + +class DenoisingDataset(FairseqDataset): + """ + A wrapper around TokenBlockDataset for BART dataset. + + Args: + dataset (TokenBlockDataset): dataset to wrap + sizes (List[int]): sentence lengths + vocab (~fairseq.data.Dictionary): vocabulary + mask_idx (int): dictionary index used for masked token + mask_whole_words: only mask whole words. This should be a byte mask + over vocab indices, indicating whether it is the beginning of a + word. We will extend any mask to encompass the whole word. + shuffle (bool, optional): shuffle the elements before batching. + Default: ``True`` + seed: Seed for random number generator for reproducibility. + """ + + def __init__( + self, + dataset, + sizes, + vocab, + mask_idx, + mask_whole_words, + shuffle, + seed, + mask, + mask_random, + insert, + rotate, + permute_sentences, + bpe, + replace_length, + mask_length, + poisson_lambda, + eos=None, + item_transform_func=None, + ): + self.dataset = dataset + + self.sizes = sizes + + self.vocab = vocab + self.shuffle = shuffle + self.seed = seed + self.mask_idx = mask_idx + self.mask_whole_word = mask_whole_words + self.mask_ratio = mask + self.random_ratio = mask_random + self.insert_ratio = insert + self.rotate_ratio = rotate + self.permute_sentence_ratio = permute_sentences + self.eos = eos if eos is not None else vocab.eos() + self.item_transform_func = item_transform_func + + if bpe != "gpt2": + self.full_stop_index = self.vocab.eos() + else: + assert bpe == "gpt2" + self.full_stop_index = self.vocab.index("13") + + self.replace_length = replace_length + if self.replace_length not in [-1, 0, 1]: + raise ValueError(f"invalid arg: replace_length={self.replace_length}") + if mask_length not in ["subword", "word", "span-poisson"]: + raise ValueError(f"invalid arg: mask-length={mask_length}") + if mask_length == "subword" and replace_length not in [0, 1]: + raise ValueError(f"if using subwords, use replace-length=1 or 0") + + self.mask_span_distribution = None + if mask_length == "span-poisson": + _lambda = poisson_lambda + + lambda_to_the_k = 1 + e_to_the_minus_lambda = math.exp(-_lambda) + k_factorial = 1 + ps = [] + for k in range(0, 128): + ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) + lambda_to_the_k *= _lambda + k_factorial *= k + 1 + if ps[-1] < 0.0000001: + break + ps = torch.FloatTensor(ps) + self.mask_span_distribution = torch.distributions.Categorical(ps) + + self.epoch = 0 + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the noise changes, not item sizes + + def set_epoch(self, epoch, **unused): + self.epoch = epoch + + def __getitem__(self, index): + with data_utils.numpy_seed(self.seed, self.epoch, index): + tokens = self.dataset[index] + assert tokens[-1] == self.eos + source, target = tokens, tokens.clone() + + if self.permute_sentence_ratio > 0.0: + source = self.permute_sentences(source, self.permute_sentence_ratio) + + if self.mask_ratio > 0: + source = self.add_whole_word_mask(source, self.mask_ratio) + + if self.insert_ratio > 0: + source = self.add_insertion_noise(source, self.insert_ratio) + + if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio: + source = self.add_rolling_noise(source) + # there can additional changes to make: + if self.item_transform_func is not None: + source, target = self.item_transform_func(source, target) + + assert (source >= 0).all() + assert (source[1:-1] >= 1).all() + assert (source <= len(self.vocab)).all() + assert source[0] == self.vocab.bos() + assert source[-1] == self.eos + return { + "id": index, + "source": source, + "target": target, + } + + def __len__(self): + return len(self.dataset) + + def permute_sentences(self, source, p=1.0): + full_stops = source == self.full_stop_index + # Pretend it ends with a full stop so last span is a sentence + full_stops[-2] = 1 + + # Tokens that are full stops, where the previous token is not + sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2 + result = source.clone() + + num_sentences = sentence_ends.size(0) + num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0) + substitutions = torch.randperm(num_sentences)[:num_to_permute] + ordering = torch.arange(0, num_sentences) + ordering[substitutions] = substitutions[torch.randperm(num_to_permute)] + + # Ignore at start + index = 1 + for i in ordering: + sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]] + result[index : index + sentence.size(0)] = sentence + index += sentence.size(0) + return result + + def word_starts(self, source): + if self.mask_whole_word is not None: + is_word_start = self.mask_whole_word.gather(0, source) + else: + is_word_start = torch.ones(source.size()) + is_word_start[0] = 0 + is_word_start[-1] = 0 + return is_word_start + + def add_whole_word_mask(self, source, p): + is_word_start = self.word_starts(source) + num_to_mask = int(math.ceil(is_word_start.float().sum() * p)) + num_inserts = 0 + if num_to_mask == 0: + return source + + if self.mask_span_distribution is not None: + lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,)) + + # Make sure we have enough to mask + cum_length = torch.cumsum(lengths, 0) + while cum_length[-1] < num_to_mask: + lengths = torch.cat( + [ + lengths, + self.mask_span_distribution.sample(sample_shape=(num_to_mask,)), + ], + dim=0, + ) + cum_length = torch.cumsum(lengths, 0) + + # Trim to masking budget + i = 0 + while cum_length[i] < num_to_mask: + i += 1 + lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1]) + num_to_mask = i + 1 + lengths = lengths[:num_to_mask] + + # Handle 0-length mask (inserts) separately + lengths = lengths[lengths > 0] + num_inserts = num_to_mask - lengths.size(0) + num_to_mask -= num_inserts + if num_to_mask == 0: + return self.add_insertion_noise(source, num_inserts / source.size(0)) + + assert (lengths > 0).all() + else: + lengths = torch.ones((num_to_mask,)).long() + assert is_word_start[-1] == 0 + word_starts = is_word_start.nonzero(as_tuple=False) + indices = word_starts[ + torch.randperm(word_starts.size(0))[:num_to_mask] + ].squeeze(1) + mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio + + source_length = source.size(0) + assert source_length - 1 not in indices + to_keep = torch.ones(source_length, dtype=torch.bool) + is_word_start[ + -1 + ] = 255 # acts as a long length, so spans don't go over the end of doc + if self.replace_length == 0: + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = torch.randint( + 1, len(self.vocab), size=(mask_random.sum(),) + ) + + if self.mask_span_distribution is not None: + assert len(lengths.size()) == 1 + assert lengths.size() == indices.size() + lengths -= 1 + while indices.size(0) > 0: + assert lengths.size() == indices.size() + lengths -= is_word_start[indices + 1].long() + uncompleted = lengths >= 0 + indices = indices[uncompleted] + 1 + mask_random = mask_random[uncompleted] + lengths = lengths[uncompleted] + if self.replace_length != -1: + # delete token + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = torch.randint( + 1, len(self.vocab), size=(mask_random.sum(),) + ) + else: + # A bit faster when all lengths are 1 + while indices.size(0) > 0: + uncompleted = is_word_start[indices + 1] == 0 + indices = indices[uncompleted] + 1 + mask_random = mask_random[uncompleted] + if self.replace_length != -1: + # delete token + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = torch.randint( + 1, len(self.vocab), size=(mask_random.sum(),) + ) + + assert source_length - 1 not in indices + + source = source[to_keep] + + if num_inserts > 0: + source = self.add_insertion_noise(source, num_inserts / source.size(0)) + + return source + + def add_permuted_noise(self, tokens, p): + num_words = len(tokens) + num_to_permute = math.ceil(((num_words * 2) * p) / 2.0) + substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1 + tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]] + return tokens + + def add_rolling_noise(self, tokens): + offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1) + tokens = torch.cat( + (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]), + dim=0, + ) + return tokens + + def add_insertion_noise(self, tokens, p): + if p == 0.0: + return tokens + + num_tokens = len(tokens) + n = int(math.ceil(num_tokens * p)) + + noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1 + noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool) + noise_mask[noise_indices] = 1 + result = torch.LongTensor(n + len(tokens)).fill_(-1) + + num_random = int(math.ceil(n * self.random_ratio)) + result[noise_indices[num_random:]] = self.mask_idx + result[noise_indices[:num_random]] = torch.randint( + low=1, high=len(self.vocab), size=(num_random,) + ) + + result[~noise_mask] = tokens + + assert (result >= 0).all() + return result + + def collater(self, samples, pad_to_length=None): + """Merge a list of samples to form a mini-batch. + Args: + samples (List[dict]): samples to collate + Returns: + dict: a mini-batch of data + """ + return collate( + samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length + ) + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return self.sizes[index] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)) + else: + indices = np.arange(len(self)) + return indices[np.argsort(self.sizes[indices], kind="mergesort")] + + def prefetch(self, indices): + self.src.prefetch(indices) + self.tgt.prefetch(indices) + + @property + def supports_prefetch(self): + return ( + hasattr(self.src, "supports_prefetch") + and self.src.supports_prefetch + and hasattr(self.tgt, "supports_prefetch") + and self.tgt.supports_prefetch + ) diff --git a/fairseq/fairseq/data/dictionary.py b/fairseq/fairseq/data/dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad590a19b26158bc345a3a66903006a414e2375 --- /dev/null +++ b/fairseq/fairseq/data/dictionary.py @@ -0,0 +1,403 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from collections import Counter +from multiprocessing import Pool + +import torch +from fairseq import utils +from fairseq.data import data_utils +from fairseq.file_chunker_utils import Chunker, find_offsets +from fairseq.file_io import PathManager +from fairseq.tokenizer import tokenize_line + + +class Dictionary: + """A mapping from symbols to consecutive integers""" + + def __init__( + self, + *, # begin keyword-only arguments + bos="", + pad="", + eos="", + unk="", + extra_special_symbols=None, + add_special_symbols=True, + ): + self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos + self.symbols = [] + self.count = [] + self.indices = {} + if add_special_symbols: + self.bos_index = self.add_symbol(bos) + self.pad_index = self.add_symbol(pad) + self.eos_index = self.add_symbol(eos) + self.unk_index = self.add_symbol(unk) + if extra_special_symbols: + for s in extra_special_symbols: + self.add_symbol(s) + self.nspecial = len(self.symbols) + + def __eq__(self, other): + return self.indices == other.indices + + def __getitem__(self, idx): + if idx < len(self.symbols): + return self.symbols[idx] + return self.unk_word + + def get_count(self, idx): + return self.count[idx] + + def __len__(self): + """Returns the number of symbols in the dictionary""" + return len(self.symbols) + + def __contains__(self, sym): + return sym in self.indices + + def index(self, sym): + """Returns the index of the specified symbol""" + assert isinstance(sym, str) + if sym in self.indices: + return self.indices[sym] + return self.unk_index + + def string( + self, + tensor, + bpe_symbol=None, + escape_unk=False, + extra_symbols_to_ignore=None, + unk_string=None, + include_eos=False, + separator=" ", + ): + """Helper for converting a tensor of token indices to a string. + + Can optionally remove BPE symbols or escape words. + """ + if torch.is_tensor(tensor) and tensor.dim() == 2: + return "\n".join( + self.string( + t, + bpe_symbol, + escape_unk, + extra_symbols_to_ignore, + include_eos=include_eos, + ) + for t in tensor + ) + + extra_symbols_to_ignore = set(extra_symbols_to_ignore or []) + if not include_eos: + extra_symbols_to_ignore.add(self.eos()) + + def token_string(i): + if i == self.unk(): + if unk_string is not None: + return unk_string + else: + return self.unk_string(escape_unk) + else: + return self[i] + + if hasattr(self, "bos_index"): + extra_symbols_to_ignore.add(self.bos()) + + sent = separator.join( + token_string(i) + for i in tensor + if utils.item(i) not in extra_symbols_to_ignore + ) + + return data_utils.post_process(sent, bpe_symbol) + + def unk_string(self, escape=False): + """Return unknown string, optionally escaped as: <>""" + if escape: + return "<{}>".format(self.unk_word) + else: + return self.unk_word + + def add_symbol(self, word, n=1, overwrite=False): + """Adds a word to the dictionary""" + if word in self.indices and not overwrite: + idx = self.indices[word] + self.count[idx] = self.count[idx] + n + return idx + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(n) + return idx + + def update(self, new_dict): + """Updates counts from new dictionary.""" + for word in new_dict.symbols: + idx2 = new_dict.indices[word] + if word in self.indices: + idx = self.indices[word] + self.count[idx] = self.count[idx] + new_dict.count[idx2] + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(new_dict.count[idx2]) + + def finalize(self, threshold=-1, nwords=-1, padding_factor=8): + """Sort symbols by frequency in descending order, ignoring special ones. + + Args: + - threshold defines the minimum word count + - nwords defines the total number of words in the final dictionary, + including special symbols + - padding_factor can be used to pad the dictionary size to be a + multiple of 8, which is important on some hardware (e.g., Nvidia + Tensor Cores). + """ + if nwords <= 0: + nwords = len(self) + + new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial))) + new_symbols = self.symbols[: self.nspecial] + new_count = self.count[: self.nspecial] + + c = Counter( + dict( + sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :])) + ) + ) + for symbol, count in c.most_common(nwords - self.nspecial): + if count >= threshold: + new_indices[symbol] = len(new_symbols) + new_symbols.append(symbol) + new_count.append(count) + else: + break + + assert len(new_symbols) == len(new_indices) + + self.count = list(new_count) + self.symbols = list(new_symbols) + self.indices = new_indices + + self.pad_to_multiple_(padding_factor) + + def pad_to_multiple_(self, padding_factor): + """Pad Dictionary size to be a multiple of *padding_factor*.""" + if padding_factor > 1: + i = 0 + while len(self) % padding_factor != 0: + symbol = "madeupword{:04d}".format(i) + self.add_symbol(symbol, n=0) + i += 1 + + def bos(self): + """Helper to get index of beginning-of-sentence symbol""" + return self.bos_index + + def pad(self): + """Helper to get index of pad symbol""" + return self.pad_index + + def eos(self): + """Helper to get index of end-of-sentence symbol""" + return self.eos_index + + def unk(self): + """Helper to get index of unk symbol""" + return self.unk_index + + @classmethod + def load(cls, f, add_special_symbols=True): + """Loads the dictionary from a text file with the format: + + ``` + + + ... + ``` + """ + d = cls(add_special_symbols=add_special_symbols) + d.add_from_file(f) + return d + + def add_from_file(self, f): + """ + Loads a pre-existing dictionary from a text file and adds its symbols + to this instance. + """ + if isinstance(f, str): + try: + with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd: + self.add_from_file(fd) + except FileNotFoundError as fnfe: + raise fnfe + except UnicodeError: + raise Exception( + "Incorrect encoding detected in {}, please " + "rebuild the dataset".format(f) + ) + return + + lines = f.readlines() + indices_start_line = self._load_meta(lines) + + for line in lines[indices_start_line:]: + try: + line, field = line.rstrip().rsplit(" ", 1) + if field == "#fairseq:overwrite": + overwrite = True + line, field = line.rsplit(" ", 1) + else: + overwrite = False + count = int(field) + word = line + if word in self and not overwrite: + raise RuntimeError( + "Duplicate word found when loading Dictionary: '{}'. " + "Duplicate words can overwrite earlier ones by adding the " + "#fairseq:overwrite flag at the end of the corresponding row " + "in the dictionary file. If using the Camembert model, please " + "download an updated copy of the model file.".format(word) + ) + self.add_symbol(word, n=count, overwrite=overwrite) + except ValueError: + raise ValueError( + f"Incorrect dictionary format, expected ' [flags]': \"{line}\"" + ) + + def _save(self, f, kv_iterator): + if isinstance(f, str): + PathManager.mkdirs(os.path.dirname(f)) + with PathManager.open(f, "w", encoding="utf-8") as fd: + return self.save(fd) + for k, v in kv_iterator: + print("{} {}".format(k, v), file=f) + + def _get_meta(self): + return [], [] + + def _load_meta(self, lines): + return 0 + + def save(self, f): + """Stores dictionary into a text file""" + ex_keys, ex_vals = self._get_meta() + self._save( + f, + zip( + ex_keys + self.symbols[self.nspecial :], + ex_vals + self.count[self.nspecial :], + ), + ) + + def dummy_sentence(self, length): + t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() + t[-1] = self.eos() + return t + + def encode_line( + self, + line, + line_tokenizer=tokenize_line, + add_if_not_exist=True, + consumer=None, + append_eos=True, + reverse_order=False, + ) -> torch.IntTensor: + words = line_tokenizer(line) + if reverse_order: + words = list(reversed(words)) + nwords = len(words) + ids = torch.IntTensor(nwords + 1 if append_eos else nwords) + + for i, word in enumerate(words): + if add_if_not_exist: + idx = self.add_symbol(word) + else: + idx = self.index(word) + if consumer is not None: + consumer(word, idx) + ids[i] = idx + if append_eos: + ids[nwords] = self.eos_index + return ids + + @staticmethod + def _add_file_to_dictionary_single_worker( + filename, + tokenize, + eos_word, + start_offset, + end_offset, + ): + counter = Counter() + with Chunker(filename, start_offset, end_offset) as line_iterator: + for line in line_iterator: + for word in tokenize(line): + counter.update([word]) + counter.update([eos_word]) + return counter + + @staticmethod + def add_file_to_dictionary(filename, dict, tokenize, num_workers): + def merge_result(counter): + for w, c in sorted(counter.items()): + dict.add_symbol(w, c) + + local_file = PathManager.get_local_path(filename) + offsets = find_offsets(local_file, num_workers) + if num_workers > 1: + chunks = zip(offsets, offsets[1:]) + pool = Pool(processes=num_workers) + results = [] + for (start_offset, end_offset) in chunks: + results.append( + pool.apply_async( + Dictionary._add_file_to_dictionary_single_worker, + ( + local_file, + tokenize, + dict.eos_word, + start_offset, + end_offset, + ), + ) + ) + pool.close() + pool.join() + for r in results: + merge_result(r.get()) + else: + merge_result( + Dictionary._add_file_to_dictionary_single_worker( + local_file, tokenize, dict.eos_word, offsets[0], offsets[1] + ) + ) + + +class TruncatedDictionary(object): + def __init__(self, wrapped_dict, length): + self.__class__ = type( + wrapped_dict.__class__.__name__, + (self.__class__, wrapped_dict.__class__), + {}, + ) + self.__dict__ = wrapped_dict.__dict__ + self.wrapped_dict = wrapped_dict + self.length = min(len(self.wrapped_dict), length) + + def __len__(self): + return self.length + + def __getitem__(self, i): + if i < self.length: + return self.wrapped_dict[i] + return self.wrapped_dict.unk() diff --git a/fairseq/fairseq/data/fairseq_dataset.py b/fairseq/fairseq/data/fairseq_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2bde7fc57b99df2e14e2186a5f9cd98982870ddd --- /dev/null +++ b/fairseq/fairseq/data/fairseq_dataset.py @@ -0,0 +1,205 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import numpy as np +import torch.utils.data +from fairseq.data import data_utils + +logger = logging.getLogger(__name__) + + +class EpochListening: + """Mixin for receiving updates whenever the epoch increments.""" + + @property + def can_reuse_epoch_itr_across_epochs(self): + """ + Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for + this dataset across epochs. + + This needs to return ``False`` if the sample sizes can change across + epochs, in which case we may need to regenerate batches at each epoch. + If your dataset relies in ``set_epoch`` then you should consider setting + this to ``False``. + """ + return True + + def set_epoch(self, epoch): + """Will receive the updated epoch number at the beginning of the epoch.""" + pass + + +class FairseqDataset(torch.utils.data.Dataset, EpochListening): + """A dataset that provides helpers for batching.""" + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch suitable for forwarding with a Model + """ + raise NotImplementedError + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + raise NotImplementedError + + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + raise NotImplementedError + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + raise NotImplementedError + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + return np.arange(len(self), dtype=np.int64) + + @property + def supports_prefetch(self): + """Whether this dataset supports prefetching.""" + return False + + def attr(self, attr: str, index: int): + return getattr(self, attr, None) + + def prefetch(self, indices): + """Prefetch the data required for this epoch.""" + raise NotImplementedError + + def get_batch_shapes(self): + """ + Return a list of valid batch shapes, for example:: + + [(8, 512), (16, 256), (32, 128)] + + The first dimension of each tuple is the batch size and can be ``None`` + to automatically infer the max batch size based on ``--max-tokens``. + The second dimension of each tuple is the max supported length as given + by :func:`fairseq.data.FairseqDataset.num_tokens`. + + This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` + to restrict batch shapes. This is useful on TPUs to avoid too many + dynamic shapes (and recompilations). + """ + return None + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + """ + Given an ordered set of indices, return batches according to + *max_tokens*, *max_sentences* and *required_batch_size_multiple*. + """ + from fairseq.data import data_utils + + fixed_shapes = self.get_batch_shapes() + if fixed_shapes is not None: + + def adjust_bsz(bsz, num_tokens): + if bsz is None: + assert max_tokens is not None, "Must specify --max-tokens" + bsz = max_tokens // num_tokens + if max_sentences is not None: + bsz = min(bsz, max_sentences) + elif ( + bsz >= required_batch_size_multiple + and bsz % required_batch_size_multiple != 0 + ): + bsz -= bsz % required_batch_size_multiple + return bsz + + fixed_shapes = np.array( + [ + [adjust_bsz(bsz, num_tokens), num_tokens] + for (bsz, num_tokens) in fixed_shapes + ] + ) + + try: + num_tokens_vec = self.num_tokens_vec(indices).astype("int64") + except NotImplementedError: + num_tokens_vec = None + + return data_utils.batch_by_size( + indices, + num_tokens_fn=self.num_tokens, + num_tokens_vec=num_tokens_vec, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + fixed_shapes=fixed_shapes, + ) + + def filter_indices_by_size(self, indices, max_sizes): + """ + Filter a list of sample indices. Remove those that are longer than + specified in *max_sizes*. + + WARNING: don't update, override method in child classes + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if isinstance(max_sizes, float) or isinstance(max_sizes, int): + if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray): + ignored = indices[self.sizes[indices] > max_sizes].tolist() + indices = indices[self.sizes[indices] <= max_sizes] + elif ( + hasattr(self, "sizes") + and isinstance(self.sizes, list) + and len(self.sizes) == 1 + ): + ignored = indices[self.sizes[0][indices] > max_sizes].tolist() + indices = indices[self.sizes[0][indices] <= max_sizes] + else: + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + else: + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + return indices, ignored + + @property + def supports_fetch_outside_dataloader(self): + """Whether this dataset supports fetching outside the workers of the dataloader.""" + return True + + +class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): + """ + For datasets that need to be read sequentially, usually because the data is + being streamed or otherwise can't be manipulated on a single machine. + """ + + def __iter__(self): + raise NotImplementedError diff --git a/fairseq/fairseq/data/fasta_dataset.py b/fairseq/fairseq/data/fasta_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..007011974a997fd7446dd29d7eba097d7513bab0 --- /dev/null +++ b/fairseq/fairseq/data/fasta_dataset.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import subprocess +import threading +from pathlib import Path + +import numpy as np +import torch + + +def fasta_file_path(prefix_path): + return prefix_path + ".fasta" + + +class FastaDataset(torch.utils.data.Dataset): + """ + For loading protein sequence datasets in the common FASTA data format + """ + + def __init__(self, path: str, cache_indices=False): + self.fn = fasta_file_path(path) + self.threadlocal = threading.local() + self.cache = Path(f"{path}.fasta.idx.npy") + if cache_indices: + if self.cache.exists(): + self.offsets, self.sizes = np.load(self.cache) + else: + self.offsets, self.sizes = self._build_index(path) + np.save(self.cache, np.stack([self.offsets, self.sizes])) + else: + self.offsets, self.sizes = self._build_index(path) + + def _get_file(self): + if not hasattr(self.threadlocal, "f"): + self.threadlocal.f = open(self.fn, "r") + return self.threadlocal.f + + def __getitem__(self, idx): + f = self._get_file() + f.seek(self.offsets[idx]) + desc = f.readline().strip() + line = f.readline() + seq = "" + while line != "" and line[0] != ">": + seq += line.strip() + line = f.readline() + return desc, seq + + def __len__(self): + return self.offsets.size + + def _build_index(self, path: str): + # Use grep and awk to get 100M/s on local SSD. + # Should process your enormous 100G fasta in ~10 min single core... + path = fasta_file_path(path) + bytes_offsets = subprocess.check_output( + f"cat {path} | tqdm --bytes --total $(wc -c < {path})" + "| grep --byte-offset '^>' -o | cut -d: -f1", + shell=True, + ) + fasta_lengths = subprocess.check_output( + f"cat {path} | tqdm --bytes --total $(wc -c < {path})" + "| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'", + shell=True, + ) + bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ") + sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ") + return bytes_np, sizes_np + + def __setstate__(self, state): + self.__dict__ = state + self.threadlocal = threading.local() + + def __getstate__(self): + d = {} + for i, v in self.__dict__.items(): + if i != "threadlocal": + d[i] = v + return d + + def __del__(self): + if hasattr(self.threadlocal, "f"): + self.threadlocal.f.close() + del self.threadlocal.f + + @staticmethod + def exists(path): + return os.path.exists(fasta_file_path(path)) + + +class EncodedFastaDataset(FastaDataset): + """ + The FastaDataset returns raw sequences - this allows us to return + indices with a dictionary instead. + """ + + def __init__(self, path, dictionary): + super().__init__(path, cache_indices=True) + self.dictionary = dictionary + + def __getitem__(self, idx): + desc, seq = super().__getitem__(idx) + return self.dictionary.encode_line(seq, line_tokenizer=list).long() diff --git a/fairseq/fairseq/data/id_dataset.py b/fairseq/fairseq/data/id_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3e4d7969cf2a26e852b466f165a6fadabae3b35f --- /dev/null +++ b/fairseq/fairseq/data/id_dataset.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import FairseqDataset + + +class IdDataset(FairseqDataset): + def __getitem__(self, index): + return index + + def __len__(self): + return 0 + + def collater(self, samples): + return torch.tensor(samples) diff --git a/fairseq/fairseq/data/indexed_dataset.py b/fairseq/fairseq/data/indexed_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1947d994081167ee78e9b2a5590882e5025b2244 --- /dev/null +++ b/fairseq/fairseq/data/indexed_dataset.py @@ -0,0 +1,592 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import shutil +import struct +from functools import lru_cache + +import numpy as np +import torch +from fairseq.dataclass.constants import DATASET_IMPL_CHOICES +from fairseq.data.fasta_dataset import FastaDataset +from fairseq.file_io import PathManager +from fairseq.data.huffman import HuffmanMMapIndexedDataset, HuffmanMMapIndex + +from . import FairseqDataset + +from typing import Union + + +def best_fitting_int_dtype( + max_int_to_represent, +) -> Union[np.uint16, np.uint32, np.int64]: + + if max_int_to_represent is None: + return np.uint32 # Safe guess + elif max_int_to_represent < 65500: + return np.uint16 + elif max_int_to_represent < 4294967295: + return np.uint32 + else: + return np.int64 + # we avoid np.uint64 because it doesn't save space and its type promotion behaves unexpectedly + # https://github.com/numpy/numpy/issues/5745 + + +def get_available_dataset_impl(): + return list(map(str, DATASET_IMPL_CHOICES)) + + +def infer_dataset_impl(path): + if IndexedRawTextDataset.exists(path): + return "raw" + elif IndexedDataset.exists(path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return "cached" + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return "mmap" + elif magic == HuffmanMMapIndex._HDR_MAGIC[:8]: + return "huffman" + else: + return None + elif FastaDataset.exists(path): + return "fasta" + else: + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == "mmap": + return MMapIndexedDatasetBuilder( + out_file, dtype=best_fitting_int_dtype(vocab_size) + ) + elif impl == "fasta": + raise NotImplementedError + elif impl == "huffman": + raise ValueError( + "Use HuffmanCodeBuilder directly as it has a different interface." + ) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None): + if impl == "raw" and IndexedRawTextDataset.exists(path): + assert dictionary is not None + return IndexedRawTextDataset(path, dictionary) + elif impl == "lazy" and IndexedDataset.exists(path): + return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing) + elif impl == "cached" and IndexedDataset.exists(path): + return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing) + elif impl == "mmap" and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path) + elif impl == "fasta" and FastaDataset.exists(path): + from fairseq.data.fasta_dataset import EncodedFastaDataset + + return EncodedFastaDataset(path, dictionary) + elif impl == "huffman" and HuffmanMMapIndexedDataset.exists(path): + return HuffmanMMapIndexedDataset(path) + return None + + +def dataset_exists(path, impl): + if impl == "raw": + return IndexedRawTextDataset.exists(path) + elif impl == "mmap": + return MMapIndexedDataset.exists(path) + elif impl == "huffman": + return HuffmanMMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +_code_to_dtype = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float64, + 7: np.double, + 8: np.uint16, + 9: np.uint32, + 10: np.uint64, +} + + +def _dtype_header_code(dtype) -> int: + for k in _code_to_dtype.keys(): + if _code_to_dtype[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + ".idx" + + +def data_file_path(prefix_path): + return prefix_path + ".bin" + + +class IndexedDataset(FairseqDataset): + """Loader for TorchNet IndexedDataset""" + + _HDR_MAGIC = b"TNTIDX\x00\x00" + + def __init__(self, path, fix_lua_indexing=False): + super().__init__() + self.path = path + self.fix_lua_indexing = fix_lua_indexing + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + "Index file doesn't match expected format. " + "Make sure that --dataset-impl is configured properly." + ) + version = f.read(8) + assert struct.unpack("= self._len: + raise IndexError("index out of range") + + def __del__(self): + if self.data_file: + self.data_file.close() + + @lru_cache(maxsize=8) + def __getitem__(self, i) -> torch.Tensor: + if not self.data_file: + self.read_data(self.path) + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + item = torch.from_numpy(a).long() + if self.fix_lua_indexing: + item -= 1 # subtract 1 for 0-based indexing + return item + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return PathManager.exists(index_file_path(path)) and PathManager.exists( + data_file_path(path) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + def __init__(self, path, fix_lua_indexing=False): + super().__init__(path, fix_lua_indexing=fix_lua_indexing) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx : ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + @lru_cache(maxsize=8) + def __getitem__(self, i): + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx : ptx + a.size]) + item = torch.from_numpy(a).long() + if self.fix_lua_indexing: + item -= 1 # subtract 1 for 0-based indexing + return item + + +class IndexedRawTextDataset(FairseqDataset): + """Takes a text file as input and binarizes it in memory at instantiation. + Original lines are also kept in memory""" + + def __init__(self, path, dictionary, append_eos=True, reverse_order=False): + self.tokens_list = [] + self.lines = [] + self.sizes = [] + self.append_eos = append_eos + self.reverse_order = reverse_order + self.read_data(path, dictionary) + self.size = len(self.tokens_list) + + def read_data(self, path, dictionary): + with open(path, "r", encoding="utf-8") as f: + for line in f: + self.lines.append(line.strip("\n")) + tokens = dictionary.encode_line( + line, + add_if_not_exist=False, + append_eos=self.append_eos, + reverse_order=self.reverse_order, + ).long() + self.tokens_list.append(tokens) + self.sizes.append(len(tokens)) + self.sizes = np.array(self.sizes) + + def check_index(self, i): + if i < 0 or i >= self.size: + raise IndexError("index out of range") + + @lru_cache(maxsize=8) + def __getitem__(self, i): + self.check_index(i) + return self.tokens_list[i] + + def get_original_text(self, i): + self.check_index(i) + return self.lines[i] + + def __del__(self): + pass + + def __len__(self): + return self.size + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return PathManager.exists(path) + + +class IndexedDatasetBuilder: + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float64: 4, + np.double: 8, + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, "wb") + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + + def add_item(self, tensor): + # +1 for Lua compatibility + bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype)) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.size(): + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + begin = self.data_offsets[-1] + for offset in index.data_offsets[1:]: + self.data_offsets.append(begin + offset) + self.sizes.extend(index.sizes) + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + with open(data_file_path(another_file), "rb") as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, "wb") + index.write(b"TNTIDX\x00\x00") + index.write(struct.pack(" str: + local_index_path = PathManager.get_local_path(index_file_path(path)) + local_data_path = PathManager.get_local_path(data_file_path(path)) + + assert local_index_path.endswith(".idx") and local_data_path.endswith(".bin"), ( + "PathManager.get_local_path does not return files with expected patterns: " + f"{local_index_path} and {local_data_path}" + ) + + local_path = local_data_path[:-4] # stripping surfix ".bin" + assert local_path == local_index_path[:-4] # stripping surfix ".idx" + return local_path + + +class MMapIndexedDatasetBuilder: + def __init__(self, out_file, dtype=np.int64): + self._data_file = open(out_file, "wb") + self._dtype = dtype + self._sizes = [] + + def add_item(self, tensor): + np_array = np.array(tensor.numpy(), dtype=self._dtype) + self._data_file.write(np_array.tobytes(order="C")) + self._sizes.append(np_array.size) + + def merge_file_(self, another_file): + # Concatenate index + index = MMapIndexedDataset.Index(index_file_path(another_file)) + assert index.dtype == self._dtype + + for size in index.sizes: + self._sizes.append(size) + + # Concatenate data + with open(data_file_path(another_file), "rb") as f: + shutil.copyfileobj(f, self._data_file) + + def finalize(self, index_file): + self._data_file.close() + + with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: + index.write(self._sizes) diff --git a/fairseq/fairseq/data/iterators.py b/fairseq/fairseq/data/iterators.py new file mode 100644 index 0000000000000000000000000000000000000000..6a5a42a9cf1cee9e0559d4c3b00024992271128c --- /dev/null +++ b/fairseq/fairseq/data/iterators.py @@ -0,0 +1,879 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import math +import operator +import os +import queue +import time +from threading import Thread +from typing import Iterator, List + +import numpy as np +import torch +from fairseq.data import data_utils + + +logger = logging.getLogger(__name__) + +# Object used by _background_consumer to signal the source is exhausted +# to the main thread. +_sentinel = object() + + +class CountingIterator(object): + """Wrapper around an iterable that maintains the iteration count. + + Args: + iterable (iterable): iterable to wrap + start (int): starting iteration count. Note that this doesn't + actually advance the iterator. + total (int): override the iterator length returned by ``__len``. + This can be used to truncate *iterator*. + + Attributes: + n (int): number of elements consumed from this iterator + """ + + def __init__(self, iterable, start=None, total=None): + self._itr = iter(iterable) + self.n = start or getattr(iterable, "n", 0) + self.total = total if total is not None else self.n + len(iterable) + + def __len__(self): + return self.total + + def __iter__(self): + return self + + def __next__(self): + if not self.has_next(): + raise StopIteration + try: + x = next(self._itr) + except StopIteration: + raise IndexError( + f"Iterator expected to have length {self.total}, " + f"but exhausted at position {self.n}." + ) + self.n += 1 + return x + + def has_next(self): + """Whether the iterator has been exhausted.""" + return self.n < self.total + + def skip(self, n): + """Fast-forward the iterator by skipping n elements.""" + for _ in range(n): + next(self) + return self + + def take(self, n): + """Truncate the iterator to n elements at most.""" + self.total = min(self.total, n) + # Propagate this change to the underlying iterator + if hasattr(self._itr, "take"): + self._itr.take(max(n - self.n, 0)) + return self + + +class EpochBatchIterating(object): + def __len__(self) -> int: + raise NotImplementedError + + @property + def next_epoch_idx(self): + raise NotImplementedError + + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): + """Return a new iterator over the dataset. + + Args: + shuffle (bool, optional): shuffle batches before returning the + iterator (default: True). + fix_batches_to_gpus (bool, optional): ensure that batches are always + allocated to the same shards across epochs. Requires + that :attr:`dataset` supports prefetching (default: False). + set_dataset_epoch (bool, optional): update the wrapped Dataset with + the new epoch number (default: True). + """ + raise NotImplementedError + + def end_of_epoch(self) -> bool: + """Returns whether the most recent epoch iterator has been exhausted""" + raise NotImplementedError + + @property + def iterations_in_epoch(self) -> int: + """The number of consumed batches in the current epoch.""" + raise NotImplementedError + + def state_dict(self): + """Returns a dictionary containing a whole state of the iterator.""" + raise NotImplementedError + + def load_state_dict(self, state_dict): + """Copies the state of the iterator from the given *state_dict*.""" + raise NotImplementedError + + @property + def first_batch(self): + return "DUMMY" + + +class StreamingEpochBatchIterator(EpochBatchIterating): + """A steaming-style iterator over a :class:`torch.utils.data.IterableDataset`. + + Args: + dataset (~torch.utils.data.Dataset): dataset from which to load the data + max_sentences: batch size + collate_fn (callable): merges a list of samples to form a mini-batch + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 1). + buffer_size (int, optional): the number of batches to keep ready in the + queue. Helps speeding up dataloading. When buffer_size is zero, the + default torch.utils.data.DataLoader preloading is used. + timeout (int, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative (default: ``0``). + """ + + def __init__( + self, + dataset, + max_sentences=1, + collate_fn=None, + epoch=1, + num_workers=0, + buffer_size=0, + timeout=0, + persistent_workers=True, + ): + assert isinstance(dataset, torch.utils.data.IterableDataset) + self.dataset = dataset + self.max_sentences = max_sentences + self.collate_fn = collate_fn + self.epoch = max(epoch, 1) # we use 1-based indexing for epochs + self.num_workers = num_workers + self.persistent_workers = persistent_workers and num_workers > 0 + # This upper limit here is to prevent people from abusing this feature + # in a shared computing environment. + self.buffer_size = min(buffer_size, 20) + self.timeout = timeout + + self._current_epoch_iterator = None + + @property + def next_epoch_idx(self): + """Return the epoch index after *next_epoch_itr* is called.""" + if self._current_epoch_iterator is not None and self.end_of_epoch(): + return self.epoch + 1 + else: + return self.epoch + + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): + self.epoch = self.next_epoch_idx + if set_dataset_epoch and hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(self.epoch) + self._current_epoch_iterator = self._get_iterator_for_epoch(self.epoch, shuffle) + return self._current_epoch_iterator + + def end_of_epoch(self) -> bool: + return not self._current_epoch_iterator.has_next() + + @property + def iterations_in_epoch(self) -> int: + if self._current_epoch_iterator is not None: + return self._current_epoch_iterator.n + return 0 + + def state_dict(self): + return { + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict["epoch"] + + def _get_iterator_for_epoch(self, epoch, shuffle, offset=0): + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + # Create data loader + worker_init_fn = getattr(self.dataset, "worker_init_fn", None) + itr = torch.utils.data.DataLoader( + self.dataset, + batch_size=self.max_sentences, + collate_fn=self.collate_fn, + num_workers=self.num_workers, + timeout=self.timeout, + worker_init_fn=worker_init_fn, + pin_memory=True, + persistent_workers=self.persistent_workers, + ) + + # Wrap with a BufferedIterator if needed + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + # Wrap with CountingIterator + itr = CountingIterator(itr, start=offset) + + return itr + + +class FrozenBatchSampler: + def __init__( + self, + ordered_batches, + epoch, + fix_batches_to_gpus, + shuffle, + initial_offset, + ): + self.ordered_batches = ordered_batches + self.fix_batches_to_gpus = fix_batches_to_gpus + self.shuffle = shuffle + self.make_batches_for_epoch(epoch, initial_offset) + + def make_batches_for_epoch(self, epoch, offset=0): + self.batches = self.ordered_batches( + epoch, self.fix_batches_to_gpus, self.shuffle + ) + if offset > 0: + self.batches = self.batches[offset:] + + def __iter__(self) -> Iterator[List[int]]: + return iter(self.batches) + + def __len__(self) -> int: + return len(self.batches) + + +class EpochBatchIterator(EpochBatchIterating): + """A multi-epoch iterator over a :class:`torch.utils.data.Dataset`. + + Compared to :class:`torch.utils.data.DataLoader`, this iterator: + + - can be reused across multiple epochs with the :func:`next_epoch_itr` + method (optionally shuffled between epochs) + - can be serialized/deserialized with the :func:`state_dict` and + :func:`load_state_dict` methods + - supports sharding with the *num_shards* and *shard_id* arguments + + Args: + dataset (~torch.utils.data.Dataset): dataset from which to load the data + collate_fn (callable): merges a list of samples to form a mini-batch + batch_sampler (~torch.utils.data.Sampler or a callable): an iterator over batches of + indices, or a callable to create such an iterator (~torch.utils.data.Sampler). + A callable batch_sampler will be called for each epoch to enable per epoch dynamic + batch iterators defined by this callable batch_sampler. + seed (int, optional): seed for random number generator for + reproducibility (default: 1). + num_shards (int, optional): shard the data iterator into N + shards (default: 1). + shard_id (int, optional): which shard of the data iterator to + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 1). + buffer_size (int, optional): the number of batches to keep ready in the + queue. Helps speeding up dataloading. When buffer_size is zero, the + default torch.utils.data.DataLoader preloading is used. + timeout (int, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative (default: ``0``). + disable_shuffling (bool, optional): force disable shuffling + (default: ``False``). + skip_remainder_batch (bool, optional): if set, discard the last batch in an epoch + for the sake of training stability, as the last batch is usually smaller than + local_batch_size * distributed_word_size (default: ``False``). + grouped_shuffling (bool, optional): enable shuffling batches in groups + of num_shards. Ensures that each GPU receives similar length sequences when + batches are sorted by length. + """ + + def __init__( + self, + dataset, + collate_fn, + batch_sampler, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + buffer_size=0, + timeout=0, + disable_shuffling=False, + skip_remainder_batch=False, + grouped_shuffling=False, + reuse_dataloader=False, + persistent_workers=True, + ): + assert isinstance(dataset, torch.utils.data.Dataset) + self.dataset = dataset + self.collate_fn = collate_fn + self.batch_sampler = batch_sampler + self._frozen_batches = ( + tuple(batch_sampler) if not callable(batch_sampler) else None + ) + self.seed = seed + self.num_shards = num_shards + self.shard_id = shard_id + self.num_workers = num_workers + self.persistent_workers = persistent_workers and num_workers > 0 + # This upper limit here is to prevent people from abusing this feature + # in a shared computing environment. + self.buffer_size = min(buffer_size, 20) + self.timeout = timeout + self.disable_shuffling = disable_shuffling + self.skip_remainder_batch = skip_remainder_batch + self.grouped_shuffling = grouped_shuffling + + self.epoch = max(epoch, 1) # we use 1-based indexing for epochs + self.shuffle = not disable_shuffling + self._cur_epoch_itr = None + self._next_epoch_itr = None + self._supports_prefetch = getattr(dataset, "supports_prefetch", False) + + self.dataloader = None + self.reuse_dataloader = reuse_dataloader + + @property + def frozen_batches(self): + if self._frozen_batches is None: + self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch)) + return self._frozen_batches + + @property + def first_batch(self): + if len(self.frozen_batches) == 0: + raise Exception( + "The dataset is empty. This could indicate " + "that all elements in the dataset have been skipped. " + "Try increasing the max number of allowed tokens or using " + "a larger dataset." + ) + + if getattr(self.dataset, "supports_fetch_outside_dataloader", True): + return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]]) + else: + return "DUMMY" + + def __len__(self): + return int(math.ceil(len(self.frozen_batches) / float(self.num_shards))) + + @property + def n(self): + return self.iterations_in_epoch + + @property + def next_epoch_idx(self): + """Return the epoch index after *next_epoch_itr* is called.""" + if self._next_epoch_itr is not None: + return self.epoch + elif self._cur_epoch_itr is not None and self.end_of_epoch(): + return self.epoch + 1 + else: + return self.epoch + + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): + """Return a new iterator over the dataset. + + Args: + shuffle (bool, optional): shuffle batches before returning the + iterator (default: True). + fix_batches_to_gpus (bool, optional): ensure that batches are always + allocated to the same shards across epochs. Requires + that :attr:`dataset` supports prefetching (default: False). + set_dataset_epoch (bool, optional): update the wrapped Dataset with + the new epoch number (default: True). + """ + if self.disable_shuffling: + shuffle = False + prev_epoch = self.epoch + self.epoch = self.next_epoch_idx + if set_dataset_epoch and hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(self.epoch) + if self._next_epoch_itr is not None: + self._cur_epoch_itr = self._next_epoch_itr + self._next_epoch_itr = None + else: + if callable(self.batch_sampler) and prev_epoch != self.epoch: + # reset _frozen_batches to refresh the next epoch + self._frozen_batches = None + self._cur_epoch_itr = self._get_iterator_for_epoch( + self.epoch, + shuffle, + fix_batches_to_gpus=fix_batches_to_gpus, + ) + self.shuffle = shuffle + return self._cur_epoch_itr + + def end_of_epoch(self) -> bool: + """Returns whether the most recent epoch iterator has been exhausted""" + return not self._cur_epoch_itr.has_next() + + @property + def iterations_in_epoch(self): + """The number of consumed batches in the current epoch.""" + if self._cur_epoch_itr is not None: + return self._cur_epoch_itr.n + elif self._next_epoch_itr is not None: + return self._next_epoch_itr.n + return 0 + + def state_dict(self): + """Returns a dictionary containing a whole state of the iterator.""" + if self.end_of_epoch(): + epoch = self.epoch + 1 + iter_in_epoch = 0 + else: + epoch = self.epoch + iter_in_epoch = self.iterations_in_epoch + return { + "version": 2, + "epoch": epoch, + "iterations_in_epoch": iter_in_epoch, + "shuffle": self.shuffle, + } + + def load_state_dict(self, state_dict): + """Copies the state of the iterator from the given *state_dict*.""" + self.epoch = state_dict["epoch"] + itr_pos = state_dict.get("iterations_in_epoch", 0) + version = state_dict.get("version", 1) + if itr_pos > 0: + # fast-forward epoch iterator + self._next_epoch_itr = self._get_iterator_for_epoch( + self.epoch, + shuffle=state_dict.get("shuffle", True), + offset=itr_pos, + ) + if self._next_epoch_itr is None: + if version == 1: + # legacy behavior: we finished the epoch, increment epoch counter + self.epoch += 1 + else: + raise RuntimeError( + "Cannot resume training due to dataloader mismatch, please " + "report this to the fairseq developers. You can relaunch " + "training with `--reset-dataloader` and it should work." + ) + else: + self._next_epoch_itr = None + + def _get_iterator_for_epoch( + self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 + ): + if self.reuse_dataloader and self.dataloader is not None: + self.epoch_batch_sampler.make_batches_for_epoch(epoch, offset) + itr = self.dataloader + else: + self.epoch_batch_sampler = FrozenBatchSampler( + self.ordered_batches, + epoch, + fix_batches_to_gpus, + shuffle, + initial_offset=offset, + ) + + if offset > 0 and len(self.epoch_batch_sampler) == 0: + return None + + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + # Create data loader + itr = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_sampler=self.epoch_batch_sampler, + num_workers=self.num_workers, + timeout=self.timeout, + pin_memory=True, + persistent_workers=self.persistent_workers, + ) + + if self.reuse_dataloader: + self.dataloader = itr + + # Wrap with a BufferedIterator if needed + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + # Wrap with CountingIterator + itr = CountingIterator(itr, start=offset) + + if self.skip_remainder_batch: + # TODO: Below is a lazy implementation which discard the final batch regardless + # of whether it is a full batch or not. + + total_num_itrs = len(itr) - 1 + itr.take(total_num_itrs) + logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}") + + return itr + + def ordered_batches(self, epoch, fix_batches_to_gpus, shuffle): + def shuffle_batches(batches, seed): + with data_utils.numpy_seed(seed): + + if self.grouped_shuffling: + grouped_batches = [ + batches[(i * self.num_shards) : ((i + 1) * self.num_shards)] + for i in range((len(batches) // self.num_shards)) + ] + np.random.shuffle(grouped_batches) + batches = list(itertools.chain(*grouped_batches)) + else: + np.random.shuffle(batches) + + return batches + + if self._supports_prefetch: + batches = self.frozen_batches + + if shuffle and not fix_batches_to_gpus: + batches = shuffle_batches(list(batches), self.seed + epoch) + + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + self.dataset.prefetch([i for s in batches for i in s]) + + if shuffle and fix_batches_to_gpus: + batches = shuffle_batches(batches, self.seed + epoch + self.shard_id) + else: + if shuffle: + batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch) + else: + batches = self.frozen_batches + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + return batches + + +class GroupedIterator(CountingIterator): + """Wrapper around an iterable that returns groups (chunks) of items. + + Args: + iterable (iterable): iterable to wrap + chunk_size (int): size of each chunk + skip_remainder_batch (bool, optional): if set, discard the last grouped batch in + each training epoch, as the last grouped batch is usually smaller than + local_batch_size * distributed_word_size * chunk_size (default: ``False``). + Attributes: + n (int): number of elements consumed from this iterator + """ + + def __init__(self, iterable, chunk_size, skip_remainder_batch=False): + if skip_remainder_batch: + total_num_itrs = int(math.floor(len(iterable) / float(chunk_size))) + logger.info( + f"skip final residual batch, grouped total_num_itrs = {total_num_itrs}" + ) + else: + total_num_itrs = int(math.ceil(len(iterable) / float(chunk_size))) + logger.info(f"grouped total_num_itrs = {total_num_itrs}") + + itr = _chunk_iterator(iterable, chunk_size, skip_remainder_batch) + super().__init__( + itr, + start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))), + total=total_num_itrs, + ) + self.chunk_size = chunk_size + + if skip_remainder_batch: + self.take(total_num_itrs) + # TODO: [Hack] Here the grouped iterator modifies the base iterator size so that + # training can move into the next epoch once the grouped iterator is exhausted. + # Double-check this implementation in case unexpected behavior occurs. + iterable.take(total_num_itrs * chunk_size) + + +def _chunk_iterator(itr, chunk_size, skip_remainder_batch=False): + chunk = [] + for x in itr: + chunk.append(x) + if len(chunk) == chunk_size: + yield chunk + chunk = [] + if not skip_remainder_batch and len(chunk) > 0: + yield chunk + + +class ShardedIterator(CountingIterator): + """A sharded wrapper around an iterable, padded to length. + + Args: + iterable (iterable): iterable to wrap + num_shards (int): number of shards to split the iterable into + shard_id (int): which shard to iterator over + fill_value (Any, optional): padding value when the iterable doesn't + evenly divide *num_shards* (default: None). + + Attributes: + n (int): number of elements consumed from this iterator + """ + + def __init__( + self, iterable, num_shards, shard_id, fill_value=None, skip_remainder_batch=None + ): + """ + Args: + skip_remainder_batch: ignored""" + if shard_id < 0 or shard_id >= num_shards: + raise ValueError("shard_id must be between 0 and num_shards") + sharded_len = int(math.ceil(len(iterable) / float(num_shards))) + itr = map( + operator.itemgetter(1), + itertools.zip_longest( + range(sharded_len), + itertools.islice(iterable, shard_id, len(iterable), num_shards), + fillvalue=fill_value, + ), + ) + super().__init__( + itr, + start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))), + total=sharded_len, + ) + + +class BackgroundConsumer(Thread): + def __init__(self, queue, source, max_len, cuda_device): + Thread.__init__(self) + + self._queue = queue + self._source = source + self._max_len = max_len + self.count = 0 + self.cuda_device = cuda_device + + def run(self): + # set_device to avoid creation of GPU0 context when using pin_memory + if self.cuda_device is not None: + torch.cuda.set_device(self.cuda_device) + + try: + for item in self._source: + self._queue.put(item) + + # Stop if we reached the maximum length + self.count += 1 + if self._max_len is not None and self.count >= self._max_len: + break + + # Signal the consumer we are done. + self._queue.put(_sentinel) + except Exception as e: + self._queue.put(e) + + +class BufferedIterator(object): + def __init__(self, size, iterable): + self._queue = queue.Queue(size) + self._iterable = iterable + self._consumer = None + + self.start_time = time.time() + self.warning_time = None + + self.total = len(iterable) + + def _create_consumer(self): + self._consumer = BackgroundConsumer( + self._queue, + self._iterable, + self.total, + torch.cuda.current_device() if torch.cuda.is_available() else None, + ) + self._consumer.daemon = True + self._consumer.start() + + def __iter__(self): + return self + + def __len__(self): + return self.total + + def take(self, n): + self.total = min(self.total, n) + # Propagate this change to the underlying iterator + if hasattr(self._iterable, "take"): + self._iterable.take(n) + return self + + def __next__(self): + # Create consumer if not created yet + if self._consumer is None: + self._create_consumer() + + # Notify the user if there is a data loading bottleneck + if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)): + if time.time() - self.start_time > 5 * 60: + if ( + self.warning_time is None + or time.time() - self.warning_time > 15 * 60 + ): + logger.debug( + "Data loading buffer is empty or nearly empty. This may " + "indicate a data loading bottleneck, and increasing the " + "number of workers (--num-workers) may help." + ) + self.warning_time = time.time() + + # Get next example + item = self._queue.get(True) + if isinstance(item, Exception): + raise item + if item is _sentinel: + raise StopIteration() + return item + + +class GroupedEpochBatchIterator(EpochBatchIterator): + """Grouped version of EpochBatchIterator + It takes several samplers from different datasets. + Each epoch shuffle the dataset wise sampler individually with different + random seed. The those sub samplers are combined with into + one big samplers with deterministic permutation to mix batches from + different datasets. It will act like EpochBatchIterator but make sure + 1) data from one data set each time + 2) for different workers, they use the same order to fetch the data + so they will use data from the same dataset everytime + mult_rate is used for update_freq > 1 case where we want to make sure update_freq + mini-batches come from same source + """ + + def __init__( + self, + dataset, + collate_fn, + batch_samplers, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=0, + mult_rate=1, + buffer_size=0, + skip_remainder_batch=False, + ): + super().__init__( + dataset, + collate_fn, + batch_samplers, + seed, + num_shards, + shard_id, + num_workers, + epoch, + buffer_size, + skip_remainder_batch=skip_remainder_batch, + ) + # level 0: sub-samplers 1: batch_idx 2: batches + self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers]) + self.step_size = mult_rate * num_shards + + self.lengths = [ + (len(x) // self.step_size) * self.step_size for x in self.frozen_batches + ] + + def __len__(self): + return sum(self.lengths) + + @property + def first_batch(self): + if len(self.frozen_batches) == 0: + raise Exception( + "The dataset is empty. This could indicate " + "that all elements in the dataset have been skipped. " + "Try increasing the max number of allowed tokens or using " + "a larger dataset." + ) + + if self.dataset.supports_fetch_outside_dataloader: + return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0][0]]) + else: + return "DUMMY" + + def _get_iterator_for_epoch( + self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 + ): + def shuffle_batches(batches, seed): + with data_utils.numpy_seed(seed): + np.random.shuffle(batches) + return batches + + def return_full_batches(batch_sets, seed, shuffle): + if shuffle: + batch_sets = [shuffle_batches(list(x), seed) for x in batch_sets] + + batch_sets = [ + batch_sets[i][: self.lengths[i]] for i in range(len(batch_sets)) + ] + batches = list(itertools.chain.from_iterable(batch_sets)) + + if shuffle: + with data_utils.numpy_seed(seed): + idx = np.random.permutation(len(batches) // self.step_size) + if len(idx) * self.step_size != len(batches): + raise ValueError( + "ERROR: %d %d %d %d" + % (len(idx), self.step_size, len(batches), self.shard_id), + ":".join(["%d" % x for x in self.lengths]), + ) + mini_shards = [ + batches[i * self.step_size : (i + 1) * self.step_size] + for i in idx + ] + batches = list(itertools.chain.from_iterable(mini_shards)) + + return batches + + if self._supports_prefetch: + raise NotImplementedError("To be implemented") + else: + batches = return_full_batches( + self.frozen_batches, self.seed + epoch, shuffle + ) + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + + if offset > 0 and offset >= len(batches): + return None + + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + itr = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_sampler=batches[offset:], + num_workers=self.num_workers, + persistent_workers=self.persistent_workers, + ) + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + return CountingIterator(itr, start=offset) diff --git a/fairseq/fairseq/data/language_pair_dataset.py b/fairseq/fairseq/data/language_pair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fd356ddd044faebd7bb09aeb22499f6b70304216 --- /dev/null +++ b/fairseq/fairseq/data/language_pair_dataset.py @@ -0,0 +1,477 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch +from fairseq.data import FairseqDataset, data_utils + + +logger = logging.getLogger(__name__) + + +def collate( + samples, + pad_idx, + eos_idx, + left_pad_source=True, + left_pad_target=False, + input_feeding=True, + pad_to_length=None, + pad_to_multiple=1, +): + if len(samples) == 0: + return {} + + def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, + eos_idx, + left_pad, + move_eos_to_beginning, + pad_to_length=pad_to_length, + pad_to_multiple=pad_to_multiple, + ) + + def check_alignment(alignment, src_len, tgt_len): + if alignment is None or len(alignment) == 0: + return False + if ( + alignment[:, 0].max().item() >= src_len - 1 + or alignment[:, 1].max().item() >= tgt_len - 1 + ): + logger.warning("alignment size mismatch found, skipping alignment!") + return False + return True + + def compute_alignment_weights(alignments): + """ + Given a tensor of shape [:, 2] containing the source-target indices + corresponding to the alignments, a weight vector containing the + inverse frequency of each target index is computed. + For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then + a tensor containing [1., 0.5, 0.5, 1] should be returned (since target + index 3 is repeated twice) + """ + align_tgt = alignments[:, 1] + _, align_tgt_i, align_tgt_c = torch.unique( + align_tgt, return_inverse=True, return_counts=True + ) + align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] + return 1.0 / align_weights.float() + + id = torch.LongTensor([s["id"] for s in samples]) + src_tokens = merge( + "source", + left_pad=left_pad_source, + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) + # sort by descending source length + src_lengths = torch.LongTensor( + [s["source"].ne(pad_idx).long().sum() for s in samples] + ) + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + src_tokens = src_tokens.index_select(0, sort_order) + + prev_output_tokens = None + target = None + if samples[0].get("target", None) is not None: + target = merge( + "target", + left_pad=left_pad_target, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + target = target.index_select(0, sort_order) + tgt_lengths = torch.LongTensor( + [s["target"].ne(pad_idx).long().sum() for s in samples] + ).index_select(0, sort_order) + ntokens = tgt_lengths.sum().item() + + if samples[0].get("prev_output_tokens", None) is not None: + prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target) + elif input_feeding: + # we create a shifted version of targets for feeding the + # previous output token(s) into the next decoder step + prev_output_tokens = merge( + "target", + left_pad=left_pad_target, + move_eos_to_beginning=True, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + else: + ntokens = src_lengths.sum().item() + + batch = { + "id": id, + "nsentences": len(samples), + "ntokens": ntokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + }, + "target": target, + } + if prev_output_tokens is not None: + batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select( + 0, sort_order + ) + + if samples[0].get("alignment", None) is not None: + bsz, tgt_sz = batch["target"].shape + src_sz = batch["net_input"]["src_tokens"].shape[1] + + offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) + offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz + if left_pad_source: + offsets[:, 0] += src_sz - src_lengths + if left_pad_target: + offsets[:, 1] += tgt_sz - tgt_lengths + + alignments = [ + alignment + offset + for align_idx, offset, src_len, tgt_len in zip( + sort_order, offsets, src_lengths, tgt_lengths + ) + for alignment in [samples[align_idx]["alignment"].view(-1, 2)] + if check_alignment(alignment, src_len, tgt_len) + ] + + if len(alignments) > 0: + alignments = torch.cat(alignments, dim=0) + align_weights = compute_alignment_weights(alignments) + + batch["alignments"] = alignments + batch["align_weights"] = align_weights + + if samples[0].get("constraints", None) is not None: + # Collate the packed constraints across the samples, padding to + # the length of the longest sample. + lens = [sample.get("constraints").size(0) for sample in samples] + max_len = max(lens) + constraints = torch.zeros((len(samples), max(lens))).long() + for i, sample in enumerate(samples): + constraints[i, 0 : lens[i]] = samples[i].get("constraints") + batch["constraints"] = constraints.index_select(0, sort_order) + + return batch + + +class LanguagePairDataset(FairseqDataset): + """ + A pair of torch.utils.data.Datasets. + + Args: + src (torch.utils.data.Dataset): source dataset to wrap + src_sizes (List[int]): source sentence lengths + src_dict (~fairseq.data.Dictionary): source vocabulary + tgt (torch.utils.data.Dataset, optional): target dataset to wrap + tgt_sizes (List[int], optional): target sentence lengths + tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary + left_pad_source (bool, optional): pad source tensors on the left side + (default: True). + left_pad_target (bool, optional): pad target tensors on the left side + (default: False). + shuffle (bool, optional): shuffle dataset elements before batching + (default: True). + input_feeding (bool, optional): create a shifted version of the targets + to be passed into the model for teacher forcing (default: True). + remove_eos_from_source (bool, optional): if set, removes eos from end + of source if it's present (default: False). + append_eos_to_target (bool, optional): if set, appends eos to end of + target if it's absent (default: False). + align_dataset (torch.utils.data.Dataset, optional): dataset + containing alignments. + constraints (Tensor, optional): 2d tensor with a concatenated, zero- + delimited list of constraints for each sentence. + append_bos (bool, optional): if set, appends bos to the beginning of + source/target sentence. + num_buckets (int, optional): if set to a value greater than 0, then + batches will be bucketed into the given number of batch shapes. + src_lang_id (int, optional): source language ID, if set, the collated batch + will contain a field 'src_lang_id' in 'net_input' which indicates the + source language of the samples. + tgt_lang_id (int, optional): target language ID, if set, the collated batch + will contain a field 'tgt_lang_id' which indicates the target language + of the samples. + """ + + def __init__( + self, + src, + src_sizes, + src_dict, + tgt=None, + tgt_sizes=None, + tgt_dict=None, + left_pad_source=True, + left_pad_target=False, + shuffle=True, + input_feeding=True, + remove_eos_from_source=False, + append_eos_to_target=False, + align_dataset=None, + constraints=None, + append_bos=False, + eos=None, + num_buckets=0, + src_lang_id=None, + tgt_lang_id=None, + pad_to_multiple=1, + ): + if tgt_dict is not None: + assert src_dict.pad() == tgt_dict.pad() + assert src_dict.eos() == tgt_dict.eos() + assert src_dict.unk() == tgt_dict.unk() + if tgt is not None: + assert len(src) == len( + tgt + ), "Source and target must contain the same number of examples" + self.src = src + self.tgt = tgt + self.src_sizes = np.array(src_sizes) + self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None + self.sizes = ( + np.vstack((self.src_sizes, self.tgt_sizes)).T + if self.tgt_sizes is not None + else self.src_sizes + ) + self.src_dict = src_dict + self.tgt_dict = tgt_dict + self.left_pad_source = left_pad_source + self.left_pad_target = left_pad_target + self.shuffle = shuffle + self.input_feeding = input_feeding + self.remove_eos_from_source = remove_eos_from_source + self.append_eos_to_target = append_eos_to_target + self.align_dataset = align_dataset + if self.align_dataset is not None: + assert ( + self.tgt_sizes is not None + ), "Both source and target needed when alignments are provided" + self.constraints = constraints + self.append_bos = append_bos + self.eos = eos if eos is not None else src_dict.eos() + self.src_lang_id = src_lang_id + self.tgt_lang_id = tgt_lang_id + if num_buckets > 0: + from fairseq.data import BucketPadLengthDataset + + self.src = BucketPadLengthDataset( + self.src, + sizes=self.src_sizes, + num_buckets=num_buckets, + pad_idx=self.src_dict.pad(), + left_pad=self.left_pad_source, + ) + self.src_sizes = self.src.sizes + logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) + if self.tgt is not None: + self.tgt = BucketPadLengthDataset( + self.tgt, + sizes=self.tgt_sizes, + num_buckets=num_buckets, + pad_idx=self.tgt_dict.pad(), + left_pad=self.left_pad_target, + ) + self.tgt_sizes = self.tgt.sizes + logger.info( + "bucketing target lengths: {}".format(list(self.tgt.buckets)) + ) + + # determine bucket sizes using self.num_tokens, which will return + # the padded lengths (thanks to BucketPadLengthDataset) + num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long]) + self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) + self.buckets = [ + (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) + ] + else: + self.buckets = None + self.pad_to_multiple = pad_to_multiple + + def get_batch_shapes(self): + return self.buckets + + def __getitem__(self, index): + tgt_item = self.tgt[index] if self.tgt is not None else None + src_item = self.src[index] + # Append EOS to end of tgt sentence if it does not have an EOS and remove + # EOS from end of src sentence if it exists. This is useful when we use + # use existing datasets for opposite directions i.e., when we want to + # use tgt_dataset as src_dataset and vice versa + if self.append_eos_to_target: + eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() + if self.tgt and self.tgt[index][-1] != eos: + tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) + + if self.append_bos: + bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() + if self.tgt and self.tgt[index][0] != bos: + tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) + + bos = self.src_dict.bos() + if self.src[index][0] != bos: + src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) + + if self.remove_eos_from_source: + eos = self.src_dict.eos() + if self.src[index][-1] == eos: + src_item = self.src[index][:-1] + + example = { + "id": index, + "source": src_item, + "target": tgt_item, + } + if self.align_dataset is not None: + example["alignment"] = self.align_dataset[index] + if self.constraints is not None: + example["constraints"] = self.constraints[index] + return example + + def __len__(self): + return len(self.src) + + def collater(self, samples, pad_to_length=None): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + pad_to_length (dict, optional): a dictionary of + {'source': source_pad_to_length, 'target': target_pad_to_length} + to indicate the max length to pad to in source and target respectively. + + Returns: + dict: a mini-batch with the following keys: + + - `id` (LongTensor): example IDs in the original input order + - `ntokens` (int): total number of tokens in the batch + - `net_input` (dict): the input to the Model, containing keys: + + - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in + the source sentence of shape `(bsz, src_len)`. Padding will + appear on the left if *left_pad_source* is ``True``. + - `src_lengths` (LongTensor): 1D Tensor of the unpadded + lengths of each source sentence of shape `(bsz)` + - `prev_output_tokens` (LongTensor): a padded 2D Tensor of + tokens in the target sentence, shifted right by one + position for teacher forcing, of shape `(bsz, tgt_len)`. + This key will not be present if *input_feeding* is + ``False``. Padding will appear on the left if + *left_pad_target* is ``True``. + - `src_lang_id` (LongTensor): a long Tensor which contains source + language IDs of each sample in the batch + + - `target` (LongTensor): a padded 2D Tensor of tokens in the + target sentence of shape `(bsz, tgt_len)`. Padding will appear + on the left if *left_pad_target* is ``True``. + - `tgt_lang_id` (LongTensor): a long Tensor which contains target language + IDs of each sample in the batch + """ + res = collate( + samples, + pad_idx=self.src_dict.pad(), + eos_idx=self.eos, + left_pad_source=self.left_pad_source, + left_pad_target=self.left_pad_target, + input_feeding=self.input_feeding, + pad_to_length=pad_to_length, + pad_to_multiple=self.pad_to_multiple, + ) + if self.src_lang_id is not None or self.tgt_lang_id is not None: + src_tokens = res["net_input"]["src_tokens"] + bsz = src_tokens.size(0) + if self.src_lang_id is not None: + res["net_input"]["src_lang_id"] = ( + torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) + ) + if self.tgt_lang_id is not None: + res["tgt_lang_id"] = ( + torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) + ) + return res + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return max( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) + + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + sizes = self.src_sizes[indices] + if self.tgt_sizes is not None: + sizes = np.maximum(sizes, self.tgt_sizes[indices]) + return sizes + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return ( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)).astype(np.int64) + else: + indices = np.arange(len(self), dtype=np.int64) + if self.buckets is None: + # sort by target length, then source length + if self.tgt_sizes is not None: + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] + else: + # sort by bucketed_num_tokens, which is: + # max(padded_src_len, padded_tgt_len) + return indices[ + np.argsort(self.bucketed_num_tokens[indices], kind="mergesort") + ] + + @property + def supports_prefetch(self): + return getattr(self.src, "supports_prefetch", False) and ( + getattr(self.tgt, "supports_prefetch", False) or self.tgt is None + ) + + def prefetch(self, indices): + self.src.prefetch(indices) + if self.tgt is not None: + self.tgt.prefetch(indices) + if self.align_dataset is not None: + self.align_dataset.prefetch(indices) + + def filter_indices_by_size(self, indices, max_sizes): + """Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + return data_utils.filter_paired_dataset_indices_by_size( + self.src_sizes, + self.tgt_sizes, + indices, + max_sizes, + ) diff --git a/fairseq/fairseq/data/list_dataset.py b/fairseq/fairseq/data/list_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..12f00aa43661d6bad701c9e72653ba8779136906 --- /dev/null +++ b/fairseq/fairseq/data/list_dataset.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import BaseWrapperDataset + + +class ListDataset(BaseWrapperDataset): + def __init__(self, dataset, sizes=None): + super().__init__(dataset) + self._sizes = sizes + + def __iter__(self): + for x in self.dataset: + yield x + + def collater(self, samples): + return samples + + @property + def sizes(self): + return self._sizes + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + def set_epoch(self, epoch): + pass diff --git a/fairseq/fairseq/data/lm_context_window_dataset.py b/fairseq/fairseq/data/lm_context_window_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1a945927cf0d96719003685676a990737a3762b2 --- /dev/null +++ b/fairseq/fairseq/data/lm_context_window_dataset.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from typing import Dict + +from fairseq.data.monolingual_dataset import MonolingualDataset + +from . import FairseqDataset + + +class LMContextWindowDataset(FairseqDataset): + """ + Wraps a MonolingualDataset and provides more context for evaluation. + + Each item in the new dataset will have a maximum size of + ``tokens_per_sample + context_window``. + + Args: + dataset: dataset to wrap + tokens_per_sample (int): the max number of tokens in each dataset item + context_window (int): the number of accumulated tokens to add to each + dataset item + pad_idx (int): padding symbol + """ + + def __init__( + self, + dataset: MonolingualDataset, + tokens_per_sample: int, + context_window: int, + pad_idx: int, + ): + assert context_window > 0 + self.dataset = dataset + self.tokens_per_sample = tokens_per_sample + self.context_window = context_window + self.pad_idx = pad_idx + self.prev_tokens = np.empty([0]) + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples) -> Dict: + sample = self.dataset.collater(samples) + + pad = self.pad_idx + max_sample_len = self.tokens_per_sample + self.context_window + + bsz, tsz = sample["net_input"]["src_tokens"].shape + start_idxs = [0] * bsz + toks = sample["net_input"]["src_tokens"] + lengths = sample["net_input"]["src_lengths"] + tgt = sample["target"] + new_toks = np.empty([bsz, tsz + self.context_window], dtype=np.int64) + new_tgt = np.full([bsz, tsz + self.context_window], pad, dtype=np.int64) + sample_lens = toks.ne(pad).long().sum(dim=1).cpu() + for i in range(bsz): + sample_len = sample_lens[i] + extra = len(self.prev_tokens) + sample_len - max_sample_len + if extra > 0: + self.prev_tokens = self.prev_tokens[extra:] + pads = np.full(self.context_window - len(self.prev_tokens), pad) + new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads]) + new_tgt[ + i, len(self.prev_tokens) : len(self.prev_tokens) + len(tgt[i]) + ] = tgt[i] + start_idxs[i] = len(self.prev_tokens) + lengths[i] += len(self.prev_tokens) + self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window :] + sample["net_input"]["src_tokens"] = torch.from_numpy(new_toks) + sample["target"] = torch.from_numpy(new_tgt) + sample["start_indices"] = start_idxs + return sample + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + return self.dataset.size(index) + + def ordered_indices(self): + # NOTE we don't shuffle the data to retain access to the previous dataset elements + return np.arange(len(self.dataset)) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.dataset.prefetch(indices) diff --git a/fairseq/fairseq/data/lru_cache_dataset.py b/fairseq/fairseq/data/lru_cache_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a7854ac1701392754ce5795cafe9c634671aebdf --- /dev/null +++ b/fairseq/fairseq/data/lru_cache_dataset.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache + +from . import BaseWrapperDataset + + +class LRUCacheDataset(BaseWrapperDataset): + def __init__(self, dataset, token=None): + super().__init__(dataset) + + @lru_cache(maxsize=8) + def __getitem__(self, index): + return self.dataset[index] + + @lru_cache(maxsize=8) + def collater(self, samples): + return self.dataset.collater(samples) diff --git a/fairseq/fairseq/data/mask_tokens_dataset.py b/fairseq/fairseq/data/mask_tokens_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca9051c9ae9eaf4bb31a917ae115cbeb13879a7 --- /dev/null +++ b/fairseq/fairseq/data/mask_tokens_dataset.py @@ -0,0 +1,226 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache + +import numpy as np +import torch +from fairseq.data import Dictionary, data_utils + +from . import BaseWrapperDataset, LRUCacheDataset + + +class MaskTokensDataset(BaseWrapperDataset): + """ + A wrapper Dataset for masked language modeling. + + Input items are masked according to the specified masking probability. + + Args: + dataset: Dataset to wrap. + sizes: Sentence lengths + vocab: Dictionary with the vocabulary and special tokens. + pad_idx: Id of pad token in vocab + mask_idx: Id of mask token in vocab + return_masked_tokens: controls whether to return the non-masked tokens + (the default) or to return a tensor with the original masked token + IDs (and *pad_idx* elsewhere). The latter is useful as targets for + masked LM training. + seed: Seed for random number generator for reproducibility. + mask_prob: probability of replacing a token with *mask_idx*. + leave_unmasked_prob: probability that a masked token is unmasked. + random_token_prob: probability of replacing a masked token with a + random token from the vocabulary. + freq_weighted_replacement: sample random replacement words based on + word frequencies in the vocab. + mask_whole_words: only mask whole words. This should be a byte mask + over vocab indices, indicating whether it is the beginning of a + word. We will extend any mask to encompass the whole word. + bpe: BPE to use for whole-word masking. + mask_multiple_length : repeat each mask index multiple times. Default + value is 1. + mask_stdev : standard deviation of masks distribution in case of + multiple masking. Default value is 0. + """ + + @classmethod + def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs): + """Return the source and target datasets for masked LM training.""" + dataset = LRUCacheDataset(dataset) + return ( + LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)), + LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)), + ) + + def __init__( + self, + dataset: torch.utils.data.Dataset, + vocab: Dictionary, + pad_idx: int, + mask_idx: int, + return_masked_tokens: bool = False, + seed: int = 1, + mask_prob: float = 0.15, + leave_unmasked_prob: float = 0.1, + random_token_prob: float = 0.1, + freq_weighted_replacement: bool = False, + mask_whole_words: torch.Tensor = None, + mask_multiple_length: int = 1, + mask_stdev: float = 0.0, + skip_masking: bool = False, + ): + assert 0.0 < mask_prob < 1.0 + assert 0.0 <= random_token_prob <= 1.0 + assert 0.0 <= leave_unmasked_prob <= 1.0 + assert random_token_prob + leave_unmasked_prob <= 1.0 + assert mask_multiple_length >= 1 + assert mask_stdev >= 0.0 + + self.dataset = dataset + self.vocab = vocab + self.pad_idx = pad_idx + self.mask_idx = mask_idx + self.return_masked_tokens = return_masked_tokens + self.seed = seed + self.mask_prob = mask_prob + self.leave_unmasked_prob = leave_unmasked_prob + self.random_token_prob = random_token_prob + self.mask_whole_words = mask_whole_words + self.mask_multiple_length = mask_multiple_length + self.mask_stdev = mask_stdev + self.skip_masking = skip_masking + + if random_token_prob > 0.0: + if freq_weighted_replacement: + weights = np.array(self.vocab.count) + else: + weights = np.ones(len(self.vocab)) + weights[: self.vocab.nspecial] = 0 + self.weights = weights / weights.sum() + + self.epoch = 0 + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the noise changes, not item sizes + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + def __getitem__(self, index: int): + return self.__getitem_cached__(self.seed, self.epoch, index) + + @lru_cache(maxsize=8) + def __getitem_cached__(self, seed: int, epoch: int, index: int): + seed = int(hash((seed, epoch, index)) % 1e6) + rng = np.random.default_rng(seed) + item = self.dataset[index] + sz = len(item) + + assert ( + self.mask_idx not in item + ), "Dataset contains mask_idx (={}), this is not expected!".format( + self.mask_idx, + ) + if self.skip_masking: + return torch.from_numpy(np.copy(item)) + + if self.mask_whole_words is not None: + word_begins_mask = self.mask_whole_words.gather(0, item) + word_begins_idx = word_begins_mask.nonzero().view(-1) + sz = len(word_begins_idx) + words = np.split(word_begins_mask, word_begins_idx)[1:] + assert len(words) == sz + word_lens = list(map(len, words)) + + # decide elements to mask + mask = np.full(sz, False) + num_mask = int( + # add a random number for probabilistic rounding + self.mask_prob * sz / float(self.mask_multiple_length) + + rng.random() + ) + + # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453) + mask_idc = rng.choice(sz, num_mask, replace=False) + if self.mask_stdev > 0.0: + lengths = rng.normal( + self.mask_multiple_length, self.mask_stdev, size=num_mask + ) + lengths = [max(0, int(round(x))) for x in lengths] + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ], + dtype=np.int64, + ) + else: + mask_idc = np.concatenate( + [mask_idc + i for i in range(self.mask_multiple_length)] + ) + mask_idc = mask_idc[mask_idc < len(mask)] + try: + mask[mask_idc] = True + except: # something wrong + print("Assigning mask indexes {} to mask {} failed!".format(mask_idc, mask)) + raise + + # if self.return_masked_tokens: + # print(( + # f"IDX={index}; seed={seed}; epoch={epoch}; is_tgt={self.return_masked_tokens}: " + # f"{np.nonzero(mask)[0].sum()}" + # )) + if self.return_masked_tokens: + # exit early if we're just returning the masked tokens + # (i.e., the targets for masked LM training) + if self.mask_whole_words is not None: + mask = np.repeat(mask, word_lens) + new_item = np.full(len(mask), self.pad_idx) + new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1] + return torch.from_numpy(new_item) + + # decide unmasking and random replacement + rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob + if rand_or_unmask_prob > 0.0: + rand_or_unmask = mask & (rng.random(sz) < rand_or_unmask_prob) + if self.random_token_prob == 0.0: + unmask = rand_or_unmask + rand_mask = None + elif self.leave_unmasked_prob == 0.0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob + decision = rng.random(sz) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + else: + unmask = rand_mask = None + + if unmask is not None: + mask = mask ^ unmask + + if self.mask_whole_words is not None: + mask = np.repeat(mask, word_lens) + + new_item = np.copy(item) + new_item[mask] = self.mask_idx + if rand_mask is not None: + num_rand = rand_mask.sum() + if num_rand > 0: + if self.mask_whole_words is not None: + rand_mask = np.repeat(rand_mask, word_lens) + num_rand = rand_mask.sum() + + new_item[rand_mask] = rng.choice( + len(self.vocab), + num_rand, + p=self.weights, + ) + + return torch.from_numpy(new_item) diff --git a/fairseq/fairseq/data/monolingual_dataset.py b/fairseq/fairseq/data/monolingual_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..54fd583b64a3a475324ade6eaaeccf593d747fdc --- /dev/null +++ b/fairseq/fairseq/data/monolingual_dataset.py @@ -0,0 +1,253 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import FairseqDataset, data_utils + + +def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None): + if len(samples) == 0: + return {} + + def merge(key, is_list=False): + if is_list: + res = [] + for i in range(len(samples[0][key])): + res.append( + data_utils.collate_tokens( + [s[key][i] for s in samples], + pad_idx, + eos_idx, + left_pad=False, + pad_to_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, + ) + ) + return res + else: + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, + eos_idx, + left_pad=False, + pad_to_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, + ) + + src_tokens = merge("source") + if samples[0]["target"] is not None: + is_target_list = isinstance(samples[0]["target"], list) + target = merge("target", is_target_list) + else: + target = src_tokens + + return { + "id": torch.LongTensor([s["id"] for s in samples]), + "nsentences": len(samples), + "ntokens": sum(len(s["source"]) for s in samples), + "net_input": { + "src_tokens": src_tokens, + "src_lengths": torch.LongTensor([s["source"].numel() for s in samples]), + }, + "target": target, + } + + +class MonolingualDataset(FairseqDataset): + """ + A wrapper around torch.utils.data.Dataset for monolingual data. + + Args: + dataset (torch.utils.data.Dataset): dataset to wrap + sizes (List[int]): sentence lengths + vocab (~fairseq.data.Dictionary): vocabulary + shuffle (bool, optional): shuffle the elements before batching + (default: True). + """ + + def __init__( + self, + dataset, + sizes, + src_vocab, + tgt_vocab=None, + add_eos_for_other_targets=False, + shuffle=False, + targets=None, + add_bos_token=False, + fixed_pad_length=None, + pad_to_bsz=None, + src_lang_idx=None, + tgt_lang_idx=None, + ): + self.dataset = dataset + self.sizes = np.array(sizes) + self.vocab = src_vocab + self.tgt_vocab = tgt_vocab or src_vocab + self.add_eos_for_other_targets = add_eos_for_other_targets + self.shuffle = shuffle + self.add_bos_token = add_bos_token + self.fixed_pad_length = fixed_pad_length + self.pad_to_bsz = pad_to_bsz + self.src_lang_idx = src_lang_idx + self.tgt_lang_idx = tgt_lang_idx + + assert targets is None or all( + t in {"self", "future", "past"} for t in targets + ), "targets must be none or one of 'self', 'future', 'past'" + if targets is not None and len(targets) == 0: + targets = None + self.targets = targets + + def __getitem__(self, index): + if self.targets is not None: + # *future_target* is the original sentence + # *source* is shifted right by 1 (maybe left-padded with eos) + # *past_target* is shifted right by 2 (left-padded as needed) + # + # Left-to-right language models should condition on *source* and + # predict *future_target*. + # Right-to-left language models should condition on *source* and + # predict *past_target*. + source, future_target, past_target = self.dataset[index] + source, target = self._make_source_target( + source, future_target, past_target + ) + else: + source = self.dataset[index] + target = None + source, target = self._maybe_add_bos(source, target) + return {"id": index, "source": source, "target": target} + + def __len__(self): + return len(self.dataset) + + def _make_source_target(self, source, future_target, past_target): + if self.targets is not None: + target = [] + + if ( + self.add_eos_for_other_targets + and (("self" in self.targets) or ("past" in self.targets)) + and source[-1] != self.vocab.eos() + ): + # append eos at the end of source + source = torch.cat([source, source.new([self.vocab.eos()])]) + + if "future" in self.targets: + future_target = torch.cat( + [future_target, future_target.new([self.vocab.pad()])] + ) + if "past" in self.targets: + # first token is before the start of sentence which is only used in "none" break mode when + # add_eos_for_other_targets is False + past_target = torch.cat( + [ + past_target.new([self.vocab.pad()]), + past_target[1:], + source[-2, None], + ] + ) + + for t in self.targets: + if t == "self": + target.append(source) + elif t == "future": + target.append(future_target) + elif t == "past": + target.append(past_target) + else: + raise Exception("invalid target " + t) + + if len(target) == 1: + target = target[0] + else: + target = future_target + + return source, self._filter_vocab(target) + + def _maybe_add_bos(self, source, target): + if self.add_bos_token: + source = torch.cat([source.new([self.vocab.bos()]), source]) + if target is not None: + target = torch.cat([target.new([self.tgt_vocab.bos()]), target]) + return source, target + + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + return self.sizes[indices] + + def _filter_vocab(self, target): + if len(self.tgt_vocab) != len(self.vocab): + + def _filter(target): + mask = target.ge(len(self.tgt_vocab)) + if mask.any(): + target[mask] = self.tgt_vocab.unk() + return target + + if isinstance(target, list): + return [_filter(t) for t in target] + return _filter(target) + return target + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch with the following keys: + + - `id` (LongTensor): example IDs in the original input order + - `ntokens` (int): total number of tokens in the batch + - `net_input` (dict): the input to the Model, containing keys: + + - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in + the source sentence of shape `(bsz, src_len)`. Padding will + appear on the right. + + - `target` (LongTensor): a padded 2D Tensor of tokens in the + target sentence of shape `(bsz, tgt_len)`. Padding will appear + on the right. + """ + return collate( + samples, + self.vocab.pad(), + self.vocab.eos(), + self.fixed_pad_length, + self.pad_to_bsz, + ) + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return self.sizes[index] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + order.append(self.sizes) + return np.lexsort(order) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + self.dataset.prefetch(indices) diff --git a/fairseq/fairseq/data/multi_corpus_dataset.py b/fairseq/fairseq/data/multi_corpus_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6f2fe074b2280c85706979614ba7abc5ad4c7bb5 --- /dev/null +++ b/fairseq/fairseq/data/multi_corpus_dataset.py @@ -0,0 +1,285 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import logging +import time +from collections import OrderedDict +from typing import Dict, List, Optional + +import numpy as np + +from fairseq.data import data_utils + +from . import FairseqDataset + +logger = logging.getLogger(__name__) + + +class MultiCorpusDataset(FairseqDataset): + """ + Stores multiple instances of FairseqDataset together. + Unless batch_sample=True, requires each instance + to be the same dataset, as the collate method needs to work on batches with + samples from each dataset. + + Allows specifying a distribution over the datasets to use. Note that unlike + MultiCorpusSampledDataset, this distribution allows sampling for each item, + rather than on a batch level. Note that datasets with sampling probabilty + of 0 will be skipped. + + Each time ordered_indices() is called, a new sample is generated with + the specified distribution. + + Args: + datasets: a OrderedDict of FairseqDataset instances. + distribution: a List containing the probability of getting an utterance from + corresponding dataset + seed: random seed for sampling the datsets + sort_indices: if true, will sort the ordered indices by size + batch_sample: if true, will ensure each batch is from a single dataset + """ + + def __init__( + self, + datasets: Dict[str, FairseqDataset], + distribution: List[float], + seed: int, + sort_indices: bool = False, + batch_sample: bool = False, + distributed_rank: Optional[int] = None, + ): + super().__init__() + assert isinstance(datasets, OrderedDict) + assert len(datasets) == len(distribution) + assert sum(distribution) == 1 + self.datasets = datasets + self.distribution = distribution + self.seed = seed + self.sort_indices = sort_indices + self.batch_sample = batch_sample + self.distributed_rank = distributed_rank + + # Avoid repeated conversions to list later + self.dataset_list = list(datasets.values()) + self.total_num_instances = 0 + + first_dataset = self.dataset_list[0] + + self.num_instances_per_dataset = [] + self.dataset_offsets = [] + for i, dataset in enumerate(self.dataset_list): + assert isinstance(dataset, FairseqDataset) + assert type(dataset) is type(first_dataset) + self.num_instances_per_dataset.append( + 0 if self.distribution[i] == 0 else len(dataset) + ) + self.dataset_offsets.append(self.total_num_instances) + self.total_num_instances += self.num_instances_per_dataset[i] + + def ordered_indices(self): + start = time.time() + with data_utils.numpy_seed(self.seed, self.epoch): + logger.info( + f"sampling new dataset with seed {self.seed} epoch {self.epoch}" + ) + sampled_indices = [] + num_selected_instances = 0 + + # For each dataset i, sample self.distribution[i] * self.total_num_instances + for i, key in enumerate(self.datasets): + if self.distribution[i] == 0: + # skip dataset if sampling probability is 0 + continue + + if i < len(self.datasets) - 1: + num_instances = int(self.distribution[i] * self.total_num_instances) + high = self.dataset_offsets[i + 1] + else: + num_instances = self.total_num_instances - num_selected_instances + high = self.total_num_instances + + logger.info(f"sampling {num_instances} from {key} dataset") + num_selected_instances += num_instances + + # First, add k copies of the dataset where k = num_instances // len(dataset). + # This ensures an equal distribution of the data points as much as possible. + # For the remaining entries randomly sample them + dataset_size = len(self.datasets[key]) + num_copies = num_instances // dataset_size + dataset_indices = ( + np.random.permutation(high - self.dataset_offsets[i]) + + self.dataset_offsets[i] + )[: num_instances - num_copies * dataset_size] + if num_copies > 0: + sampled_indices += list( + np.concatenate( + ( + np.repeat( + np.arange(self.dataset_offsets[i], high), num_copies + ), + dataset_indices, + ) + ) + ) + else: + sampled_indices += list(dataset_indices) + + assert ( + len(sampled_indices) == self.total_num_instances + ), f"{len(sampled_indices)} vs {self.total_num_instances}" + + np.random.shuffle(sampled_indices) + if self.sort_indices: + sampled_indices.sort(key=lambda i: self.num_tokens(i)) + + logger.info( + "multi_corpus_dataset ordered_indices took {}s".format( + time.time() - start + ) + ) + return np.array(sampled_indices, dtype=np.int64) + + def _map_index(self, index: int): + """ + If dataset A has length N and dataset B has length M + then index 1 maps to index 1 of dataset A, and index N + 1 + maps to index 1 of B. + """ + counter = 0 + for num_instances, key in zip(self.num_instances_per_dataset, self.datasets): + if index < counter + num_instances: + return index - counter, key + counter += num_instances + raise ValueError( + "Invalid index: {}, max: {}".format(index, self.total_num_instances) + ) + + def __len__(self): + """ + Length of this dataset is the sum of individual datasets + """ + return self.total_num_instances + + async def getitem(self, index): + new_index, key = self._map_index(index) + try: + if hasattr(self.datasets[key], "getitem"): + item = await self.datasets[key].getitem(new_index) + else: + item = self.datasets[key][new_index] + item["full_id"] = index + return item + except Exception as e: + e.args = (f"Error from {key} dataset", *e.args) + raise + + def __getitem__(self, index): + return asyncio.run(self.getitem(index)) + + async def getitems(self, indices): + # initialize a bunch of everstore read operations + # wait in the end to reduce overhead + # very helpful if io is latency bounded + + max_concurrency = 32 + sem = asyncio.Semaphore(max_concurrency) + + async def controlled_getitem(index): + async with sem: + return await self.getitem(index) + + coroutines = [] + for index in indices: + coroutines.append(controlled_getitem(index)) + results = await asyncio.gather(*coroutines) + return results + + def __getitems__(self, indices): + return asyncio.run(self.getitems(indices)) + + def collater(self, samples): + """ + If we are doing batch sampling, then pick the right collater to use. + + Otherwise we assume all collaters are the same. + """ + if len(samples) == 0: + return None + if "full_id" in samples[0]: + _, key = self._map_index(samples[0]["full_id"]) + try: + batch = self.datasets[key].collater(samples) + except Exception: + print(f"Collating failed for key {key}", flush=True) + raise + return batch + else: + # Subclasses may override __getitem__ to not specify full_id + return list(self.datasets.values())[0].collater(samples) + + def num_tokens(self, index: int): + index, key = self._map_index(index) + return self.datasets[key].num_tokens(index) + + def size(self, index: int): + index, key = self._map_index(index) + return self.datasets[key].size(index) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + logger.info(f"setting epoch of multi_corpus_dataset to {epoch}") + self.epoch = epoch + + @property + def supports_prefetch(self): + return False + + @property + def supports_fetch_outside_dataloader(self): + return all( + self.datasets[key].supports_fetch_outside_dataloader + for key in self.datasets + ) + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + if not self.batch_sample: + return super().batch_by_size( + indices, max_tokens, max_sentences, required_batch_size_multiple + ) + + dataset_indices = {key: [] for key in self.datasets} + for i in indices: + _, key = self._map_index(i) + dataset_indices[key].append(i) + + batches = [] + for key in dataset_indices: + cur_batches = super().batch_by_size( + np.array(dataset_indices[key], dtype=np.int64), + max_tokens, + max_sentences, + required_batch_size_multiple, + ) + logger.info(f"Created {len(cur_batches)} batches for dataset {key}") + batches += cur_batches + + # If this dataset is used in a distributed training setup, + # then shuffle such that the order is seeded by the distributed rank + # as well + if self.distributed_rank is not None: + with data_utils.numpy_seed(self.seed, self.epoch, self.distributed_rank): + np.random.shuffle(batches) + return batches diff --git a/fairseq/fairseq/data/multi_corpus_sampled_dataset.py b/fairseq/fairseq/data/multi_corpus_sampled_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e9fdf004dd1da519a170a5e8bc225775776f72 --- /dev/null +++ b/fairseq/fairseq/data/multi_corpus_sampled_dataset.py @@ -0,0 +1,152 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict +from typing import Callable, Dict, List + +import numpy as np + +from . import FairseqDataset + + +def uniform_sampler(x): + # Sample from uniform distribution + return np.random.choice(x, 1).item() + + +class MultiCorpusSampledDataset(FairseqDataset): + """ + Stores multiple instances of FairseqDataset together and in every iteration + creates a batch by first sampling a dataset according to a specified + probability distribution and then getting instances from that dataset. + + Args: + datasets: an OrderedDict of FairseqDataset instances. + sampling_func: A function for sampling over list of dataset keys. + The default strategy is to sample uniformly. + """ + + def __init__( + self, + datasets: Dict[str, FairseqDataset], + sampling_func: Callable[[List], int] = None, + ): + super().__init__() + assert isinstance(datasets, OrderedDict) + self.datasets = datasets + if sampling_func is None: + sampling_func = uniform_sampler + self.sampling_func = sampling_func + + self.total_num_instances = 0 + for _, dataset in datasets.items(): + assert isinstance(dataset, FairseqDataset) + self.total_num_instances += len(dataset) + + self._ordered_indices = None + + def __len__(self): + """ + Length of this dataset is the sum of individual datasets + """ + return self.total_num_instances + + def ordered_indices(self): + """ + Ordered indices for batching. Here we call the underlying + dataset's ordered_indices() so that we get the same random ordering + as we would have from using the underlying dataset directly. + """ + if self._ordered_indices is None: + self._ordered_indices = OrderedDict( + [ + (key, dataset.ordered_indices()) + for key, dataset in self.datasets.items() + ] + ) + return np.arange(len(self)) + + def _map_index_to_dataset(self, key: int, index: int): + """ + Different underlying datasets have different lengths. In order to ensure + we are not accessing an index outside the range of the current dataset + size, we wrap around. This function should be called after we have + created an ordering for this and all underlying datasets. + """ + assert ( + self._ordered_indices is not None + ), "Must call MultiCorpusSampledDataset.ordered_indices() first" + mapped_index = index % len(self.datasets[key]) + return self._ordered_indices[key][mapped_index] + + def __getitem__(self, index: int): + """ + Get the item associated with index from each underlying dataset. + Since index is in the range of [0, TotalNumInstances], we need to + map the index to the dataset before retrieving the item. + """ + return OrderedDict( + [ + (key, dataset[self._map_index_to_dataset(key, index)]) + for key, dataset in self.datasets.items() + ] + ) + + def collater(self, samples: List[Dict]): + """ + Generate a mini-batch for this dataset. + To convert this into a regular mini-batch we use the following + logic: + 1. Select a dataset using the specified probability distribution. + 2. Call the collater function of the selected dataset. + """ + if len(samples) == 0: + return None + + selected_key = self.sampling_func(list(self.datasets.keys())) + selected_samples = [sample[selected_key] for sample in samples] + return self.datasets[selected_key].collater(selected_samples) + + def num_tokens(self, index: int): + """ + Return an example's length (number of tokens), used for batching. Here + we return the max across all examples at index across all underlying + datasets. + """ + return max( + dataset.num_tokens(self._map_index_to_dataset(key, index)) + for key, dataset in self.datasets.items() + ) + + def size(self, index: int): + """ + Return an example's size as a float or tuple. Here we return the max + across all underlying datasets. This value is used when filtering a + dataset with max-positions. + """ + return max( + dataset.size(self._map_index_to_dataset(key, index)) + for key, dataset in self.datasets.items() + ) + + @property + def supports_prefetch(self): + return all( + getattr(dataset, "supports_prefetch", False) + for dataset in self.datasets.values() + ) + + def prefetch(self, indices): + for key, dataset in self.datasets.items(): + dataset.prefetch( + [self._map_index_to_dataset(key, index) for index in indices] + ) + + @property + def supports_fetch_outside_dataloader(self): + return all( + self.datasets[key].supports_fetch_outside_dataloader + for key in self.datasets + ) diff --git a/fairseq/fairseq/data/nested_dictionary_dataset.py b/fairseq/fairseq/data/nested_dictionary_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..52e74abddacc923c5e29b0a0c41d7efc85482d3b --- /dev/null +++ b/fairseq/fairseq/data/nested_dictionary_dataset.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import torch +from torch.utils.data.dataloader import default_collate + +from . import FairseqDataset + + +def _flatten(dico, prefix=None): + """Flatten a nested dictionary.""" + new_dico = OrderedDict() + if isinstance(dico, dict): + prefix = prefix + "." if prefix is not None else "" + for k, v in dico.items(): + if v is None: + continue + new_dico.update(_flatten(v, prefix + k)) + elif isinstance(dico, list): + for i, v in enumerate(dico): + new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]")) + else: + new_dico = OrderedDict({prefix: dico}) + return new_dico + + +def _unflatten(dico): + """Unflatten a flattened dictionary into a nested dictionary.""" + new_dico = OrderedDict() + for full_k, v in dico.items(): + full_k = full_k.split(".") + node = new_dico + for k in full_k[:-1]: + if k.startswith("[") and k.endswith("]"): + k = int(k[1:-1]) + if k not in node: + node[k] = OrderedDict() + node = node[k] + node[full_k[-1]] = v + return new_dico + + +class NestedDictionaryDataset(FairseqDataset): + def __init__(self, defn, sizes=None): + super().__init__() + self.defn = _flatten(defn) + self.sizes = [sizes] if not isinstance(sizes, (list, tuple)) else sizes + + first = None + for v in self.defn.values(): + if not isinstance( + v, + ( + FairseqDataset, + torch.utils.data.Dataset, + ), + ): + raise ValueError("Expected Dataset but found: {}".format(v.__class__)) + first = first or v + if len(v) > 0: + assert len(v) == len(first), "dataset lengths must match" + + self._len = len(first) + + def __getitem__(self, index): + return OrderedDict((k, ds[index]) for k, ds in self.defn.items()) + + def __len__(self): + return self._len + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch suitable for forwarding with a Model + """ + if len(samples) == 0: + return {} + sample = OrderedDict() + for k, ds in self.defn.items(): + try: + sample[k] = ds.collater([s[k] for s in samples]) + except NotImplementedError: + sample[k] = default_collate([s[k] for s in samples]) + return _unflatten(sample) + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return max(s[index] for s in self.sizes) + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + if len(self.sizes) == 1: + return self.sizes[0][index] + else: + return (s[index] for s in self.sizes) + + @property + def supports_prefetch(self): + """Whether this dataset supports prefetching.""" + return any(ds.supports_prefetch for ds in self.defn.values()) + + def prefetch(self, indices): + """Prefetch the data required for this epoch.""" + for ds in self.defn.values(): + if getattr(ds, "supports_prefetch", False): + ds.prefetch(indices) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values()) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + for ds in self.defn.values(): + ds.set_epoch(epoch) diff --git a/fairseq/fairseq/data/noising.py b/fairseq/fairseq/data/noising.py new file mode 100644 index 0000000000000000000000000000000000000000..e92e83c2cd2e2950d387f93ae8a80acbc12f909f --- /dev/null +++ b/fairseq/fairseq/data/noising.py @@ -0,0 +1,334 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from fairseq.data import data_utils + + +class WordNoising(object): + """Generate a noisy version of a sentence, without changing words themselves.""" + + def __init__(self, dictionary, bpe_cont_marker="@@", bpe_end_marker=None): + self.dictionary = dictionary + self.bpe_end = None + if bpe_cont_marker: + self.bpe_end = np.array( + [ + not self.dictionary[i].endswith(bpe_cont_marker) + for i in range(len(self.dictionary)) + ] + ) + elif bpe_end_marker: + self.bpe_end = np.array( + [ + self.dictionary[i].endswith(bpe_end_marker) + for i in range(len(self.dictionary)) + ] + ) + + self.get_word_idx = ( + self._get_bpe_word_idx if self.bpe_end is not None else self._get_token_idx + ) + + def noising(self, x, lengths, noising_prob=0.0): + raise NotImplementedError() + + def _get_bpe_word_idx(self, x): + """ + Given a list of BPE tokens, for every index in the tokens list, + return the index of the word grouping that it belongs to. + For example, for input x corresponding to ["how", "are", "y@@", "ou"], + return [[0], [1], [2], [2]]. + """ + # x: (T x B) + bpe_end = self.bpe_end[x] + + if x.size(0) == 1 and x.size(1) == 1: + # Special case when we only have one word in x. If x = [[N]], + # bpe_end is a scalar (bool) instead of a 2-dim array of bools, + # which makes the sum operation below fail. + return np.array([[0]]) + + # do a reduce front sum to generate word ids + word_idx = bpe_end[::-1].cumsum(0)[::-1] + word_idx = word_idx.max(0)[None, :] - word_idx + return word_idx + + def _get_token_idx(self, x): + """ + This is to extend noising functions to be able to apply to non-bpe + tokens, e.g. word or characters. + """ + x = torch.t(x) + word_idx = np.array([range(len(x_i)) for x_i in x]) + return np.transpose(word_idx) + + +class WordDropout(WordNoising): + """Randomly drop input words. If not passing blank_idx (default is None), + then dropped words will be removed. Otherwise, it will be replaced by the + blank_idx.""" + + def __init__( + self, + dictionary, + default_dropout_prob=0.1, + bpe_cont_marker="@@", + bpe_end_marker=None, + ): + super().__init__(dictionary, bpe_cont_marker, bpe_end_marker) + self.default_dropout_prob = default_dropout_prob + + def noising(self, x, lengths, dropout_prob=None, blank_idx=None): + if dropout_prob is None: + dropout_prob = self.default_dropout_prob + # x: (T x B), lengths: B + if dropout_prob == 0: + return x, lengths + + assert 0 < dropout_prob < 1 + + # be sure to drop entire words + word_idx = self.get_word_idx(x) + sentences = [] + modified_lengths = [] + for i in range(lengths.size(0)): + # Since dropout probabilities need to apply over non-pad tokens, + # it is not trivial to generate the keep mask without consider + # input lengths; otherwise, this could be done outside the loop + + # We want to drop whole words based on word_idx grouping + num_words = max(word_idx[:, i]) + 1 + + # ith example: [x0, x1, ..., eos, pad, ..., pad] + # We should only generate keep probs for non-EOS tokens. Thus if the + # input sentence ends in EOS, the last word idx is not included in + # the dropout mask generation and we append True to always keep EOS. + # Otherwise, just generate the dropout mask for all word idx + # positions. + has_eos = x[lengths[i] - 1, i] == self.dictionary.eos() + if has_eos: # has eos? + keep = np.random.rand(num_words - 1) >= dropout_prob + keep = np.append(keep, [True]) # keep EOS symbol + else: + keep = np.random.rand(num_words) >= dropout_prob + + words = x[: lengths[i], i].tolist() + + # TODO: speed up the following loop + # drop words from the input according to keep + new_s = [ + w if keep[word_idx[j, i]] else blank_idx for j, w in enumerate(words) + ] + new_s = [w for w in new_s if w is not None] + # we need to have at least one word in the sentence (more than the + # start / end sentence symbols) + if len(new_s) <= 1: + # insert at beginning in case the only token left is EOS + # EOS should be at end of list. + new_s.insert(0, words[np.random.randint(0, len(words))]) + assert len(new_s) >= 1 and ( + not has_eos # Either don't have EOS at end or last token is EOS + or (len(new_s) >= 2 and new_s[-1] == self.dictionary.eos()) + ), "New sentence is invalid." + sentences.append(new_s) + modified_lengths.append(len(new_s)) + # re-construct input + modified_lengths = torch.LongTensor(modified_lengths) + modified_x = torch.LongTensor( + modified_lengths.max(), modified_lengths.size(0) + ).fill_(self.dictionary.pad()) + for i in range(modified_lengths.size(0)): + modified_x[: modified_lengths[i], i].copy_(torch.LongTensor(sentences[i])) + + return modified_x, modified_lengths + + +class WordShuffle(WordNoising): + """Shuffle words by no more than k positions.""" + + def __init__( + self, + dictionary, + default_max_shuffle_distance=3, + bpe_cont_marker="@@", + bpe_end_marker=None, + ): + super().__init__(dictionary, bpe_cont_marker, bpe_end_marker) + self.default_max_shuffle_distance = 3 + + def noising(self, x, lengths, max_shuffle_distance=None): + if max_shuffle_distance is None: + max_shuffle_distance = self.default_max_shuffle_distance + # x: (T x B), lengths: B + if max_shuffle_distance == 0: + return x, lengths + + # max_shuffle_distance < 1 will return the same sequence + assert max_shuffle_distance > 1 + + # define noise word scores + noise = np.random.uniform( + 0, + max_shuffle_distance, + size=(x.size(0), x.size(1)), + ) + noise[0] = -1 # do not move start sentence symbol + # be sure to shuffle entire words + word_idx = self.get_word_idx(x) + x2 = x.clone() + for i in range(lengths.size(0)): + length_no_eos = lengths[i] + if x[lengths[i] - 1, i] == self.dictionary.eos(): + length_no_eos = lengths[i] - 1 + # generate a random permutation + scores = word_idx[:length_no_eos, i] + noise[word_idx[:length_no_eos, i], i] + # ensure no reordering inside a word + scores += 1e-6 * np.arange(length_no_eos.item()) + permutation = scores.argsort() + # shuffle words + x2[:length_no_eos, i].copy_( + x2[:length_no_eos, i][torch.from_numpy(permutation)] + ) + return x2, lengths + + +class UnsupervisedMTNoising(WordNoising): + """ + Implements the default configuration for noising in UnsupervisedMT + (github.com/facebookresearch/UnsupervisedMT) + """ + + def __init__( + self, + dictionary, + max_word_shuffle_distance, + word_dropout_prob, + word_blanking_prob, + bpe_cont_marker="@@", + bpe_end_marker=None, + ): + super().__init__(dictionary) + self.max_word_shuffle_distance = max_word_shuffle_distance + self.word_dropout_prob = word_dropout_prob + self.word_blanking_prob = word_blanking_prob + + self.word_dropout = WordDropout( + dictionary=dictionary, + bpe_cont_marker=bpe_cont_marker, + bpe_end_marker=bpe_end_marker, + ) + self.word_shuffle = WordShuffle( + dictionary=dictionary, + bpe_cont_marker=bpe_cont_marker, + bpe_end_marker=bpe_end_marker, + ) + + def noising(self, x, lengths): + # 1. Word Shuffle + noisy_src_tokens, noisy_src_lengths = self.word_shuffle.noising( + x=x, + lengths=lengths, + max_shuffle_distance=self.max_word_shuffle_distance, + ) + # 2. Word Dropout + noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising( + x=noisy_src_tokens, + lengths=noisy_src_lengths, + dropout_prob=self.word_dropout_prob, + ) + # 3. Word Blanking + noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising( + x=noisy_src_tokens, + lengths=noisy_src_lengths, + dropout_prob=self.word_blanking_prob, + blank_idx=self.dictionary.unk(), + ) + + return noisy_src_tokens + + +class NoisingDataset(torch.utils.data.Dataset): + def __init__( + self, + src_dataset, + src_dict, + seed, + noiser=None, + noising_class=UnsupervisedMTNoising, + **kwargs + ): + """ + Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the + samples based on the supplied noising configuration. + + Args: + src_dataset (~torch.utils.data.Dataset): dataset to wrap. + to build self.src_dataset -- + a LanguagePairDataset with src dataset as the source dataset and + None as the target dataset. Should NOT have padding so that + src_lengths are accurately calculated by language_pair_dataset + collate function. + We use language_pair_dataset here to encapsulate the tgt_dataset + so we can re-use the LanguagePairDataset collater to format the + batches in the structure that SequenceGenerator expects. + src_dict (~fairseq.data.Dictionary): source dictionary + seed (int): seed to use when generating random noise + noiser (WordNoising): a pre-initialized :class:`WordNoising` + instance. If this is None, a new instance will be created using + *noising_class* and *kwargs*. + noising_class (class, optional): class to use to initialize a + default :class:`WordNoising` instance. + kwargs (dict, optional): arguments to initialize the default + :class:`WordNoising` instance given by *noiser*. + """ + self.src_dataset = src_dataset + self.src_dict = src_dict + self.seed = seed + self.noiser = ( + noiser + if noiser is not None + else noising_class( + dictionary=src_dict, + **kwargs, + ) + ) + self.sizes = src_dataset.sizes + + def __getitem__(self, index): + """ + Returns a single noisy sample. Multiple samples are fed to the collater + create a noising dataset batch. + """ + src_tokens = self.src_dataset[index] + src_lengths = torch.LongTensor([len(src_tokens)]) + src_tokens = src_tokens.unsqueeze(0) + + # Transpose src tokens to fit expected shape of x in noising function + # (batch size, sequence length) -> (sequence length, batch size) + src_tokens_t = torch.t(src_tokens) + + with data_utils.numpy_seed(self.seed + index): + noisy_src_tokens = self.noiser.noising(src_tokens_t, src_lengths) + + # Transpose back to expected src_tokens format + # (sequence length, 1) -> (1, sequence length) + noisy_src_tokens = torch.t(noisy_src_tokens) + return noisy_src_tokens[0] + + def __len__(self): + """ + The length of the noising dataset is the length of src. + """ + return len(self.src_dataset) + + @property + def supports_prefetch(self): + return self.src_dataset.supports_prefetch + + def prefetch(self, indices): + if self.src_dataset.supports_prefetch: + self.src_dataset.prefetch(indices) diff --git a/fairseq/fairseq/data/num_samples_dataset.py b/fairseq/fairseq/data/num_samples_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..99a17495c701d8a05e0268f98bf453905e11d078 --- /dev/null +++ b/fairseq/fairseq/data/num_samples_dataset.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import FairseqDataset + + +class NumSamplesDataset(FairseqDataset): + def __getitem__(self, index): + return 1 + + def __len__(self): + return 0 + + def collater(self, samples): + return sum(samples) diff --git a/fairseq/fairseq/data/numel_dataset.py b/fairseq/fairseq/data/numel_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ac86dfd2f1d89055de909656d61d6aca85523f00 --- /dev/null +++ b/fairseq/fairseq/data/numel_dataset.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import BaseWrapperDataset + + +class NumelDataset(BaseWrapperDataset): + def __init__(self, dataset, reduce=False): + super().__init__(dataset) + self.reduce = reduce + + def __getitem__(self, index): + item = self.dataset[index] + if torch.is_tensor(item): + return torch.numel(item) + else: + return np.size(item) + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if self.reduce: + return sum(samples) + else: + return torch.tensor(samples) diff --git a/fairseq/fairseq/data/plasma_utils.py b/fairseq/fairseq/data/plasma_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..459fb8acd789e7b03c70201cb5cb2a9e7dc4f325 --- /dev/null +++ b/fairseq/fairseq/data/plasma_utils.py @@ -0,0 +1,197 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import hashlib +import json +import subprocess +import tempfile +from typing import Hashable + +try: + import pyarrow.plasma as plasma + + PYARROW_AVAILABLE = True +except ImportError: + plasma = None + PYARROW_AVAILABLE = False + + +class PlasmaArray: + """ + Wrapper around numpy arrays that automatically moves the data to shared + memory upon serialization. This is particularly helpful when passing numpy + arrays through multiprocessing, so that data is not unnecessarily + duplicated or pickled. + """ + + def __init__(self, array): + super().__init__() + self.array = array + self.disable = array.nbytes < 134217728 # disable for arrays <128MB + self.object_id = None + self.path = None + + # variables with underscores shouldn't be pickled + self._client = None + self._server = None + self._server_tmp = None + self._plasma = None + + @property + def plasma(self): + if self._plasma is None and not self.disable: + self._plasma = plasma + return self._plasma + + def start_server(self): + if self.plasma is None or self._server is not None: + return + assert self.object_id is None + assert self.path is None + self._server_tmp = tempfile.NamedTemporaryFile() + self.path = self._server_tmp.name + self._server = subprocess.Popen( + ["plasma_store", "-m", str(int(1.05 * self.array.nbytes)), "-s", self.path] + ) + + @property + def client(self): + if self._client is None: + assert self.path is not None + self._client = self.plasma.connect(self.path, num_retries=200) + return self._client + + def __getstate__(self): + """Called on pickle load""" + if self.plasma is None: + return self.__dict__ + if self.object_id is None: + self.start_server() + self.object_id = self.client.put(self.array) + state = self.__dict__.copy() + del state["array"] + state["_client"] = None + state["_server"] = None + state["_server_tmp"] = None + state["_plasma"] = None + return state + + def __setstate__(self, state): + """Called on pickle save""" + self.__dict__.update(state) + if self.plasma is None: + return + self.array = self.client.get(self.object_id) + + def __del__(self): + if self._server is not None: + self._server.kill() + self._server = None + self._server_tmp.close() + self._server_tmp = None + + +DEFAULT_PLASMA_PATH = "/tmp/plasma" + + +class PlasmaView: + """Interface to write and read from shared memory. Whereas PlasmaArray writes to plasma on serialization, + PlasmaView writes to shared memory on instantiation.""" + + def __init__(self, array, split_path: str, hash_data: Hashable, plasma_path=None): + """ + Args: + array: numpy array to store. This can be read with ``PlasmaView().array`` + split_path: the path whence the data was read, used for hashing + hash_data: other metadata about the array that can be used to create a unique key. + as of writing, the 3 callers in ``TokenBlockDataset`` use:: + + hash_data = ((block_size, document_sep_len, str(break_mode), len(dataset)), 0|1|2) + + + """ + assert PYARROW_AVAILABLE + assert split_path is not None + if plasma_path is None: + plasma_path = DEFAULT_PLASMA_PATH + + self.path = plasma_path + self.split_path = split_path + self._client = None # Initialize lazily for pickle. plasma clients should not be deep copied or serialized. + self._n = None + + self.object_id = self.get_object_id(self.split_path, hash_data) + try: + self.client.put(array, object_id=self.object_id) + except plasma.PlasmaObjectExists: + pass + + @property + def client(self): + if self._client is None: + self._client = plasma.connect(self.path, num_retries=200) + return self._client + + @property + def array(self): + """Fetch a read only view of an np.array, stored in plasma.""" + ret = self.client.get(self.object_id) + return ret + + @staticmethod + def get_object_id(split_path: str, hash_data: Hashable): + """Returns plasma.ObjectID from hashing split_path and object_num.""" + hash = hashlib.blake2b(bytes(split_path, "utf-8"), digest_size=20) + harg = json.dumps(hash_data).encode("utf-8") + hash.update(harg) + return plasma.ObjectID(hash.digest()) + + def __getstate__(self): + """Called on pickle save""" + self.disconnect() + state = self.__dict__.copy() + assert state["_client"] is None + assert "object_id" in state + return state + + def __setstate__(self, state): + """Called on pickle load""" + self.__dict__.update(state) + + def __del__(self): + self.disconnect() + + def disconnect(self): + if self._client is not None: + self._client.disconnect() + self._client = None + + def __len__(self): + """Save reads by caching len""" + if self._n is None: + self._n = len(self.array) + return self._n + + +GB100 = (1024**3) * 100 + + +class PlasmaStore: + def __init__(self, path=DEFAULT_PLASMA_PATH, nbytes: int = GB100): + + self.server = self.start(path, nbytes) + + def __del__(self): + self.server.kill() + + @staticmethod + def start(path=DEFAULT_PLASMA_PATH, nbytes: int = GB100) -> subprocess.Popen: + if not PYARROW_AVAILABLE: + raise ImportError("please run pip install pyarrow to use --use_plasma_view") + # best practice is to allocate more space than we need. The limitation seems to be the size of /dev/shm + _server = subprocess.Popen(["plasma_store", "-m", str(nbytes), "-s", path]) + plasma.connect(path, num_retries=200) # If we can't connect we fail immediately + return _server diff --git a/fairseq/fairseq/data/prepend_dataset.py b/fairseq/fairseq/data/prepend_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ad74784d2d7920e4a6225282d95543ce16ea50d9 --- /dev/null +++ b/fairseq/fairseq/data/prepend_dataset.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import BaseWrapperDataset + + +class PrependDataset(BaseWrapperDataset): + def __init__(self, dataset, prepend_getter, ensure_first_token_is=None): + super().__init__(dataset) + self.prepend_getter = prepend_getter + self.ensure_first_token = ensure_first_token_is + + def __getitem__(self, idx): + item = self.dataset[idx] + is_tuple = isinstance(item, tuple) + src = item[0] if is_tuple else item + + assert self.ensure_first_token is None or src[0] == self.ensure_first_token + prepend_idx = self.prepend_getter(self.dataset, idx) + assert isinstance(prepend_idx, int) + src[0] = prepend_idx + item = tuple((src,) + item[1:]) if is_tuple else src + return item diff --git a/fairseq/fairseq/data/roll_dataset.py b/fairseq/fairseq/data/roll_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a2915eeb3e8fb4dfb4b2bb33e0464ad0783d854c --- /dev/null +++ b/fairseq/fairseq/data/roll_dataset.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import BaseWrapperDataset + + +class RollDataset(BaseWrapperDataset): + def __init__(self, dataset, shifts): + super().__init__(dataset) + self.shifts = shifts + + def __getitem__(self, index): + item = self.dataset[index] + return torch.roll(item, self.shifts) diff --git a/fairseq/fairseq/data/round_robin_zip_datasets.py b/fairseq/fairseq/data/round_robin_zip_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb7447ea955a7c3ae7372f09ee426c08acd430e --- /dev/null +++ b/fairseq/fairseq/data/round_robin_zip_datasets.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections import OrderedDict +from typing import Dict, Sequence + +import numpy as np + +from . import FairseqDataset, LanguagePairDataset + +logger = logging.getLogger(__name__) + + +class RoundRobinZipDatasets(FairseqDataset): + """Zip multiple :class:`~fairseq.data.FairseqDataset` instances together. + + Shorter datasets are repeated in a round-robin fashion to match the length + of the longest one. + + Args: + datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of + :class:`~fairseq.data.FairseqDataset` instances. + eval_key (str, optional): a key used at evaluation time that causes + this instance to pass-through batches from *datasets[eval_key]*. + """ + + def __init__(self, datasets, eval_key=None): + super().__init__() + if isinstance(datasets, dict): + datasets = OrderedDict(datasets) + assert isinstance(datasets, OrderedDict) + assert datasets, "Can't make a RoundRobinZipDatasets out of nothing" + for dataset in datasets.values(): + assert isinstance(dataset, FairseqDataset) + + self.datasets = datasets + self.eval_key = eval_key + + self.longest_dataset_key = max(datasets, key=lambda k: len(datasets[k])) + self.longest_dataset = datasets[self.longest_dataset_key] + self._ordered_indices: Dict[str, Sequence[int]] = None + + def _map_index(self, key, index): + assert ( + self._ordered_indices is not None + ), "Must call RoundRobinZipDatasets.ordered_indices() first" + o = self._ordered_indices[key] + return o[index % len(o)] + + def __getitem__(self, index): + if self.eval_key is None: + return OrderedDict( + [ + (key, dataset[self._map_index(key, index)]) + for key, dataset in self.datasets.items() + ] + ) + else: + # at evaluation time it's useful to pass-through batches from a single key + return self.datasets[self.eval_key][self._map_index(self.eval_key, index)] + + def __len__(self): + if self._ordered_indices is not None: + return len(self._ordered_indices[self.longest_dataset_key]) + return len(self.longest_dataset) + + def collater(self, samples): + """Merge a list of samples to form a mini-batch.""" + if len(samples) == 0: + return None + if self.eval_key is None: + return OrderedDict( + [ + (key, dataset.collater([sample[key] for sample in samples])) + for key, dataset in self.datasets.items() + ] + ) + else: + # at evaluation time it's useful to pass-through batches from a single key + return self.datasets[self.eval_key].collater(samples) + + def num_tokens(self, index): + """Return an example's length (number of tokens), used for batching.""" + # TODO make it configurable whether to use max() or sum() here + return max( + dataset.num_tokens(self._map_index(key, index)) + for key, dataset in self.datasets.items() + ) + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return { + key: dataset.size(self._map_index(key, index)) + for key, dataset in self.datasets.items() + } + + def ordered_indices(self): + """Ordered indices for batching.""" + if self._ordered_indices is None: + # Call the underlying dataset's ordered_indices() here, so that we + # get the same random ordering as we would have from using the + # underlying sub-datasets directly. + self._ordered_indices = OrderedDict( + [ + (key, dataset.ordered_indices()) + for key, dataset in self.datasets.items() + ] + ) + return np.arange(len(self)) + + def filter_indices_by_size(self, indices, max_positions=None): + """ + Filter each sub-dataset independently, then update the round robin to work + on the filtered sub-datasets. + """ + + def _deep_until_language_pair(dataset): + if isinstance(dataset, LanguagePairDataset): + return dataset + if hasattr(dataset, "tgt_dataset"): + return _deep_until_language_pair(dataset.tgt_dataset) + if hasattr(dataset, "dataset"): + return _deep_until_language_pair(dataset.dataset) + raise Exception(f"Don't know how to unwrap this dataset: {dataset}") + + if not isinstance(max_positions, dict): + max_positions = {k: max_positions for k in self.datasets.keys()} + ignored_some = False + for key, dataset in self.datasets.items(): + dataset = _deep_until_language_pair(dataset) + self._ordered_indices[key], ignored = dataset.filter_indices_by_size( + self._ordered_indices[key], max_positions[key] + ) + if len(ignored) > 0: + ignored_some = True + logger.warning( + f"{len(ignored)} samples from {key} have invalid sizes and will be skipped, " + f"max_positions={max_positions[key]}, first few sample ids={ignored[:10]}" + ) + # Since we are modifying in place the _ordered_indices, + # it's not possible anymore to return valid ignored indices. + # Hopefully the extra debug information print above should be enough to debug. + # Ideally we would receive ignore_invalid_inputs so that we could have + # a proper error message. + return (np.arange(len(self)), [0] if ignored_some else []) + + @property + def supports_prefetch(self): + return all( + getattr(dataset, "supports_prefetch", False) + for dataset in self.datasets.values() + ) + + def prefetch(self, indices): + for key, dataset in self.datasets.items(): + dataset.prefetch([self._map_index(key, index) for index in indices]) diff --git a/fairseq/fairseq/data/speech_dlm_dataset.py b/fairseq/fairseq/data/speech_dlm_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..06c4808f0aaacd3191eadfdb1e03d49add2c3827 --- /dev/null +++ b/fairseq/fairseq/data/speech_dlm_dataset.py @@ -0,0 +1,307 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import numpy as np +import torch + +from fairseq.data import FairseqDataset, MonolingualDataset, data_utils + + +class SpeechDLMDataset(FairseqDataset): + """The dataset used to train the SpeechDLM model as described in the paper: + https://arxiv.org/pdf/2203.16502.pdf + + The input datasets is expected to be a dict over channel names with the values + being instances of :class:`~fairseq.data.MonolingualDataset`. + + Each element of SpeechDLMDataset is a dictionary with the following keys: + - `id` (int) : index of the item + - `source` (OrderedDict[str, Tensor of shape (seq_len,)]) : dictionary over + channels with the values containing the input unit tokens + - `target_next` (OrderedDict[str, Tensor of shape (seq_len,)]) : dictionary + over channels with the values containing the next unit tokens (input + tokens shifted by 1). + Its value is None if 'next' not in self.targets + - `target_edge` (OrderedDict[str, Tensor of shape (dedup_seq_len,)]) : dictionary + over channels with the values containing the edge unit tokens (input tokens + deduplicated). + Its value is None if 'edge' not in self.targets + - `target_duration` (OrderedDict[str, Tensor of shape (dedup_seq_len,)]) : + dictionary over channels with the values being the durations of the edge units. + Its value is None if 'duration' not in targets. + - `target_edge_indices` (OrderedDict[str, Tensor of shape (dedup_seq_len,)]) : + dictionary over channels with the values being the indices of the edge units + in the source sequence. + Its value is None if neither 'edge' or 'duration in targets. + + Args: + datasets (Dict[str, ~fairseq.data.MonolingualDataset]): a dictionary of + :class:`~fairseq.data.MonolingualDataset` instances. + targets (List[str]): list of the target types that the SpeechDLM model + should predict. Can be one of "next", "edge", "duration". + shuffle (bool, optional): shuffle the elements before batching + (default: True). + """ + + def __init__( + self, datasets, targets=None, max_target_durations=None, shuffle=False + ): + super().__init__() + if isinstance(datasets, dict): + datasets = OrderedDict(datasets) + assert isinstance( + datasets, OrderedDict + ), "datasets is expected to be an instance of Dictionary or OrderedDict" + assert datasets, "datasets is None" + for dataset in datasets.values(): + assert isinstance( + dataset, MonolingualDataset + ), "Each value of datasets is expected to be an instance of MonolingualDataset" + + self.datasets = datasets + self.targets = targets + if max_target_durations is not None and max_target_durations > 0: + self.max_target_durations = max_target_durations + else: + self.max_target_durations = float("inf") + self.sizes = next(iter(datasets.values())).sizes + self.vocab = next(iter(datasets.values())).vocab + self.length = len(next(iter(datasets.values()))) + self.shuffle = shuffle + + for channel, dataset in datasets.items(): + assert ( + len(dataset) == self.length + ), "[{}] length mismatch ({} vs {})".format( + channel, len(dataset), self.length + ) + assert (dataset.sizes == self.sizes).all(), "[{}] sizes mismatch".format( + channel + ) + + assert ( + dataset.vocab.pad() == self.vocab.pad() + ), "pad token is expected to be the same" + assert ( + dataset.vocab.eos() == self.vocab.eos() + ), "eos token is expected to be the same" + assert ( + dataset.vocab.bos() == self.vocab.bos() + ), "bos token is expected to be the same" + assert ( + dataset.vocab.unk() == self.vocab.unk() + ), "unk token is expected to be the same" + + def __getitem__(self, index): + source = OrderedDict( + [ + (key, dataset[index]["source"]) + for (key, dataset) in self.datasets.items() + ] + ) + + item = { + "id": index, + "source": source, + "target_next": None, + "target_edge": None, + "target_duration": None, + "target_edge_indices": None, + } + + if self.targets is not None: + for channel in self.datasets: + target = self._get_target(index, channel) + for t in target: + if item[f"target_{t}"] is None: + item[f"target_{t}"] = OrderedDict() + item[f"target_{t}"][channel] = target[t] + + return item + + def __len__(self): + return self.length + + def _get_target(self, index, channel): + """Get target in one of ['next', 'edge', 'duration'] + - 'next' is the future unit + - 'edge' is the edge unit + - 'duration' is the duration of the edge unit + """ + if self.targets is not None: + target = {} + pad_idx = self.vocab.pad() + max_dur = self.max_target_durations + future_target = self.datasets[channel][index]["target"] + if "edge" in self.targets or "duration" in self.targets: + edge_units, edge_unit_counts = torch.unique_consecutive( + future_target, return_counts=True + ) + padding_end = edge_units[-1] == pad_idx + if padding_end: + edge_units = edge_units[:-1] + edge_unit_counts = edge_unit_counts[:-1] + edge_indices = torch.cumsum(edge_unit_counts, 0) + edge_indices = torch.cat([torch.tensor([0]), edge_indices[:-1]]) + target["edge_indices"] = edge_indices + + for t in self.targets: + if t == "next": + target[t] = future_target + elif t == "edge": + target[t] = edge_units + elif t == "duration": + # count the remaining duration of the last edge indices in the next sentence + if not padding_end and index < len(self.datasets[channel]) - 1: + i = 0 + next_sentence_target = self.datasets[channel][index + 1][ + "target" + ] + while ( + next_sentence_target[i] == edge_units[-1] + and edge_unit_counts[-1] + i < max_dur + ): + i += 1 + edge_unit_counts[-1] += i + + # cut off to the maximal threshold + if max_dur: + edge_unit_counts[edge_unit_counts > max_dur] = max_dur + + target[t] = edge_unit_counts + else: + raise Exception("invalid target " + t) + + return target + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch with the following keys: + + - `id` (LongTensor): example IDs in the original input order + - `ntokens` (int): total number of tokens in the batch + - `net_input` (dict): the input to the Model, containing keys: + + - `src_tokens` (OrderedDict[str, LongTensor]): dictionary + over channel with the values being padded 2D Tensor of + samples `source` of shape `(bsz, src_len)`. + Padding will appear on the right. + - `src_lengths` (LongTensor): lengths of source sentences + in the mini-batch + + - `target` (dict): the target of the Model, containing keys: + + - `next` (OrderedDict[str, LongTensor]): dictionary + over channel with the values being padded 2D Tensor of + batch samples' `target_next` of shape `(bsz, tgt_len)`. + Padding will appear on the right. + - `edge` (OrderedDict[str, LongTensor]): dictionary + over channel with the values being the concatenated + 1D Tensor of batch samples' `target_edge` of shape + `(sum of dedup_tgt_len,)` + - `duration` (OrderedDict[str, LongTensor]): dictionary + over channel with the values being the concatenated + 1D Tensor of batch samples' `target_duration` of shape + `(sum of dedup_tgt_len,)` + - `edge_indices` (OrderedDict[str, LongTensor]): dictionary + over channel with the values being the concatenated + 1D Tensor of batch samples' `target_edge_indices` of + shape `(sum of dedup_tgt_len,)`. + The indices are added to multiplies of batch size + such that they are the actual indices in the flatten + `src_tokens` Tensor + """ + if len(samples) == 0: + return {} + + pad_idx = self.vocab.pad() + eos_idx = self.vocab.eos() + + def merge(key, max_size=None): + if samples[0][key] is None: + return None + res = OrderedDict() + for channel in samples[0][key]: + if key in ["source", "target_next"]: + # fill batch of shape: (batch_size, max_size) + res[channel] = data_utils.collate_tokens( + [s[key][channel] for s in samples], + pad_idx, + eos_idx, + left_pad=False, + ) + elif key in ["target_edge", "target_duration"]: + # concatenate the edge units/duration + res[channel] = torch.cat([s[key][channel] for s in samples]) + elif key == "target_edge_indices": + # increase the edge indices to the indices in the flatten batch + res[channel] = torch.cat( + [s[key][channel] + i * max_size for i, s in enumerate(samples)] + ) + + return res + + src_tokens = merge("source") + tgt_next = merge("target_next") + tgt_edge = merge("target_edge") + tgt_duration = merge("target_duration") + tgt_edge_indices = merge( + "target_edge_indices", max_size=next(iter(src_tokens.values())).size(-1) + ) + return { + "id": torch.LongTensor([s["id"] for s in samples]), + "nsentences": len(samples), + "ntokens": sum(len(item) for s in samples for item in s["source"].values()), + "net_input": { + "src_tokens": src_tokens, + "src_lengths": torch.LongTensor( + [next(iter(s["source"].values())).numel() for s in samples] + ), + }, + "target": { + "next": tgt_next, + "edge": tgt_edge, + "duration": tgt_duration, + "edge_indices": tgt_edge_indices, + }, + } + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return self.sizes[index] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + order.append(self.sizes) + return np.lexsort(order) + + @property + def supports_prefetch(self): + return all( + getattr(dataset, "supports_prefetch", False) + for dataset in self.datasets.values() + ) + + def prefetch(self, indices): + for key, dataset in self.datasets.items(): + dataset.prefetch(indices) diff --git a/fairseq/fairseq/data/strip_token_dataset.py b/fairseq/fairseq/data/strip_token_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cae39ba4d2f8106398eccd7eb0cf5c2194ec0db5 --- /dev/null +++ b/fairseq/fairseq/data/strip_token_dataset.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import BaseWrapperDataset + + +class StripTokenDataset(BaseWrapperDataset): + def __init__(self, dataset, id_to_strip): + super().__init__(dataset) + self.id_to_strip = id_to_strip + + def __getitem__(self, index): + item = self.dataset[index] + while len(item) > 0 and item[-1] == self.id_to_strip: + item = item[:-1] + while len(item) > 0 and item[0] == self.id_to_strip: + item = item[1:] + return item diff --git a/fairseq/fairseq/data/subsample_dataset.py b/fairseq/fairseq/data/subsample_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5c7e2ac864613a10b1886bca78cbc53f5bfd64 --- /dev/null +++ b/fairseq/fairseq/data/subsample_dataset.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging + +import numpy as np +from fairseq.data.data_utils import numpy_seed + +from . import BaseWrapperDataset + + +logger = logging.getLogger(__name__) + + +class SubsampleDataset(BaseWrapperDataset): + """Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples + + Args: + dataset (~torch.utils.data.Dataset): dataset to subsample + size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive) + """ + + def __init__(self, dataset, size_ratio, shuffle=False, seed=None): + super().__init__(dataset) + assert size_ratio < 1 + self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int) + with numpy_seed(seed) if seed is not None else contextlib.ExitStack(): + self.indices = np.random.choice( + list(range(len(self.dataset))), self.actual_size, replace=False + ) + self.shuffle = shuffle + logger.info( + "subsampled dataset from {} to {} (ratio={})".format( + len(self.dataset), self.actual_size, size_ratio + ) + ) + + def __getitem__(self, index): + return self.dataset[self.indices[index]] + + def __len__(self): + return self.actual_size + + def collater(self, samples): + return self.dataset.collater(samples) + + @property + def sizes(self): + return self.dataset.sizes[self.indices] + + @property + def name(self): + return self.dataset.name + + def num_tokens(self, index): + return self.dataset.num_tokens(self.indices[index]) + + def size(self, index): + return self.dataset.size(self.indices[index]) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + order.append(self.sizes) + return np.lexsort(order) + + def prefetch(self, indices): + self.dataset.prefetch(self.indices[indices]) diff --git a/fairseq/fairseq/data/text_compressor.py b/fairseq/fairseq/data/text_compressor.py new file mode 100644 index 0000000000000000000000000000000000000000..d699f2ea296f33cdc37ca152ab225d09cb04b5ea --- /dev/null +++ b/fairseq/fairseq/data/text_compressor.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +class TextCompressionLevel(Enum): + none = 0 + low = 1 + high = 2 + + +class TextCompressor(object): + def __init__( + self, level: TextCompressionLevel, max_input_byte_length: int = 2**16 + ): + self.level = level + self.max_input_length = max_input_byte_length + + def compress(self, text: str) -> bytes: + if self.level == TextCompressionLevel.low: + import zlib + + # zlib: built-in, fast + return zlib.compress(text.encode(), level=0) + elif self.level == TextCompressionLevel.high: + try: + import unishox2 + + # unishox2: optimized for short text but slower + except ImportError: + raise ImportError( + "Please install unishox2 for the text compression feature: " + "pip install unishox2-py3" + ) + assert len(text.encode()) <= self.max_input_length + return unishox2.compress(text)[0] + else: + return text.encode() + + def decompress(self, compressed: bytes) -> str: + if self.level == TextCompressionLevel.low: + import zlib + + return zlib.decompress(compressed).decode() + elif self.level == TextCompressionLevel.high: + try: + import unishox2 + except ImportError: + raise ImportError( + "Please install unishox2 for the text compression feature: " + "pip install unishox2-py3" + ) + return unishox2.decompress(compressed, self.max_input_length) + else: + return compressed.decode() diff --git a/fairseq/fairseq/data/token_block_dataset.py b/fairseq/fairseq/data/token_block_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a414e7ef64193b4c9e285e357350c09663dd2d8f --- /dev/null +++ b/fairseq/fairseq/data/token_block_dataset.py @@ -0,0 +1,206 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from fairseq.data import FairseqDataset, plasma_utils +from fairseq.data.indexed_dataset import best_fitting_int_dtype +from typing import Tuple + + +class TokenBlockDataset(FairseqDataset): + """Break a Dataset of tokens into blocks. + + Args: + dataset (~torch.utils.data.Dataset): dataset to break into blocks + sizes (List[int]): sentence lengths (required for 'complete' and 'eos') + block_size (int): maximum block size (ignored in 'eos' break mode) + break_mode (str, optional): Mode used for breaking tokens. Values can + be one of: + - 'none': break tokens into equally sized blocks (up to block_size) + - 'complete': break tokens into blocks (up to block_size) such that + blocks contains complete sentences, although block_size may be + exceeded if some sentences exceed block_size + - 'complete_doc': similar to 'complete' mode, but do not + cross document boundaries + - 'eos': each block contains one sentence (block_size is ignored) + include_targets (bool, optional): return next tokens as targets + (default: False). + document_sep_len (int, optional): document separator size (required for + 'complete_doc' break mode). Typically 1 if the sentences have eos + and 0 otherwise. + """ + + def __init__( + self, + dataset, + sizes, + block_size, + pad, + eos, + break_mode=None, + include_targets=False, + document_sep_len=1, + use_plasma_view=False, + split_path=None, + plasma_path=None, + ): + + super().__init__() + self.dataset = dataset + self.pad = pad + self.eos = eos + self.include_targets = include_targets + + assert len(dataset) > 0 + + assert len(dataset) == len(sizes) + _sizes, block_to_dataset_index, slice_indices = self._build_slice_indices( + sizes, break_mode, document_sep_len, block_size + ) + if use_plasma_view: + plasma_id = (block_size, document_sep_len, str(break_mode), len(dataset)) + self._slice_indices = plasma_utils.PlasmaView( + slice_indices, split_path, (plasma_id, 0), plasma_path=plasma_path + ) + self._sizes = plasma_utils.PlasmaView( + _sizes, split_path, (plasma_id, 1), plasma_path=plasma_path + ) + self._block_to_dataset_index = plasma_utils.PlasmaView( + block_to_dataset_index, + split_path, + (plasma_id, 2), + plasma_path=plasma_path, + ) + else: + self._slice_indices = plasma_utils.PlasmaArray(slice_indices) + self._sizes = plasma_utils.PlasmaArray(_sizes) + self._block_to_dataset_index = plasma_utils.PlasmaArray( + block_to_dataset_index + ) + + @staticmethod + def _build_slice_indices( + sizes, break_mode, document_sep_len, block_size + ) -> Tuple[np.ndarray]: + """Use token_block_utils_fast to build arrays for indexing into self.dataset""" + try: + from fairseq.data.token_block_utils_fast import ( + _get_slice_indices_fast, + _get_block_to_dataset_index_fast, + ) + except ImportError: + raise ImportError( + "Please build Cython components with: `pip install --editable .` " + "or `python setup.py build_ext --inplace`" + ) + + if isinstance(sizes, list): + sizes = np.array(sizes, dtype=np.int64) + else: + if torch.is_tensor(sizes): + sizes = sizes.numpy() + sizes = sizes.astype(np.int64) + + break_mode = break_mode if break_mode is not None else "none" + + # For "eos" break-mode, block_size is not required parameters. + if break_mode == "eos" and block_size is None: + block_size = 0 + + slice_indices = _get_slice_indices_fast( + sizes, str(break_mode), block_size, document_sep_len + ) + _sizes = slice_indices[:, 1] - slice_indices[:, 0] + + # build index mapping block indices to the underlying dataset indices + if break_mode == "eos": + # much faster version for eos break mode + block_to_dataset_index = np.stack( + [ + np.arange(len(sizes)), # starting index in dataset + np.zeros( + len(sizes), dtype=np.compat.long + ), # starting offset within starting index + np.arange(len(sizes)), # ending index in dataset + ], + 1, + ) + else: + block_to_dataset_index = _get_block_to_dataset_index_fast( + sizes, + slice_indices, + ) + size_dtype = np.uint16 if block_size < 65535 else np.uint32 + num_tokens = slice_indices[-1].max() + slice_indices_dtype = best_fitting_int_dtype(num_tokens) + slice_indices = slice_indices.astype(slice_indices_dtype) + _sizes = _sizes.astype(size_dtype) + block_to_dataset_index = block_to_dataset_index.astype(slice_indices_dtype) + return _sizes, block_to_dataset_index, slice_indices + + @property + def slice_indices(self): + return self._slice_indices.array + + @property + def sizes(self): + return self._sizes.array + + @property + def block_to_dataset_index(self): + return self._block_to_dataset_index.array + + def attr(self, attr: str, index: int): + start_ds_idx, _, _ = self.block_to_dataset_index[index] + return self.dataset.attr(attr, start_ds_idx) + + def __getitem__(self, index): + start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] + + buffer = torch.cat( + [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] + ) + slice_s, slice_e = self.slice_indices[index] + length = slice_e - slice_s + s, e = start_offset, start_offset + length + item = buffer[s:e] + + if self.include_targets: + # *target* is the original sentence (=item) + # *source* is shifted right by 1 (maybe left-padded with eos) + # *past_target* is shifted right by 2 (left-padded as needed) + if s == 0: + source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]]) + past_target = torch.cat( + [item.new([self.pad, self.eos]), buffer[0 : e - 2]] + ) + else: + source = buffer[s - 1 : e - 1] + if s == 1: + past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]]) + else: + past_target = buffer[s - 2 : e - 2] + + return source, item, past_target + + return item + + def __len__(self): + return len(self.slice_indices) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + self.dataset.prefetch( + { + ds_idx + for index in indices + for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]] + for ds_idx in range(start_ds_idx, end_ds_idx + 1) + } + ) diff --git a/fairseq/fairseq/data/transform_eos_concat_langpair_dataset.py b/fairseq/fairseq/data/transform_eos_concat_langpair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..effa127d50c63546c7eeac952053930dd0a4f2b1 --- /dev/null +++ b/fairseq/fairseq/data/transform_eos_concat_langpair_dataset.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from torch.utils.data.dataloader import default_collate + +from fairseq.data import ConcatDataset + +logger = logging.getLogger(__name__) + + +class TransformEosConcatLangPairDataset(ConcatDataset): + """ + It is a combination of TransformEosLangPairDataset and ConcatDataset for multiple LangPairDataset datasets. + Assume all datasets share the same src_eos, tgt_bos, left_pad_source and left_pad_target + """ + + def __init__( + self, + datasets, + src_eos, + tgt_bos, + new_src_eos=None, + new_tgt_bos=None, + ): + super().__init__(datasets) + if new_src_eos is not None and new_src_eos != []: + assert len(new_src_eos) == len(datasets) + else: + new_src_eos = [] + if new_tgt_bos is not None and new_tgt_bos != []: + assert len(new_tgt_bos) == len(datasets) + else: + new_tgt_bos = [] + self.src_eos = src_eos + self.tgt_bos = tgt_bos + self.new_src_eos = ( + torch.LongTensor(new_src_eos).cpu() if len(new_src_eos) > 0 else [] + ) + self.new_tgt_bos = ( + torch.LongTensor(new_tgt_bos).cpu() if len(new_tgt_bos) > 0 else [] + ) + self.left_pad_source = self.is_left_pad_source(datasets) + self.left_pad_target = self.is_left_pad_target(datasets) + self.pad_idx = self.src_dict_pad() + + def src_dict_pad(self): + if hasattr(self.datasets[0], "src_dict"): + return self.datasets[0].src_dict.pad() + if hasattr(self.datasets[0], "dataset"): + return self.datasets[0].dataset.src_dict.pad() + raise NotImplementedError("No src_dict is found") + + def __getitem__(self, idx): + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + return dataset_idx, self.datasets[dataset_idx][sample_idx] + + def is_left_pad_source(self, datasets): + def _left_pad_source(ds): + if hasattr(ds, "left_pad_source"): + return ds.left_pad_source + if hasattr(ds, "dataset"): + return _left_pad_source(ds.dataset) + logger.warn(f"{type(ds)} has no left_pad_source, using default True") + return True + + left_pad_source = _left_pad_source(datasets[0]) + for ds in datasets: + if left_pad_source != _left_pad_source(ds): + raise ValueError("Different left_pad_source setting detected!") + return left_pad_source + + def is_left_pad_target(self, datasets): + def _left_pad_target(ds): + if hasattr(ds, "left_pad_target"): + return ds.left_pad_target + if hasattr(ds, "dataset"): + return _left_pad_target(ds.dataset) + logger.warn(f"{type(ds)} has no left_pad_target, using default False") + return False + + left_pad_target = _left_pad_target(datasets[0]) + for ds in datasets: + if left_pad_target != _left_pad_target(ds): + raise ValueError("Different left_pad_target setting detected!") + return left_pad_target + + def collater(self, samples, **extra_args): + if len(samples) == 0: + return samples + + dataset_ids = [s[0] for s in samples] + samples = [s[1] for s in samples] + + if hasattr(self.datasets[0], "collater"): + samples = self.datasets[0].collater(samples, **extra_args) + else: + samples = default_collate(samples, **extra_args) + + if len(self.new_src_eos) > 0: + if self.left_pad_source: + assert ( + samples["net_input"]["src_tokens"][:, -1] != self.src_eos + ).sum() == 0 + samples["net_input"]["src_tokens"][:, -1] = self.new_src_eos[ + dataset_ids + ] + + else: + eos_idx = samples["net_input"]["src_lengths"] - 1 + assert ( + samples["net_input"]["src_tokens"][ + torch.arange(eos_idx.size(0)), eos_idx + ] + != self.src_eos + ).sum() == 0 + samples["net_input"]["src_tokens"].scatter_( + 1, eos_idx.view(-1, 1), self.new_src_eos[dataset_ids].view(-1, 1) + ) + + if len(self.new_tgt_bos) > 0 and "prev_output_tokens" in samples["net_input"]: + if self.left_pad_target: + # TODO: support different padding direction on target side + raise NotImplementedError( + "TransformEosLangPairDataset does not implement --left-pad-target True option" + ) + else: + assert ( + samples["net_input"]["prev_output_tokens"][:, 0] != self.tgt_bos + ).sum() == 0 + samples["net_input"]["prev_output_tokens"][:, 0] = self.new_tgt_bos[ + dataset_ids + ] + + return samples diff --git a/fairseq/fairseq/data/transform_eos_dataset.py b/fairseq/fairseq/data/transform_eos_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fb14ff018edf13b20f5d0e486692dfb0a37ec6d1 --- /dev/null +++ b/fairseq/fairseq/data/transform_eos_dataset.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import FairseqDataset + + +class TransformEosDataset(FairseqDataset): + """A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS. + + Note that the transformation is applied in :func:`collater`. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset to wrap + eos (int): index of the end-of-sentence symbol + append_eos_to_src (bool, optional): append EOS to the end of src + remove_eos_from_src (bool, optional): remove EOS from the end of src + append_eos_to_tgt (bool, optional): append EOS to the end of tgt + remove_eos_from_tgt (bool, optional): remove EOS from the end of tgt + """ + + def __init__( + self, + dataset, + eos, + append_eos_to_src=False, + remove_eos_from_src=False, + append_eos_to_tgt=False, + remove_eos_from_tgt=False, + has_target=True, + ): + if not isinstance(dataset, FairseqDataset): + raise ValueError("dataset must be an instance of FairseqDataset") + if append_eos_to_src and remove_eos_from_src: + raise ValueError("cannot combine append_eos_to_src and remove_eos_from_src") + if append_eos_to_tgt and remove_eos_from_tgt: + raise ValueError("cannot combine append_eos_to_tgt and remove_eos_from_tgt") + + self.dataset = dataset + self.eos = torch.LongTensor([eos]) + self.append_eos_to_src = append_eos_to_src + self.remove_eos_from_src = remove_eos_from_src + self.append_eos_to_tgt = append_eos_to_tgt + self.remove_eos_from_tgt = remove_eos_from_tgt + self.has_target = has_target + + # precompute how we should adjust the reported sizes + self._src_delta = 0 + self._src_delta += 1 if append_eos_to_src else 0 + self._src_delta -= 1 if remove_eos_from_src else 0 + self._tgt_delta = 0 + self._tgt_delta += 1 if append_eos_to_tgt else 0 + self._tgt_delta -= 1 if remove_eos_from_tgt else 0 + + self._checked_src = False + self._checked_tgt = False + + def _check_src(self, src, expect_eos): + if not self._checked_src: + assert (src[-1] == self.eos[0]) == expect_eos + self._checked_src = True + + def _check_tgt(self, tgt, expect_eos): + if self.has_target and not self._checked_tgt: + assert (tgt[-1] == self.eos[0]) == expect_eos + self._checked_tgt = True + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + def transform(item): + if self.append_eos_to_src: + self.eos = self.eos.to(device=item["source"].device) + self._check_src(item["source"], expect_eos=False) + item["source"] = torch.cat([item["source"], self.eos]) + if self.remove_eos_from_src: + self.eos = self.eos.to(device=item["source"].device) + self._check_src(item["source"], expect_eos=True) + item["source"] = item["source"][:-1] + if self.append_eos_to_tgt: + self.eos = self.eos.to(device=item["target"].device) + self._check_tgt(item["target"], expect_eos=False) + item["target"] = torch.cat([item["target"], self.eos]) + if self.remove_eos_from_tgt: + self.eos = self.eos.to(device=item["target"].device) + self._check_tgt(item["target"], expect_eos=True) + item["target"] = item["target"][:-1] + return item + + samples = list(map(transform, samples)) + return self.dataset.collater(samples) + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + if self.has_target: + src_len, tgt_len = self.dataset.size(index) + return (src_len + self._src_delta, tgt_len + self._tgt_delta) + else: + return self.dataset.size(index) + + def ordered_indices(self): + # NOTE: we assume that the ordering does not change based on the + # addition or removal of eos + return self.dataset.ordered_indices() + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.dataset.prefetch(indices) diff --git a/fairseq/fairseq/data/transform_eos_lang_pair_dataset.py b/fairseq/fairseq/data/transform_eos_lang_pair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b21090144bfc975d4d5a3ee2c21b2e8acde03d --- /dev/null +++ b/fairseq/fairseq/data/transform_eos_lang_pair_dataset.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +import torch + +from . import FairseqDataset + + +class TransformEosLangPairDataset(FairseqDataset): + """A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on + collated samples of language pair dataset. + + Note that the transformation is applied in :func:`collater`. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset that collates sample into + LanguagePairDataset schema + src_eos (int): original source end-of-sentence symbol index to be replaced + new_src_eos (int, optional): new end-of-sentence symbol index to replace source eos symbol + tgt_bos (int, optional): original target beginning-of-sentence symbol index to be replaced + new_tgt_bos (int, optional): new beginning-of-sentence symbol index to replace at the + beginning of 'prev_output_tokens' + """ + + def __init__( + self, + dataset: FairseqDataset, + src_eos: int, + new_src_eos: Optional[int] = None, + tgt_bos: Optional[int] = None, + new_tgt_bos: Optional[int] = None, + ): + self.dataset = dataset + self.src_eos = src_eos + self.new_src_eos = new_src_eos + self.tgt_bos = tgt_bos + self.new_tgt_bos = new_tgt_bos + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples, **extra_args): + samples = self.dataset.collater(samples, **extra_args) + if len(samples) == 0: + return samples + + if "net_input" not in samples: + return samples + + if self.new_src_eos is not None: + if self.dataset.left_pad_source: + assert ( + samples["net_input"]["src_tokens"][:, -1] != self.src_eos + ).sum() == 0 + samples["net_input"]["src_tokens"][:, -1] = self.new_src_eos + else: + eos_idx = samples["net_input"]["src_lengths"] - 1 + assert ( + samples["net_input"]["src_tokens"][ + torch.arange(eos_idx.size(0)), eos_idx + ] + != self.src_eos + ).sum() == 0 + eos_idx = eos_idx.resize_(len(samples["net_input"]["src_lengths"]), 1) + samples["net_input"]["src_tokens"].scatter_( + 1, eos_idx, self.new_src_eos + ) + + if ( + self.new_tgt_bos is not None + and "prev_output_tokens" in samples["net_input"] + ): + if self.dataset.left_pad_target: + # TODO: support different padding direction on target side + raise NotImplementedError( + "TransformEosLangPairDataset does not implement --left-pad-target True option" + ) + else: + assert ( + samples["net_input"]["prev_output_tokens"][:, 0] != self.tgt_bos + ).sum() == 0 + samples["net_input"]["prev_output_tokens"][:, 0] = self.new_tgt_bos + + return samples + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + return self.dataset.size(index) + + @property + def sizes(self): + # dataset.sizes can be a dynamically computed sizes: + return self.dataset.sizes + + def ordered_indices(self): + return self.dataset.ordered_indices() + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.dataset.prefetch(indices) diff --git a/fairseq/fairseq/libbase.cpython-310-x86_64-linux-gnu.so b/fairseq/fairseq/libbase.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..81b8f7301c9b4a076c7bbf918a63eb27bce6d419 --- /dev/null +++ b/fairseq/fairseq/libbase.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afaa6dcec349c5fd7161fa86858f1216df0a7a9710ff701289684141ac62f6ab +size 177760