diff --git a/.gitattributes b/.gitattributes index 34cd0b7d9ed3676541c1cfaa61e9c4e7257e7b42..1d38984a2211bce7158255ab92c06a870855ac02 100644 --- a/.gitattributes +++ b/.gitattributes @@ -45,3 +45,4 @@ fairseq/fairseq/libnat.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge fairseq/fairseq/ngram_repeat_block_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text fairseq/fairseq/data/data_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text +fairseq/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text diff --git a/fairseq/fairseq/data/__pycache__/iterators.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/iterators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f7435afc164d03b1576541adbe67726efafdeaf Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/iterators.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28e0530b97a129b3491e20729a48fc3a7d5346a2 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53645ad3db3b548e786ddb1fa286bd8c037b8b02 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af43e0cd53b5032f71ad7186e9182828bcdd303c Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/prepend_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/prepend_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8d72f36405c83a4ad4165a6e82d02e9ac136d24 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/prepend_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a098f7f5d6338577827f4a1481bfffe340312cd4 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a4fb47bd82bf65883d013e5dbb77ebd76558dd7 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/transform_eos_concat_langpair_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/transform_eos_concat_langpair_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51f50da37a94a0a5a185aa4054baf12a7dbbf125 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/transform_eos_concat_langpair_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc b/fairseq/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7089cef5416f80ad723252cfb69fe09c48359290 Binary files /dev/null and b/fairseq/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc differ diff --git a/fairseq/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so b/fairseq/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..f1711ba23abe18eccbcbf5bf0e79453a1fb2870d --- /dev/null +++ b/fairseq/fairseq/data/token_block_utils_fast.cpython-310-x86_64-linux-gnu.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d4d6c9907358e6cb6d6061abd137909131f1a687a5df6ceb49bdc6ae061b54f +size 285696 diff --git a/fairseq/fairseq/dataclass/__init__.py b/fairseq/fairseq/dataclass/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25408d28ec44cee56eb5fb3ab0c817dc04159e95 --- /dev/null +++ b/fairseq/fairseq/dataclass/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .configs import FairseqDataclass +from .constants import ChoiceEnum + + +__all__ = [ + "FairseqDataclass", + "ChoiceEnum", +] diff --git a/fairseq/fairseq/dataclass/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/dataclass/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b653dc0fc7ff59debaaf3de3cbb5cb62ad94671 Binary files /dev/null and b/fairseq/fairseq/dataclass/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/dataclass/__pycache__/configs.cpython-310.pyc b/fairseq/fairseq/dataclass/__pycache__/configs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64215895594c42c8432eb52b839af2211d3cbc7d Binary files /dev/null and b/fairseq/fairseq/dataclass/__pycache__/configs.cpython-310.pyc differ diff --git a/fairseq/fairseq/dataclass/__pycache__/constants.cpython-310.pyc b/fairseq/fairseq/dataclass/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa703e83aff72bc78694951c92bc0e2a5ab4ab9e Binary files /dev/null and b/fairseq/fairseq/dataclass/__pycache__/constants.cpython-310.pyc differ diff --git a/fairseq/fairseq/dataclass/__pycache__/initialize.cpython-310.pyc b/fairseq/fairseq/dataclass/__pycache__/initialize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..708ce6cb11d4178c77a0bae53a8a34966743a2c3 Binary files /dev/null and b/fairseq/fairseq/dataclass/__pycache__/initialize.cpython-310.pyc differ diff --git a/fairseq/fairseq/dataclass/__pycache__/utils.cpython-310.pyc b/fairseq/fairseq/dataclass/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ab2f2887ff9a71e2f5d6ce3b5d75696fa580f9c Binary files /dev/null and b/fairseq/fairseq/dataclass/__pycache__/utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/dataclass/configs.py b/fairseq/fairseq/dataclass/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..af957fec64711c697da6840969da305e412783df --- /dev/null +++ b/fairseq/fairseq/dataclass/configs.py @@ -0,0 +1,1147 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys +from dataclasses import _MISSING_TYPE, dataclass, field +from typing import Any, List, Optional + +import torch +from omegaconf import II, MISSING + +from fairseq.dataclass.constants import ( + DATASET_IMPL_CHOICES, + DDP_BACKEND_CHOICES, + DDP_COMM_HOOK_CHOICES, + GENERATION_CONSTRAINTS_CHOICES, + GENERATION_DECODING_FORMAT_CHOICES, + LOG_FORMAT_CHOICES, + PIPELINE_CHECKPOINT_CHOICES, + PRINT_ALIGNMENT_CHOICES, + ZERO_SHARDING_CHOICES, +) + + +@dataclass +class FairseqDataclass: + """fairseq base dataclass that supported fetching attributes and metas""" + + _name: Optional[str] = None + + @staticmethod + def name(): + return None + + def _get_all_attributes(self) -> List[str]: + return [k for k in self.__dataclass_fields__.keys()] + + def _get_meta( + self, attribute_name: str, meta: str, default: Optional[Any] = None + ) -> Any: + return self.__dataclass_fields__[attribute_name].metadata.get(meta, default) + + def _get_name(self, attribute_name: str) -> str: + return self.__dataclass_fields__[attribute_name].name + + def _get_default(self, attribute_name: str) -> Any: + if hasattr(self, attribute_name): + if str(getattr(self, attribute_name)).startswith("${"): + return str(getattr(self, attribute_name)) + elif str(self.__dataclass_fields__[attribute_name].default).startswith( + "${" + ): + return str(self.__dataclass_fields__[attribute_name].default) + elif ( + getattr(self, attribute_name) + != self.__dataclass_fields__[attribute_name].default + ): + return getattr(self, attribute_name) + + f = self.__dataclass_fields__[attribute_name] + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default + + def _get_type(self, attribute_name: str) -> Any: + return self.__dataclass_fields__[attribute_name].type + + def _get_help(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "help") + + def _get_argparse_const(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "argparse_const") + + def _get_argparse_alias(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "argparse_alias") + + def _get_choices(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "choices") + + @classmethod + def from_namespace(cls, args): + if isinstance(args, cls): + return args + else: + config = cls() + for k in config.__dataclass_fields__.keys(): + if k.startswith("_"): + # private member, skip + continue + if hasattr(args, k): + setattr(config, k, getattr(args, k)) + + return config + + +@dataclass +class CommonConfig(FairseqDataclass): + # This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were + # used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc. + no_progress_bar: bool = field( + default=False, metadata={"help": "disable progress bar"} + ) + log_interval: int = field( + default=100, + metadata={ + "help": "log progress every N batches (when progress bar is disabled)" + }, + ) + log_format: Optional[LOG_FORMAT_CHOICES] = field( + default=None, metadata={"help": "log format to use"} + ) + log_file: Optional[str] = field( + default=None, metadata={"help": "log file to copy metrics to."} + ) + aim_repo: Optional[str] = field( + default=None, + metadata={"help": "path to Aim repository"}, + ) + aim_run_hash: Optional[str] = field( + default=None, + metadata={ + "help": "Aim run hash. If skipped, creates or continues run " + "based on save_dir" + }, + ) + tensorboard_logdir: Optional[str] = field( + default=None, + metadata={ + "help": "path to save logs for tensorboard, should match --logdir " + "of running tensorboard (default: no tensorboard logging)" + }, + ) + wandb_project: Optional[str] = field( + default=None, + metadata={"help": "Weights and Biases project name to use for logging"}, + ) + azureml_logging: Optional[bool] = field( + default=False, + metadata={"help": "Log scalars to AzureML context"}, + ) + seed: int = field( + default=1, metadata={"help": "pseudo random number generator seed"} + ) + cpu: bool = field(default=False, metadata={"help": "use CPU instead of CUDA"}) + tpu: bool = field(default=False, metadata={"help": "use TPU instead of CUDA"}) + bf16: bool = field(default=False, metadata={"help": "use bfloat16; implies --tpu"}) + memory_efficient_bf16: bool = field( + default=False, + metadata={ + "help": "use a memory-efficient version of BF16 training; implies --bf16" + }, + ) + fp16: bool = field(default=False, metadata={"help": "use FP16"}) + memory_efficient_fp16: bool = field( + default=False, + metadata={ + "help": "use a memory-efficient version of FP16 training; implies --fp16" + }, + ) + fp16_no_flatten_grads: bool = field( + default=False, metadata={"help": "don't flatten FP16 grads tensor"} + ) + fp16_init_scale: int = field( + default=2**7, metadata={"help": "default FP16 loss scale"} + ) + fp16_scale_window: Optional[int] = field( + default=None, + metadata={"help": "number of updates before increasing loss scale"}, + ) + fp16_scale_tolerance: float = field( + default=0.0, + metadata={ + "help": "pct of updates that can overflow before decreasing the loss scale" + }, + ) + on_cpu_convert_precision: bool = field( + default=False, + metadata={ + "help": "if set, the floating point conversion to fp16/bf16 runs on CPU. " + "This reduces bus transfer time and GPU memory usage." + }, + ) + min_loss_scale: float = field( + default=1e-4, + metadata={ + "help": "minimum FP16/AMP loss scale, after which training is stopped" + }, + ) + threshold_loss_scale: Optional[float] = field( + default=None, metadata={"help": "threshold FP16 loss scale from below"} + ) + amp: bool = field(default=False, metadata={"help": "use automatic mixed precision"}) + amp_batch_retries: int = field( + default=2, + metadata={ + "help": "number of retries of same batch after reducing loss scale with AMP" + }, + ) + amp_init_scale: int = field( + default=2**7, metadata={"help": "default AMP loss scale"} + ) + amp_scale_window: Optional[int] = field( + default=None, + metadata={"help": "number of updates before increasing AMP loss scale"}, + ) + user_dir: Optional[str] = field( + default=None, + metadata={ + "help": "path to a python module containing custom extensions (tasks and/or architectures)" + }, + ) + empty_cache_freq: int = field( + default=0, + metadata={"help": "how often to clear the PyTorch CUDA cache (0 to disable)"}, + ) + all_gather_list_size: int = field( + default=16384, + metadata={"help": "number of bytes reserved for gathering stats from workers"}, + ) + model_parallel_size: int = field( + default=1, metadata={"help": "total number of GPUs to parallelize model over"} + ) + quantization_config_path: Optional[str] = field( + default=None, metadata={"help": "path to quantization config file"} + ) + profile: bool = field( + default=False, metadata={"help": "enable autograd profiler emit_nvtx"} + ) + reset_logging: bool = field( + default=False, + metadata={ + "help": "when using Hydra, reset the logging at the beginning of training" + }, + ) + suppress_crashes: bool = field( + default=False, + metadata={ + "help": "suppress crashes when training with the hydra_train entry point so that the " + "main method can return a value (useful for sweeps)" + }, + ) + use_plasma_view: bool = field( + default=False, metadata={"help": "Store indices and sizes in shared memory"} + ) + plasma_path: Optional[str] = field( + default="/tmp/plasma", + metadata={ + "help": "path to run plasma_store, defaults to /tmp/plasma. Paths outside /tmp tend to fail." + }, + ) + + +@dataclass +class DistributedTrainingConfig(FairseqDataclass): + distributed_world_size: int = field( + default=max(1, torch.cuda.device_count()), + metadata={ + "help": "total number of GPUs across all nodes (default: all visible GPUs)" + }, + ) + distributed_num_procs: Optional[int] = field( + default=max(1, torch.cuda.device_count()), + metadata={ + "help": "total number of processes to fork (default: all visible GPUs)" + }, + ) + distributed_rank: Optional[int] = field( + default=0, metadata={"help": "rank of the current worker"} + ) + distributed_backend: str = field( + default="nccl", metadata={"help": "distributed backend"} + ) + distributed_init_method: Optional[str] = field( + default=None, + metadata={ + "help": "typically tcp://hostname:port that will be used to " + "establish initial connetion" + }, + ) + distributed_port: int = field( + default=-1, + metadata={ + "help": "port number (not required if using --distributed-init-method)" + }, + ) + device_id: int = field( + default=os.getenv("LOCAL_RANK", 0), + metadata={ + "help": "which GPU to use (by default looks for $LOCAL_RANK, usually configured automatically)", + "argparse_alias": "--local_rank", + }, + ) + distributed_no_spawn: bool = field( + default=False, + metadata={ + "help": "do not spawn multiple processes even if multiple GPUs are visible" + }, + ) + ddp_backend: DDP_BACKEND_CHOICES = field( + default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"} + ) + ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field( + default="none", metadata={"help": "communication hook"} + ) + bucket_cap_mb: int = field( + default=25, metadata={"help": "bucket size for reduction"} + ) + fix_batches_to_gpus: bool = field( + default=False, + metadata={ + "help": "don't shuffle batches between GPUs; this reduces overall " + "randomness and may affect precision but avoids the cost of re-reading the data" + }, + ) + find_unused_parameters: bool = field( + default=False, + metadata={ + "help": "disable unused parameter detection (not applicable to " + "--ddp-backend=legacy_ddp)" + }, + ) + gradient_as_bucket_view: bool = field( + default=False, + metadata={ + "help": "when set to True, gradients will be views pointing to different offsets of allreduce communication buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size. " + "--gradient-as-bucket-view=gradient_as_bucket_view)" + }, + ) + fast_stat_sync: bool = field( + default=False, + metadata={"help": "[deprecated] this is now defined per Criterion"}, + ) + heartbeat_timeout: int = field( + default=-1, + metadata={ + "help": "kill the job if no progress is made in N seconds; " + "set to -1 to disable" + }, + ) + broadcast_buffers: bool = field( + default=False, + metadata={ + "help": "Copy non-trainable parameters between GPUs, such as " + "batchnorm population statistics" + }, + ) + slowmo_momentum: Optional[float] = field( + default=None, + metadata={ + "help": "SlowMo momentum term; by default use 0.0 for 16 GPUs, " + "0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs" + }, + ) + slowmo_base_algorithm: str = field( + default="localsgd", + metadata={ + "help": "Base algorithm. Either 'localsgd' or 'sgp'. Please refer " + "to the documentation of 'slowmo_base_algorithm' parameter in " + "https://fairscale.readthedocs.io/en/latest/api/experimental/nn/slowmo_ddp.html " + "for more details" + }, + ) + localsgd_frequency: int = field( + default=3, metadata={"help": "Local SGD allreduce frequency"} + ) + nprocs_per_node: int = field( + default=max(1, torch.cuda.device_count()), + metadata={ + "help": "number of GPUs in each node. An allreduce operation across GPUs in " + "a node is very fast. Hence, we do allreduce across GPUs in a node, " + "and gossip across different nodes" + }, + ) + pipeline_model_parallel: bool = field( + default=False, + metadata={"help": "if set, use pipeline model parallelism across GPUs"}, + ) + pipeline_balance: Optional[str] = field( + default=None, + metadata={ + "help": "partition the model into N_K pieces, where each piece " + "contains N_i layers. The sum(args.pipeline_balance) " + "should equal the total number of layers in the model" + }, + ) + pipeline_devices: Optional[str] = field( + default=None, + metadata={ + "help": "a list of device indices indicating which device to place " + "each of the N_K partitions. The length of this list should " + "equal the length of the --pipeline-balance argument" + }, + ) + pipeline_chunks: Optional[int] = field( + default=0, metadata={"help": "microbatch count for pipeline model parallelism"} + ) + pipeline_encoder_balance: Optional[str] = field( + default=None, + metadata={ + "help": "partition the pipeline parallel encoder into N_K pieces, where each piece " + "contains N_i layers. The sum(args.pipeline_encoder_balance) " + "should equal the total number of encoder layers in the model" + }, + ) + pipeline_encoder_devices: Optional[str] = field( + default=None, + metadata={ + "help": "a list of device indices indicating which device to place " + "each of the N_K partitions. The length of this list should " + "equal the length of the --pipeline-encoder-balance argument" + }, + ) + pipeline_decoder_balance: Optional[str] = field( + default=None, + metadata={ + "help": "partition the pipeline parallel decoder into N_K pieces, where each piece " + "contains N_i layers. The sum(args.pipeline_decoder_balance) " + "should equal the total number of decoder layers in the model" + }, + ) + pipeline_decoder_devices: Optional[str] = field( + default=None, + metadata={ + "help": "a list of device indices indicating which device to place " + "each of the N_K partitions. The length of this list should " + "equal the length of the --pipeline-decoder-balance argument" + }, + ) + pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field( + default="never", + metadata={"help": "checkpointing mode for pipeline model parallelism"}, + ) + zero_sharding: ZERO_SHARDING_CHOICES = field( + default="none", metadata={"help": "ZeRO sharding"} + ) + fp16: bool = II("common.fp16") + memory_efficient_fp16: bool = II("common.memory_efficient_fp16") + tpu: bool = II("common.tpu") + # configuration for --ddp-backend=fully_sharded + no_reshard_after_forward: bool = field( + default=False, + metadata={"help": "don't reshard parameters after forward pass"}, + ) + fp32_reduce_scatter: bool = field( + default=False, + metadata={"help": "reduce-scatter grads in FP32"}, + ) + cpu_offload: bool = field( + default=False, metadata={"help": "offload FP32 params to CPU"} + ) + use_sharded_state: bool = field( + default=False, + metadata={"help": "use sharded checkpoint files"}, + ) + not_fsdp_flatten_parameters: bool = field( + default=False, + metadata={"help": "not flatten parameter param for fsdp"}, + ) + + +@dataclass +class DatasetConfig(FairseqDataclass): + num_workers: int = field( + default=1, metadata={"help": "how many subprocesses to use for data loading"} + ) + skip_invalid_size_inputs_valid_test: bool = field( + default=False, + metadata={"help": "ignore too long or too short lines in valid and test set"}, + ) + max_tokens: Optional[int] = field( + default=None, metadata={"help": "maximum number of tokens in a batch"} + ) + batch_size: Optional[int] = field( + default=None, + metadata={ + "help": "number of examples in a batch", + "argparse_alias": "--max-sentences", + }, + ) + required_batch_size_multiple: int = field( + default=8, metadata={"help": "batch size will be a multiplier of this value"} + ) + required_seq_len_multiple: int = field( + default=1, + metadata={ + "help": "maximum sequence length in batch will be a multiplier of this value" + }, + ) + dataset_impl: Optional[DATASET_IMPL_CHOICES] = field( + default=None, metadata={"help": "output dataset implementation"} + ) + data_buffer_size: int = field( + default=10, metadata={"help": "Number of batches to preload"} + ) + train_subset: str = field( + default="train", + metadata={"help": "data subset to use for training (e.g. train, valid, test)"}, + ) + valid_subset: str = field( + default="valid", + metadata={ + "help": "comma separated list of data subsets to use for validation" + " (e.g. train, valid, test)" + }, + ) + combine_valid_subsets: Optional[bool] = field( + default=None, + metadata={ + "help": "comma separated list of data subsets to use for validation" + " (e.g. train, valid, test)", + "argparse_alias": "--combine-val", + }, + ) + ignore_unused_valid_subsets: Optional[bool] = field( + default=False, + metadata={"help": "do not raise error if valid subsets are ignored"}, + ) + + validate_interval: int = field( + default=1, metadata={"help": "validate every N epochs"} + ) + validate_interval_updates: int = field( + default=0, metadata={"help": "validate every N updates"} + ) + validate_after_updates: int = field( + default=0, metadata={"help": "dont validate until reaching this many updates"} + ) + fixed_validation_seed: Optional[int] = field( + default=None, metadata={"help": "specified random seed for validation"} + ) + disable_validation: bool = field( + default=False, metadata={"help": "disable validation"} + ) + max_tokens_valid: Optional[int] = field( + default=II("dataset.max_tokens"), + metadata={ + "help": "maximum number of tokens in a validation batch" + " (defaults to --max-tokens)" + }, + ) + batch_size_valid: Optional[int] = field( + default=II("dataset.batch_size"), + metadata={ + "help": "batch size of the validation batch (defaults to --batch-size)", + "argparse_alias": "--max-sentences-valid", + }, + ) + max_valid_steps: Optional[int] = field( + default=None, + metadata={"help": "How many batches to evaluate", "argparse_alias": "--nval"}, + ) + curriculum: int = field( + default=0, metadata={"help": "don't shuffle batches for first N epochs"} + ) + gen_subset: str = field( + default="test", + metadata={"help": "data subset to generate (train, valid, test)"}, + ) + num_shards: int = field( + default=1, metadata={"help": "shard generation over N shards"} + ) + shard_id: int = field( + default=0, metadata={"help": "id of the shard to generate (id < num_shards)"} + ) + grouped_shuffling: bool = field( + default=False, + metadata={ + "help": "shuffle batches in groups of num_shards to enable similar sequence lengths on each GPU worker when batches are sorted by length", + }, + ) + update_epoch_batch_itr: bool = field( + default=II("dataset.grouped_shuffling"), + metadata={ + "help": "if true then prevents the reuse the epoch batch iterator by setting can_reuse_epoch_itr to false, defaults to --grouped-shuffling )", + }, + ) + update_ordered_indices_seed: bool = field( + default=False, + metadata={ + "help": "if true then increment seed with epoch for getting batch iterators, defautls to False.", + }, + ) + + +@dataclass +class OptimizationConfig(FairseqDataclass): + max_epoch: int = field( + default=0, metadata={"help": "force stop training at specified epoch"} + ) + max_update: int = field( + default=0, metadata={"help": "force stop training at specified update"} + ) + stop_time_hours: float = field( + default=0, + metadata={ + "help": "force stop training after specified cumulative time (if >0)" + }, + ) + clip_norm: float = field( + default=0.0, metadata={"help": "clip threshold of gradients"} + ) + sentence_avg: bool = field( + default=False, + metadata={ + "help": "normalize gradients by the number of sentences in a batch" + " (default is to normalize by number of tokens)" + }, + ) + update_freq: List[int] = field( + default_factory=lambda: [1], + metadata={"help": "update parameters every N_i batches, when in epoch i"}, + ) + lr: List[float] = field( + default_factory=lambda: [0.25], + metadata={ + "help": "learning rate for the first N epochs; all epochs >N using LR_N" + " (note: this may be interpreted differently depending on --lr-scheduler)" + }, + ) + stop_min_lr: float = field( + default=-1.0, + metadata={"help": "stop training when the learning rate reaches this minimum"}, + ) + use_bmuf: bool = field( + default=False, + metadata={ + "help": "specify global optimizer for syncing models on different GPUs/shards" + }, + ) + skip_remainder_batch: Optional[bool] = field( + default=False, + metadata={ + "help": "if set, include the last (partial) batch of each epoch in training" + " (default is to skip it)." + }, + ) + debug_param_names: bool = False + + +@dataclass +class CheckpointConfig(FairseqDataclass): + save_dir: str = field( + default="checkpoints", metadata={"help": "path to save checkpoints"} + ) + restore_file: str = field( + default="checkpoint_last.pt", + metadata={ + "help": "filename from which to load checkpoint " + "(default: /checkpoint_last.pt" + }, + ) + continue_once: Optional[str] = field( + default=None, + metadata={ + "help": "continues from this checkpoint, unless a checkpoint indicated in 'restore_file' option is present" + }, + ) + finetune_from_model: Optional[str] = field( + default=None, + metadata={ + "help": "finetune from a pretrained model; note that meters and lr scheduler will be reset" + }, + ) + reset_dataloader: bool = field( + default=False, + metadata={ + "help": "if set, does not reload dataloader state from the checkpoint" + }, + ) + reset_lr_scheduler: bool = field( + default=False, + metadata={ + "help": "if set, does not load lr scheduler state from the checkpoint" + }, + ) + reset_meters: bool = field( + default=False, + metadata={"help": "if set, does not load meters from the checkpoint"}, + ) + reset_optimizer: bool = field( + default=False, + metadata={"help": "if set, does not load optimizer state from the checkpoint"}, + ) + optimizer_overrides: str = field( + default="{}", + metadata={ + "help": "a dictionary used to override optimizer args when loading a checkpoint" + }, + ) + save_interval: int = field( + default=1, metadata={"help": "save a checkpoint every N epochs"} + ) + save_interval_updates: int = field( + default=0, metadata={"help": "save a checkpoint (and validate) every N updates"} + ) + keep_interval_updates: int = field( + default=-1, + metadata={ + "help": "keep the last N checkpoints saved with --save-interval-updates" + }, + ) + keep_interval_updates_pattern: int = field( + default=-1, + metadata={ + "help": "when used with --keep-interval-updates, skips deleting " + "any checkpoints with update X where " + "X %% keep_interval_updates_pattern == 0" + }, + ) + keep_last_epochs: int = field( + default=-1, metadata={"help": "keep last N epoch checkpoints"} + ) + keep_best_checkpoints: int = field( + default=-1, metadata={"help": "keep best N checkpoints based on scores"} + ) + no_save: bool = field( + default=False, metadata={"help": "don't save models or checkpoints"} + ) + no_epoch_checkpoints: bool = field( + default=False, metadata={"help": "only store last and best checkpoints"} + ) + no_last_checkpoints: bool = field( + default=False, metadata={"help": "don't store last checkpoints"} + ) + no_save_optimizer_state: bool = field( + default=False, + metadata={"help": "don't save optimizer-state as part of checkpoint"}, + ) + best_checkpoint_metric: str = field( + default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'} + ) + maximize_best_checkpoint_metric: bool = field( + default=False, + metadata={ + "help": 'select the largest metric value for saving "best" checkpoints' + }, + ) + patience: int = field( + default=-1, + metadata={ + "help": ( + "early stop training if valid performance doesn't " + "improve for N consecutive validation runs; note " + "that this is influenced by --validate-interval" + ) + }, + ) + checkpoint_suffix: str = field( + default="", metadata={"help": "suffix to add to the checkpoint file name"} + ) + checkpoint_shard_count: int = field( + default=1, + metadata={ + "help": "Number of shards containing the checkpoint - " + "if the checkpoint is over 300GB, it is preferable " + "to split it into shards to prevent OOM on CPU while loading " + "the checkpoint" + }, + ) + load_checkpoint_on_all_dp_ranks: bool = field( + default=False, + metadata={ + "help": "load checkpoints on all data parallel devices " + "(default: only load on rank 0 and broadcast to other devices)" + }, + ) + write_checkpoints_asynchronously: bool = field( + default=False, + metadata={ + "help": ( + "Write checkpoints asynchronously in a separate " + "thread. NOTE: This feature is currently being tested." + ), + "argparse_alias": "--save-async", + }, + ) + model_parallel_size: int = II("common.model_parallel_size") + + +@dataclass +class FairseqBMUFConfig(FairseqDataclass): + block_lr: float = field( + default=1, metadata={"help": "block learning rate for bmuf"} + ) + block_momentum: float = field( + default=0.875, metadata={"help": "block momentum for bmuf"} + ) + global_sync_iter: int = field( + default=50, metadata={"help": "Iteration for syncing global model"} + ) + warmup_iterations: int = field( + default=500, metadata={"help": "warmup iterations for model to broadcast"} + ) + use_nbm: bool = field( + default=False, + metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"}, + ) + average_sync: bool = field( + default=False, + metadata={ + "help": "Specify whether you want to average the local momentum after each sync" + }, + ) + distributed_world_size: int = II("distributed_training.distributed_world_size") + + +@dataclass +class GenerationConfig(FairseqDataclass): + beam: int = field( + default=5, + metadata={"help": "beam size"}, + ) + beam_mt: int = field( + default=0, + metadata={"help": "beam size for the first-pass decoder"}, + ) + nbest: int = field( + default=1, + metadata={"help": "number of hypotheses to output"}, + ) + max_len_a: float = field( + default=0, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length" + }, + ) + max_len_b: int = field( + default=200, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length" + }, + ) + max_len_a_mt: float = field( + default=0, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder" + }, + ) + max_len_b_mt: int = field( + default=200, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder" + }, + ) + min_len: int = field( + default=1, + metadata={"help": "minimum generation length"}, + ) + match_source_len: bool = field( + default=False, + metadata={"help": "generations should match the source length"}, + ) + unnormalized: bool = field( + default=False, + metadata={"help": "compare unnormalized hypothesis scores"}, + ) + no_early_stop: bool = field( + default=False, + metadata={"help": "deprecated"}, + ) + no_beamable_mm: bool = field( + default=False, + metadata={"help": "don't use BeamableMM in attention layers"}, + ) + lenpen: float = field( + default=1, + metadata={ + "help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences" + }, + ) + lenpen_mt: float = field( + default=1, + metadata={ + "help": "length penalty for the first-pass decoder: <1.0 favors shorter, >1.0 favors longer sentences" + }, + ) + unkpen: float = field( + default=0, + metadata={ + "help": "unknown word penalty: <0 produces more unks, >0 produces fewer" + }, + ) + replace_unk: Optional[str] = field( + default=None, + metadata={ + "help": "perform unknown replacement (optionally with alignment dictionary)", + "argparse_const": "@@ ", + }, + ) + sacrebleu: bool = field( + default=False, + metadata={"help": "score with sacrebleu"}, + ) + score_reference: bool = field( + default=False, + metadata={"help": "just score the reference translation"}, + ) + prefix_size: int = field( + default=0, + metadata={"help": "initialize generation by target prefix of given length"}, + ) + no_repeat_ngram_size: int = field( + default=0, + metadata={ + "help": "ngram blocking such that this size ngram cannot be repeated in the generation" + }, + ) + sampling: bool = field( + default=False, + metadata={"help": "sample hypotheses instead of using beam search"}, + ) + sampling_topk: int = field( + default=-1, + metadata={"help": "sample from top K likely next words instead of all words"}, + ) + sampling_topp: float = field( + default=-1.0, + metadata={ + "help": "sample from the smallest set whose cumulative probability mass exceeds p for next words" + }, + ) + constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field( + default=None, + metadata={ + "help": "enables lexically constrained decoding", + "argparse_const": "ordered", + }, + ) + temperature: float = field( + default=1.0, + metadata={"help": "temperature for generation"}, + ) + diverse_beam_groups: int = field( + default=-1, + metadata={"help": "number of groups for Diverse Beam Search"}, + ) + diverse_beam_strength: float = field( + default=0.5, + metadata={"help": "strength of diversity penalty for Diverse Beam Search"}, + ) + diversity_rate: float = field( + default=-1.0, + metadata={"help": "strength of diversity penalty for Diverse Siblings Search"}, + ) + print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field( + default=None, + metadata={ + "help": "if set, uses attention feedback to compute and print alignment to source tokens " + "(valid options are: hard, soft, otherwise treated as hard alignment)", + "argparse_const": "hard", + }, + ) + print_step: bool = field( + default=False, + metadata={"help": "print steps"}, + ) + lm_path: Optional[str] = field( + default=None, + metadata={"help": "path to lm checkpoint for lm fusion"}, + ) + lm_weight: float = field( + default=0.0, + metadata={"help": "weight for lm probs for lm fusion"}, + ) + + # arguments for iterative refinement generator + iter_decode_eos_penalty: float = field( + default=0.0, + metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, + ) + iter_decode_max_iter: int = field( + default=10, + metadata={"help": "maximum iterations for iterative refinement."}, + ) + iter_decode_force_max_iter: bool = field( + default=False, + metadata={ + "help": "if set, run exact the maximum number of iterations without early stop" + }, + ) + iter_decode_with_beam: int = field( + default=1, + metadata={ + "help": "if > 1, model will generate translations varying by the lengths." + }, + ) + iter_decode_with_external_reranker: bool = field( + default=False, + metadata={ + "help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations" + }, + ) + retain_iter_history: bool = field( + default=False, + metadata={ + "help": "if set, decoding returns the whole history of iterative refinement" + }, + ) + retain_dropout: bool = field( + default=False, + metadata={"help": "Use dropout at inference time"}, + ) + # temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed + # retain_dropout_modules: Optional[List[str]] = field( + retain_dropout_modules: Any = field( + default=None, + metadata={ + "help": "if set, only retain dropout for the specified modules; " + "if not set, then dropout will be retained for all modules" + }, + ) + # special decoding format for advanced decoding. + decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field( + default=None, + metadata={"help": "special decoding format for advanced decoding."}, + ) + no_seed_provided: bool = field( + default=False, + metadata={"help": "if set, dont use seed for initializing random generators"}, + ) + eos_token: Optional[str] = field( + default=None, + metadata={"help": "EOS token"}, + ) + + +@dataclass +class CommonEvalConfig(FairseqDataclass): + path: Optional[str] = field( + default=None, + metadata={"help": "path(s) to model file(s), colon separated"}, + ) + post_process: Optional[str] = field( + default=None, + metadata={ + "help": ( + "post-process text by removing BPE, letter segmentation, etc. " + "Valid options can be found in fairseq.data.utils.post_process." + ), + "argparse_const": "subword_nmt", + "argparse_alias": "--remove-bpe", + }, + ) + quiet: bool = field(default=False, metadata={"help": "only print final scores"}) + model_overrides: str = field( + default="{}", + metadata={ + "help": "a dictionary used to override model args at generation that were used during model training" + }, + ) + results_path: Optional[str] = field( + default=None, metadata={"help": "path to save eval results (optional)"} + ) + + +@dataclass +class EvalLMConfig(FairseqDataclass): + output_word_probs: bool = field( + default=False, + metadata={ + "help": "if set, outputs words and their predicted log probabilities to standard output" + }, + ) + output_word_stats: bool = field( + default=False, + metadata={ + "help": "if set, outputs word statistics such as word count, average probability, etc" + }, + ) + context_window: int = field( + default=0, + metadata={ + "help": "ensures that every evaluated token has access to a context of at least this size, if possible" + }, + ) + softmax_batch: int = field( + default=sys.maxsize, + metadata={ + "help": "if BxT is more than this, will batch the softmax over vocab to this amount of tokens, in order to fit into GPU memory" + }, + ) + + +@dataclass +class InteractiveConfig(FairseqDataclass): + buffer_size: int = field( + default=0, + metadata={ + "help": "read this many sentences into a buffer before processing them" + }, + ) + input: str = field( + default="-", + metadata={"help": "file to read from; use - for stdin"}, + ) + + +@dataclass +class EMAConfig(FairseqDataclass): + store_ema: bool = field( + default=False, metadata={help: "store exponential moving average shadow model"} + ) + ema_decay: float = field( + default=0.9999, metadata={"help": "decay for exponential moving average model"} + ) + ema_start_update: int = field( + default=0, metadata={"help": "start EMA update after this many model updates"} + ) + ema_seed_model: Optional[str] = field( + default=None, + metadata={ + "help": "Seed to load EMA model from. " + "Used to load EMA model separately from the actual model." + }, + ) + ema_update_freq: int = field( + default=1, metadata={"help": "Do EMA update every this many model updates"} + ) + ema_fp32: bool = field( + default=False, + metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"}, + ) + + +@dataclass +class FairseqConfig(FairseqDataclass): + common: CommonConfig = CommonConfig() + common_eval: CommonEvalConfig = CommonEvalConfig() + distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() + dataset: DatasetConfig = DatasetConfig() + optimization: OptimizationConfig = OptimizationConfig() + checkpoint: CheckpointConfig = CheckpointConfig() + bmuf: FairseqBMUFConfig = FairseqBMUFConfig() + generation: GenerationConfig = GenerationConfig() + eval_lm: EvalLMConfig = EvalLMConfig() + interactive: InteractiveConfig = InteractiveConfig() + model: Any = MISSING + task: Any = None + criterion: Any = None + optimizer: Any = None + lr_scheduler: Any = None + scoring: Any = None + bpe: Any = None + tokenizer: Any = None + ema: EMAConfig = EMAConfig() diff --git a/fairseq/fairseq/dataclass/constants.py b/fairseq/fairseq/dataclass/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..5af92f2b3aa51e460f0b045a348d3766f93eb90b --- /dev/null +++ b/fairseq/fairseq/dataclass/constants.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum, EnumMeta +from typing import List + + +class StrEnumMeta(EnumMeta): + # this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see + # https://github.com/facebookresearch/hydra/issues/1156 + @classmethod + def __instancecheck__(cls, other): + return "enum" in str(type(other)) + + +class StrEnum(Enum, metaclass=StrEnumMeta): + def __str__(self): + return self.value + + def __eq__(self, other: str): + return self.value == other + + def __repr__(self): + return self.value + + def __hash__(self): + return hash(str(self)) + + +def ChoiceEnum(choices: List[str]): + """return the Enum class used to enforce list of choices""" + return StrEnum("Choices", {k: k for k in choices}) + + +LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) +DDP_BACKEND_CHOICES = ChoiceEnum( + [ + "c10d", # alias for pytorch_ddp + "fully_sharded", # FullyShardedDataParallel from fairscale + "legacy_ddp", + "no_c10d", # alias for legacy_ddp + "pytorch_ddp", + "slowmo", + ] +) +DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"]) +DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"]) +GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) +GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum( + ["unigram", "ensemble", "vote", "dp", "bs"] +) +ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) +PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) +PRINT_ALIGNMENT_CHOICES = ChoiceEnum(["hard", "soft"]) diff --git a/fairseq/fairseq/dataclass/initialize.py b/fairseq/fairseq/dataclass/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7784bad194761b6d60ccfa5aed2fc01aa123c0 --- /dev/null +++ b/fairseq/fairseq/dataclass/initialize.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import logging +from hydra.core.config_store import ConfigStore +from fairseq.dataclass.configs import FairseqConfig +from omegaconf import DictConfig, OmegaConf + + +logger = logging.getLogger(__name__) + + +def hydra_init(cfg_name="config") -> None: + + cs = ConfigStore.instance() + cs.store(name=f"{cfg_name}", node=FairseqConfig) + + for k in FairseqConfig.__dataclass_fields__: + v = FairseqConfig.__dataclass_fields__[k].default + try: + cs.store(name=k, node=v) + except BaseException: + logger.error(f"{k} - {v}") + raise + + +def add_defaults(cfg: DictConfig) -> None: + """This function adds default values that are stored in dataclasses that hydra doesn't know about""" + + from fairseq.registry import REGISTRIES + from fairseq.tasks import TASK_DATACLASS_REGISTRY + from fairseq.models import ARCH_MODEL_NAME_REGISTRY, MODEL_DATACLASS_REGISTRY + from fairseq.dataclass.utils import merge_with_parent + from typing import Any + + OmegaConf.set_struct(cfg, False) + + for k, v in FairseqConfig.__dataclass_fields__.items(): + field_cfg = cfg.get(k) + if field_cfg is not None and v.type == Any: + dc = None + + if isinstance(field_cfg, str): + field_cfg = DictConfig({"_name": field_cfg}) + field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"] + + name = getattr(field_cfg, "_name", None) + + if k == "task": + dc = TASK_DATACLASS_REGISTRY.get(name) + elif k == "model": + name = ARCH_MODEL_NAME_REGISTRY.get(name, name) + dc = MODEL_DATACLASS_REGISTRY.get(name) + elif k in REGISTRIES: + dc = REGISTRIES[k]["dataclass_registry"].get(name) + + if dc is not None: + cfg[k] = merge_with_parent(dc, field_cfg) diff --git a/fairseq/fairseq/dataclass/utils.py b/fairseq/fairseq/dataclass/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f6467d5f402f3904dd2adf67101a248e89bba887 --- /dev/null +++ b/fairseq/fairseq/dataclass/utils.py @@ -0,0 +1,510 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import inspect +import logging +import os +import re +from argparse import ArgumentError, ArgumentParser, Namespace +from dataclasses import _MISSING_TYPE, MISSING, is_dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Type + +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import FairseqConfig +from hydra.core.global_hydra import GlobalHydra +from hydra.experimental import compose, initialize +from omegaconf import DictConfig, OmegaConf, open_dict, _utils + +logger = logging.getLogger(__name__) + + +def eval_str_list(x, x_type=float): + if x is None: + return None + if isinstance(x, str): + if len(x) == 0: + return [] + x = ast.literal_eval(x) + try: + return list(map(x_type, x)) + except TypeError: + return [x_type(x)] + + +def interpret_dc_type(field_type): + if isinstance(field_type, str): + raise RuntimeError("field should be a type") + + if field_type == Any: + return str + + typestring = str(field_type) + if re.match( + r"(typing.|^)Union\[(.*), NoneType\]$", typestring + ) or typestring.startswith("typing.Optional"): + return field_type.__args__[0] + return field_type + + +def gen_parser_from_dataclass( + parser: ArgumentParser, + dataclass_instance: FairseqDataclass, + delete_default: bool = False, + with_prefix: Optional[str] = None, +) -> None: + """ + convert a dataclass instance to tailing parser arguments. + + If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are + building a flat namespace from a structured dataclass (see transformer_config.py for example). + """ + + def argparse_name(name: str): + if name == "data" and (with_prefix is None or with_prefix == ""): + # normally data is positional args, so we don't add the -- nor the prefix + return name + if name == "_name": + # private member, skip + return None + full_name = "--" + name.replace("_", "-") + if with_prefix is not None and with_prefix != "": + # if a prefix is specified, construct the prefixed arg name + full_name = with_prefix + "-" + full_name[2:] # strip -- when composing + return full_name + + def get_kwargs_from_dc( + dataclass_instance: FairseqDataclass, k: str + ) -> Dict[str, Any]: + """k: dataclass attributes""" + + kwargs = {} + + field_type = dataclass_instance._get_type(k) + inter_type = interpret_dc_type(field_type) + + field_default = dataclass_instance._get_default(k) + + if isinstance(inter_type, type) and issubclass(inter_type, Enum): + field_choices = [t.value for t in list(inter_type)] + else: + field_choices = None + + field_help = dataclass_instance._get_help(k) + field_const = dataclass_instance._get_argparse_const(k) + + if isinstance(field_default, str) and field_default.startswith("${"): + kwargs["default"] = field_default + else: + if field_default is MISSING: + kwargs["required"] = True + if field_choices is not None: + kwargs["choices"] = field_choices + if ( + isinstance(inter_type, type) + and (issubclass(inter_type, List) or issubclass(inter_type, Tuple)) + ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)): + if "int" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, int) + elif "float" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, float) + elif "str" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, str) + else: + raise NotImplementedError( + "parsing of type " + str(inter_type) + " is not implemented" + ) + if field_default is not MISSING: + kwargs["default"] = ( + ",".join(map(str, field_default)) + if field_default is not None + else None + ) + elif ( + isinstance(inter_type, type) and issubclass(inter_type, Enum) + ) or "Enum" in str(inter_type): + kwargs["type"] = str + if field_default is not MISSING: + if isinstance(field_default, Enum): + kwargs["default"] = field_default.value + else: + kwargs["default"] = field_default + elif inter_type is bool: + kwargs["action"] = ( + "store_false" if field_default is True else "store_true" + ) + kwargs["default"] = field_default + else: + kwargs["type"] = inter_type + if field_default is not MISSING: + kwargs["default"] = field_default + + # build the help with the hierarchical prefix + if with_prefix is not None and with_prefix != "" and field_help is not None: + field_help = with_prefix[2:] + ": " + field_help + + kwargs["help"] = field_help + if field_const is not None: + kwargs["const"] = field_const + kwargs["nargs"] = "?" + + return kwargs + + for k in dataclass_instance._get_all_attributes(): + field_name = argparse_name(dataclass_instance._get_name(k)) + field_type = dataclass_instance._get_type(k) + if field_name is None: + continue + elif inspect.isclass(field_type) and issubclass(field_type, FairseqDataclass): + # for fields that are of type FairseqDataclass, we can recursively + # add their fields to the namespace (so we add the args from model, task, etc. to the root namespace) + prefix = None + if with_prefix is not None: + # if a prefix is specified, then we don't want to copy the subfields directly to the root namespace + # but we prefix them with the name of the current field. + prefix = field_name + gen_parser_from_dataclass(parser, field_type(), delete_default, prefix) + continue + + kwargs = get_kwargs_from_dc(dataclass_instance, k) + + field_args = [field_name] + alias = dataclass_instance._get_argparse_alias(k) + if alias is not None: + field_args.append(alias) + + if "default" in kwargs: + if isinstance(kwargs["default"], str) and kwargs["default"].startswith( + "${" + ): + if kwargs["help"] is None: + # this is a field with a name that will be added elsewhere + continue + else: + del kwargs["default"] + if delete_default and "default" in kwargs: + del kwargs["default"] + try: + parser.add_argument(*field_args, **kwargs) + except ArgumentError: + pass + + +def _set_legacy_defaults(args, cls): + """Helper to set default arguments based on *add_args*.""" + if not hasattr(cls, "add_args"): + return + + import argparse + + parser = argparse.ArgumentParser( + argument_default=argparse.SUPPRESS, allow_abbrev=False + ) + cls.add_args(parser) + # copied from argparse.py: + defaults = argparse.Namespace() + for action in parser._actions: + if action.dest is not argparse.SUPPRESS: + if not hasattr(defaults, action.dest): + if action.default is not argparse.SUPPRESS: + setattr(defaults, action.dest, action.default) + for key, default_value in vars(defaults).items(): + if not hasattr(args, key): + setattr(args, key, default_value) + + +def _override_attr( + sub_node: str, data_class: Type[FairseqDataclass], args: Namespace +) -> List[str]: + overrides = [] + + if not inspect.isclass(data_class) or not issubclass(data_class, FairseqDataclass): + return overrides + + def get_default(f): + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default + + for k, v in data_class.__dataclass_fields__.items(): + if k.startswith("_"): + # private member, skip + continue + + val = get_default(v) if not hasattr(args, k) else getattr(args, k) + + field_type = interpret_dc_type(v.type) + if ( + isinstance(val, str) + and not val.startswith("${") # not interpolation + and field_type != str + and ( + not inspect.isclass(field_type) or not issubclass(field_type, Enum) + ) # not choices enum + ): + # upgrade old models that stored complex parameters as string + val = ast.literal_eval(val) + + if isinstance(val, tuple): + val = list(val) + + v_type = getattr(v.type, "__origin__", None) + if ( + (v_type is List or v_type is list or v_type is Optional) + # skip interpolation + and not (isinstance(val, str) and val.startswith("${")) + ): + # if type is int but val is float, then we will crash later - try to convert here + if hasattr(v.type, "__args__"): + t_args = v.type.__args__ + if len(t_args) == 1 and (t_args[0] is float or t_args[0] is int): + val = list(map(t_args[0], val)) + elif val is not None and ( + field_type is int or field_type is bool or field_type is float + ): + try: + val = field_type(val) + except: + pass # ignore errors here, they are often from interpolation args + + if val is None: + overrides.append("{}.{}=null".format(sub_node, k)) + elif val == "": + overrides.append("{}.{}=''".format(sub_node, k)) + elif isinstance(val, str): + val = val.replace("'", r"\'") + overrides.append("{}.{}='{}'".format(sub_node, k, val)) + elif isinstance(val, FairseqDataclass): + overrides += _override_attr(f"{sub_node}.{k}", type(val), args) + elif isinstance(val, Namespace): + sub_overrides, _ = override_module_args(val) + for so in sub_overrides: + overrides.append(f"{sub_node}.{k}.{so}") + else: + overrides.append("{}.{}={}".format(sub_node, k, val)) + + return overrides + + +def migrate_registry( + name, value, registry, args, overrides, deletes, use_name_as_val=False +): + if value in registry: + overrides.append("{}={}".format(name, value)) + overrides.append("{}._name={}".format(name, value)) + overrides.extend(_override_attr(name, registry[value], args)) + elif use_name_as_val and value is not None: + overrides.append("{}={}".format(name, value)) + else: + deletes.append(name) + + +def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: + """use the field in args to overrides those in cfg""" + overrides = [] + deletes = [] + + for k in FairseqConfig.__dataclass_fields__.keys(): + overrides.extend( + _override_attr(k, FairseqConfig.__dataclass_fields__[k].type, args) + ) + + if args is not None: + if hasattr(args, "task"): + from fairseq.tasks import TASK_DATACLASS_REGISTRY + + migrate_registry( + "task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes + ) + else: + deletes.append("task") + + # these options will be set to "None" if they have not yet been migrated + # so we can populate them with the entire flat args + CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"} + + from fairseq.registry import REGISTRIES + + for k, v in REGISTRIES.items(): + if hasattr(args, k): + migrate_registry( + k, + getattr(args, k), + v["dataclass_registry"], + args, + overrides, + deletes, + use_name_as_val=k not in CORE_REGISTRIES, + ) + else: + deletes.append(k) + + no_dc = True + if hasattr(args, "arch"): + from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_MODEL_NAME_REGISTRY + + if args.arch in ARCH_MODEL_REGISTRY: + m_cls = ARCH_MODEL_REGISTRY[args.arch] + dc = getattr(m_cls, "__dataclass", None) + if dc is not None: + m_name = ARCH_MODEL_NAME_REGISTRY[args.arch] + overrides.append("model={}".format(m_name)) + overrides.append("model._name={}".format(args.arch)) + # override model params with those exist in args + overrides.extend(_override_attr("model", dc, args)) + no_dc = False + if no_dc: + deletes.append("model") + + return overrides, deletes + + +class omegaconf_no_object_check: + def __init__(self): + # Changed in https://github.com/omry/omegaconf/pull/911 - both are kept for back compat. + if hasattr(_utils, "is_primitive_type"): + self.old_is_primitive = _utils.is_primitive_type + else: + self.old_is_primitive = _utils.is_primitive_type_annotation + + def __enter__(self): + if hasattr(_utils, "is_primitive_type"): + _utils.is_primitive_type = lambda _: True + else: + _utils.is_primitive_type_annotation = lambda _: True + + def __exit__(self, type, value, traceback): + if hasattr(_utils, "is_primitive_type"): + _utils.is_primitive_type = self.old_is_primitive + else: + _utils.is_primitive_type_annotation = self.old_is_primitive + + +def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: + """Convert a flat argparse.Namespace to a structured DictConfig.""" + + # Here we are using field values provided in args to override counterparts inside config object + overrides, deletes = override_module_args(args) + + # configs will be in fairseq/config after installation + config_path = os.path.join("..", "config") + + GlobalHydra.instance().clear() + + with initialize(config_path=config_path): + try: + composed_cfg = compose("config", overrides=overrides, strict=False) + except: + logger.error("Error when composing. Overrides: " + str(overrides)) + raise + + for k in deletes: + composed_cfg[k] = None + + cfg = OmegaConf.create( + OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True) + ) + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import _utils + + with omegaconf_no_object_check(): + if cfg.task is None and getattr(args, "task", None): + cfg.task = Namespace(**vars(args)) + from fairseq.tasks import TASK_REGISTRY + + _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task]) + cfg.task._name = args.task + if cfg.model is None and getattr(args, "arch", None): + cfg.model = Namespace(**vars(args)) + from fairseq.models import ARCH_MODEL_REGISTRY + + _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) + cfg.model._name = args.arch + if cfg.optimizer is None and getattr(args, "optimizer", None): + cfg.optimizer = Namespace(**vars(args)) + from fairseq.optim import OPTIMIZER_REGISTRY + + _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer]) + cfg.optimizer._name = args.optimizer + if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None): + cfg.lr_scheduler = Namespace(**vars(args)) + from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY + + _set_legacy_defaults( + cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler] + ) + cfg.lr_scheduler._name = args.lr_scheduler + if cfg.criterion is None and getattr(args, "criterion", None): + cfg.criterion = Namespace(**vars(args)) + from fairseq.criterions import CRITERION_REGISTRY + + _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion]) + cfg.criterion._name = args.criterion + + OmegaConf.set_struct(cfg, True) + return cfg + + +def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): + # this will be deprecated when we get rid of argparse and model_overrides logic + + from fairseq.registry import REGISTRIES + + with open_dict(cfg): + for k in cfg.keys(): + # "k in cfg" will return false if its a "mandatory value (e.g. ???)" + if k in cfg and isinstance(cfg[k], DictConfig): + if k in overrides and isinstance(overrides[k], dict): + for ok, ov in overrides[k].items(): + if isinstance(ov, dict) and cfg[k][ok] is not None: + overwrite_args_by_name(cfg[k][ok], ov) + else: + cfg[k][ok] = ov + else: + overwrite_args_by_name(cfg[k], overrides) + elif k in cfg and isinstance(cfg[k], Namespace): + for override_key, val in overrides.items(): + setattr(cfg[k], override_key, val) + elif k in overrides: + if ( + k in REGISTRIES + and overrides[k] in REGISTRIES[k]["dataclass_registry"] + ): + cfg[k] = DictConfig( + REGISTRIES[k]["dataclass_registry"][overrides[k]] + ) + overwrite_args_by_name(cfg[k], overrides) + cfg[k]._name = overrides[k] + else: + cfg[k] = overrides[k] + + +def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=False): + if remove_missing: + + def remove_missing_rec(src_keys, target_cfg): + if is_dataclass(target_cfg): + target_keys = set(target_cfg.__dataclass_fields__.keys()) + else: + target_keys = set(target_cfg.keys()) + + for k in list(src_keys.keys()): + if k not in target_keys: + del src_keys[k] + elif OmegaConf.is_config(src_keys[k]): + tgt = getattr(target_cfg, k) + if tgt is not None and (is_dataclass(tgt) or hasattr(tgt, "keys")): + remove_missing_rec(src_keys[k], tgt) + + with open_dict(cfg): + remove_missing_rec(cfg, dc) + + merged_cfg = OmegaConf.merge(dc, cfg) + merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] + OmegaConf.set_struct(merged_cfg, True) + return merged_cfg diff --git a/fairseq/fairseq/distributed/__init__.py b/fairseq/fairseq/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9130db8f5d039519d663ee16c7ff2c102f5481f5 --- /dev/null +++ b/fairseq/fairseq/distributed/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .distributed_timeout_wrapper import DistributedTimeoutWrapper +from .fully_sharded_data_parallel import ( + fsdp_enable_wrap, + fsdp_wrap, + FullyShardedDataParallel, +) +from .legacy_distributed_data_parallel import LegacyDistributedDataParallel +from .module_proxy_wrapper import ModuleProxyWrapper +from .tpu_distributed_data_parallel import TPUDistributedDataParallel + + +__all__ = [ + "DistributedTimeoutWrapper", + "fsdp_enable_wrap", + "fsdp_wrap", + "FullyShardedDataParallel", + "LegacyDistributedDataParallel", + "ModuleProxyWrapper", + "TPUDistributedDataParallel", +] diff --git a/fairseq/fairseq/distributed/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/distributed/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19cf93240034137221db85d67c32570836df0c18 Binary files /dev/null and b/fairseq/fairseq/distributed/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-310.pyc b/fairseq/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..966be944d7d3212c5d8ba9be0df111ff0bdcc96a Binary files /dev/null and b/fairseq/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-310.pyc differ diff --git a/fairseq/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-310.pyc b/fairseq/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b12e90e91d2ded067d5eb86ad9d98c66cbe41a67 Binary files /dev/null and b/fairseq/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-310.pyc differ diff --git a/fairseq/fairseq/distributed/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc b/fairseq/fairseq/distributed/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f867f0fa389c23f1dc9fb30de4a54cb3e892da1 Binary files /dev/null and b/fairseq/fairseq/distributed/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc differ diff --git a/fairseq/fairseq/distributed/__pycache__/module_proxy_wrapper.cpython-310.pyc b/fairseq/fairseq/distributed/__pycache__/module_proxy_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf422c636f61b198cd976176ea9d0d6bf799fbff Binary files /dev/null and b/fairseq/fairseq/distributed/__pycache__/module_proxy_wrapper.cpython-310.pyc differ diff --git a/fairseq/fairseq/distributed/__pycache__/tpu_distributed_data_parallel.cpython-310.pyc b/fairseq/fairseq/distributed/__pycache__/tpu_distributed_data_parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9976fae7b88651cec5fe66cd2c0a07f175afd26b Binary files /dev/null and b/fairseq/fairseq/distributed/__pycache__/tpu_distributed_data_parallel.cpython-310.pyc differ diff --git a/fairseq/fairseq/distributed/__pycache__/utils.cpython-310.pyc b/fairseq/fairseq/distributed/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7b9fa2ce9a45114114a64e3f78148e328768860 Binary files /dev/null and b/fairseq/fairseq/distributed/__pycache__/utils.cpython-310.pyc differ diff --git a/fairseq/fairseq/distributed/distributed_timeout_wrapper.py b/fairseq/fairseq/distributed/distributed_timeout_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..6e06b4b6dd9a5fedd5d72bde02ceb7aaf74833d7 --- /dev/null +++ b/fairseq/fairseq/distributed/distributed_timeout_wrapper.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import signal +import threading + +from torch import nn + + +logger = logging.getLogger(__name__) + + +class DistributedTimeoutWrapper(nn.Module): + """ + A wrapper that kills the process if no progress is made within a given + *timeout*. The timer is reset every time :func:`forward` is called. + + Usage:: + + module = DistributedTimeoutWrapper(module, timeout=30) + x = module(input) + time.sleep(20) # safe + x = module(input) + time.sleep(45) # job will be killed before this returns + + Args: + module (nn.Module): module to wrap + timeout (int): number of seconds before killing the process + (set to a value <= 0 to disable the timeout) + signal (Optional): signal to send once timeout is triggered + """ + + def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGINT): + super().__init__() + self.module = module + self.timeout = timeout + self.signal = signal + + if timeout > 0: + self._heartbeat = threading.Event() + self._heartbeat_thread = threading.Thread( + target=self._check_heartbeat, + args=(os.getpid(),), + daemon=True, + ) + self._heartbeat_thread.start() + self._terminated = False + else: + self._heartbeat = None + self._heartbeat_thread = None + + def __del__(self): + self.stop_timeout() + + def __getattr__(self, name): + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.module, name) + + def stop_timeout(self): + if self._heartbeat_thread is not None: + self._terminated = True + self._heartbeat_thread.join() + + def state_dict(self, *args, **kwargs): + return self.module.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + return self.module.load_state_dict(*args, **kwargs) + + def forward(self, *args, **kwargs): + if self._heartbeat is not None: + self._heartbeat.set() + return self.module(*args, **kwargs) + + def _check_heartbeat(self, parent_pid): + self._heartbeat.wait() # wait for the first forward pass + while True: + self._heartbeat.clear() + success = self._heartbeat.wait(timeout=self.timeout) + if self._terminated: + break + elif not success: + logger.error( + ( + "Killing job for not making progress in {} seconds. " + "Set --heartbeat-timeout=-1 to disable this timeout." + ).format(int(self.timeout)) + ) + os.kill(parent_pid, self.signal) + return diff --git a/fairseq/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/fairseq/distributed/fully_sharded_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..1c508b05dd2c5aa4a3aa586a6998e04dbbbbb918 --- /dev/null +++ b/fairseq/fairseq/distributed/fully_sharded_data_parallel.py @@ -0,0 +1,145 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from typing import Optional + +import torch +from fairseq.dataclass.configs import DistributedTrainingConfig +from fairseq.distributed import utils as dist_utils + + +try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + + has_FSDP = True +except ImportError: + FSDP = torch.nn.Module + has_FSDP = False + + +class FullyShardedDataParallel(FSDP): + """ + A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some + fairseq-specific checkpoint saving/loading logic. + + Args: + use_sharded_state (bool): if True, then ``state_dict`` will return + ``FSDP.local_state_dict`` and ``load_state_dict`` will call + ``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will + return the full model weights on data parallel rank 0 (empty on + other ranks) and ``load_state_dict`` will broadcast model weights + from rank 0 to other ranks. + """ + + def __init__(self, *args, use_sharded_state: bool = False, **kwargs): + if not has_FSDP: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + super().__init__(*args, **kwargs) + self.use_sharded_state = use_sharded_state + + @property + def unwrapped_module(self) -> torch.nn.Module: + if self.flatten_parameters: + return self.module.module + else: + return self.module + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if self.use_sharded_state: + return super().local_state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + if self.rank == 0: + return super().state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + # We must call state_dict() due to use of communication + # primitives. But we don't use the result. + super().state_dict() + return destination or {} + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + if self.use_sharded_state: + return super().load_local_state_dict(state_dict, strict=strict) + else: + state_dict = dist_utils.broadcast_object( + state_dict, src_rank=0, group=self.process_group + ) + return super().load_state_dict(state_dict, strict=strict) + + +class DummyProcessGroup: + def __init__(self, rank: int, size: int): + self._rank = rank + self._size = size + + def rank(self) -> int: + return self._rank + + def size(self) -> int: + return self._size + + +@contextlib.contextmanager +def fsdp_enable_wrap(cfg: DistributedTrainingConfig): + try: + from fairscale.nn import enable_wrap + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + if cfg.memory_efficient_fp16: + assert cfg.fp16 # memory_efficient_fp16 should imply fp16 + group = dist_utils.get_data_parallel_group() + if group is None and cfg.distributed_world_size == 1: + group = DummyProcessGroup(rank=0, size=1) + fsdp_config = { + "process_group": group, + "reshard_after_forward": not cfg.no_reshard_after_forward, + "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16, + "fp32_reduce_scatter": cfg.fp32_reduce_scatter, + "flatten_parameters": not cfg.not_fsdp_flatten_parameters, + "cpu_offload": cfg.cpu_offload, + "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, + "bucket_cap_mb": cfg.bucket_cap_mb, + "state_dict_device": torch.device("cpu"), # reduce GPU mem usage + } + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + use_sharded_state=cfg.use_sharded_state, + **fsdp_config, + ): + yield + + +def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): + """ + Helper to wrap layers/modules in FSDP. This falls back to a no-op if + fairscale is not available. + + Args: + module (nn.Module): module to (maybe) wrap + min_num_params (int, Optional): minimum number of layer params to wrap + """ + try: + from fairscale.nn import wrap + + if min_num_params is not None: + num_params = sum(p.numel() for p in module.parameters()) + if num_params >= min_num_params: + return wrap(module, **kwargs) + else: + return module + else: + return wrap(module, **kwargs) + except ImportError: + return module diff --git a/fairseq/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/fairseq/distributed/legacy_distributed_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..cd434c7372ba30ea0e6f87e084230448f53480e9 --- /dev/null +++ b/fairseq/fairseq/distributed/legacy_distributed_data_parallel.py @@ -0,0 +1,165 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +A modified version of the legacy DistributedDataParallel module that uses c10d +communication primitives. This version is simpler than the latest PyTorch +version and is useful for debugging. Notably it does not overlap gradient +communication with the backward pass, which makes it slower but more robust +than the PyTorch version. + +This version also supports the *no_sync* context manager, which allows faster +training with `--update-freq`. +""" + +from collections import OrderedDict +from contextlib import contextmanager + +import torch +from torch import nn + +from fairseq.distributed import utils + + +class LegacyDistributedDataParallel(nn.Module): + """Implements distributed data parallelism at the module level. + + A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`. + This version uses a c10d process group for communication and does not + broadcast buffers. + + Args: + module (~torch.nn.Module): module to be parallelized + process_group: the c10d process group to be used for distributed data + parallel all-reduction. + buffer_size (int, optional): number of elements to buffer before + performing all-reduce (default: 256M). + """ + + def __init__(self, module, process_group, buffer_size=2**28): + super().__init__() + + self.module = module + self.process_group = process_group + self.world_size = utils.get_world_size(self.process_group) + + # Never use a bigger buffer than the number of model params + self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters())) + self.buffer = None + + # We can also forcibly accumulate grads locally and only do the + # all-reduce at some later time + self.accumulate_grads = False + + # make per-device lists of parameters + paramlists = OrderedDict() + for param in self.module.parameters(): + device = param.device + if paramlists.get(device) is None: + paramlists[device] = [] + paramlists[device] += [param] + self.per_device_params = list(paramlists.values()) + + @contextmanager + def no_sync(self): + """A context manager to disable gradient synchronization.""" + old_accumulate_grads = self.accumulate_grads + self.accumulate_grads = True + yield + self.accumulate_grads = old_accumulate_grads + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + def all_reduce_grads(self): + """ + This function must be called explicitly after backward to reduce + gradients. There is no automatic hook like c10d. + """ + + def all_reduce_params(params): + buffer = self.buffer + nonzero_buffer = False + if len(params) > 1: + offset = 0 + for p in params: + sz = p.numel() + if p.grad is not None: + buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) + nonzero_buffer = True + else: + buffer[offset : offset + sz].zero_() + offset += sz + else: + # we only have a single grad to all-reduce + p = params[0] + if p.grad is not None: + buffer = p.grad.data + nonzero_buffer = True + elif p.numel() <= self.buffer.numel(): + buffer = buffer[: p.numel()] + buffer.zero_() + else: + buffer = torch.zeros_like(p) + + if nonzero_buffer: + buffer.div_(self.world_size) + + utils.all_reduce(buffer, self.process_group) + + # copy all-reduced grads back into their original place + offset = 0 + for p in params: + sz = p.numel() + if p.grad is not None: + p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) + else: + p.grad = buffer[offset : offset + sz].view_as(p).clone() + offset += sz + + def reduction_fn(): + # This function only needs to be called once + if self.accumulate_grads: + return + + if self.buffer is None: + self.buffer = next(self.module.parameters()).new(self.buffer_size) + + for params in self.per_device_params: + # All-reduce the gradients in buckets + offset = 0 + buffered_params = [] + for param in params: + if not param.requires_grad: + continue + if param.grad is None: + param.grad = torch.zeros_like(param) + + if hasattr(param, "expert"): + # Skip gradient sync for unshared parameters + continue + + if param.grad.requires_grad: + raise RuntimeError( + "DistributedDataParallel only works " + "with gradients that don't require " + "grad" + ) + sz = param.numel() + if sz > self.buffer.numel(): + # all-reduce big params directly + all_reduce_params([param]) + else: + if offset + sz > self.buffer.numel(): + all_reduce_params(buffered_params) + offset = 0 + buffered_params.clear() + buffered_params.append(param) + offset += sz + + if len(buffered_params) > 0: + all_reduce_params(buffered_params) + + reduction_fn() diff --git a/fairseq/fairseq/distributed/module_proxy_wrapper.py b/fairseq/fairseq/distributed/module_proxy_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..904dc0c202e09db244518836c0f061e0850cad61 --- /dev/null +++ b/fairseq/fairseq/distributed/module_proxy_wrapper.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn + + +class ModuleProxyWrapper(nn.Module): + """ + Wrap a DistributedDataParallel module and forward requests for missing + attributes to the module wrapped by DDP (the twice-wrapped module). + Also forward calls to :func:`state_dict` and :func:`load_state_dict`. + + Usage:: + + module.xyz = "hello world" + wrapped_module = DistributedDataParallel(module, **ddp_args) + wrapped_module = ModuleProxyWrapper(wrapped_module) + assert wrapped_module.xyz == "hello world" + assert wrapped_module.state_dict().keys() == module.state_dict().keys() + + Args: + module (nn.Module): module to wrap + """ + + def __init__(self, module: nn.Module): + super().__init__() + assert hasattr( + module, "module" + ), "ModuleProxyWrapper expects input to wrap another module" + self.module = module + + def __getattr__(self, name): + """Forward missing attributes to twice-wrapped module.""" + try: + # defer to nn.Module's logic + return super().__getattr__(name) + except AttributeError: + try: + # forward to the once-wrapped module + return getattr(self.module, name) + except AttributeError: + # forward to the twice-wrapped module + return getattr(self.module.module, name) + + def state_dict(self, *args, **kwargs): + """Forward to the twice-wrapped module.""" + return self.module.module.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + """Forward to the twice-wrapped module.""" + return self.module.module.load_state_dict(*args, **kwargs) + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) diff --git a/fairseq/fairseq/distributed/tpu_distributed_data_parallel.py b/fairseq/fairseq/distributed/tpu_distributed_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9e1033011db87100c64ec39845e81228a26381 --- /dev/null +++ b/fairseq/fairseq/distributed/tpu_distributed_data_parallel.py @@ -0,0 +1,43 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + +from fairseq.distributed import utils + + +class TPUDistributedDataParallel(nn.Module): + def __init__(self, module, process_group): + super().__init__() + self.module = module + self.process_group = process_group + self.world_size = utils.get_world_size(self.process_group) + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + def all_reduce_grads(self): + gradients = [] + for p in self.parameters(): + if not p.requires_grad: + continue + if p.grad is None: + p.grad = torch.zeros_like(p) + if p.grad.requires_grad: + raise RuntimeError( + "TPUDistributedDataParallel only works with gradients that don't " + "require grad" + ) + gradients.append(p.grad) + + import torch_xla.core.xla_model as xm + + xm.all_reduce( + "sum", + gradients, + scale=1.0 / self.world_size, + groups=self.process_group[1], + ) diff --git a/fairseq/fairseq/distributed/utils.py b/fairseq/fairseq/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..968830d58582e436386111d90896bf95889c736e --- /dev/null +++ b/fairseq/fairseq/distributed/utils.py @@ -0,0 +1,843 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import io +import logging +import os +import pickle +import random +import socket +import struct +import subprocess +import warnings +from argparse import Namespace +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Dict, List, Mapping, Optional + +import torch +import torch.distributed as dist +from fairseq.dataclass.configs import DistributedTrainingConfig, FairseqConfig +from omegaconf import open_dict + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +# Flag to indicate if we're using Megatron +# NOTE: this is a temporary hack until we move away from Megatron's model parallel init +_USE_MEGATRON = False + +# Whether to use XLA ops (e.g., on TPUs) instead of CUDA ops. +_USE_XLA = False + + +logger = logging.getLogger(__name__) + + +def is_master(cfg: DistributedTrainingConfig): + return cfg.distributed_rank == 0 + + +def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): + if cfg.distributed_init_method is not None or cfg.tpu: + return + + num_pipelines_per_node = None + if cfg.pipeline_model_parallel: + num_pipeline_devices, num_pipelines_per_node = _pipeline_parallel_pre_init(cfg) + + if cfg.distributed_world_size == 1: + return + if all( + key in os.environ + for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] + ): + # support torch.distributed.launch + _infer_torch_distributed_launch_init(cfg) + else: + # we can determine the init method automatically for Slurm + if not _infer_slurm_init(cfg, num_pipelines_per_node): + if cfg.distributed_port <= 0 or force_distributed: + _infer_single_node_init(cfg) + elif cfg.distributed_port <= 0: + _infer_single_node_init(cfg) + + if cfg.pipeline_model_parallel: + _pipeline_parallel_post_init(cfg, num_pipeline_devices, num_pipelines_per_node) + elif not cfg.distributed_no_spawn: + with open_dict(cfg): + cfg.distributed_num_procs = min( + torch.cuda.device_count(), cfg.distributed_world_size + ) + else: + if cfg.device_id > 0: + logger.info( + "setting CUDA device={} on rank {}".format( + cfg.device_id, cfg.distributed_rank + ) + ) + torch.cuda.set_device(cfg.device_id) + + +def _infer_torch_distributed_launch_init(cfg: DistributedTrainingConfig): + cfg.distributed_init_method = "env://" + cfg.distributed_world_size = int(os.environ["WORLD_SIZE"]) + cfg.distributed_rank = int(os.environ["RANK"]) + cfg.device_id = cfg.distributed_rank % torch.cuda.device_count() + # processes are created by torch.distributed.launch + cfg.distributed_no_spawn = True + + +def _infer_slurm_init(cfg: DistributedTrainingConfig, num_pipelines_per_node): + node_list = os.environ.get("SLURM_STEP_NODELIST") + if node_list is None: + node_list = os.environ.get("SLURM_JOB_NODELIST") + if node_list is not None: + try: + hostnames = subprocess.check_output( + ["scontrol", "show", "hostnames", node_list] + ) + cfg.distributed_init_method = "tcp://{host}:{port}".format( + host=hostnames.split()[0].decode("utf-8"), + port=cfg.distributed_port, + ) + nnodes = int(os.environ.get("SLURM_NNODES")) + ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") + if ntasks_per_node is not None: + ntasks_per_node = int(ntasks_per_node) + else: + ntasks = int(os.environ.get("SLURM_NTASKS")) + nnodes = int(os.environ.get("SLURM_NNODES")) + assert ntasks % nnodes == 0 + ntasks_per_node = int(ntasks / nnodes) + if ntasks_per_node == 1: + gpus_per_node = torch.cuda.device_count() + node_id = int(os.environ.get("SLURM_NODEID")) + cfg.distributed_rank = node_id * gpus_per_node + cfg.distributed_world_size = nnodes * gpus_per_node + elif cfg.pipeline_model_parallel: + assert ntasks_per_node == num_pipelines_per_node, ( + "SLURM --ntasks-per-node must match number of pipelines per " + "node (={})".format(num_pipelines_per_node) + ) + cfg.distributed_no_spawn = True + # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on + # the first node, [1, 2] on the second node, etc. This + # matches torch.distributed.launch. + node_id = int(os.environ.get("SLURM_NODEID")) + local_id = int(os.environ.get("SLURM_LOCALID")) + cfg.distributed_rank = node_id * num_pipelines_per_node + local_id + # In the above example, device_id will always be in [0, 1], + # which also matches torch.distributed.launch. + cfg.device_id = local_id + # We also want to set distributed_world_size to be the total + # number of pipelines across all nodes. + cfg.distributed_world_size = nnodes * num_pipelines_per_node + else: + assert ( + ntasks_per_node == cfg.distributed_world_size // nnodes + ), f"{ntasks_per_node}, {cfg.distributed_world_size}, {nnodes}" + cfg.distributed_no_spawn = True + cfg.distributed_rank = int(os.environ.get("SLURM_PROCID")) + cfg.device_id = int(os.environ.get("SLURM_LOCALID")) + logger.info(f"Rank {cfg.distributed_rank}, device_id: {cfg.device_id}") + return True + except subprocess.CalledProcessError as e: # scontrol failed + raise e + except FileNotFoundError: # Slurm is not installed + pass + + return False + + +def _infer_single_node_init(cfg: DistributedTrainingConfig): + assert ( + cfg.distributed_world_size <= torch.cuda.device_count() + ), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices" + + if cfg.distributed_port <= 0: + jobid = os.environ.get("SLURM_JOB_ID") + task_id = os.environ.get("SLURM_ARRAY_TASK_ID") + + if jobid is not None: + if task_id is not None: + jobid += str(task_id) + jobid = int(jobid) + rng = random.Random(jobid) + port = rng.randint(10000, 60000) + else: + port = random.randint(10000, 60000) + + cfg.distributed_port = port + cfg.distributed_init_method = "tcp://localhost:{port}".format( + port=cfg.distributed_port + ) + + +def _pipeline_parallel_pre_init(cfg: DistributedTrainingConfig): + from fairseq import utils + + balance_exists = ( + cfg.pipeline_balance is not None + or cfg.pipeline_encoder_balance is not None + or cfg.pipeline_decoder_balance is not None + ) + devices_exist = ( + cfg.pipeline_devices is not None + or cfg.pipeline_encoder_devices is not None + or cfg.pipeline_decoder_devices is not None + ) + if not balance_exists: + raise ValueError( + "--pipeline-balance is currently required for pipeline model parallelism" + ) + if not devices_exist: + raise ValueError( + "--pipeline-devices is currently required for pipeline model parallelism" + ) + + cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int) + if cfg.pipeline_devices is not None: + cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int) + num_pipeline_devices = len(set(cfg.pipeline_devices)) + else: + cfg.pipeline_encoder_devices = utils.eval_str_list( + cfg.pipeline_encoder_devices, type=int + ) + cfg.pipeline_decoder_devices = utils.eval_str_list( + cfg.pipeline_decoder_devices, type=int + ) + num_pipeline_devices = len( + set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices) + ) + gpus_per_node = torch.cuda.device_count() + assert ( + gpus_per_node >= num_pipeline_devices + and gpus_per_node % num_pipeline_devices == 0 + ), ( + "the number of unique device IDs in --pipeline-devices must evenly divide " + "the number of GPUs per node (multi-node pipelining is not yet supported)" + ) + num_pipelines_per_node = gpus_per_node // num_pipeline_devices + return num_pipeline_devices, num_pipelines_per_node + + +def _pipeline_parallel_post_init( + cfg: DistributedTrainingConfig, num_pipeline_devices, num_pipelines_per_node +): + if not cfg.distributed_no_spawn: + # When distributed_no_spawn is False, we expect distributed_rank and + # distributed_world_size to be based on the total number of GPUs, so + # we need to correct them to be based on the number of pipelines. + assert cfg.distributed_world_size % num_pipeline_devices == 0 + cfg.distributed_world_size = cfg.distributed_world_size // num_pipeline_devices + # In the case of 4-way MP on nodes with 8 GPUs, we want + # distributed_rank to be the starting GPU index for each pipeline + # i.e., 0, 2, ... + gpus_per_node = torch.cuda.device_count() + assert cfg.distributed_rank % gpus_per_node == 0 + assert cfg.distributed_rank % num_pipeline_devices == 0 + + with open_dict(cfg): + cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices + # launch one process per pipeline + cfg.distributed_num_procs = num_pipelines_per_node + + # if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0 + # and 4, indicating the starting device IDs for each pipeline + cfg.device_id *= num_pipeline_devices + + if cfg.device_id > 0: + # if there's multiple pipelines on a node (e.g., 4-way MP on an 8 + # GPU node), we need to adjust pipeline_devices accordingly + logger.debug( + "setting CUDA device={} on rank {}".format( + cfg.device_id, cfg.distributed_rank + ) + ) + torch.cuda.set_device(cfg.device_id) + with open_dict(cfg): + cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices] + logger.info( + "setting pipeline_devices={} on rank {}".format( + cfg.pipeline_devices, cfg.distributed_rank + ) + ) + + +def distributed_init(cfg: FairseqConfig): + if isinstance(cfg, Namespace): + from fairseq.dataclass.utils import convert_namespace_to_omegaconf + + cfg = convert_namespace_to_omegaconf(cfg) + + if not cfg.common.tpu: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + warnings.warn( + "Distributed is already initialized, cannot initialize twice!" + ) + else: + logger.info( + "distributed init (rank {}): {}".format( + cfg.distributed_training.distributed_rank, + cfg.distributed_training.distributed_init_method, + ) + ) + dist.init_process_group( + backend=cfg.distributed_training.distributed_backend, + init_method=cfg.distributed_training.distributed_init_method, + world_size=cfg.distributed_training.distributed_world_size, + rank=cfg.distributed_training.distributed_rank, + ) + logger.info( + "initialized host {} as rank {}".format( + socket.gethostname(), + cfg.distributed_training.distributed_rank, + ) + ) + + # perform a dummy all-reduce to initialize the NCCL communicator + if torch.cuda.is_available(): + dist.all_reduce(torch.zeros(1).cuda()) + + cfg.distributed_training.distributed_rank = torch.distributed.get_rank() + else: + assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size + global _USE_XLA + _USE_XLA = True + cfg.distributed_training.device_id = xm.get_local_ordinal() + cfg.distributed_training.distributed_rank = xm.get_ordinal() + xm.rendezvous("distributed_init") # wait for all workers + + if is_master(cfg.distributed_training): + logging.getLogger().setLevel(logging.INFO) + else: + logging.getLogger().setLevel(logging.WARNING) + + if cfg.common.model_parallel_size > 1: + try: + from fairseq.model_parallel.megatron.mpu import ( + initialize_model_parallel, + model_parallel_cuda_manual_seed, + ) + except ImportError: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + global _USE_MEGATRON + _USE_MEGATRON = True + initialize_model_parallel(cfg.common.model_parallel_size) + model_parallel_cuda_manual_seed(cfg.common.seed) + model_part_number = get_model_parallel_rank() + cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) + + if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0: + cfg.checkpoint.checkpoint_suffix = ( + f"-rank-{cfg.distributed_training.distributed_rank}" + ) + + return cfg.distributed_training.distributed_rank + + +def distributed_main(i, main, cfg: FairseqConfig, kwargs): + cfg.distributed_training.device_id = i + if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu: + torch.cuda.set_device(cfg.distributed_training.device_id) + if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn + cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i + + cfg.distributed_training.distributed_rank = distributed_init(cfg) + + after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None) + if after_distributed_init_fn: + cfg = after_distributed_init_fn(cfg) + + main(cfg, **kwargs) + + if torch.distributed.is_initialized(): + torch.distributed.barrier(get_global_group()) + + +def call_main(cfg: FairseqConfig, main, **kwargs): + if cfg.distributed_training.distributed_init_method is None: + infer_init_method(cfg.distributed_training) + + if cfg.distributed_training.distributed_init_method is not None: + # distributed training + if not cfg.distributed_training.distributed_no_spawn: + start_rank = cfg.distributed_training.distributed_rank + cfg.distributed_training.distributed_rank = None # assign automatically + kwargs["start_rank"] = start_rank + + torch.multiprocessing.spawn( + fn=distributed_main, + args=(main, cfg, kwargs), + nprocs=min( + torch.cuda.device_count(), + cfg.distributed_training.distributed_world_size, + ), + join=True, + ) + else: + distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) + elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1: + import torch_xla.distributed.xla_multiprocessing as xmp + + torch.multiprocessing.set_sharing_strategy("file_system") + xmp.spawn( + fn=distributed_main, + args=(main, cfg, kwargs), + # tpu-comment: + # 8 devices in one TPU VM, is the max processes to be spawned. + # The rest is driven by xm.distributed.xla_dist + nprocs=min(cfg.distributed_training.distributed_world_size, 8), + ) + else: + # single GPU main + main(cfg, **kwargs) + + +def use_xla(): + global _USE_XLA + return _USE_XLA + + +def new_groups(grouped_ranks: List[List[int]]): + if use_xla(): + return ("tpu", grouped_ranks) + else: + groups = [dist.new_group(g) for g in grouped_ranks] + my_group_idx = _find_my_group_index(grouped_ranks) + return groups[my_group_idx] + + +def _find_my_group_index(grouped_ranks): + my_rank = get_global_rank() + for i, group in enumerate(grouped_ranks): + if my_rank in group: + return i + raise RuntimeError + + +def _find_my_group(grouped_ranks): + index = _find_my_group_index(grouped_ranks) + return grouped_ranks[index] + + +def get_rank(group): + if use_xla(): + assert group[0] == "tpu" + my_group = _find_my_group(group[1]) + return my_group.index(get_global_rank()) + else: + return dist.get_rank(group=group) + + +def get_world_size(group): + if use_xla(): + assert group[0] == "tpu" + my_group = _find_my_group(group[1]) + return len(my_group) + elif torch.distributed.is_initialized(): + return dist.get_world_size(group=group) + else: + return 1 + + +def get_global_group(): + if use_xla(): + return new_groups([list(range(get_global_world_size()))]) + elif torch.distributed.is_initialized(): + if not hasattr(get_global_group, "_global_group"): + # ideally we could use torch.distributed.group.WORLD, but it seems + # to cause random NCCL hangs in some cases + get_global_group._global_group = dist.new_group() + return get_global_group._global_group + else: + return None + + +def get_global_rank(): + if use_xla(): + return xm.get_ordinal() + elif torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def get_global_world_size(): + if use_xla(): + return xm.xrt_world_size() + elif torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + global _USE_MEGATRON + if _USE_MEGATRON: + from fairseq.model_parallel.megatron import mpu + + return mpu.get_data_parallel_group() + else: + return get_global_group() + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return get_rank(get_data_parallel_group()) + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return get_world_size(get_data_parallel_group()) + + +def get_model_parallel_group(): + global _USE_MEGATRON + if _USE_MEGATRON: + from fairseq.model_parallel.megatron import mpu + + return mpu.get_model_parallel_group() + else: + return None + + +def get_model_parallel_rank(): + """Return my rank for the model parallel group.""" + return get_rank(get_model_parallel_group()) + + +def get_model_parallel_world_size(): + """Return world size for the model parallel group.""" + return get_world_size(get_model_parallel_group()) + + +def all_reduce(tensor, group, op="sum"): + if use_xla(): + assert isinstance(group, tuple) and group[0] == "tpu" + tensor = [tensor] # wrap in a list to make xm.all_reduce in-place + return xm.all_reduce(op, tensor, groups=group[1])[0] + else: + if op == "sum": + op = dist.ReduceOp.SUM + elif op == "max": + op = dist.ReduceOp.MAX + else: + raise NotImplementedError + dist.all_reduce(tensor, op=op, group=group) + return tensor + + +def broadcast(tensor, src, group): + if use_xla(): + # XLA doesn't support broadcast, hack it with all_reduce + if get_rank(group) != src: + tensor.zero_() + all_reduce(tensor, group) + else: + dist.broadcast(tensor, src=src, group=group) + + +def all_to_all(tensor, group): + """Perform an all-to-all operation on a 1D Tensor.""" + assert tensor.dim() == 1 + split_count = get_world_size(group=group) + assert tensor.numel() % split_count == 0 + if use_xla(): + assert isinstance(group, tuple) and group[0] == "tpu" + return xm.all_to_all( + tensor, + split_dimension=0, + concat_dimension=0, + split_count=split_count, + groups=group[1], + ) + else: + output = torch.zeros_like(tensor) + dist.all_to_all_single(output, tensor, group=group) + return output + + +def all_gather(tensor, group, return_tensor=False): + """Perform an all-gather operation.""" + if use_xla(): + result = xm.all_gather(tensor, groups=group[1]) + world_size = get_world_size(group=group) + result = result.view(world_size, *tensor.size()) + if return_tensor: + return result + else: + return [result[i] for i in range(world_size)] + else: + world_size = get_world_size(group=group) + rank = get_rank(group=group) + tensor_list = [ + tensor if i == rank else torch.empty_like(tensor) for i in range(world_size) + ] + dist.all_gather(tensor_list, tensor, group=group) + if return_tensor: + return torch.stack(tensor_list, dim=0) + else: + return tensor_list + + +def all_gather_list(data, group=None, max_size=16384): + """Gathers arbitrary data from all nodes into a list. + + Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python + data. Note that *data* must be picklable and any CUDA tensors will be moved + to CPU and returned on CPU as well. + + Args: + data (Any): data from the local worker to be gathered on other workers + group: group of the collective + max_size (int, optional): maximum size of the data to be gathered + across workers + """ + from fairseq import utils + + if group is None: + group = get_global_group() + rank = get_rank(group=group) + world_size = get_world_size(group=group) + + buffer_size = max_size * world_size + if ( + not hasattr(all_gather_list, "_buffer") + or all_gather_list._buffer.numel() < buffer_size + ): + all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) + all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() + buffer = all_gather_list._buffer + buffer.zero_() + cpu_buffer = all_gather_list._cpu_buffer + + data = utils.move_to_cpu(data) + enc = pickle.dumps(data) + enc_size = len(enc) + header_size = 4 # size of header that contains the length of the encoded data + size = header_size + enc_size + if size > max_size: + raise ValueError( + "encoded data size ({}) exceeds max_size ({})".format(size, max_size) + ) + + header = struct.pack(">I", enc_size) + cpu_buffer[:size] = torch.ByteTensor(list(header + enc)) + start = rank * max_size + buffer[start : start + size].copy_(cpu_buffer[:size]) + + all_reduce(buffer, group=group) + + buffer = buffer.cpu() + try: + result = [] + for i in range(world_size): + out_buffer = buffer[i * max_size : (i + 1) * max_size] + (enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist())) + if enc_size > 0: + result.append( + pickle.loads( + bytes(out_buffer[header_size : header_size + enc_size].tolist()) + ) + ) + return result + except pickle.UnpicklingError: + raise Exception( + "Unable to unpickle data from other workers. all_gather_list requires all " + "workers to enter the function together, so this error usually indicates " + "that the workers have fallen out of sync somehow. Workers can fall out of " + "sync if one of them runs out of memory, or if there are other conditions " + "in your training script that can cause one worker to finish an epoch " + "while other workers are still iterating over their portions of the data. " + "Try rerunning with --ddp-backend=legacy_ddp and see if that helps." + ) + + +def all_reduce_dict(data: Mapping[str, Any], device, group) -> Dict[str, Any]: + """ + AllReduce a dictionary of values across workers. We separately + reduce items that are already on the device and items on CPU for + better performance. + + Args: + data (Mapping[str, Any]): dictionary of data to all-reduce, but + cannot be a nested dictionary + device (torch.device): device for the reduction + group: group of the collective + """ + data_keys = list(data.keys()) + + # We want to separately reduce items that are already on the + # device and items on CPU for performance reasons. + cpu_data = OrderedDict() + device_data = OrderedDict() + for k in data_keys: + t = data[k] + if not torch.is_tensor(t): + cpu_data[k] = torch.tensor(t, dtype=torch.double) + elif t.device.type != device.type: + cpu_data[k] = t.to(dtype=torch.double) + else: + device_data[k] = t.to(dtype=torch.double) + + def _all_reduce_dict(data: OrderedDict): + if len(data) == 0: + return data + buf = torch.cat([t.view(-1) for t in data.values()]).to(device=device) + all_reduce(buf, group=group) + split_buf = torch.split(buf.clone(), [t.numel() for t in data.values()]) + reduced_data = [t.view_as(orig) for t, orig in zip(split_buf, data.values())] + return OrderedDict(zip(data.keys(), reduced_data)) + + cpu_data = _all_reduce_dict(cpu_data) + device_data = _all_reduce_dict(device_data) + + def get_from_stack(key): + if key in cpu_data: + return cpu_data[key] + elif key in device_data: + return device_data[key] + raise KeyError + + return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) + + +def broadcast_tensors( + tensors: Optional[List[torch.Tensor]], + src_rank: int, + group: object, + dist_device: Optional[torch.device] = None, +) -> List[torch.Tensor]: + """ + Broadcasts a list of tensors without other (non-src) ranks needing to know + the dtypes/shapes of the tensors. + """ + if dist_device is None: + if torch.distributed.get_backend(group) == "nccl": + dist_device = torch.device("cuda") + else: + dist_device = torch.device("cpu") + + # share metadata first to simplify transfer + is_src_rank = get_rank(group) == src_rank + if is_src_rank: + metadata = [ + {"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors + ] + metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device) + else: + metadata = _broadcast_object_slow(None, src_rank, group, dist_device) + + out_tensors = [] + for i, meta in enumerate(metadata): + if is_src_rank: + tensor = tensors[i] + broadcast(tensors[i].to(dist_device), src=src_rank, group=group) + else: + tensor = torch.zeros( + [meta["size"].numel()], dtype=meta["dtype"], device=dist_device + ) + broadcast(tensor, src=src_rank, group=group) + tensor = tensor.view(meta["size"]).to(meta["device"]) + out_tensors.append(tensor) + return out_tensors + + +def broadcast_object( + obj: Any, + src_rank: int, + group: object, + dist_device: Optional[torch.device] = None, +) -> Any: + """Broadcast an arbitrary Python object to other workers.""" + if dist_device is None: + if torch.distributed.get_backend(group) == "nccl": + dist_device = torch.device("cuda") + else: + dist_device = torch.device("cpu") + + if get_rank(group) == src_rank: + # split the tensors from the non-tensors so we can broadcast them + # directly, avoiding unnecessary serialization/deserialization + tensors = [] + obj = _split_tensors_from_obj(obj, tensors) + obj = _broadcast_object_slow(obj, src_rank, group, dist_device) + tensors = broadcast_tensors(tensors, src_rank, group, dist_device) + else: + obj = _broadcast_object_slow(None, src_rank, group, dist_device) + tensors = broadcast_tensors(None, src_rank, group, dist_device) + return _put_tensors_in_obj(obj, tensors) + + +def _broadcast_object_slow( + obj: Any, + src_rank: int, + group: object, + dist_device: torch.device, +) -> Any: + if get_rank(group) == src_rank: + # Emit data + buffer = io.BytesIO() + torch.save(obj, buffer) + buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device) + length = torch.LongTensor([len(buffer)]).to(dist_device) + broadcast(length, src=src_rank, group=group) + broadcast(buffer, src=src_rank, group=group) + else: + # Fetch from the source + length = torch.LongTensor([0]).to(dist_device) + broadcast(length, src=src_rank, group=group) + buffer = torch.ByteTensor(int(length.item())).to(dist_device) + broadcast(buffer, src=src_rank, group=group) + buffer = io.BytesIO(buffer.cpu().numpy()) + obj = torch.load(buffer, map_location="cpu") + return obj + + +@dataclass(frozen=True) +class _TensorPlaceholder: + index: int + + +def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: + if torch.is_tensor(obj): + placeholder = _TensorPlaceholder(index=len(tensors)) + tensors.append(obj) + return placeholder + elif isinstance(obj, dict): + return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()} + elif isinstance(obj, list): + return [_split_tensors_from_obj(v, tensors) for v in obj] + elif isinstance(obj, tuple): + return tuple(_split_tensors_from_obj(v, tensors) for v in obj) + elif isinstance(obj, set): + return {_split_tensors_from_obj(v, tensors) for v in obj} + else: + return obj + + +def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: + if isinstance(obj, _TensorPlaceholder): + return tensors[obj.index] + elif isinstance(obj, dict): + return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()} + elif isinstance(obj, list): + return [_put_tensors_in_obj(v, tensors) for v in obj] + elif isinstance(obj, tuple): + return tuple(_put_tensors_in_obj(v, tensors) for v in obj) + elif isinstance(obj, set): + return {_put_tensors_in_obj(v, tensors) for v in obj} + else: + return obj diff --git a/fairseq/fairseq/logging/__init__.py b/fairseq/fairseq/logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/fairseq/logging/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/logging/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c41ecf7835849460dbc60ff066cc0c38260db6e Binary files /dev/null and b/fairseq/fairseq/logging/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/logging/__pycache__/meters.cpython-310.pyc b/fairseq/fairseq/logging/__pycache__/meters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..470e8021afdb6c47516a370c22d6ba93d0747894 Binary files /dev/null and b/fairseq/fairseq/logging/__pycache__/meters.cpython-310.pyc differ diff --git a/fairseq/fairseq/logging/__pycache__/metrics.cpython-310.pyc b/fairseq/fairseq/logging/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..622561047926beb35fc5a4f12d395b328a7bb98f Binary files /dev/null and b/fairseq/fairseq/logging/__pycache__/metrics.cpython-310.pyc differ diff --git a/fairseq/fairseq/logging/__pycache__/progress_bar.cpython-310.pyc b/fairseq/fairseq/logging/__pycache__/progress_bar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb65db4e98737f922ec8dec560cf8a1fe32d788e Binary files /dev/null and b/fairseq/fairseq/logging/__pycache__/progress_bar.cpython-310.pyc differ diff --git a/fairseq/fairseq/logging/meters.py b/fairseq/fairseq/logging/meters.py new file mode 100644 index 0000000000000000000000000000000000000000..495bd083000de9e4a05f1470228c1171c8c8bb9c --- /dev/null +++ b/fairseq/fairseq/logging/meters.py @@ -0,0 +1,351 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import bisect +import time +from collections import OrderedDict +from typing import Dict, Optional + +try: + import torch + + def type_as(a, b): + if torch.is_tensor(a) and torch.is_tensor(b): + return a.to(b) + else: + return a + +except ImportError: + torch = None + + def type_as(a, b): + return a + + +try: + import numpy as np +except ImportError: + np = None + + +class Meter(object): + """Base class for Meters.""" + + def __init__(self): + pass + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + pass + + def reset(self): + raise NotImplementedError + + @property + def smoothed_value(self) -> float: + """Smoothed value used for logging.""" + raise NotImplementedError + + +def safe_round(number, ndigits): + if hasattr(number, "__round__"): + return round(number, ndigits) + elif torch is not None and torch.is_tensor(number) and number.numel() == 1: + return safe_round(number.item(), ndigits) + elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"): + return safe_round(number.item(), ndigits) + else: + return number + + +class AverageMeter(Meter): + """Computes and stores the average and current value""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.reset() + + def reset(self): + self.val = None # most recent update + self.sum = 0 # sum from all updates + self.count = 0 # total n from all updates + + def update(self, val, n=1): + if val is not None: + self.val = val + if n > 0: + self.sum = type_as(self.sum, val) + (val * n) + self.count = type_as(self.count, n) + n + + def state_dict(self): + return { + "val": self.val, + "sum": self.sum, + "count": self.count, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.val = state_dict["val"] + self.sum = state_dict["sum"] + self.count = state_dict["count"] + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.sum / self.count if self.count > 0 else self.val + + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class SumMeter(Meter): + """Computes and stores the sum""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.reset() + + def reset(self): + self.sum = 0 # sum from all updates + + def update(self, val): + if val is not None: + self.sum = type_as(self.sum, val) + val + + def state_dict(self): + return { + "sum": self.sum, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.sum = state_dict["sum"] + self.round = state_dict.get("round", None) + + @property + def smoothed_value(self) -> float: + val = self.sum + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class ConcatTensorMeter(Meter): + """Concatenates tensors""" + + def __init__(self, dim=0): + super().__init__() + self.reset() + self.dim = dim + + def reset(self): + self.tensor = None + + def update(self, val): + if self.tensor is None: + self.tensor = val + else: + self.tensor = torch.cat([self.tensor, val], dim=self.dim) + + def state_dict(self): + return { + "tensor": self.tensor, + } + + def load_state_dict(self, state_dict): + self.tensor = state_dict["tensor"] + + @property + def smoothed_value(self) -> float: + return [] # return a dummy value + + +class TimeMeter(Meter): + """Computes the average occurrence of some event per second""" + + def __init__( + self, + init: int = 0, + n: int = 0, + round: Optional[int] = None, + ): + self.round = round + self.reset(init, n) + + def reset(self, init=0, n=0): + self.init = init + self.start = time.perf_counter() + self.n = n + self.i = 0 + + def update(self, val=1): + self.n = type_as(self.n, val) + val + self.i += 1 + + def state_dict(self): + return { + "init": self.elapsed_time, + "n": self.n, + "round": self.round, + } + + def load_state_dict(self, state_dict): + if "start" in state_dict: + # backwards compatibility for old state_dicts + self.reset(init=state_dict["init"]) + else: + self.reset(init=state_dict["init"], n=state_dict["n"]) + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.n / self.elapsed_time + + @property + def elapsed_time(self): + return self.init + (time.perf_counter() - self.start) + + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class StopwatchMeter(Meter): + """Computes the sum/avg duration of some event in seconds""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.sum = 0 + self.n = 0 + self.start_time = None + + def start(self): + self.start_time = time.perf_counter() + + def stop(self, n=1, prehook=None): + if self.start_time is not None: + if prehook is not None: + prehook() + delta = time.perf_counter() - self.start_time + self.sum = self.sum + delta + self.n = type_as(self.n, n) + n + + def reset(self): + self.sum = 0 # cumulative time during which stopwatch was active + self.n = 0 # total n across all start/stop + self.start() + + def state_dict(self): + return { + "sum": self.sum, + "n": self.n, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.sum = state_dict["sum"] + self.n = state_dict["n"] + self.start_time = None + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.sum / self.n if self.n > 0 else self.sum + + @property + def elapsed_time(self): + if self.start_time is None: + return 0.0 + return time.perf_counter() - self.start_time + + @property + def smoothed_value(self) -> float: + val = self.avg if self.sum > 0 else self.elapsed_time + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class MetersDict(OrderedDict): + """A sorted dictionary of :class:`Meters`. + + Meters are sorted according to a priority that is given when the + meter is first added to the dictionary. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.priorities = [] + + def __setitem__(self, key, value): + assert key not in self, "MetersDict doesn't support reassignment" + priority, value = value + bisect.insort(self.priorities, (priority, len(self.priorities), key)) + super().__setitem__(key, value) + for _, _, key in self.priorities: # reorder dict to match priorities + self.move_to_end(key) + + def add_meter(self, key, meter, priority): + self.__setitem__(key, (priority, meter)) + + def state_dict(self): + return [ + (pri, key, self[key].__class__.__name__, self[key].state_dict()) + for pri, _, key in self.priorities + # can't serialize DerivedMeter instances + if not isinstance(self[key], MetersDict._DerivedMeter) + ] + + def load_state_dict(self, state_dict): + self.clear() + self.priorities.clear() + for pri, key, meter_cls, meter_state in state_dict: + meter = globals()[meter_cls]() + meter.load_state_dict(meter_state) + self.add_meter(key, meter, pri) + + def get_smoothed_value(self, key: str) -> float: + """Get a single smoothed value.""" + meter = self[key] + if isinstance(meter, MetersDict._DerivedMeter): + return meter.fn(self) + else: + return meter.smoothed_value + + def get_smoothed_values(self) -> Dict[str, float]: + """Get all smoothed values.""" + return OrderedDict( + [ + (key, self.get_smoothed_value(key)) + for key in self.keys() + if not key.startswith("_") + ] + ) + + def reset(self): + """Reset Meter instances.""" + for meter in self.values(): + if isinstance(meter, MetersDict._DerivedMeter): + continue + meter.reset() + + class _DerivedMeter(Meter): + """A Meter whose values are derived from other Meters.""" + + def __init__(self, fn): + self.fn = fn + + def reset(self): + pass diff --git a/fairseq/fairseq/logging/metrics.py b/fairseq/fairseq/logging/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..49301f27f84351b83b8c869bac86c78ec9f126e6 --- /dev/null +++ b/fairseq/fairseq/logging/metrics.py @@ -0,0 +1,336 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +A standalone module for aggregating metrics. + +Metrics can be logged from anywhere using the `log_*` functions defined +in this module. The logged values will be aggregated dynamically based +on the aggregation context in which the logging occurs. See the +:func:`aggregate` context manager for more details. +""" + +import contextlib +import uuid +from collections import defaultdict +from typing import Callable, List, Optional + +from .meters import * + + +# Aggregation contexts are considered "active" when inside the scope +# created by the :func:`aggregate` context manager. +_aggregators = OrderedDict() +_active_aggregators = OrderedDict() +_active_aggregators_cnt = defaultdict(lambda: 0) + + +def reset() -> None: + """Reset all metrics aggregators.""" + _aggregators.clear() + _active_aggregators.clear() + _active_aggregators_cnt.clear() + + # The "default" aggregator observes all logged values. + _aggregators["default"] = MetersDict() + _active_aggregators["default"] = _aggregators["default"] + _active_aggregators_cnt["default"] = 1 + + +reset() + + +@contextlib.contextmanager +def aggregate(name: Optional[str] = None, new_root: bool = False): + """Context manager to aggregate metrics under a given name. + + Aggregations can be nested. If *new_root* is ``False``, then logged + metrics will be recorded along the entire stack of nested + aggregators, including a global "default" aggregator. If *new_root* + is ``True``, then this aggregator will be the root of a new + aggregation stack, thus bypassing any parent aggregators. + + Note that aggregation contexts are uniquely identified by their + *name* (e.g., train, valid). Creating a context with an existing + name will reuse the corresponding :class:`MetersDict` instance. + If no name is given, then a temporary aggregator will be created. + + Usage:: + + with metrics.aggregate("train"): + for step, batch in enumerate(epoch): + with metrics.aggregate("train_inner") as agg: + metrics.log_scalar("loss", get_loss(batch)) + if step % log_interval == 0: + print(agg.get_smoothed_value("loss")) + agg.reset() + print(metrics.get_smoothed_values("train")["loss"]) + + Args: + name (str): name of the aggregation. Defaults to a + random/temporary name if not given explicitly. + new_root (bool): make this aggregation the root of a new + aggregation stack. + """ + if name is None: + # generate a temporary name + name = str(uuid.uuid4()) + assert name not in _aggregators + agg = MetersDict() + else: + assert name != "default" + agg = _aggregators.setdefault(name, MetersDict()) + + if new_root: + backup_aggregators = _active_aggregators.copy() + _active_aggregators.clear() + backup_aggregators_cnt = _active_aggregators_cnt.copy() + _active_aggregators_cnt.clear() + + _active_aggregators[name] = agg + _active_aggregators_cnt[name] += 1 + + yield agg + + _active_aggregators_cnt[name] -= 1 + if _active_aggregators_cnt[name] == 0 and name in _active_aggregators: + del _active_aggregators[name] + + if new_root: + _active_aggregators.clear() + _active_aggregators.update(backup_aggregators) + _active_aggregators_cnt.clear() + _active_aggregators_cnt.update(backup_aggregators_cnt) + + +def get_active_aggregators() -> List[MetersDict]: + return list(_active_aggregators.values()) + + +def log_scalar( + key: str, + value: float, + weight: float = 1, + priority: int = 10, + round: Optional[int] = None, +): + """Log a scalar value. + + Args: + key (str): name of the field to log + value (float): value to log + weight (float): weight that this value contributes to the average. + A weight of 0 will always log the latest value. + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, AverageMeter(round=round), priority) + agg[key].update(value, weight) + + +def log_scalar_sum( + key: str, + value: float, + priority: int = 10, + round: Optional[int] = None, +): + """Log a scalar value that is summed for reporting. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, SumMeter(round=round), priority) + agg[key].update(value) + + +def log_concat_tensor( + key: str, + value: torch.Tensor, + priority: int = 10, + dim: int = 0, +): + """Log a scalar value that is summed for reporting. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, ConcatTensorMeter(dim=dim), priority) + agg[key].update(value) + + +def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20): + """Log a scalar value derived from other meters. + + Args: + key (str): name of the field to log + fn (Callable[[MetersDict], float]): function that takes a single + argument *meters* and returns the derived value + priority (int): smaller values are logged earlier in the output + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, MetersDict._DerivedMeter(fn), priority) + + +def log_speed( + key: str, + value: float, + priority: int = 30, + round: Optional[int] = None, +): + """Log the rate of some quantity per second. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, TimeMeter(round=round), priority) + agg[key].reset() # reset meter on the first call + else: + agg[key].update(value) + + +def log_start_time(key: str, priority: int = 40, round: Optional[int] = None): + """Log the duration of some event in seconds. + + The duration will be computed once :func:`log_stop_time` is called. + + Args: + key (str): name of the field to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, StopwatchMeter(round=round), priority) + agg[key].start() + + +def log_stop_time(key: str, weight: float = 0.0, prehook=None): + """Log the duration of some event in seconds. + + The duration will be computed since :func:`log_start_time` was called. + Set weight > 0 to report the average time instead of the sum. + + Args: + key (str): name of the field to log + weight (float): weight that this time contributes to the average + prehook (function, no arguments): will be called before the timer + is stopped. For example, use prehook=torch.cuda.synchronize to + make sure all gpu operations are done before timer is stopped. + """ + for agg in get_active_aggregators(): + if key in agg: + agg[key].stop(weight, prehook) + + +def log_custom( + new_meter_fn: Callable[[], Meter], + key: str, + *args, + priority: int = 50, + **kwargs, +): + """Log using a custom Meter. + + Any extra *args* or *kwargs* will be passed through to the Meter's + *update* method. + + Args: + new_meter_fn (Callable[[], Meter]): function that returns a new + Meter instance + key (str): name of the field to log + priority (int): smaller values are logged earlier in the output + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, new_meter_fn(), priority) + agg[key].update(*args, **kwargs) + + +def reset_meter(name: str, key: str) -> None: + """Reset Meter instance aggregated under a given *name* and *key*.""" + meter = get_meter(name, key) + if meter is not None: + meter.reset() + + +def reset_meters(name: str) -> None: + """Reset Meter instances aggregated under a given *name*.""" + meters = get_meters(name) + if meters is not None: + meters.reset() + + +def get_meter(name: str, key: str) -> Meter: + """Get a single Meter instance aggregated under *name* and *key*. + + Returns: + Meter or None if no metrics have been logged under *name* and *key*. + """ + if name not in _aggregators: + return None + return _aggregators[name].get(key, None) + + +def get_meters(name: str) -> MetersDict: + """Get Meter instances aggregated under a given *name*. + + Returns: + MetersDict or None if no metrics have been logged under *name*. + """ + return _aggregators.get(name, None) + + +def get_smoothed_value(name: str, key: str) -> float: + """Get a single smoothed value. + + Raises: + KeyError: if no metrics have been logged under *name* and *key*. + """ + return _aggregators[name].get_smoothed_value(key) + + +def get_smoothed_values(name: str) -> Dict[str, float]: + """Get smoothed values aggregated under a given *name*. + + Raises: + KeyError: if no metrics have been logged under *name*. + """ + return _aggregators[name].get_smoothed_values() + + +def state_dict(): + return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()]) + + +def load_state_dict(state_dict): + for name, agg_state in state_dict.items(): + _aggregators[name] = MetersDict() + _aggregators[name].load_state_dict(agg_state) + + +def xla_metrics_report(): + try: + import torch_xla.debug.metrics as met + + print(met.metrics_report()) + except ImportError: + return diff --git a/fairseq/fairseq/logging/progress_bar.py b/fairseq/fairseq/logging/progress_bar.py new file mode 100644 index 0000000000000000000000000000000000000000..4c64b61bad6edbf4b9ff5bcc2f26952e8b1bfc9c --- /dev/null +++ b/fairseq/fairseq/logging/progress_bar.py @@ -0,0 +1,582 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Wrapper around various loggers and progress bars (e.g., tqdm). +""" + +import atexit +import json +import logging +import os +import sys +from collections import OrderedDict +from contextlib import contextmanager +from numbers import Number +from typing import Optional + +import torch + +from .meters import AverageMeter, StopwatchMeter, TimeMeter + +logger = logging.getLogger(__name__) + + +def progress_bar( + iterator, + log_format: Optional[str] = None, + log_interval: int = 100, + log_file: Optional[str] = None, + epoch: Optional[int] = None, + prefix: Optional[str] = None, + aim_repo: Optional[str] = None, + aim_run_hash: Optional[str] = None, + aim_param_checkpoint_dir: Optional[str] = None, + tensorboard_logdir: Optional[str] = None, + default_log_format: str = "tqdm", + wandb_project: Optional[str] = None, + wandb_run_name: Optional[str] = None, + azureml_logging: Optional[bool] = False, +): + if log_format is None: + log_format = default_log_format + if log_file is not None: + handler = logging.FileHandler(filename=log_file) + logger.addHandler(handler) + + if log_format == "tqdm" and not sys.stderr.isatty(): + log_format = "simple" + + if log_format == "json": + bar = JsonProgressBar(iterator, epoch, prefix, log_interval) + elif log_format == "none": + bar = NoopProgressBar(iterator, epoch, prefix) + elif log_format == "simple": + bar = SimpleProgressBar(iterator, epoch, prefix, log_interval) + elif log_format == "tqdm": + bar = TqdmProgressBar(iterator, epoch, prefix) + else: + raise ValueError("Unknown log format: {}".format(log_format)) + + if aim_repo: + bar = AimProgressBarWrapper( + bar, + aim_repo=aim_repo, + aim_run_hash=aim_run_hash, + aim_param_checkpoint_dir=aim_param_checkpoint_dir, + ) + + if tensorboard_logdir: + try: + # [FB only] custom wrapper for TensorBoard + import palaas # noqa + + from .fb_tbmf_wrapper import FbTbmfWrapper + + bar = FbTbmfWrapper(bar, log_interval) + except ImportError: + bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir) + + if wandb_project: + bar = WandBProgressBarWrapper(bar, wandb_project, run_name=wandb_run_name) + + if azureml_logging: + bar = AzureMLProgressBarWrapper(bar) + + return bar + + +def build_progress_bar( + args, + iterator, + epoch: Optional[int] = None, + prefix: Optional[str] = None, + default: str = "tqdm", + no_progress_bar: str = "none", +): + """Legacy wrapper that takes an argparse.Namespace.""" + if getattr(args, "no_progress_bar", False): + default = no_progress_bar + if getattr(args, "distributed_rank", 0) == 0: + tensorboard_logdir = getattr(args, "tensorboard_logdir", None) + else: + tensorboard_logdir = None + return progress_bar( + iterator, + log_format=args.log_format, + log_interval=args.log_interval, + epoch=epoch, + prefix=prefix, + tensorboard_logdir=tensorboard_logdir, + default_log_format=default, + ) + + +def format_stat(stat): + if isinstance(stat, Number): + stat = "{:g}".format(stat) + elif isinstance(stat, AverageMeter): + stat = "{:.3f}".format(stat.avg) + elif isinstance(stat, TimeMeter): + stat = "{:g}".format(round(stat.avg)) + elif isinstance(stat, StopwatchMeter): + stat = "{:g}".format(round(stat.sum)) + elif torch.is_tensor(stat): + stat = stat.tolist() + return stat + + +class BaseProgressBar(object): + """Abstract class for progress bars.""" + + def __init__(self, iterable, epoch=None, prefix=None): + self.iterable = iterable + self.n = getattr(iterable, "n", 0) + self.epoch = epoch + self.prefix = "" + if epoch is not None: + self.prefix += "epoch {:03d}".format(epoch) + if prefix is not None: + self.prefix += (" | " if self.prefix != "" else "") + prefix + + def __len__(self): + return len(self.iterable) + + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def __iter__(self): + raise NotImplementedError + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + raise NotImplementedError + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + raise NotImplementedError + + def update_config(self, config): + """Log latest configuration.""" + pass + + def _str_commas(self, stats): + return ", ".join(key + "=" + stats[key].strip() for key in stats.keys()) + + def _str_pipes(self, stats): + return " | ".join(key + " " + stats[key].strip() for key in stats.keys()) + + def _format_stats(self, stats): + postfix = OrderedDict(stats) + # Preprocess stats according to datatype + for key in postfix.keys(): + postfix[key] = str(format_stat(postfix[key])) + return postfix + + +@contextmanager +def rename_logger(logger, new_name): + old_name = logger.name + if new_name is not None: + logger.name = new_name + yield logger + logger.name = old_name + + +class JsonProgressBar(BaseProgressBar): + """Log output in JSON format.""" + + def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): + super().__init__(iterable, epoch, prefix) + self.log_interval = log_interval + self.i = None + self.size = None + + def __iter__(self): + self.size = len(self.iterable) + for i, obj in enumerate(self.iterable, start=self.n): + self.i = i + yield obj + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + step = step or self.i or 0 + if step > 0 and self.log_interval is not None and step % self.log_interval == 0: + update = ( + self.epoch - 1 + (self.i + 1) / float(self.size) + if self.epoch is not None + else None + ) + stats = self._format_stats(stats, epoch=self.epoch, update=update) + with rename_logger(logger, tag): + logger.info(json.dumps(stats)) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self.stats = stats + if tag is not None: + self.stats = OrderedDict( + [(tag + "_" + k, v) for k, v in self.stats.items()] + ) + stats = self._format_stats(self.stats, epoch=self.epoch) + with rename_logger(logger, tag): + logger.info(json.dumps(stats)) + + def _format_stats(self, stats, epoch=None, update=None): + postfix = OrderedDict() + if epoch is not None: + postfix["epoch"] = epoch + if update is not None: + postfix["update"] = round(update, 3) + # Preprocess stats according to datatype + for key in stats.keys(): + postfix[key] = format_stat(stats[key]) + return postfix + + +class NoopProgressBar(BaseProgressBar): + """No logging.""" + + def __init__(self, iterable, epoch=None, prefix=None): + super().__init__(iterable, epoch, prefix) + + def __iter__(self): + for obj in self.iterable: + yield obj + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + pass + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + pass + + +class SimpleProgressBar(BaseProgressBar): + """A minimal logger for non-TTY environments.""" + + def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): + super().__init__(iterable, epoch, prefix) + self.log_interval = log_interval + self.i = None + self.size = None + + def __iter__(self): + self.size = len(self.iterable) + for i, obj in enumerate(self.iterable, start=self.n): + self.i = i + yield obj + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + step = step or self.i or 0 + if step > 0 and self.log_interval is not None and step % self.log_interval == 0: + stats = self._format_stats(stats) + postfix = self._str_commas(stats) + with rename_logger(logger, tag): + logger.info( + "{}: {:5d} / {:d} {}".format( + self.prefix, self.i + 1, self.size, postfix + ) + ) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + postfix = self._str_pipes(self._format_stats(stats)) + with rename_logger(logger, tag): + logger.info("{} | {}".format(self.prefix, postfix)) + + +class TqdmProgressBar(BaseProgressBar): + """Log to tqdm.""" + + def __init__(self, iterable, epoch=None, prefix=None): + super().__init__(iterable, epoch, prefix) + from tqdm import tqdm + + self.tqdm = tqdm( + iterable, + self.prefix, + leave=False, + disable=(logger.getEffectiveLevel() > logging.INFO), + ) + + def __iter__(self): + return iter(self.tqdm) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + self.tqdm.set_postfix(self._format_stats(stats), refresh=False) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + postfix = self._str_pipes(self._format_stats(stats)) + with rename_logger(logger, tag): + logger.info("{} | {}".format(self.prefix, postfix)) + + +try: + import functools + + from aim import Repo as AimRepo + + @functools.lru_cache() + def get_aim_run(repo, run_hash): + from aim import Run + + return Run(run_hash=run_hash, repo=repo) + +except ImportError: + get_aim_run = None + AimRepo = None + + +class AimProgressBarWrapper(BaseProgressBar): + """Log to Aim.""" + + def __init__(self, wrapped_bar, aim_repo, aim_run_hash, aim_param_checkpoint_dir): + self.wrapped_bar = wrapped_bar + + if get_aim_run is None: + self.run = None + logger.warning("Aim not found, please install with: pip install aim") + else: + logger.info(f"Storing logs at Aim repo: {aim_repo}") + + if not aim_run_hash: + # Find run based on save_dir parameter + query = f"run.checkpoint.save_dir == '{aim_param_checkpoint_dir}'" + try: + runs_generator = AimRepo(aim_repo).query_runs(query) + run = next(runs_generator.iter_runs()) + aim_run_hash = run.run.hash + except Exception: + pass + + if aim_run_hash: + logger.info(f"Appending to run: {aim_run_hash}") + + self.run = get_aim_run(aim_repo, aim_run_hash) + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to Aim.""" + self._log_to_aim(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self._log_to_aim(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def update_config(self, config): + """Log latest configuration.""" + if self.run is not None: + for key in config: + self.run.set(key, config[key], strict=False) + self.wrapped_bar.update_config(config) + + def _log_to_aim(self, stats, tag=None, step=None): + if self.run is None: + return + + if step is None: + step = stats["num_updates"] + + if "train" in tag: + context = {"tag": tag, "subset": "train"} + elif "val" in tag: + context = {"tag": tag, "subset": "val"} + else: + context = {"tag": tag} + + for key in stats.keys() - {"num_updates"}: + self.run.track(stats[key], name=key, step=step, context=context) + + +try: + _tensorboard_writers = {} + from torch.utils.tensorboard import SummaryWriter +except ImportError: + try: + from tensorboardX import SummaryWriter + except ImportError: + SummaryWriter = None + + +def _close_writers(): + for w in _tensorboard_writers.values(): + w.close() + + +atexit.register(_close_writers) + + +class TensorboardProgressBarWrapper(BaseProgressBar): + """Log to tensorboard.""" + + def __init__(self, wrapped_bar, tensorboard_logdir): + self.wrapped_bar = wrapped_bar + self.tensorboard_logdir = tensorboard_logdir + + if SummaryWriter is None: + logger.warning( + "tensorboard not found, please install with: pip install tensorboard" + ) + + def _writer(self, key): + if SummaryWriter is None: + return None + _writers = _tensorboard_writers + if key not in _writers: + _writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key)) + _writers[key].add_text("sys.argv", " ".join(sys.argv)) + return _writers[key] + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to tensorboard.""" + self._log_to_tensorboard(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self._log_to_tensorboard(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def update_config(self, config): + """Log latest configuration.""" + # TODO add hparams to Tensorboard + self.wrapped_bar.update_config(config) + + def _log_to_tensorboard(self, stats, tag=None, step=None): + writer = self._writer(tag or "") + if writer is None: + return + if step is None: + step = stats["num_updates"] + for key in stats.keys() - {"num_updates"}: + if isinstance(stats[key], AverageMeter): + writer.add_scalar(key, stats[key].val, step) + elif isinstance(stats[key], Number): + writer.add_scalar(key, stats[key], step) + elif torch.is_tensor(stats[key]) and stats[key].numel() == 1: + writer.add_scalar(key, stats[key].item(), step) + writer.flush() + + +try: + import wandb +except ImportError: + wandb = None + + +class WandBProgressBarWrapper(BaseProgressBar): + """Log to Weights & Biases.""" + + def __init__(self, wrapped_bar, wandb_project, run_name=None): + self.wrapped_bar = wrapped_bar + if wandb is None: + logger.warning("wandb not found, pip install wandb") + return + + # reinit=False to ensure if wandb.init() is called multiple times + # within one process it still references the same run + wandb.init(project=wandb_project, reinit=False, name=run_name) + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to tensorboard.""" + self._log_to_wandb(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self._log_to_wandb(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def update_config(self, config): + """Log latest configuration.""" + if wandb is not None: + wandb.config.update(config) + self.wrapped_bar.update_config(config) + + def _log_to_wandb(self, stats, tag=None, step=None): + if wandb is None: + return + if step is None: + step = stats["num_updates"] + + prefix = "" if tag is None else tag + "/" + + for key in stats.keys() - {"num_updates"}: + if isinstance(stats[key], AverageMeter): + wandb.log({prefix + key: stats[key].val}, step=step) + elif isinstance(stats[key], Number): + wandb.log({prefix + key: stats[key]}, step=step) + + +try: + from azureml.core import Run +except ImportError: + Run = None + + +class AzureMLProgressBarWrapper(BaseProgressBar): + """Log to Azure ML""" + + def __init__(self, wrapped_bar): + self.wrapped_bar = wrapped_bar + if Run is None: + logger.warning("azureml.core not found, pip install azureml-core") + return + self.run = Run.get_context() + + def __exit__(self, *exc): + if Run is not None: + self.run.complete() + return False + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to AzureML""" + self._log_to_azureml(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats""" + self._log_to_azureml(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def update_config(self, config): + """Log latest configuration.""" + self.wrapped_bar.update_config(config) + + def _log_to_azureml(self, stats, tag=None, step=None): + if Run is None: + return + if step is None: + step = stats["num_updates"] + + prefix = "" if tag is None else tag + "/" + + for key in stats.keys() - {"num_updates"}: + name = prefix + key + if isinstance(stats[key], AverageMeter): + self.run.log_row(name=name, **{"step": step, key: stats[key].val}) + elif isinstance(stats[key], Number): + self.run.log_row(name=name, **{"step": step, key: stats[key]}) diff --git a/fairseq/fairseq/model_parallel/__init__.py b/fairseq/fairseq/model_parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69f21684872f72ae8ee26d9ff7d2d2b6e6d526c3 --- /dev/null +++ b/fairseq/fairseq/model_parallel/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import criterions, models, modules # noqa diff --git a/fairseq/fairseq/model_parallel/criterions/__init__.py b/fairseq/fairseq/model_parallel/criterions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fae7bd4c2cfa7b4f64ad62dd9b9082f59f0e50d --- /dev/null +++ b/fairseq/fairseq/model_parallel/criterions/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +# automatically import any Python files in the criterions/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("fairseq.model_parallel.criterions." + module) diff --git a/fairseq/fairseq/model_parallel/criterions/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/model_parallel/criterions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c59b5ad984087f9c567039789cad58e434fdd73 Binary files /dev/null and b/fairseq/fairseq/model_parallel/criterions/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc b/fairseq/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8db04acc1e562dc88b2e00a81c87876d3a206dc Binary files /dev/null and b/fairseq/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc differ diff --git a/fairseq/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py b/fairseq/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..5ffbaa87640973317e3cac4c396cdc11af2fa380 --- /dev/null +++ b/fairseq/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion + + +try: + from fairseq.model_parallel.megatron.mpu.cross_entropy import ( + vocab_parallel_cross_entropy, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +@register_criterion("vocab_parallel_cross_entropy") +class VocabParallelCrossEntropyCriterion(FairseqCriterion): + def __init__(self, task, sentence_avg): + super().__init__(task) + self.sentence_avg = sentence_avg + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + target = sample["target"] + + loss = vocab_parallel_cross_entropy(net_output[0].float(), target) + loss = (loss * (target != self.padding_idx)).sum() + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": utils.item(loss.data) if reduce else loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/fairseq/model_parallel/megatron_trainer.py b/fairseq/fairseq/model_parallel/megatron_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..aedf608bce21d11d0a1e9646d9c373aae198dce6 --- /dev/null +++ b/fairseq/fairseq/model_parallel/megatron_trainer.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Train a network across multiple GPUs. +""" + +from fairseq.dataclass.configs import FairseqConfig +from fairseq.distributed import utils as distributed_utils +from fairseq.trainer import Trainer + +try: + from fairseq.model_parallel.megatron.mpu import ( + get_data_parallel_rank, + get_data_parallel_world_size, + get_model_parallel_src_rank, + get_cuda_rng_tracker, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +class MegatronTrainer(Trainer): + """Main class for model parallel with data parallel training.""" + + def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs): + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + super().__init__(cfg, task, model, criterion, **kwargs) + + def clip_grad_norm(self, clip_norm): + def _aggregate_model_parallel_grad_norm(total_norm): + total_norm = total_norm**2 + distributed_utils.all_reduce( + total_norm, group=distributed_utils.get_model_parallel_group() + ) + total_norm = total_norm**0.5 + return total_norm + + return self.optimizer.clip_grad_norm( + clip_norm, + aggregate_norm_fn=_aggregate_model_parallel_grad_norm, + ) + + def save_checkpoint(self, filename, extra_state): + """Save all training state in a checkpoint file.""" + extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states() + super().save_checkpoint(filename, extra_state) + + def load_checkpoint( + self, + filename, + reset_optimizer=False, + reset_lr_scheduler=False, + optimizer_overrides=None, + reset_meters=False, + ): + extra_state = super().load_checkpoint( + filename, + reset_optimizer=reset_optimizer, + reset_lr_scheduler=reset_lr_scheduler, + optimizer_overrides=optimizer_overrides, + reset_meters=reset_meters, + ) + if extra_state is not None and "rng_tracker_states" in extra_state: + get_cuda_rng_tracker().set_states(extra_state["rng_tracker_states"]) + return extra_state diff --git a/fairseq/fairseq/model_parallel/models/__init__.py b/fairseq/fairseq/model_parallel/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3532479e52a0e1f1ba204c6f5d51c71c98ee5df0 --- /dev/null +++ b/fairseq/fairseq/model_parallel/models/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("fairseq.model_parallel.models." + model_name) diff --git a/fairseq/fairseq/model_parallel/models/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/model_parallel/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9444705919733d1ba578cff67d6d28fbccade232 Binary files /dev/null and b/fairseq/fairseq/model_parallel/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/model_parallel/models/__pycache__/transformer.cpython-310.pyc b/fairseq/fairseq/model_parallel/models/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..608e33d6c804f25dc025ee5a0e0dd9b90906a0c4 Binary files /dev/null and b/fairseq/fairseq/model_parallel/models/__pycache__/transformer.cpython-310.pyc differ diff --git a/fairseq/fairseq/model_parallel/models/__pycache__/transformer_lm.cpython-310.pyc b/fairseq/fairseq/model_parallel/models/__pycache__/transformer_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaf84ec5b0a69c783ac87a7528a2ab2a4a239330 Binary files /dev/null and b/fairseq/fairseq/model_parallel/models/__pycache__/transformer_lm.cpython-310.pyc differ diff --git a/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py b/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..117827c3e9c176477f33e3a6fd7fe19a922411a2 --- /dev/null +++ b/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .model import * # noqa diff --git a/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5949b3b1041cd5c9f3f9bea7f52803e36e1cb14d Binary files /dev/null and b/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/layers.cpython-310.pyc b/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d19615dd59c1dbdd0a464f2794dd7a53919c995 Binary files /dev/null and b/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/layers.cpython-310.pyc differ diff --git a/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/model.cpython-310.pyc b/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..148c302a910238c9966daf20657d792ea64617bb Binary files /dev/null and b/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/model.cpython-310.pyc differ diff --git a/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py b/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..85dbd44b3c7f762048ff21808313d0317f8da7a4 --- /dev/null +++ b/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py @@ -0,0 +1,600 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import options, utils +from fairseq.modules import ( + AdaptiveSoftmax, + LayerNorm, + MultiheadAttention, + PositionalEmbedding, +) + +EncoderOut = namedtuple( + "TransformerEncoderOut", + [ + "encoder_out", # T x B x C + "encoder_padding_mask", # B x T + "encoder_embedding", # B x T x C + "encoder_states", # List[T x B x C] + ], +) + + +class TransformerEncoderEmbedding(nn.Module): + """Encoder Embedding + Positional Embedding""" + + def __init__(self, args, embed_tokens): + super().__init__() + self.dropout = args.dropout + self.max_source_positions = args.max_source_positions + self.embed_tokens = embed_tokens + if isinstance(embed_tokens, nn.ModuleList): + self.padding_idx = embed_tokens[0].padding_idx + embed_dim = sum(e.embedding_dim for e in embed_tokens) + else: + self.padding_idx = embed_tokens.padding_idx + embed_dim = embed_tokens.embedding_dim + self.embed_scale = math.sqrt(embed_dim) + self.embed_positions = ( + PositionalEmbedding( + args.max_source_positions, + embed_dim, + self.padding_idx, + learned=args.encoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + if getattr(args, "layernorm_embedding", False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + + def forward(self, input): + # embed tokens and positions + src_tokens = input[0] + prev_output_tokens = input[2] + if isinstance(self.embed_tokens, nn.ModuleList): + x_embed_list = [] + for embed_tokens_part in self.embed_tokens: + x_embed_list.append(embed_tokens_part(src_tokens)) + + embedded = torch.cat(x_embed_list, dim=-1) + else: + embedded = self.embed_tokens(src_tokens) + x = embed = self.embed_scale * embedded + if self.embed_positions is not None: + x = embed + self.embed_positions(src_tokens) + if self.layernorm_embedding: + x = self.layernorm_embedding(x) + x = F.dropout(x, p=self.dropout, training=self.training) + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # compute padding mask + encoder_padding_mask = src_tokens.eq(self.padding_idx) + return (x, encoder_padding_mask, prev_output_tokens) + + +class TransformerEncoderLayerNorm(nn.Module): + """ + Layer norm at the the end of all encoder layers if + args.encoder_enormalize_before = True + """ + + def __init__(self, args, embed_dim): + super().__init__() + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward(self, input): + x = input[0] + encoder_padding_mask = input[1] + prev_output_tokens = input[2] + if self.layer_norm: + x = self.layer_norm(x) + # keeping track of the incremental_state is not supported yet + return (x, encoder_padding_mask, prev_output_tokens) + + +class TransformerDecoderEmbedding(nn.Module): + """Decoder Embedding + Positional Embedding""" + + def __init__(self, args, embed_tokens): + super().__init__() + self.dropout = args.dropout + self.share_input_output_embed = args.share_decoder_input_output_embed + input_embed_dim = ( + sum(e.embedding_dim for e in embed_tokens) + if isinstance(embed_tokens, nn.ModuleList) + else embed_tokens.embedding_dim + ) + embed_dim = args.decoder_embed_dim + self.output_embed_dim = args.decoder_output_dim + + padding_idx = ( + embed_tokens[0].padding_idx + if isinstance(embed_tokens, nn.ModuleList) + else embed_tokens.padding_idx + ) + self.max_target_positions = args.max_target_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + args.max_target_positions, + embed_dim, + padding_idx, + learned=args.decoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + + def forward(self, input): + mt_task = False + if isinstance(input, tuple): + if len(input) == 3: + encoder_out = input[0] + encoder_padding_mask = input[1] + prev_output_tokens = input[2] + incremental_state = None # Hardcoding to avoid passing of None objects + mt_task = True + else: + # HACK for now, need to fix (TODO sidgoyal) + prev_output_tokens = input[0] + # discard "src_lengths" + encoder_out = None + encoder_padding_mask = None + incremental_state = None + + else: + prev_output_tokens = input + encoder_out = None + encoder_padding_mask = None + incremental_state = None + + positions = ( + self.embed_positions( + prev_output_tokens, + incremental_state=incremental_state, + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + + if isinstance(self.embed_tokens, nn.ModuleList): + x_embed_list = [] + for embed_tokens_part in self.embed_tokens: + x_embed_list.append(embed_tokens_part(prev_output_tokens)) + + x = self.embed_scale * torch.cat(x_embed_list, dim=-1) + else: + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + if mt_task: + return (x, encoder_out, encoder_padding_mask) + return x + + +class TransformerDecoderOutputLayer(nn.Module): + def __init__(self, args, embed_tokens, dictionary): + super().__init__() + self.share_input_output_embed = args.share_decoder_input_output_embed + self.embed_tokens = embed_tokens + self.output_embed_dim = args.decoder_output_dim + embed_dim = args.decoder_embed_dim + + self.project_out_dim = ( + Linear(embed_dim, self.output_embed_dim, bias=False) + if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights + else None + ) + self.adaptive_softmax = None + if args.adaptive_softmax_cutoff is not None: + assert not isinstance(embed_tokens, nn.ModuleList) + self.adaptive_softmax = AdaptiveSoftmax( + len(dictionary), + self.output_embed_dim, + options.eval_str_list(args.adaptive_softmax_cutoff, type=int), + dropout=args.adaptive_softmax_dropout, + adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, + factor=args.adaptive_softmax_factor, + tie_proj=args.tie_adaptive_proj, + ) + elif not self.share_input_output_embed: + self.embed_tokens = nn.Parameter( + torch.Tensor(len(dictionary), self.output_embed_dim) + ) + nn.init.normal_( + self.embed_tokens, mean=0, std=self.output_embed_dim**-0.5 + ) + + if args.decoder_normalize_before and not getattr( + args, "no_decoder_final_norm", False + ): + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward(self, input, apply_final_proj=True): + if isinstance(input, tuple): + x = input[0] + else: + x = input + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + if apply_final_proj: + x = self.output_layer(x) + return x + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + if self.share_input_output_embed: + if isinstance(self.embed_tokens, nn.ModuleList): + output = None + for i, emb in enumerate(self.embed_tokens): + sidx = i * emb.embedding_dim + eidx = (i + 1) * emb.embedding_dim + if output is None: + output = F.linear(features[:, :, sidx:eidx], emb.weight) + else: + output += F.linear(features[:, :, sidx:eidx], emb.weight) + + return output + else: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_tokens) + else: + return features + + +class TransformerEncoderLayer(nn.Module): + """Encoder layer block. + In the original paper each operation (multi-head attention or FFN) is + postprocessed with: `dropout -> add residual -> layernorm`. In the + tensor2tensor code they suggest that learning is more robust when + preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *args.encoder_normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + + def __init__(self, args): + super().__init__() + self.embed_dim = args.encoder_embed_dim + self.self_attn = MultiheadAttention( + self.embed_dim, + args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + ) + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.dropout = args.dropout + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, "activation_fn", "relu") + ) + self.activation_dropout = getattr(args, "activation_dropout", 0) + if self.activation_dropout == 0: + # for backwards compatibility with models that use args.relu_dropout + self.activation_dropout = getattr(args, "relu_dropout", 0) + self.normalize_before = args.encoder_normalize_before + self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) + self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) + + def upgrade_state_dict_named(self, state_dict, name): + """ + Rename layer norm states from `...layer_norms.0.weight` to + `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to + `...final_layer_norm.weight` + """ + layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layer_norms.{}.{}".format(name, old, m) + if k in state_dict: + state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] + del state_dict[k] + + def forward(self, input): + """ + Args: + input (Tuple): + input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + input[1] (ByteTensor/FloatTensor): encoder padding mask - + binary ByteTensor of shape `(batch, src_len)` where padding elements + are indicated by ``1``. + input[2] (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing) + Returns: + output (Tuple): + output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` + output[1] (ByteTensor/FloatTensor): encoder padding mask + output[2] (LongTensor): previous decoder outputs + """ + x = input[0] + encoder_padding_mask = input[1] + prev_output_tokens = input[2] + residual = x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) + x, _ = self.self_attn( + query=x, key=x, value=x, key_padding_mask=encoder_padding_mask + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) + + residual = x + x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) + return (x, encoder_padding_mask, prev_output_tokens) + + def maybe_layer_norm(self, layer_norm, x, before=False, after=False): + assert before ^ after + if after ^ self.normalize_before: + return layer_norm(x) + else: + return x + + +class TransformerDecoderLayer(nn.Module): + """Decoder layer block. + + In the original paper each operation (multi-head attention, encoder + attention or FFN) is postprocessed with: `dropout -> add residual -> + layernorm`. In the tensor2tensor code they suggest that learning is more + robust when preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *args.decoder_normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): + super().__init__() + self.embed_dim = args.decoder_embed_dim + self.self_attn = MultiheadAttention( + embed_dim=self.embed_dim, + num_heads=args.decoder_attention_heads, + dropout=args.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=True, + ) + self.dropout = args.dropout + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, "activation_fn", "relu") + ) + self.activation_dropout = getattr(args, "activation_dropout", 0) + if self.activation_dropout == 0: + # for backwards compatibility with models that use args.relu_dropout + self.activation_dropout = getattr(args, "relu_dropout", 0) + self.normalize_before = args.decoder_normalize_before + + # use layerNorm rather than FusedLayerNorm for exporting. + # char_inputs can be used to determint this. + # TODO remove this once we update apex with the fix + export = getattr(args, "char_inputs", False) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = MultiheadAttention( + self.embed_dim, + args.decoder_attention_heads, + kdim=getattr(args, "encoder_embed_dim", None), + vdim=getattr(args, "encoder_embed_dim", None), + dropout=args.attention_dropout, + encoder_decoder_attention=True, + ) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) + self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) + + self.final_layer_norm = LayerNorm(self.embed_dim, export=export) + self.need_attn = True + + self.onnx_trace = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def forward(self, input): + """ + Args: + input (Tuple): + input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)` + input[2] (ByteTensor/FloatTensor): encoder padding mask - + binary ByteTensor of shape `(batch, src_len)` where padding elements + are indicated by ``1``. + Returns: + output (Tuple): + output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` + output[1] (ByteTensor/FloatTensor): encoder padding mask + output[2] (LongTensor): previous decoder outputs + """ + # Note: incremental state is not yet supported + mt_task = False + if isinstance(input, tuple): + x = input[0] + encoder_out = input[1] + encoder_padding_mask = input[2] + incremental_state = None + mt_task = True + else: + x = input + encoder_out = None + encoder_padding_mask = None + incremental_state = None + + if incremental_state is None: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + # TODO: add back prev_self_attn_state, prev_attn_state, + # self_attn_padding_mask + prev_self_attn_state = None + prev_attn_state = None + self_attn_padding_mask = None + + residual = x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) + if prev_self_attn_state is not None: + if incremental_state is None: + incremental_state = {} + prev_key, prev_value = prev_self_attn_state + saved_state = {"prev_key": prev_key, "prev_value": prev_value} + self.self_attn._set_input_buffer(incremental_state, saved_state) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) + + if self.encoder_attn is not None: + residual = x + x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) + if prev_attn_state is not None: + if incremental_state is None: + incremental_state = {} + prev_key, prev_value = prev_attn_state + saved_state = {"prev_key": prev_key, "prev_value": prev_value} + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=(not self.training and self.need_attn), + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) + + residual = x + x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) + + if mt_task: + return (x, encoder_out, encoder_padding_mask) + return x + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + if self._future_mask.size(0) < dim: + self._future_mask = torch.triu( + utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def maybe_layer_norm(self, layer_norm, x, before=False, after=False): + assert before ^ after + if after ^ self.normalize_before: + return layer_norm(x) + else: + return x + + def make_generation_fast_(self, need_attn=False, **kwargs): + self.need_attn = need_attn + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m diff --git a/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7873ac679170d2647f0491747a75f60364e248dc --- /dev/null +++ b/fairseq/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -0,0 +1,779 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import ( + Embedding, + TransformerDecoderEmbedding, + TransformerDecoderLayer, + TransformerDecoderOutputLayer, + TransformerEncoderEmbedding, + TransformerEncoderLayer, + TransformerEncoderLayerNorm, +) +from fairseq.models import ( + BaseFairseqModel, + FairseqDecoder, + FairseqEncoder, + register_model, + register_model_architecture, +) +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.transformer import ( + base_architecture, + transformer_iwslt_de_en, + transformer_wmt_en_de_big, +) +from fairseq.modules import SinusoidalPositionalEmbedding + + +logger = logging.getLogger(__name__) + + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 +TORCH_PIPE = False +RPC_INIT = False + + +def import_pipe(): + global TORCH_PIPE + global RPC_INIT + try: + from torch.distributed.pipeline.sync import Pipe # noqa + + global Pipe + from torch.distributed.pipeline.sync.utils import partition_model + + global partition_model + from torch.distributed import rpc + import tempfile + + TORCH_PIPE = True + # Initialize single process RPC agent since TORCH_PIPE requires + # RRef. RRef depends on RPC being initialized and as a result we initialize + # RPC with a single node. + tmpfile = tempfile.NamedTemporaryFile() + if not RPC_INIT: + rpc.init_rpc( + name="worker", + rank=0, + world_size=1, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + init_method="file://{}".format(tmpfile.name), + ), + ) + RPC_INIT = True + logger.info("Using torch pipe") + except ImportError: + try: + from fairscale.nn import Pipe # noqa + + logger.info("Using fairscale pipe") + except ImportError: + raise ImportError("Please install fairscale with: pip install fairscale") + + +@register_model("pipeline_parallel_transformer") +class PipelineParallelTransformerModel(BaseFairseqModel): + def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): + import_pipe() + super().__init__() + assert isinstance(encoder, FairseqEncoder) + assert isinstance(decoder, FairseqDecoder) + encoder_module_list = ( + [encoder.embedding_layer] + + list(encoder.encoder_layers) + + [encoder.final_layer_norm] + ) + self.num_encoder_modules = len(encoder_module_list) + decoder_module_list = ( + [decoder.embedding_layer] + + list(decoder.decoder_layers) + + [decoder.decoder_output_layer] + ) + self.num_decoder_modules = len(decoder_module_list) + module_list = encoder_module_list + decoder_module_list + self.devices = devices + if TORCH_PIPE: + self.model = Pipe( + partition_model(nn.Sequential(*module_list), balance, devices), + chunks=chunks, + checkpoint=checkpoint, + ) + else: + self.model = Pipe( + nn.Sequential(*module_list), + balance=balance, + devices=devices, + chunks=chunks, + checkpoint=checkpoint, + ) + self.encoder_max_positions = self.max_positions_helper( + encoder.embedding_layer, "max_source_positions" + ) + self.decoder_max_positions = self.max_positions_helper( + decoder.embedding_layer, "max_target_positions" + ) + self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None) + # Note: To be populated during inference + self.encoder = None + self.decoder = None + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + if self.training: + input_lst = [src_tokens, src_lengths, prev_output_tokens] + input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst) + if TORCH_PIPE: + return self.model(input).local_value() + else: + return self.model(input) + else: + assert self.encoder is not None and self.decoder is not None, ( + "encoder and decoder need to be initialized by " + + "calling the `prepare_for_inference_()` method" + ) + encoder_output_tuple = self.encoder(input) + return self.decoder(encoder_output_tuple) + + def prepare_for_inference_(self, cfg): + if self.encoder is not None and self.decoder is not None: + logger.info("Encoder and Decoder already initialized") + return + encoder_module_list = [] + decoder_module_list = [] + module_count = 0 + for partition in self.model.partitions: + for module in partition: + if module_count < self.num_encoder_modules: + encoder_module_list.append(module) + else: + decoder_module_list.append(module) + module_count += 1 + self.model = None + self.encoder = TransformerEncoder( + cfg.distributed_training, None, None, encoder_module_list + ) + self.decoder = TransformerDecoder( + cfg.distributed_training, + None, + None, + decoder_module_list=decoder_module_list, + ) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--activation-fn', + choices=utils.get_available_activation_fns(), + help='activation function to use') + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--attention-dropout', type=float, metavar='D', + help='dropout probability for attention weights') + parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', + help='dropout probability after activation in FFN.') + parser.add_argument('--encoder-embed-path', type=str, metavar='STR', + help='path to pre-trained encoder embedding') + parser.add_argument('--encoder-embed-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', + help='encoder embedding dimension for FFN') + parser.add_argument('--encoder-layers', type=int, metavar='N', + help='num encoder layers') + parser.add_argument('--encoder-attention-heads', type=int, metavar='N', + help='num encoder attention heads') + parser.add_argument('--encoder-normalize-before', action='store_true', + help='apply layernorm before each encoder block') + parser.add_argument('--encoder-learned-pos', action='store_true', + help='use learned positional embeddings in the encoder') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', + help='decoder embedding dimension for FFN') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='num decoder layers') + parser.add_argument('--decoder-attention-heads', type=int, metavar='N', + help='num decoder attention heads') + parser.add_argument('--decoder-learned-pos', action='store_true', + help='use learned positional embeddings in the decoder') + parser.add_argument('--decoder-normalize-before', action='store_true', + help='apply layernorm before each decoder block') + parser.add_argument('--share-decoder-input-output-embed', action='store_true', + help='share decoder input and output embeddings') + parser.add_argument('--share-all-embeddings', action='store_true', + help='share encoder, decoder and output embeddings' + ' (requires shared dictionary and embed dim)') + parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', + help='if set, disables positional embeddings (outside self attention)') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion'), + parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', + help='sets adaptive softmax dropout for the tail projections') + parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1, + help='Number of embedding layer chunks (enables more even distribution' + 'of optimizer states across data parallel nodes' + 'when using optimizer state sharding and' + 'a big embedding vocabulary)') + # fmt: on + + @classmethod + def build_model_base(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if not hasattr(args, "max_source_positions"): + args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS + if not hasattr(args, "max_target_positions"): + args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1): + assert embed_dim % num_embed_chunks == 0, ( + f"Number of embedding chunks = {num_embed_chunks} should be " + + f"divisible by the embedding dimension = {embed_dim}" + ) + assert path is None or num_embed_chunks == 1, ( + "Loading embedding from a path with number of embedding chunks > 1" + + " is not yet supported" + ) + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + # if provided, load from preloaded dictionaries + if path: + emb = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + else: + embed_chunk_dim = embed_dim // num_embed_chunks + emb = nn.ModuleList() + for i in range(num_embed_chunks): + emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx)) + return emb + + num_embed_chunks = args.num_embedding_chunks + if args.share_all_embeddings: + if src_dict != tgt_dict: + raise ValueError("--share-all-embeddings requires a joined dictionary") + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + encoder_embed_tokens = build_embedding( + src_dict, + args.encoder_embed_dim, + args.encoder_embed_path, + num_embed_chunks, + ) + decoder_embed_tokens = encoder_embed_tokens + args.share_decoder_input_output_embed = True + else: + assert args.share_decoder_input_output_embed or num_embed_chunks == 1, ( + "Not sharing decoder I/O embeddings is not yet supported with number of " + + "embedding chunks > 1" + ) + encoder_embed_tokens = build_embedding( + src_dict, + args.encoder_embed_dim, + args.encoder_embed_path, + num_embed_chunks, + ) + decoder_embed_tokens = build_embedding( + tgt_dict, + args.decoder_embed_dim, + args.decoder_embed_path, + num_embed_chunks, + ) + + encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) + decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + return (encoder, decoder) + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return TransformerEncoder(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return TransformerDecoder(args, tgt_dict, embed_tokens) + + @classmethod + def build_model(cls, args, task): + encoder, decoder = cls.build_model_base(args, task) + return PipelineParallelTransformerModel( + encoder=encoder, + decoder=decoder, + balance=utils.eval_str_list(args.pipeline_balance, type=int), + devices=utils.eval_str_list(args.pipeline_devices, type=int), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + + def output_layer(self, features, **kwargs): + """Project features to the default output size (typically vocabulary size).""" + return self.decoder.output_layer(features, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return (self.encoder_max_positions, self.decoder_max_positions) + + def max_positions_helper( + self, embedding_layer, max_positions_field="max_source_positions" + ): + """Maximum input length supported by the encoder or decoder.""" + if embedding_layer.embed_positions is None: + return getattr(embedding_layer, max_positions_field) + return min( + getattr(embedding_layer, max_positions_field), + embedding_layer.embed_positions.max_positions, + ) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + + if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: + if sample is not None: + assert "target" in sample + target = sample["target"] + else: + target = None + out = self.adaptive_softmax.get_log_prob(net_output, target=target) + return out.exp_() if not log_probs else out + + # A Pipe() module returns a tuple of tensors as the output. + # In this case, the tuple has one element - the output tensor of logits + logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0] + if log_probs: + return utils.log_softmax(logits, dim=-1, onnx_trace=False) + else: + return utils.softmax(logits, dim=-1, onnx_trace=False) + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder_max_positions + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + """Copies parameters and buffers from *state_dict* into this module and + its descendants. + + Overrides the method in :class:`nn.Module`. Compared with that method + this additionally "upgrades" *state_dicts* from old checkpoints. + """ + self.upgrade_state_dict(state_dict) + is_regular_transformer = not any("model.partitions" in k for k in state_dict) + if is_regular_transformer: + state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict) + return super().load_state_dict(state_dict, strict) + + def convert_to_pipeline_parallel_state_dict(self, state_dict): + new_state_dict = self.state_dict() + encoder_layer_idx = 0 + decoder_layer_idx = 0 + encoder_key_suffixes = [ + "self_attn.k_proj.weight", + "self_attn.k_proj.bias", + "self_attn.v_proj.weight", + "self_attn.v_proj.bias", + "self_attn.q_proj.weight", + "self_attn.q_proj.bias", + "self_attn.out_proj.weight", + "self_attn.out_proj.bias", + "self_attn_layer_norm.weight", + "self_attn_layer_norm.bias", + "fc1.weight", + "fc1.bias", + "fc2.weight", + "fc2.bias", + "final_layer_norm.weight", + "final_layer_norm.bias", + ] + decoder_key_suffixes = [ + "self_attn.k_proj.weight", + "self_attn.k_proj.bias", + "self_attn.v_proj.weight", + "self_attn.v_proj.bias", + "self_attn.q_proj.weight", + "self_attn.q_proj.bias", + "self_attn.out_proj.weight", + "self_attn.out_proj.bias", + "self_attn_layer_norm.weight", + "self_attn_layer_norm.bias", + "encoder_attn.k_proj.weight", + "encoder_attn.k_proj.bias", + "encoder_attn.v_proj.weight", + "encoder_attn.v_proj.bias", + "encoder_attn.q_proj.weight", + "encoder_attn.q_proj.bias", + "encoder_attn.out_proj.weight", + "encoder_attn.out_proj.bias", + "encoder_attn_layer_norm.weight", + "encoder_attn_layer_norm.bias", + "fc1.weight", + "fc1.bias", + "fc2.weight", + "fc2.bias", + "final_layer_norm.weight", + "final_layer_norm.bias", + ] + for pid, partition in enumerate(self.model.partitions): + logger.info(f"Begin Partition {pid}") + for mid, module in enumerate(partition): + # fmt: off + if isinstance(module, TransformerEncoderEmbedding): + new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight'] + if isinstance(module, TransformerEncoderLayer): + for suffix in encoder_key_suffixes: + new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}'] + encoder_layer_idx += 1 + if isinstance(module, TransformerDecoderLayer): + for suffix in decoder_key_suffixes: + new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}'] + decoder_layer_idx += 1 + if isinstance(module, TransformerEncoderLayerNorm): + if 'encoder.layer_norm.weight' in state_dict: + new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight'] + new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias'] + if isinstance(module, TransformerDecoderEmbedding): + new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight'] + if isinstance(module, TransformerDecoderOutputLayer): + new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight'] + # fmt: on + return new_state_dict + + +class TransformerEncoder(FairseqEncoder): + """ + Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`TransformerEncoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): encoding dictionary + embed_tokens (torch.nn.Embedding): input embedding + """ + + def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + import_pipe() + self.use_pipeline = encoder_module_list is not None + if not self.use_pipeline: + self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) + self.encoder_layers = nn.Sequential( + *[TransformerEncoderLayer(args) for i in range(args.encoder_layers)] + ) + if isinstance(embed_tokens, nn.ModuleList): + emb_dim = sum(e.embedding_dim for e in embed_tokens) + else: + emb_dim = embed_tokens.embedding_dim + self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim) + else: + encoder_balance = utils.eval_str_list( + args.pipeline_encoder_balance, type=int + ) + encoder_devices = utils.eval_str_list( + args.pipeline_encoder_devices, type=int + ) + assert sum(encoder_balance) == len(encoder_module_list), ( + f"Sum of encoder_balance={encoder_balance} is not equal " + + f"to num_encoder_modules={len(encoder_module_list)}" + ) + if TORCH_PIPE: + self.model = Pipe( + module=partition_model( + nn.Sequential(*encoder_module_list), + encoder_balance, + encoder_devices, + ), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + else: + self.model = Pipe( + module=nn.Sequential(*encoder_module_list), + balance=encoder_balance, + devices=encoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + + def forward(self, src_tokens, src_lengths): + """ + Args: + input_tuple( + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + ) + + Returns: + output_tuple( + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - prev_output_tokens + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + ) + """ + dummy_prev_output_tokens = torch.zeros( + 1, dtype=src_tokens.dtype, device=src_tokens.device + ) + input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens) + if self.use_pipeline: + input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) + if TORCH_PIPE: + encoder_out = self.model(input_tuple).local_value() + else: + encoder_out = self.model(input_tuple) + else: + encoder_embed_output_tuple = self.embedding_layer(input_tuple) + encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple) + encoder_out = self.final_layer_norm(encoder_layers_output) + # first element is the encoder output + # second element is the encoder padding mask + # the remaining elements of EncoderOut are not computed by + # the PipelineParallelTransformer + return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None) + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + if encoder_out.encoder_out is not None: + encoder_out = encoder_out._replace( + encoder_out=encoder_out.encoder_out.index_select(1, new_order) + ) + if encoder_out.encoder_padding_mask is not None: + encoder_out = encoder_out._replace( + encoder_padding_mask=encoder_out.encoder_padding_mask.index_select( + 0, new_order + ) + ) + if encoder_out.encoder_embedding is not None: + encoder_out = encoder_out._replace( + encoder_embedding=encoder_out.encoder_embedding.index_select( + 0, new_order + ) + ) + if encoder_out.encoder_states is not None: + for idx, state in enumerate(encoder_out.encoder_states): + encoder_out.encoder_states[idx] = state.index_select(1, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + if self.embedding_layer.embed_positions is None: + return self.embedding_layer.max_source_positions + return min( + self.embedding_layer.max_source_positions, + self.embedding_layer.embed_positions.max_positions, + ) + + +class TransformerDecoder(FairseqDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + decoder_module_list=None, + ): + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + import_pipe() + self.use_pipeline = decoder_module_list is not None + if not self.use_pipeline: + self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) + self.decoder_layers = nn.Sequential( + *[ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ] + ) + self.decoder_output_layer = TransformerDecoderOutputLayer( + args, embed_tokens, dictionary + ) + else: + decoder_balance = utils.eval_str_list( + args.pipeline_decoder_balance, type=int + ) + decoder_devices = utils.eval_str_list( + args.pipeline_decoder_devices, type=int + ) + assert sum(decoder_balance) == len(decoder_module_list), ( + f"Sum of decoder_balance={decoder_balance} is not equal " + + f"to num_decoder_modules={len(decoder_module_list)}" + ) + if TORCH_PIPE: + self.model = Pipe( + module=partition_model( + nn.Sequential(*decoder_module_list), + decoder_balance, + decoder_devices, + ), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + else: + self.model = Pipe( + module=nn.Sequential(*decoder_module_list), + balance=decoder_balance, + devices=decoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + + def forward( + self, + prev_output_tokens, + encoder_out=None, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + input_tuple = ( + encoder_out.encoder_out, + encoder_out.encoder_padding_mask, + prev_output_tokens, + ) + if self.use_pipeline: + input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) + if TORCH_PIPE: + return (self.model(input_tuple).local_value(),) + else: + return (self.model(input_tuple),) + else: + embed_layer_output = self.embedding_layer(input_tuple) + state = self.decoder_layers(embed_layer_output) + return (self.decoder_output_layer(state),) + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + if self.share_input_output_embed: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_out) + else: + return features + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embedding_layer.embed_positions is None: + return self.embedding_layer.max_target_positions + return min( + self.embedding_layer.max_target_positions, + self.embedding_layer.embed_positions.max_positions, + ) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + for i in range(len(self.layers)): + # update layer norms + layer_norm_map = { + "0": "self_attn_layer_norm", + "1": "encoder_attn_layer_norm", + "2": "final_layer_norm", + } + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) + if k in state_dict: + state_dict[ + "{}.layers.{}.{}.{}".format(name, i, new, m) + ] = state_dict[k] + del state_dict[k] + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + + return state_dict + + +@register_model_architecture( + "pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel" +) +def transformer_iwslt_de_en_dist(args): + transformer_iwslt_de_en(args) + + +@register_model_architecture( + "pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel" +) +def transformer_wmt_en_de_big_dist(args): + transformer_wmt_en_de_big(args) diff --git a/fairseq/fairseq/model_parallel/models/roberta/__init__.py b/fairseq/fairseq/model_parallel/models/roberta/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..117827c3e9c176477f33e3a6fd7fe19a922411a2 --- /dev/null +++ b/fairseq/fairseq/model_parallel/models/roberta/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .model import * # noqa diff --git a/fairseq/fairseq/model_parallel/models/roberta/__pycache__/__init__.cpython-310.pyc b/fairseq/fairseq/model_parallel/models/roberta/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b621c6635b5568591e94affba212c888f8195e91 Binary files /dev/null and b/fairseq/fairseq/model_parallel/models/roberta/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/fairseq/model_parallel/models/roberta/__pycache__/model.cpython-310.pyc b/fairseq/fairseq/model_parallel/models/roberta/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e2af3cb59a00b5019c562bbe4a9cc041bb2c20d Binary files /dev/null and b/fairseq/fairseq/model_parallel/models/roberta/__pycache__/model.cpython-310.pyc differ diff --git a/fairseq/fairseq/model_parallel/models/roberta/model.py b/fairseq/fairseq/model_parallel/models/roberta/model.py new file mode 100644 index 0000000000000000000000000000000000000000..77a80ef72057219110b34678a38705549910edd3 --- /dev/null +++ b/fairseq/fairseq/model_parallel/models/roberta/model.py @@ -0,0 +1,225 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +RoBERTa: A Robustly Optimized BERT Pretraining Approach. +""" + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.model_parallel.models.transformer import ModelParallelTransformerEncoder +from fairseq.models import register_model, register_model_architecture +from fairseq.models.roberta import ( + roberta_base_architecture, + roberta_prenorm_architecture, + RobertaEncoder, + RobertaModel, +) +from fairseq.modules import LayerNorm + + +try: + from fairseq.model_parallel.megatron.mpu import ( + copy_to_model_parallel_region, + gather_from_model_parallel_region, + ColumnParallelLinear, + VocabParallelEmbedding, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + +logger = logging.getLogger(__name__) + + +@register_model("model_parallel_roberta") +class ModelParallelRobertaModel(RobertaModel): + def __init__(self, args, encoder): + super().__init__(args, encoder) + + self.classification_heads = nn.ModuleDict() + + @staticmethod + def add_args(parser): + RobertaModel.add_args(parser) + parser.add_argument( + "--no-final-layer-norm", + action="store_true", + help=( + "don't add final layernorm (only applicable when " + "--encoder-normalize-before=True" + ), + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present + base_architecture(args) + + task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + + if not hasattr(args, "max_positions"): + args.max_positions = args.tokens_per_sample + + if getattr(args, "untie_weights_roberta", False): + raise NotImplementedError( + "--untie-weights-roberta is not supported in model parallel mode" + ) + + encoder = ModelParallelRobertaEncoder(args, task.source_dictionary) + return cls(args, encoder) + + def forward( + self, + src_tokens, + features_only=False, + return_all_hiddens=False, + classification_head_name=None, + **kwargs + ): + if classification_head_name is not None: + features_only = True + + x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs) + + if classification_head_name is not None: + x = self.classification_heads[classification_head_name](x) + return x, extra + + def register_classification_head( + self, name, num_classes=None, inner_dim=None, **kwargs + ): + """Register a classification head.""" + if name in self.classification_heads: + prev_num_classes = self.classification_heads[name].out_proj.out_features + prev_inner_dim = self.classification_heads[name].dense.out_features + if num_classes != prev_num_classes or inner_dim != prev_inner_dim: + logger.warning( + 're-registering head "{}" with num_classes {} (prev: {}) ' + "and inner_dim {} (prev: {})".format( + name, num_classes, prev_num_classes, inner_dim, prev_inner_dim + ) + ) + self.classification_heads[name] = ModelParallelRobertaClassificationHead( + self.args.encoder_embed_dim, + inner_dim or self.args.encoder_embed_dim, + num_classes, + self.args.pooler_activation_fn, + self.args.pooler_dropout, + ) + + +class ModelParallelRobertaLMHead(nn.Module): + """Head for masked language modeling.""" + + def __init__(self, embed_dim, output_dim, activation_fn, weight=None): + super().__init__() + self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.layer_norm = LayerNorm(embed_dim) + + if weight is None: + weight = nn.Linear(embed_dim, output_dim, bias=False).weight + self.weight = weight + self.bias = nn.Parameter(torch.zeros(output_dim)) + + def forward(self, features, masked_tokens=None, **kwargs): + # Only project the unmasked tokens while training, + # saves both memory and computation + if masked_tokens is not None: + features = features[masked_tokens, :] + + x = self.dense(features) + x = self.activation_fn(x) + x = self.layer_norm(x) + + x = copy_to_model_parallel_region(x) + # project back to size of vocabulary with bias + x = F.linear(x, self.weight) + x = gather_from_model_parallel_region(x).contiguous() + x = x + self.bias + return x + + +class ModelParallelRobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout + ): + super().__init__() + self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = self.activation_fn(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class ModelParallelRobertaEncoder(RobertaEncoder): + """RoBERTa encoder.""" + + def __init__(self, args, dictionary): + super().__init__(args, dictionary) + assert not self.args.untie_weights_roberta + + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) + + def build_encoder(self, args, dictionary, embed_tokens): + return ModelParallelTransformerEncoder(args, dictionary, embed_tokens) + + def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): + return ModelParallelRobertaLMHead(embed_dim, output_dim, activation_fn, weight) + + +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta") +def base_architecture(args): + args.no_final_layer_norm = getattr(args, "no_final_layer_norm", False) + # model parallel RoBERTa defaults to "Pre-LN" formulation + roberta_prenorm_architecture(args) + + +# earlier versions of model parallel RoBERTa removed the final layer norm +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_v1") +def model_parallel_roberta_v1_architecture(args): + args.no_final_layer_norm = getattr(args, "no_final_layer_norm", True) + base_architecture(args) + + +@register_model_architecture( + "model_parallel_roberta", "model_parallel_roberta_postnorm" +) +def model_parallel_roberta_postnorm_architecture(args): + # the original BERT/RoBERTa uses the "Post-LN" formulation + roberta_base_architecture(args) + + +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base") +def model_parallel_roberta_base_architecture(args): + base_architecture(args) + + +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large") +def model_parallel_roberta_large_architecture(args): + args.encoder_layers = getattr(args, "encoder_layers", 24) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + base_architecture(args) diff --git a/fairseq/fairseq/model_parallel/models/transformer.py b/fairseq/fairseq/model_parallel/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cf3b2e8baf01389a34056cc68cbf6ad1d4475707 --- /dev/null +++ b/fairseq/fairseq/model_parallel/models/transformer.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch.nn as nn + +from fairseq.model_parallel.modules import ( + ModelParallelTransformerDecoderLayer, + ModelParallelTransformerEncoderLayer, +) +from fairseq.models import register_model +from fairseq.models.transformer import ( + TransformerDecoder, + TransformerEncoder, + TransformerModel, +) + +try: + from fairseq.model_parallel.megatron.mpu import ( + VocabParallelEmbedding, + copy_to_model_parallel_region, + gather_from_model_parallel_region, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +logger = logging.getLogger(__name__) + + +@register_model("model_parallel_transformer") +class ModelParallelTransformerModel(TransformerModel): + """ + Model parallel Transformer model. + """ + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + dictionary.pad_to_multiple_(args.model_parallel_size * 8) + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + + def _vocab_init(tensor, **kwargs): + nn.init.normal_(tensor, mean=0, std=num_embeddings**-0.5) + nn.init.constant_(tensor[1], 0) + + emb = VocabParallelEmbedding( + num_embeddings, embed_dim, padding_idx, init_method=_vocab_init + ) + # if provided, load from preloaded dictionaries + if path: + raise NotImplementedError( + "Loading of embedding from path is not supported for model parallel" + ) + return emb + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return ModelParallelTransformerEncoder(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return ModelParallelTransformerDecoder( + args, + tgt_dict, + embed_tokens, + no_encoder_attn=getattr(args, "no_cross_attention", False), + ) + + +class ModelParallelTransformerEncoder(TransformerEncoder): + """ + Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`ModelParallelTransformerEncoderLayer`. + """ + + def __init__(self, args, dictionary, embed_tokens): + super().__init__(args, dictionary, embed_tokens) + + if args.no_final_layer_norm: + self.layer_norm = None + + def build_encoder_layer(self, args): + return ModelParallelTransformerEncoderLayer(args) + + +class ModelParallelTransformerDecoder(TransformerDecoder): + """ + Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`ModelParallelTransformerDecoderLayer`. + """ + + def build_decoder_layer(self, args, no_encoder_attn=False): + return ModelParallelTransformerDecoderLayer(args, no_encoder_attn) + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + if not self.share_input_output_embed: + raise NotImplementedError( + "Model parallel training currently requires --share-decoder-input-output-embed" + ) + + features = copy_to_model_parallel_region(features) + + # project back to size of vocabulary + x = self.output_projection(features) + + if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy": + x = gather_from_model_parallel_region(x).contiguous() + return x diff --git a/fairseq/fairseq/model_parallel/models/transformer_lm.py b/fairseq/fairseq/model_parallel/models/transformer_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..03e4dbe26393eedfb71da94d6675b08cbdb8626d --- /dev/null +++ b/fairseq/fairseq/model_parallel/models/transformer_lm.py @@ -0,0 +1,169 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder +from fairseq.models import register_model, register_model_architecture +from fairseq.models.transformer_lm import TransformerLanguageModel + +try: + from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +DEFAULT_MAX_TARGET_POSITIONS = 1024 + + +@register_model("model_parallel_transformer_lm") +class ModelParallelTransformerLanguageModel(TransformerLanguageModel): + @staticmethod + def add_args(parser): + TransformerLanguageModel.add_args(parser) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + + # make sure all arguments are present in older models + base_lm_architecture(args) + + task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) + + if args.character_embeddings: + raise NotImplementedError( + "Character embeddings is not supported for model parallel" + ) + elif args.adaptive_input: + raise NotImplementedError( + "Adaptive input is not supported for model parallel" + ) + else: + embed_tokens = cls.build_embedding( + args, task.source_dictionary, args.decoder_input_dim + ) + + decoder = ModelParallelTransformerDecoder( + args, + task.target_dictionary, + embed_tokens, + no_encoder_attn=True, + ) + return cls(decoder) + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + def _vocab_init(tensor, **kwargs): + nn.init.normal_(tensor, mean=0, std=embed_dim**-0.5) + nn.init.constant_(tensor[1], 0) + + embed_tokens = VocabParallelEmbedding( + len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init + ) + return embed_tokens + + +def base_lm_architecture(args): + # backward compatibility for older model checkpoints + if hasattr(args, "no_tie_adaptive_proj"): + # previous models defined --no-tie-adaptive-proj, so use the existence of + # that option to determine if this is an "old" model checkpoint + args.no_decoder_final_norm = True # old models always set this to True + if args.no_tie_adaptive_proj is False: + args.tie_adaptive_proj = True + if hasattr(args, "decoder_final_norm"): + args.no_decoder_final_norm = not args.decoder_final_norm + + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.relu_dropout = getattr(args, "relu_dropout", 0.0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + # Model training is not stable without this + args.decoder_normalize_before = True + args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.character_embeddings = getattr(args, "character_embeddings", False) + args.character_filters = getattr( + args, + "character_filters", + "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]", + ) + args.character_embedding_dim = getattr(args, "character_embedding_dim", 4) + args.char_embedder_highway_layers = getattr(args, "char_embedder_highway_layers", 2) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) + args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0.0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0.0) + args.add_bos_token = getattr(args, "add_bos_token", False) + + +@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_megatron") +def transformer_lm_megatron(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 4) + args.decoder_layers = getattr(args, "decoder_layers", 72) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture( + "model_parallel_transformer_lm", "transformer_lm_megatron_11b" +) +def transformer_lm_megatron_11b(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 6) + args.decoder_layers = getattr(args, "decoder_layers", 72) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) diff --git a/fairseq/fairseq/model_parallel/modules/__init__.py b/fairseq/fairseq/model_parallel/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..11603217a188f420ea849ae0fde19979736ba208 --- /dev/null +++ b/fairseq/fairseq/model_parallel/modules/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +from .multihead_attention import ModelParallelMultiheadAttention +from .transformer_layer import ( + ModelParallelTransformerEncoderLayer, + ModelParallelTransformerDecoderLayer, +) + +__all__ = [ + "ModelParallelMultiheadAttention", + "ModelParallelTransformerEncoderLayer", + "ModelParallelTransformerDecoderLayer", +] diff --git a/fairseq/fairseq/model_parallel/modules/__pycache__/transformer_layer.cpython-310.pyc b/fairseq/fairseq/model_parallel/modules/__pycache__/transformer_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a589e7a37c699494efa303527df23adc46d78627 Binary files /dev/null and b/fairseq/fairseq/model_parallel/modules/__pycache__/transformer_layer.cpython-310.pyc differ diff --git a/fairseq/fairseq/models/bart/hub_interface.py b/fairseq/fairseq/models/bart/hub_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..6b647c9642147bd1bedf56af1b7180a1d39fec98 --- /dev/null +++ b/fairseq/fairseq/models/bart/hub_interface.py @@ -0,0 +1,211 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging +from typing import Dict, List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.data import encoders +from fairseq.hub_utils import GeneratorHubInterface +from omegaconf import open_dict + + +logger = logging.getLogger(__name__) + + +class BARTHubInterface(GeneratorHubInterface): + """A simple PyTorch Hub interface to BART. + + Usage: https://github.com/pytorch/fairseq/tree/main/examples/bart + """ + + def __init__(self, cfg, task, model): + super().__init__(cfg, task, [model]) + self.model = self.models[0] + + def encode( + self, sentence: str, *addl_sentences, no_separator=True + ) -> torch.LongTensor: + """ + BPE-encode a sentence (or multiple sentences). + + Every sequence begins with a beginning-of-sentence (``) symbol. + Every sentence ends with an end-of-sentence (``). + + Example (single sentence): ` a b c ` + Example (sentence pair): ` d e f 1 2 3 ` + + The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE + requires leading spaces. For example:: + + >>> bart.encode('Hello world').tolist() + [0, 31414, 232, 2] + >>> bart.encode(' world').tolist() + [0, 232, 2] + >>> bart.encode('world').tolist() + [0, 8331, 2] + """ + tokens = self.bpe.encode(sentence) + if len(tokens.split(" ")) > min(self.max_positions) - 2: + tokens = " ".join(tokens.split(" ")[: min(self.max_positions) - 2]) + bpe_sentence = " " + tokens + " " + for s in addl_sentences: + bpe_sentence += " " if not no_separator else "" + bpe_sentence += " " + self.bpe.encode(s) + " " + tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False) + return tokens.long() + + def decode(self, tokens: torch.LongTensor): + assert tokens.dim() == 1 + tokens = tokens.cpu().numpy() + if tokens[0] == self.task.source_dictionary.bos(): + tokens = tokens[1:] # remove + eos_mask = tokens == self.task.source_dictionary.eos() + doc_mask = eos_mask[1:] & eos_mask[:-1] + sentences = np.split(tokens, doc_mask.nonzero()[0] + 1) + sentences = [ + self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences + ] + if len(sentences) == 1: + return sentences[0] + return sentences + + def _build_sample(self, src_tokens: List[torch.LongTensor]): + # assert torch.is_tensor(src_tokens) + dataset = self.task.build_dataset_for_inference( + src_tokens, + [x.numel() for x in src_tokens], + ) + sample = dataset.collater(dataset) + sample = utils.apply_to_sample(lambda tensor: tensor.to(self.device), sample) + return sample + + def generate( + self, + tokenized_sentences: List[torch.LongTensor], + *args, + inference_step_args=None, + skip_invalid_size_inputs=False, + **kwargs + ) -> List[List[Dict[str, torch.Tensor]]]: + inference_step_args = inference_step_args or {} + if "prefix_tokens" in inference_step_args: + raise NotImplementedError("prefix generation not implemented for BART") + res = [] + for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): + src_tokens = batch["net_input"]["src_tokens"] + inference_step_args["prefix_tokens"] = src_tokens.new_full( + (src_tokens.size(0), 1), fill_value=self.task.source_dictionary.bos() + ).to(device=self.device) + results = super().generate( + src_tokens, + *args, + inference_step_args=inference_step_args, + skip_invalid_size_inputs=skip_invalid_size_inputs, + **kwargs + ) + for id, hypos in zip(batch["id"].tolist(), results): + res.append((id, hypos)) + res = [hypos for _, hypos in sorted(res, key=lambda x: x[0])] + return res + + def extract_features( + self, tokens: torch.LongTensor, return_all_hiddens: bool = False + ) -> torch.Tensor: + if tokens.dim() == 1: + tokens = tokens.unsqueeze(0) + if tokens.size(-1) > min(self.model.max_positions()): + raise ValueError( + "tokens exceeds maximum length: {} > {}".format( + tokens.size(-1), self.model.max_positions() + ) + ) + tokens.to(device=self.device), + prev_output_tokens = tokens.clone() + + prev_output_tokens[:, 0] = tokens.gather( + 1, + (tokens.ne(self.task.source_dictionary.pad()).sum(dim=1) - 1).unsqueeze(-1), + ).squeeze() + + prev_output_tokens[:, 1:] = tokens[:, :-1] + features, extra = self.model( + src_tokens=tokens, + src_lengths=None, + prev_output_tokens=prev_output_tokens, + features_only=True, + return_all_hiddens=return_all_hiddens, + ) + if return_all_hiddens: + # convert from T x B x C -> B x T x C + inner_states = extra["inner_states"] + return [inner_state.transpose(0, 1) for inner_state in inner_states] + else: + return features # just the last layer's features + + def register_classification_head( + self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs + ): + self.model.register_classification_head( + name, num_classes=num_classes, embedding_size=embedding_size, **kwargs + ) + + def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False): + if tokens.dim() == 1: + tokens = tokens.unsqueeze(0) + features = self.extract_features(tokens.to(device=self.device)) + sentence_representation = features[ + tokens.eq(self.task.source_dictionary.eos()), : + ].view(features.size(0), -1, features.size(-1))[:, -1, :] + + logits = self.model.classification_heads[head](sentence_representation) + if return_logits: + return logits + return F.log_softmax(logits, dim=-1) + + def fill_mask( + self, + masked_inputs: List[str], + topk: int = 5, + match_source_len: bool = True, + **generate_kwargs + ): + masked_token = "" + batch_tokens = [] + for masked_input in masked_inputs: + assert ( + masked_token in masked_input + ), "please add one {} token for the input".format(masked_token) + + text_spans = masked_input.split(masked_token) + text_spans_bpe = ( + (" {0} ".format(masked_token)) + .join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans]) + .strip() + ) + tokens = self.task.source_dictionary.encode_line( + " " + text_spans_bpe + " ", + append_eos=False, + add_if_not_exist=False, + ).long() + batch_tokens.append(tokens) + + # ensure beam size is at least as big as topk + generate_kwargs["beam"] = max( + topk, + generate_kwargs.get("beam", -1), + ) + generate_kwargs["match_source_len"] = match_source_len + batch_hypos = self.generate(batch_tokens, **generate_kwargs) + + return [ + [(self.decode(hypo["tokens"]), hypo["score"]) for hypo in hypos[:topk]] + for hypos in batch_hypos + ] diff --git a/fairseq/fairseq/models/composite_encoder.py b/fairseq/fairseq/models/composite_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4e20fe3a833a2d87876cbec294ad2bebfba7f591 --- /dev/null +++ b/fairseq/fairseq/models/composite_encoder.py @@ -0,0 +1,57 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .fairseq_encoder import FairseqEncoder + + +class CompositeEncoder(FairseqEncoder): + """ + A wrapper around a dictionary of :class:`FairseqEncoder` objects. + + We run forward on each encoder and return a dictionary of outputs. The first + encoder's dictionary is used for initialization. + + Args: + encoders (dict): a dictionary of :class:`FairseqEncoder` objects. + """ + + def __init__(self, encoders): + super().__init__(next(iter(encoders.values())).dictionary) + self.encoders = encoders + for key in self.encoders: + self.add_module(key, self.encoders[key]) + + def forward(self, src_tokens, src_lengths): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of shape + `(batch)` + + Returns: + dict: + the outputs from each Encoder + """ + encoder_out = {} + for key in self.encoders: + encoder_out[key] = self.encoders[key](src_tokens, src_lengths) + return encoder_out + + def reorder_encoder_out(self, encoder_out, new_order): + """Reorder encoder output according to new_order.""" + for key in self.encoders: + encoder_out[key] = self.encoders[key].reorder_encoder_out( + encoder_out[key], new_order + ) + return encoder_out + + def max_positions(self): + return min(self.encoders[key].max_positions() for key in self.encoders) + + def upgrade_state_dict(self, state_dict): + for key in self.encoders: + self.encoders[key].upgrade_state_dict(state_dict) + return state_dict diff --git a/fairseq/fairseq/models/distributed_fairseq_model.py b/fairseq/fairseq/models/distributed_fairseq_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fd76bcd4bfdba4dce83fb5a6ef01b15f8de1fe67 --- /dev/null +++ b/fairseq/fairseq/models/distributed_fairseq_model.py @@ -0,0 +1,147 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import signal +import threading + +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel + +from fairseq.distributed import ( + DistributedTimeoutWrapper, + LegacyDistributedDataParallel, + ModuleProxyWrapper, + TPUDistributedDataParallel, +) + +logger = logging.getLogger(__name__) + + +_SLOWMO_DDP_DISABLED = False +try: + from fairscale.experimental.nn.data_parallel import ( + SlowMoBaseAlgorithm, + SlowMoDistributedDataParallel, + ) +except ImportError: + _SLOWMO_DDP_DISABLED = True + + +def DistributedFairseqModel(args, model, process_group, device): + """ + Wrap a *model* to support distributed data parallel training. + + This is similar to the built-in DistributedDataParallel, but allows + additional configuration of the DistributedDataParallel class to + use, and also provides easier access to the wrapped model by + forwarding requests for missing attributes to the wrapped model. + + Args: + args (argparse.Namespace): fairseq args + model (BaseFairseqModel): model to wrap + process_group: the c10d process group to be used for distributed data + parallel all-reduction. + device: device to move model to + """ + assert isinstance(model, nn.Module) + if args.tpu: + wrapped_model = TPUDistributedDataParallel( + module=model.to(device), + process_group=process_group, + ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend in {"c10d", "pytorch_ddp"}: + wrapped_model = DistributedDataParallel( + module=model.to(device), + device_ids=[args.device_id], + output_device=args.device_id, + broadcast_buffers=args.broadcast_buffers, + bucket_cap_mb=args.bucket_cap_mb, + process_group=process_group, + find_unused_parameters=args.find_unused_parameters, + gradient_as_bucket_view=args.gradient_as_bucket_view, + ) + if args.ddp_comm_hook == "fp16": + logger.info("enable fp16 communication hook in DDP") + try: + from torch.distributed.algorithms.ddp_comm_hooks import ( + DDPCommHookType, + register_ddp_comm_hook, + ) + except: + logger.error( + "Could not import from torch.distributed.algorithms.ddp_comm_hooks; you may need to update your pytorch version" + ) + raise + + register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, wrapped_model) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend in {"no_c10d", "legacy_ddp"}: + wrapped_model = LegacyDistributedDataParallel( + module=model.to(device), + buffer_size=2**28, + process_group=process_group, + ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend == "slowmo": + if _SLOWMO_DDP_DISABLED: + raise ImportError( + "Cannot find SlowMoDistributedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + + # The values of slowmo_momentum below were obtained by tuning on the + # En-De 16 dataset by training the transformer_wmt_en_de_large model + if args.slowmo_momentum is None: + if args.distributed_world_size <= 16: + args.slowmo_momentum = 0.0 + elif args.distributed_world_size <= 32: + args.slowmo_momentum = 0.2 + elif args.distributed_world_size <= 64: + args.slowmo_momentum = 0.5 + else: + args.slowmo_momentum = 0.6 + slowmo_base_algorithm = SlowMoBaseAlgorithm[args.slowmo_base_algorithm.upper()] + + wrapped_model = SlowMoDistributedDataParallel( + module=model.to(device), + broadcast_buffers=args.broadcast_buffers, + nprocs_per_node=args.nprocs_per_node, + slowmo_momentum=args.slowmo_momentum, + slowmo_base_algorithm=slowmo_base_algorithm, + localsgd_frequency=args.localsgd_frequency, + ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend == "fully_sharded": + try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP" + wrapped_model = model + if args.memory_efficient_fp16: + wrapped_model = wrapped_model.half() + if not args.cpu_offload: + wrapped_model = wrapped_model.to(device=device) + else: + raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) + + # kill hung distributed jobs after a timeout + if getattr(args, "heartbeat_timeout", -1) > 0: + wrapped_model = DistributedTimeoutWrapper( + wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) + ) + + return wrapped_model diff --git a/fairseq/fairseq/models/fairseq_encoder.py b/fairseq/fairseq/models/fairseq_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..08cbde15a46e9b6d58e11c2f6052e7cf2d0cc8b2 --- /dev/null +++ b/fairseq/fairseq/models/fairseq_encoder.py @@ -0,0 +1,92 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, NamedTuple, Optional + +import torch +import torch.nn as nn +from torch import Tensor + + +EncoderOut = NamedTuple( + "EncoderOut", + [ + ("encoder_out", Tensor), # T x B x C + ("encoder_padding_mask", Optional[Tensor]), # B x T + ("encoder_embedding", Optional[Tensor]), # B x T x C + ("encoder_states", Optional[List[Tensor]]), # List[T x B x C] + ("src_tokens", Optional[Tensor]), # B x T + ("src_lengths", Optional[Tensor]), # B x 1 + ], +) + + +class FairseqEncoder(nn.Module): + """Base class for encoders.""" + + def __init__(self, dictionary): + super().__init__() + self.dictionary = dictionary + + def forward(self, src_tokens, src_lengths=None, **kwargs): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of shape + `(batch)` + """ + raise NotImplementedError + + def forward_torchscript(self, net_input: Dict[str, Tensor]): + """A TorchScript-compatible version of forward. + + Encoders which use additional arguments may want to override + this method for TorchScript compatibility. + """ + if torch.jit.is_scripting(): + return self.forward( + src_tokens=net_input["src_tokens"], + src_lengths=net_input["src_lengths"], + ) + else: + return self.forward_non_torchscript(net_input) + + @torch.jit.unused + def forward_non_torchscript(self, net_input: Dict[str, Tensor]): + encoder_input = { + k: v for k, v in net_input.items() if k != "prev_output_tokens" + } + return self.forward(**encoder_input) + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to `new_order`. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + `encoder_out` rearranged according to `new_order` + """ + raise NotImplementedError + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return 1e6 # an arbitrary large number + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade old state dicts to work with newer code.""" + return state_dict + + def set_num_updates(self, num_updates): + """State from trainer to pass along to model at every update.""" + + def _apply(m): + if hasattr(m, "set_num_updates") and m != self: + m.set_num_updates(num_updates) + + self.apply(_apply) diff --git a/fairseq/fairseq/models/fconv_lm.py b/fairseq/fairseq/models/fconv_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..4b243d6669cb57880353b45a01843ec22010fb5f --- /dev/null +++ b/fairseq/fairseq/models/fconv_lm.py @@ -0,0 +1,136 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import utils +from fairseq.models import ( + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.fconv import FConvDecoder +from fairseq.utils import safe_hasattr + + +@register_model("fconv_lm") +class FConvLanguageModel(FairseqLanguageModel): + def __init__(self, decoder): + super().__init__(decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-layers", + type=str, + metavar="EXPR", + help="decoder layers [(dim, kernel_size), ...]", + ) + parser.add_argument( + "--decoder-out-embed-dim", + type=int, + metavar="N", + help="decoder output embedding dimension", + ) + parser.add_argument( + "--adaptive-softmax-cutoff", + metavar="EXPR", + help="comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion", + ) + parser.add_argument( + "--adaptive-softmax-dropout", + type=float, + metavar="D", + help="sets adaptive softmax dropout for the tail projections", + ) + parser.add_argument( + "--decoder-attention", + type=str, + metavar="EXPR", + help="decoder attention [True, ...]", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure all arguments are present in older models + base_lm_architecture(args) + + if safe_hasattr(args, "max_target_positions") and not safe_hasattr( + args, "tokens_per_sample" + ): + args.tokens_per_sample = args.max_target_positions + + decoder = FConvDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + convolutions=eval(args.decoder_layers), + out_embed_dim=args.decoder_embed_dim, + attention=eval(args.decoder_attention), + dropout=args.dropout, + max_positions=args.tokens_per_sample, + share_embed=False, + positional_embeddings=False, + adaptive_softmax_cutoff=( + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) + if args.criterion == "adaptive_loss" + else None + ), + adaptive_softmax_dropout=args.adaptive_softmax_dropout, + ) + return FConvLanguageModel(decoder) + + +@register_model_architecture("fconv_lm", "fconv_lm") +def base_lm_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128) + args.decoder_layers = getattr(args, "decoder_layers", "[(1268, 4)] * 13") + args.decoder_attention = getattr(args, "decoder_attention", "False") + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + + +@register_model_architecture("fconv_lm", "fconv_lm_dauphin_wikitext103") +def fconv_lm_dauphin_wikitext103(args): + layers = "[(850, 6)] * 3" + layers += " + [(850, 1)] * 1" + layers += " + [(850, 5)] * 4" + layers += " + [(850, 1)] * 1" + layers += " + [(850, 4)] * 3" + layers += " + [(1024, 4)] * 1" + layers += " + [(2048, 4)] * 1" + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 280) + args.decoder_layers = getattr(args, "decoder_layers", layers) + args.decoder_attention = getattr(args, "decoder_attention", "False") + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "10000,20000,200000" + ) + base_lm_architecture(args) + + +@register_model_architecture("fconv_lm", "fconv_lm_dauphin_gbw") +def fconv_lm_dauphin_gbw(args): + layers = "[(512, 5)]" + layers += " + [(128, 1, 0), (128, 5, 0), (512, 1, 3)] * 3" + layers += " + [(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3" + layers += " + [(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 6" + layers += " + [(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]" + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128) + args.decoder_layers = getattr(args, "decoder_layers", layers) + args.decoder_attention = getattr(args, "decoder_attention", "False") + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "10000,50000,200000" + ) + base_lm_architecture(args) diff --git a/fairseq/fairseq/models/fconv_self_att.py b/fairseq/fairseq/models/fconv_self_att.py new file mode 100644 index 0000000000000000000000000000000000000000..8357ef7847ed25a62345e219c41906156828c233 --- /dev/null +++ b/fairseq/fairseq/models/fconv_self_att.py @@ -0,0 +1,674 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import checkpoint_utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.models import ( + CompositeEncoder, + FairseqDecoder, + FairseqEncoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.modules import ( + DownsampledMultiHeadAttention, + FairseqDropout, + GradMultiply, + LayerNorm, + LearnedPositionalEmbedding, + LinearizedConvolution, +) + + +logger = logging.getLogger(__name__) + + +@register_model("fconv_self_att") +class FConvModelSelfAtt(FairseqEncoderDecoderModel): + @classmethod + def hub_models(cls): + return { + "conv.stories.pretrained": { + "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz", + "checkpoint_file": "pretrained_checkpoint.pt", + "tokenizer": "nltk", + }, + "conv.stories": { + "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz", + "checkpoint_file": "fusion_checkpoint.pt", + "tokenizer": "nltk", + "pretrained": "True", + "pretrained_checkpoint": "./pretrained_checkpoint.pt", + }, + # Test set containing dictionaries + "data.stories": "https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2", + } + + def __init__(self, encoder, decoder, pretrained_encoder=None): + super().__init__(encoder, decoder) + self.encoder.num_attention_layers = sum( + layer is not None for layer in decoder.attention + ) + self.pretrained_encoder = pretrained_encoder + if self.pretrained_encoder is None: + encoders = {"encoder": encoder} + else: + encoders = {"encoder": encoder, "pretrained": self.pretrained_encoder} + # for fusion model, CompositeEncoder contains both pretrained and training encoders + # these are forwarded and then combined in the decoder + self.encoder = CompositeEncoder(encoders) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--encoder-embed-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-layers', type=str, metavar='EXPR', + help='encoder layers [(dim, kernel_size), ...]') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-layers', type=str, metavar='EXPR', + help='decoder layers [(dim, kernel_size), ...]') + parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', + help='decoder output embedding dimension') + parser.add_argument('--decoder-attention', type=str, metavar='EXPR', + help='decoder attention [True, ...]') + parser.add_argument('--self-attention', type=str, metavar='EXPR', + help='decoder self-attention layers, ex: [True] + [False]*5') + parser.add_argument('--multihead-attention-nheads', type=int, + help='Number of heads to use in attention') + parser.add_argument('--multihead-self-attention-nheads', type=int, + help='Number of heads to use in self-attention') + parser.add_argument('--encoder-attention', type=str, metavar='EXPR', + help='encoder attention [True, ...]') + parser.add_argument('--encoder-attention-nheads', type=int, + help='Number of heads to use in encoder attention') + parser.add_argument('--project-input', type=str, metavar='EXPR', + help='Use projections in self-attention [True, ...]') + parser.add_argument('--gated-attention', type=str, metavar='EXPR', + help='Use GLU layers in self-attention projections [True, ...]') + parser.add_argument('--downsample', type=str, metavar='EXPR', + help='Use downsampling in self-attention [True, ...]') + parser.add_argument('--pretrained-checkpoint', metavar='DIR', + help='path to load checkpoint from pretrained model') + parser.add_argument('--pretrained', type=str, metavar='EXPR', + help='use pretrained model when training [True, ...]') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + trained_encoder, trained_decoder = None, None + pretrained = eval(args.pretrained) + if pretrained: + logger.info("loading pretrained model") + if not os.path.exists(args.pretrained_checkpoint): + new_pretrained_checkpoint = os.path.join( + args.data, args.pretrained_checkpoint + ) + if os.path.exists(new_pretrained_checkpoint): + args.pretrained_checkpoint = new_pretrained_checkpoint + trained_model = checkpoint_utils.load_model_ensemble( + filenames=[args.pretrained_checkpoint], + task=task, + )[0][0] + trained_decoder = list(trained_model.children())[1] + trained_encoder = list(trained_model.children())[0] + + # freeze pretrained model + for param in trained_decoder.parameters(): + param.requires_grad = False + for param in trained_encoder.parameters(): + param.requires_grad = False + + encoder = FConvEncoder( + task.source_dictionary, + embed_dim=args.encoder_embed_dim, + convolutions=eval(args.encoder_layers), + dropout=args.dropout, + max_positions=args.max_source_positions, + attention=eval(args.encoder_attention), + attention_nheads=args.encoder_attention_nheads, + ) + + decoder = FConvDecoder( + task.target_dictionary, + embed_dim=args.decoder_embed_dim, + convolutions=eval(args.decoder_layers), + out_embed_dim=args.decoder_out_embed_dim, + attention=eval(args.decoder_attention), + dropout=args.dropout, + max_positions=args.max_target_positions, + selfattention=eval(args.self_attention), + attention_nheads=args.multihead_attention_nheads, + selfattention_nheads=args.multihead_self_attention_nheads, + project_input=eval(args.project_input), + gated_attention=eval(args.gated_attention), + downsample=eval(args.downsample), + pretrained=pretrained, + trained_decoder=trained_decoder, + ) + model = FConvModelSelfAtt(encoder, decoder, trained_encoder) + + return model + + @property + def pretrained(self): + return self.pretrained_encoder is not None + + +class FConvEncoder(FairseqEncoder): + """Convolutional encoder""" + + def __init__( + self, + dictionary, + embed_dim=512, + max_positions=1024, + convolutions=((512, 3),) * 20, + dropout=0.1, + attention=False, + attention_nheads=1, + ): + super().__init__(dictionary) + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.num_attention_layers = None + + num_embeddings = len(dictionary) + self.padding_idx = dictionary.pad() + self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) + self.embed_positions = PositionalEmbedding( + max_positions, + embed_dim, + self.padding_idx, + ) + + def expand_bool_array(val): + if isinstance(val, bool): + # expand True into [True, True, ...] and do the same with False + return [val] * len(convolutions) + return val + + attention = expand_bool_array(attention) + + in_channels = convolutions[0][0] + self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) + self.projections = nn.ModuleList() + self.convolutions = nn.ModuleList() + self.attention = nn.ModuleList() + self.attproj = nn.ModuleList() + for i, (out_channels, kernel_size) in enumerate(convolutions): + self.projections.append( + Linear(in_channels, out_channels) + if in_channels != out_channels + else None + ) + self.convolutions.append( + ConvTBC(in_channels, out_channels * 2, kernel_size, dropout=dropout) + ) + + self.attention.append( + SelfAttention(out_channels, embed_dim, attention_nheads) + if attention[i] + else None + ) + in_channels = out_channels + + self.fc2 = Linear(in_channels, embed_dim) + + def forward(self, src_tokens, src_lengths): + # embed tokens and positions + x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens) + x = self.dropout_module(x) + input_embedding = x.transpose(0, 1) + + # project to size of convolution + x = self.fc1(x) + + encoder_padding_mask = src_tokens.eq(self.padding_idx).t() # -> T x B + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # temporal convolutions + for proj, conv, attention in zip( + self.projections, self.convolutions, self.attention + ): + residual = x if proj is None else proj(x) + + if encoder_padding_mask is not None: + x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) + + x = self.dropout_module(x) + padding_l = (conv.kernel_size[0] - 1) // 2 + padding_r = conv.kernel_size[0] // 2 + x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r)) + x = conv(x) + x = F.glu(x, dim=2) + if attention is not None: + x = attention(x) + x = (x + residual) * math.sqrt(0.5) + + # T x B x C -> B x T x C + x = x.transpose(1, 0) + + # project back to size of embedding + x = self.fc2(x) + + if encoder_padding_mask is not None: + encoder_padding_mask = encoder_padding_mask.t() # -> B x T + x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) + + # scale gradients (this only affects backward, not forward) + x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers)) + + # add output to input embedding for attention + y = (x + input_embedding.transpose(0, 1)) * math.sqrt(0.5) + + return { + "encoder_out": (x, y), + "encoder_padding_mask": encoder_padding_mask, # B x T + } + + def reorder_encoder_out(self, encoder_out, new_order): + encoder_out["encoder_out"] = tuple( + eo.index_select(0, new_order) for eo in encoder_out["encoder_out"] + ) + + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + + if "pretrained" in encoder_out: + encoder_out["pretrained"]["encoder_out"] = tuple( + eo.index_select(0, new_order) + for eo in encoder_out["pretrained"]["encoder_out"] + ) + + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return self.embed_positions.max_positions + + +@with_incremental_state +class FConvDecoder(FairseqDecoder): + """Convolutional decoder""" + + def __init__( + self, + dictionary, + embed_dim=512, + out_embed_dim=256, + max_positions=1024, + convolutions=((512, 3),) * 8, + attention=True, + dropout=0.1, + selfattention=False, + attention_nheads=1, + selfattention_nheads=1, + project_input=False, + gated_attention=False, + downsample=False, + pretrained=False, + trained_decoder=None, + ): + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([2])) + self.pretrained = pretrained + self.pretrained_decoder = trained_decoder + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.need_attn = True + in_channels = convolutions[0][0] + + def expand_bool_array(val): + if isinstance(val, bool): + # expand True into [True, True, ...] and do the same with False + return [val] * len(convolutions) + return val + + attention = expand_bool_array(attention) + selfattention = expand_bool_array(selfattention) + + if not isinstance(attention, list) or len(attention) != len(convolutions): + raise ValueError( + "Attention is expected to be a list of booleans of " + "length equal to the number of layers." + ) + + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + + self.embed_positions = PositionalEmbedding( + max_positions, + embed_dim, + padding_idx, + ) + + self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) + self.projections = nn.ModuleList() + self.convolutions = nn.ModuleList() + self.attention = nn.ModuleList() + self.selfattention = nn.ModuleList() + self.attproj = nn.ModuleList() + for i, (out_channels, kernel_size) in enumerate(convolutions): + self.projections.append( + Linear(in_channels, out_channels) + if in_channels != out_channels + else None + ) + self.convolutions.append( + LinearizedConv1d( + in_channels, + out_channels * 2, + kernel_size, + padding=(kernel_size - 1), + dropout=dropout, + ) + ) + + self.attention.append( + DownsampledMultiHeadAttention( + out_channels, + embed_dim, + attention_nheads, + project_input=project_input, + gated=False, + downsample=False, + ) + if attention[i] + else None + ) + + self.attproj.append( + Linear(out_channels, embed_dim, dropout=dropout) + if attention[i] + else None + ) + self.selfattention.append( + SelfAttention( + out_channels, + embed_dim, + selfattention_nheads, + project_input=project_input, + gated=gated_attention, + downsample=downsample, + ) + if selfattention[i] + else None + ) + in_channels = out_channels + + self.fc2 = Linear(in_channels, out_embed_dim) + self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) + + # model fusion + if self.pretrained: + # independent gates are learned from the concatenated input + self.gate1 = nn.Sequential( + Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid() + ) + self.gate2 = nn.Sequential( + Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid() + ) + # pretrained and trained models are joined + self.joining = nn.Sequential( + Linear(out_embed_dim * 2, out_embed_dim * 2), + LayerNorm(out_embed_dim * 2), + nn.GLU(), + Linear(out_embed_dim, out_embed_dim * 2), + LayerNorm(out_embed_dim * 2), + nn.GLU(), + Linear(out_embed_dim, out_embed_dim), + LayerNorm(out_embed_dim), + ) + # pretrained model contains an output layer that is nhid -> vocab size + # but the models are combined in their hidden state + # the hook stores the output of the pretrained model forward + self.pretrained_outputs = {} + + def save_output(): + def hook(a, b, output): + self.pretrained_outputs["out"] = output + + return hook + + self.pretrained_decoder.fc2.register_forward_hook(save_output()) + + def forward(self, prev_output_tokens, encoder_out): + trained_encoder_out = encoder_out["pretrained"] if self.pretrained else None + encoder_out = encoder_out["encoder"]["encoder_out"] + + encoder_a, encoder_b = self._split_encoder_out(encoder_out) + + # embed positions + positions = self.embed_positions(prev_output_tokens) + + # embed tokens and positions + x = self.embed_tokens(prev_output_tokens) + positions + x = self.dropout_module(x) + target_embedding = x.transpose(0, 1) + + # project to size of convolution + x = self.fc1(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # temporal convolutions + avg_attn_scores = None + for proj, conv, attention, selfattention, attproj in zip( + self.projections, + self.convolutions, + self.attention, + self.selfattention, + self.attproj, + ): + residual = x if proj is None else proj(x) + + x = self.dropout_module(x) + x = conv(x) + x = F.glu(x, dim=2) + + # attention + if attention is not None: + r = x + x, attn_scores = attention( + attproj(x) + target_embedding, encoder_a, encoder_b + ) + x = x + r + if not self.training and self.need_attn: + if avg_attn_scores is None: + avg_attn_scores = attn_scores + else: + avg_attn_scores.add_(attn_scores) + + if selfattention is not None: + x = selfattention(x) + + x = (x + residual) * math.sqrt(0.5) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + # project back to size of vocabulary + x = self.fc2(x) + x = self.dropout_module(x) + if not self.pretrained: + x = self.fc3(x) + + # fusion gating + if self.pretrained: + trained_x, _ = self.pretrained_decoder.forward( + prev_output_tokens, trained_encoder_out + ) + y = torch.cat([x, self.pretrained_outputs["out"]], dim=-1) + gate1 = self.gate1(y) + gate2 = self.gate2(y) + gated_x1 = gate1 * x + gated_x2 = gate2 * self.pretrained_outputs["out"] + fusion = torch.cat([gated_x1, gated_x2], dim=-1) + fusion = self.joining(fusion) + fusion_output = self.fc3(fusion) + return fusion_output, avg_attn_scores + else: + return x, avg_attn_scores + + def max_positions(self): + """Maximum output length supported by the decoder.""" + return self.embed_positions.max_positions + + def make_generation_fast_(self, need_attn=False, **kwargs): + self.need_attn = need_attn + + def _split_encoder_out(self, encoder_out): + """Split and transpose encoder outputs.""" + # transpose only once to speed up attention layers + encoder_a, encoder_b = encoder_out + encoder_a = encoder_a.transpose(0, 1).contiguous() + encoder_b = encoder_b.transpose(0, 1).contiguous() + result = (encoder_a, encoder_b) + return result + + +class SelfAttention(nn.Module): + def __init__( + self, + out_channels, + embed_dim, + num_heads, + project_input=False, + gated=False, + downsample=False, + ): + super().__init__() + self.attention = DownsampledMultiHeadAttention( + out_channels, + embed_dim, + num_heads, + dropout=0, + bias=True, + project_input=project_input, + gated=gated, + downsample=downsample, + ) + self.in_proj_q = Linear(out_channels, embed_dim) + self.in_proj_k = Linear(out_channels, embed_dim) + self.in_proj_v = Linear(out_channels, embed_dim) + self.ln = LayerNorm(out_channels) + + def forward(self, x): + residual = x + query = self.in_proj_q(x) + key = self.in_proj_k(x) + value = self.in_proj_v(x) + x, _ = self.attention( + query, key, value, mask_future_timesteps=True, use_scalar_bias=True + ) + return self.ln(x + residual) + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + m.weight.data.normal_(0, 0.1) + return m + + +def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx): + m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) + m.weight.data.normal_(0, 0.1) + return m + + +def Linear(in_features, out_features, dropout=0.0): + """Weight-normalized Linear layer (input: N x T x C)""" + m = nn.Linear(in_features, out_features) + m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features)) + m.bias.data.zero_() + return m + + +def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): + """Weight-normalized Conv1d layer optimized for decoding""" + m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) + std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) + m.weight.data.normal_(mean=0, std=std) + m.bias.data.zero_() + return m + + +def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): + """Weight-normalized Conv1d layer""" + from fairseq.modules import ConvTBC + + m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs) + std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) + m.weight.data.normal_(mean=0, std=std) + m.bias.data.zero_() + return m + + +@register_model_architecture("fconv_self_att", "fconv_self_att") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 3") + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 8") + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) + args.decoder_attention = getattr(args, "decoder_attention", "True") + args.self_attention = getattr(args, "self_attention", "False") + args.encoder_attention = getattr(args, "encoder_attention", "False") + args.multihead_attention_nheads = getattr(args, "multihead_attention_nheads", 1) + args.multihead_self_attention_nheads = getattr( + args, "multihead_self_attention_nheads", 1 + ) + args.encoder_attention_nheads = getattr(args, "encoder_attention_nheads", 1) + args.project_input = getattr(args, "project_input", "False") + args.gated_attention = getattr(args, "gated_attention", "False") + args.downsample = getattr(args, "downsample", "False") + args.pretrained_checkpoint = getattr(args, "pretrained_checkpoint", "") + args.pretrained = getattr(args, "pretrained", "False") + + +@register_model_architecture("fconv_self_att", "fconv_self_att_wp") +def fconv_self_att_wp(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_layers = getattr( + args, "encoder_layers", "[(128, 3)] * 2 + [(512,3)] * 1" + ) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_layers = getattr( + args, "decoder_layers", "[(512, 4)] * 4 + [(768, 4)] * 2 + [(1024, 4)] * 1" + ) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) + args.self_attention = getattr(args, "self_attention", "True") + args.multihead_self_attention_nheads = getattr( + args, "multihead_self_attention_nheads", 4 + ) + args.project_input = getattr(args, "project_input", "True") + args.gated_attention = getattr(args, "gated_attention", "True") + args.downsample = getattr(args, "downsample", "True") + base_architecture(args) diff --git a/fairseq/fairseq/models/lightconv_lm.py b/fairseq/fairseq/models/lightconv_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9efc4e42a5ecc1b83338055f18ade5a83ea666 --- /dev/null +++ b/fairseq/fairseq/models/lightconv_lm.py @@ -0,0 +1,306 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import utils +from fairseq.models import ( + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.lightconv import Embedding, LightConvDecoder +from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder + + +@register_model("lightconv_lm") +class LightConvLanguageModel(FairseqLanguageModel): + def __init__(self, decoder): + super().__init__(decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--dropout", + default=0.1, + type=float, + metavar="D", + help="dropout probability", + ) + parser.add_argument( + "--attention-dropout", + default=0.0, + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--relu-dropout", + default=0.0, + type=float, + metavar="D", + help="dropout probability after ReLU in FFN", + ) + parser.add_argument( + "--input-dropout", + type=float, + metavar="D", + help="dropout probability of the inputs", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-output-dim", + type=int, + metavar="N", + help="decoder output dimension", + ) + parser.add_argument( + "--decoder-input-dim", type=int, metavar="N", help="decoder input dimension" + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads or LightConv/DynamicConv heads", + ) + parser.add_argument( + "--decoder-normalize-before", + default=False, + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--adaptive-softmax-cutoff", + metavar="EXPR", + help="comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion", + ) + parser.add_argument( + "--adaptive-softmax-dropout", + type=float, + metavar="D", + help="sets adaptive softmax dropout for the tail projections", + ) + parser.add_argument( + "--adaptive-softmax-factor", + type=float, + metavar="N", + help="adaptive input factor", + ) + parser.add_argument( + "--no-token-positional-embeddings", + default=False, + action="store_true", + help="if set, disables positional embeddings (outside self attention)", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + default=False, + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--character-embeddings", + default=False, + action="store_true", + help="if set, uses character embedding convolutions to produce token embeddings", + ) + parser.add_argument( + "--character-filters", + type=str, + metavar="LIST", + default="[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]", + help="size of character embeddings", + ) + parser.add_argument( + "--character-embedding-dim", + type=int, + metavar="N", + default=4, + help="size of character embeddings", + ) + parser.add_argument( + "--char-embedder-highway-layers", + type=int, + metavar="N", + default=2, + help="number of highway layers for character token embeddder", + ) + parser.add_argument( + "--adaptive-input", + default=False, + action="store_true", + help="if set, uses adaptive input", + ) + parser.add_argument( + "--adaptive-input-factor", + type=float, + metavar="N", + help="adaptive input factor", + ) + parser.add_argument( + "--adaptive-input-cutoff", + metavar="EXPR", + help="comma separated list of adaptive input cutoff points.", + ) + parser.add_argument( + "--tie-adaptive-weights", + action="store_true", + help="if set, ties the weights of adaptive softmax and adaptive input", + ) + parser.add_argument( + "--tie-adaptive-proj", + action="store_true", + help="if set, ties the projection weights of adaptive softmax and adaptive input", + ) + parser.add_argument( + "--decoder-learned-pos", + action="store_true", + help="use learned positional embeddings in the decoder", + ) + + """LightConv and DynamicConv arguments""" + parser.add_argument( + "--decoder-kernel-size-list", + type=lambda x: utils.eval_str_list(x, int), + help='list of kernel size (default: "[3,7,15,31,31,31]")', + ) + parser.add_argument( + "--decoder-glu", type=utils.eval_bool, help="glu after in proj" + ) + parser.add_argument( + "--decoder-conv-type", + default="dynamic", + type=str, + choices=["dynamic", "lightweight"], + help="type of convolution", + ) + parser.add_argument("--weight-softmax", default=True, type=utils.eval_bool) + parser.add_argument( + "--weight-dropout", + type=float, + metavar="D", + help="dropout probability for conv weights", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_lm_architecture(args) + + if getattr(args, "max_source_positions", None) is None: + args.max_source_positions = args.tokens_per_sample + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = args.tokens_per_sample + + if args.character_embeddings: + embed_tokens = CharacterTokenEmbedder( + task.dictionary, + eval(args.character_filters), + args.character_embedding_dim, + args.decoder_embed_dim, + args.char_embedder_highway_layers, + ) + elif args.adaptive_input: + embed_tokens = AdaptiveInput( + len(task.dictionary), + task.dictionary.pad(), + args.decoder_input_dim, + args.adaptive_input_factor, + args.decoder_embed_dim, + utils.eval_str_list(args.adaptive_input_cutoff, type=int), + ) + else: + embed_tokens = Embedding( + len(task.dictionary), args.decoder_input_dim, task.dictionary.pad() + ) + + if args.tie_adaptive_weights: + assert args.adaptive_input + assert args.adaptive_input_factor == args.adaptive_softmax_factor + assert ( + args.adaptive_softmax_cutoff == args.adaptive_input_cutoff + ), "{} != {}".format( + args.adaptive_softmax_cutoff, args.adaptive_input_cutoff + ) + assert args.decoder_input_dim == args.decoder_output_dim + + decoder = LightConvDecoder( + args, + task.output_dictionary, + embed_tokens, + no_encoder_attn=True, + final_norm=False, + ) + return LightConvLanguageModel(decoder) + + +@register_model_architecture("lightconv_lm", "lightconv_lm") +def base_lm_architecture(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + + args.character_embeddings = getattr(args, "character_embeddings", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.decoder_conv_dim = getattr(args, "decoder_conv_dim", args.decoder_embed_dim) + + # The model training is not stable without this + args.decoder_normalize_before = True + + args.adaptive_input = getattr(args, "adaptive_input", False) + args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) + args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) + + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) + + args.decoder_kernel_size_list = getattr( + args, "decoder_kernel_size_list", [3, 7, 15, 31, 31, 31] + ) + if len(args.decoder_kernel_size_list) == 1: + args.decoder_kernel_size_list = ( + args.decoder_kernel_size_list * args.decoder_layers + ) + assert ( + len(args.decoder_kernel_size_list) == args.decoder_layers + ), "decoder_kernel_size_list doesn't match decoder_layers" + args.decoder_glu = getattr(args, "decoder_glu", True) + args.input_dropout = getattr(args, "input_dropout", 0.1) + args.weight_dropout = getattr(args, "weight_dropout", args.attention_dropout) + + +@register_model_architecture("lightconv_lm", "lightconv_lm_gbw") +def lightconv_lm_gbw(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + base_lm_architecture(args) diff --git a/fairseq/fairseq/models/lstm.py b/fairseq/fairseq/models/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..8a29156270f05f72500b9142bfb5e613a4d7a19e --- /dev/null +++ b/fairseq/fairseq/models/lstm.py @@ -0,0 +1,755 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) +from fairseq.modules import AdaptiveSoftmax, FairseqDropout +from torch import Tensor + + +DEFAULT_MAX_SOURCE_POSITIONS = 1e5 +DEFAULT_MAX_TARGET_POSITIONS = 1e5 + + +@register_model("lstm") +class LSTMModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--encoder-embed-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-embed-path', type=str, metavar='STR', + help='path to pre-trained encoder embedding') + parser.add_argument('--encoder-freeze-embed', action='store_true', + help='freeze encoder embeddings') + parser.add_argument('--encoder-hidden-size', type=int, metavar='N', + help='encoder hidden size') + parser.add_argument('--encoder-layers', type=int, metavar='N', + help='number of encoder layers') + parser.add_argument('--encoder-bidirectional', action='store_true', + help='make all layers of encoder bidirectional') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-freeze-embed', action='store_true', + help='freeze decoder embeddings') + parser.add_argument('--decoder-hidden-size', type=int, metavar='N', + help='decoder hidden size') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='number of decoder layers') + parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', + help='decoder output embedding dimension') + parser.add_argument('--decoder-attention', type=str, metavar='BOOL', + help='decoder attention') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion') + parser.add_argument('--share-decoder-input-output-embed', default=False, + action='store_true', + help='share decoder input and output embeddings') + parser.add_argument('--share-all-embeddings', default=False, action='store_true', + help='share encoder, decoder and output embeddings' + ' (requires shared dictionary and embed dim)') + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument('--encoder-dropout-in', type=float, metavar='D', + help='dropout probability for encoder input embedding') + parser.add_argument('--encoder-dropout-out', type=float, metavar='D', + help='dropout probability for encoder output') + parser.add_argument('--decoder-dropout-in', type=float, metavar='D', + help='dropout probability for decoder input embedding') + parser.add_argument('--decoder-dropout-out', type=float, metavar='D', + help='dropout probability for decoder output') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted (in case there are any new ones) + base_architecture(args) + + if args.encoder_layers != args.decoder_layers: + raise ValueError("--encoder-layers must match --decoder-layers") + + max_source_positions = getattr( + args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS + ) + max_target_positions = getattr( + args, "max_target_positions", DEFAULT_MAX_TARGET_POSITIONS + ) + + def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(embed_path) + utils.print_embed_overlap(embed_dict, dictionary) + return utils.load_embedding(embed_dict, dictionary, embed_tokens) + + if args.encoder_embed_path: + pretrained_encoder_embed = load_pretrained_embedding_from_file( + args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim + ) + else: + num_embeddings = len(task.source_dictionary) + pretrained_encoder_embed = Embedding( + num_embeddings, args.encoder_embed_dim, task.source_dictionary.pad() + ) + + if args.share_all_embeddings: + # double check all parameters combinations are valid + if task.source_dictionary != task.target_dictionary: + raise ValueError("--share-all-embeddings requires a joint dictionary") + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embed not compatible with --decoder-embed-path" + ) + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to " + "match --decoder-embed-dim" + ) + pretrained_decoder_embed = pretrained_encoder_embed + args.share_decoder_input_output_embed = True + else: + # separate decoder input embeddings + pretrained_decoder_embed = None + if args.decoder_embed_path: + pretrained_decoder_embed = load_pretrained_embedding_from_file( + args.decoder_embed_path, + task.target_dictionary, + args.decoder_embed_dim, + ) + # one last double check of parameter combinations + if args.share_decoder_input_output_embed and ( + args.decoder_embed_dim != args.decoder_out_embed_dim + ): + raise ValueError( + "--share-decoder-input-output-embeddings requires " + "--decoder-embed-dim to match --decoder-out-embed-dim" + ) + + if args.encoder_freeze_embed: + pretrained_encoder_embed.weight.requires_grad = False + if args.decoder_freeze_embed: + pretrained_decoder_embed.weight.requires_grad = False + + encoder = LSTMEncoder( + dictionary=task.source_dictionary, + embed_dim=args.encoder_embed_dim, + hidden_size=args.encoder_hidden_size, + num_layers=args.encoder_layers, + dropout_in=args.encoder_dropout_in, + dropout_out=args.encoder_dropout_out, + bidirectional=args.encoder_bidirectional, + pretrained_embed=pretrained_encoder_embed, + max_source_positions=max_source_positions, + ) + decoder = LSTMDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + hidden_size=args.decoder_hidden_size, + out_embed_dim=args.decoder_out_embed_dim, + num_layers=args.decoder_layers, + dropout_in=args.decoder_dropout_in, + dropout_out=args.decoder_dropout_out, + attention=utils.eval_bool(args.decoder_attention), + encoder_output_units=encoder.output_units, + pretrained_embed=pretrained_decoder_embed, + share_input_output_embed=args.share_decoder_input_output_embed, + adaptive_softmax_cutoff=( + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) + if args.criterion == "adaptive_loss" + else None + ), + max_target_positions=max_target_positions, + residuals=False, + ) + return cls(encoder, decoder) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + ): + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths) + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + ) + return decoder_out + + +class LSTMEncoder(FairseqEncoder): + """LSTM encoder.""" + + def __init__( + self, + dictionary, + embed_dim=512, + hidden_size=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + bidirectional=False, + left_pad=True, + pretrained_embed=None, + padding_idx=None, + max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, + ): + super().__init__(dictionary) + self.num_layers = num_layers + self.dropout_in_module = FairseqDropout( + dropout_in * 1.0, module_name=self.__class__.__name__ + ) + self.dropout_out_module = FairseqDropout( + dropout_out * 1.0, module_name=self.__class__.__name__ + ) + self.bidirectional = bidirectional + self.hidden_size = hidden_size + self.max_source_positions = max_source_positions + + num_embeddings = len(dictionary) + self.padding_idx = padding_idx if padding_idx is not None else dictionary.pad() + if pretrained_embed is None: + self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) + else: + self.embed_tokens = pretrained_embed + + self.lstm = LSTM( + input_size=embed_dim, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=self.dropout_out_module.p if num_layers > 1 else 0.0, + bidirectional=bidirectional, + ) + self.left_pad = left_pad + + self.output_units = hidden_size + if bidirectional: + self.output_units *= 2 + + def forward( + self, + src_tokens: Tensor, + src_lengths: Tensor, + enforce_sorted: bool = True, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of + shape `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of + shape `(batch)` + enforce_sorted (bool, optional): if True, `src_tokens` is + expected to contain sequences sorted by length in a + decreasing order. If False, this condition is not + required. Default: True. + """ + if self.left_pad: + # nn.utils.rnn.pack_padded_sequence requires right-padding; + # convert left-padding to right-padding + src_tokens = utils.convert_padding_direction( + src_tokens, + torch.zeros_like(src_tokens).fill_(self.padding_idx), + left_to_right=True, + ) + + bsz, seqlen = src_tokens.size() + + # embed tokens + x = self.embed_tokens(src_tokens) + x = self.dropout_in_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # pack embedded source tokens into a PackedSequence + packed_x = nn.utils.rnn.pack_padded_sequence( + x, src_lengths.cpu(), enforce_sorted=enforce_sorted + ) + + # apply LSTM + if self.bidirectional: + state_size = 2 * self.num_layers, bsz, self.hidden_size + else: + state_size = self.num_layers, bsz, self.hidden_size + h0 = x.new_zeros(*state_size) + c0 = x.new_zeros(*state_size) + packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) + + # unpack outputs and apply dropout + x, _ = nn.utils.rnn.pad_packed_sequence( + packed_outs, padding_value=self.padding_idx * 1.0 + ) + x = self.dropout_out_module(x) + assert list(x.size()) == [seqlen, bsz, self.output_units] + + if self.bidirectional: + final_hiddens = self.combine_bidir(final_hiddens, bsz) + final_cells = self.combine_bidir(final_cells, bsz) + + encoder_padding_mask = src_tokens.eq(self.padding_idx).t() + + return tuple( + ( + x, # seq_len x batch x hidden + final_hiddens, # num_layers x batch x num_directions*hidden + final_cells, # num_layers x batch x num_directions*hidden + encoder_padding_mask, # seq_len x batch + ) + ) + + def combine_bidir(self, outs, bsz: int): + out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous() + return out.view(self.num_layers, bsz, -1) + + def reorder_encoder_out( + self, encoder_out: Tuple[Tensor, Tensor, Tensor, Tensor], new_order + ): + return tuple( + ( + encoder_out[0].index_select(1, new_order), + encoder_out[1].index_select(1, new_order), + encoder_out[2].index_select(1, new_order), + encoder_out[3].index_select(1, new_order), + ) + ) + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return self.max_source_positions + + +class AttentionLayer(nn.Module): + def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=False): + super().__init__() + + self.input_proj = Linear(input_embed_dim, source_embed_dim, bias=bias) + self.output_proj = Linear( + input_embed_dim + source_embed_dim, output_embed_dim, bias=bias + ) + + def forward(self, input, source_hids, encoder_padding_mask): + # input: bsz x input_embed_dim + # source_hids: srclen x bsz x source_embed_dim + + # x: bsz x source_embed_dim + x = self.input_proj(input) + + # compute attention + attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2) + + # don't attend over padding + if encoder_padding_mask is not None: + attn_scores = ( + attn_scores.float() + .masked_fill_(encoder_padding_mask, float("-inf")) + .type_as(attn_scores) + ) # FP16 support: cast to float and back + + attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz + + # sum weighted sources + x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0) + + x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) + return x, attn_scores + + +class LSTMDecoder(FairseqIncrementalDecoder): + """LSTM decoder.""" + + def __init__( + self, + dictionary, + embed_dim=512, + hidden_size=512, + out_embed_dim=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + attention=True, + encoder_output_units=512, + pretrained_embed=None, + share_input_output_embed=False, + adaptive_softmax_cutoff=None, + max_target_positions=DEFAULT_MAX_TARGET_POSITIONS, + residuals=False, + ): + super().__init__(dictionary) + self.dropout_in_module = FairseqDropout( + dropout_in * 1.0, module_name=self.__class__.__name__ + ) + self.dropout_out_module = FairseqDropout( + dropout_out * 1.0, module_name=self.__class__.__name__ + ) + self.hidden_size = hidden_size + self.share_input_output_embed = share_input_output_embed + self.need_attn = True + self.max_target_positions = max_target_positions + self.residuals = residuals + self.num_layers = num_layers + + self.adaptive_softmax = None + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + if pretrained_embed is None: + self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + else: + self.embed_tokens = pretrained_embed + + self.encoder_output_units = encoder_output_units + if encoder_output_units != hidden_size and encoder_output_units != 0: + self.encoder_hidden_proj = Linear(encoder_output_units, hidden_size) + self.encoder_cell_proj = Linear(encoder_output_units, hidden_size) + else: + self.encoder_hidden_proj = self.encoder_cell_proj = None + + # disable input feeding if there is no encoder + # input feeding is described in arxiv.org/abs/1508.04025 + input_feed_size = 0 if encoder_output_units == 0 else hidden_size + self.layers = nn.ModuleList( + [ + LSTMCell( + input_size=input_feed_size + embed_dim + if layer == 0 + else hidden_size, + hidden_size=hidden_size, + ) + for layer in range(num_layers) + ] + ) + + if attention: + # TODO make bias configurable + self.attention = AttentionLayer( + hidden_size, encoder_output_units, hidden_size, bias=False + ) + else: + self.attention = None + + if hidden_size != out_embed_dim: + self.additional_fc = Linear(hidden_size, out_embed_dim) + + if adaptive_softmax_cutoff is not None: + # setting adaptive_softmax dropout to dropout_out for now but can be redefined + self.adaptive_softmax = AdaptiveSoftmax( + num_embeddings, + hidden_size, + adaptive_softmax_cutoff, + dropout=dropout_out, + ) + elif not self.share_input_output_embed: + self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Tuple[Tensor, Tensor, Tensor, Tensor]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + src_lengths: Optional[Tensor] = None, + ): + x, attn_scores = self.extract_features( + prev_output_tokens, encoder_out, incremental_state + ) + return self.output_layer(x), attn_scores + + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Tuple[Tensor, Tensor, Tensor, Tensor]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + ): + """ + Similar to *forward* but only return features. + """ + # get outputs from encoder + if encoder_out is not None: + encoder_outs = encoder_out[0] + encoder_hiddens = encoder_out[1] + encoder_cells = encoder_out[2] + encoder_padding_mask = encoder_out[3] + else: + encoder_outs = torch.empty(0) + encoder_hiddens = torch.empty(0) + encoder_cells = torch.empty(0) + encoder_padding_mask = torch.empty(0) + srclen = encoder_outs.size(0) + + if incremental_state is not None and len(incremental_state) > 0: + prev_output_tokens = prev_output_tokens[:, -1:] + + bsz, seqlen = prev_output_tokens.size() + + # embed tokens + x = self.embed_tokens(prev_output_tokens) + x = self.dropout_in_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # initialize previous states (or get from cache during incremental generation) + if incremental_state is not None and len(incremental_state) > 0: + prev_hiddens, prev_cells, input_feed = self.get_cached_state( + incremental_state + ) + elif encoder_out is not None: + # setup recurrent cells + prev_hiddens = [encoder_hiddens[i] for i in range(self.num_layers)] + prev_cells = [encoder_cells[i] for i in range(self.num_layers)] + if self.encoder_hidden_proj is not None: + prev_hiddens = [self.encoder_hidden_proj(y) for y in prev_hiddens] + prev_cells = [self.encoder_cell_proj(y) for y in prev_cells] + input_feed = x.new_zeros(bsz, self.hidden_size) + else: + # setup zero cells, since there is no encoder + zero_state = x.new_zeros(bsz, self.hidden_size) + prev_hiddens = [zero_state for i in range(self.num_layers)] + prev_cells = [zero_state for i in range(self.num_layers)] + input_feed = None + + assert ( + srclen > 0 or self.attention is None + ), "attention is not supported if there are no encoder outputs" + attn_scores: Optional[Tensor] = ( + x.new_zeros(srclen, seqlen, bsz) if self.attention is not None else None + ) + outs = [] + for j in range(seqlen): + # input feeding: concatenate context vector from previous time step + if input_feed is not None: + input = torch.cat((x[j, :, :], input_feed), dim=1) + else: + input = x[j] + + for i, rnn in enumerate(self.layers): + # recurrent cell + hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) + + # hidden state becomes the input to the next layer + input = self.dropout_out_module(hidden) + if self.residuals: + input = input + prev_hiddens[i] + + # save state for next time step + prev_hiddens[i] = hidden + prev_cells[i] = cell + + # apply attention using the last layer's hidden state + if self.attention is not None: + assert attn_scores is not None + out, attn_scores[:, j, :] = self.attention( + hidden, encoder_outs, encoder_padding_mask + ) + else: + out = hidden + out = self.dropout_out_module(out) + + # input feeding + if input_feed is not None: + input_feed = out + + # save final output + outs.append(out) + + # Stack all the necessary tensors together and store + prev_hiddens_tensor = torch.stack(prev_hiddens) + prev_cells_tensor = torch.stack(prev_cells) + cache_state = torch.jit.annotate( + Dict[str, Optional[Tensor]], + { + "prev_hiddens": prev_hiddens_tensor, + "prev_cells": prev_cells_tensor, + "input_feed": input_feed, + }, + ) + self.set_incremental_state(incremental_state, "cached_state", cache_state) + + # collect outputs across time steps + x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) + + # T x B x C -> B x T x C + x = x.transpose(1, 0) + + if hasattr(self, "additional_fc") and self.adaptive_softmax is None: + x = self.additional_fc(x) + x = self.dropout_out_module(x) + # srclen x tgtlen x bsz -> bsz x tgtlen x srclen + if not self.training and self.need_attn and self.attention is not None: + assert attn_scores is not None + attn_scores = attn_scores.transpose(0, 2) + else: + attn_scores = None + return x, attn_scores + + def output_layer(self, x): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + if self.share_input_output_embed: + x = F.linear(x, self.embed_tokens.weight) + else: + x = self.fc_out(x) + return x + + def get_cached_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + ) -> Tuple[List[Tensor], List[Tensor], Optional[Tensor]]: + cached_state = self.get_incremental_state(incremental_state, "cached_state") + assert cached_state is not None + prev_hiddens_ = cached_state["prev_hiddens"] + assert prev_hiddens_ is not None + prev_cells_ = cached_state["prev_cells"] + assert prev_cells_ is not None + prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)] + prev_cells = [prev_cells_[j] for j in range(self.num_layers)] + input_feed = cached_state[ + "input_feed" + ] # can be None for decoder-only language models + return prev_hiddens, prev_cells, input_feed + + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + if incremental_state is None or len(incremental_state) == 0: + return + prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) + prev_hiddens = [p.index_select(0, new_order) for p in prev_hiddens] + prev_cells = [p.index_select(0, new_order) for p in prev_cells] + if input_feed is not None: + input_feed = input_feed.index_select(0, new_order) + cached_state_new = torch.jit.annotate( + Dict[str, Optional[Tensor]], + { + "prev_hiddens": torch.stack(prev_hiddens), + "prev_cells": torch.stack(prev_cells), + "input_feed": input_feed, + }, + ) + self.set_incremental_state(incremental_state, "cached_state", cached_state_new), + return + + def max_positions(self): + """Maximum output length supported by the decoder.""" + return self.max_target_positions + + def make_generation_fast_(self, need_attn=False, **kwargs): + self.need_attn = need_attn + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.uniform_(m.weight, -0.1, 0.1) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def LSTM(input_size, hidden_size, **kwargs): + m = nn.LSTM(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if "weight" in name or "bias" in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def LSTMCell(input_size, hidden_size, **kwargs): + m = nn.LSTMCell(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if "weight" in name or "bias" in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def Linear(in_features, out_features, bias=True, dropout=0.0): + """Linear layer (input: N x T x C)""" + m = nn.Linear(in_features, out_features, bias=bias) + m.weight.data.uniform_(-0.1, 0.1) + if bias: + m.bias.data.uniform_(-0.1, 0.1) + return m + + +@register_model_architecture("lstm", "lstm") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_freeze_embed = getattr(args, "encoder_freeze_embed", False) + args.encoder_hidden_size = getattr( + args, "encoder_hidden_size", args.encoder_embed_dim + ) + args.encoder_layers = getattr(args, "encoder_layers", 1) + args.encoder_bidirectional = getattr(args, "encoder_bidirectional", False) + args.encoder_dropout_in = getattr(args, "encoder_dropout_in", args.dropout) + args.encoder_dropout_out = getattr(args, "encoder_dropout_out", args.dropout) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_freeze_embed = getattr(args, "decoder_freeze_embed", False) + args.decoder_hidden_size = getattr( + args, "decoder_hidden_size", args.decoder_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 1) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) + args.decoder_attention = getattr(args, "decoder_attention", "1") + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "10000,50000,200000" + ) + + +@register_model_architecture("lstm", "lstm_wiseman_iwslt_de_en") +def lstm_wiseman_iwslt_de_en(args): + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_dropout_in = getattr(args, "encoder_dropout_in", 0) + args.encoder_dropout_out = getattr(args, "encoder_dropout_out", 0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", 0) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + base_architecture(args) + + +@register_model_architecture("lstm", "lstm_luong_wmt_en_de") +def lstm_luong_wmt_en_de(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1000) + args.encoder_layers = getattr(args, "encoder_layers", 4) + args.encoder_dropout_out = getattr(args, "encoder_dropout_out", 0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1000) + args.decoder_layers = getattr(args, "decoder_layers", 4) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1000) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", 0) + base_architecture(args) diff --git a/fairseq/fairseq/models/lstm_lm.py b/fairseq/fairseq/models/lstm_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..454f0ac36fab78bf02a8e2f07ed9607d1da87e34 --- /dev/null +++ b/fairseq/fairseq/models/lstm_lm.py @@ -0,0 +1,142 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import utils +from fairseq.models import ( + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.lstm import Embedding, LSTMDecoder + + +DEFAULT_MAX_TARGET_POSITIONS = 1e5 + + +@register_model("lstm_lm") +class LSTMLanguageModel(FairseqLanguageModel): + def __init__(self, decoder): + super().__init__(decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-hidden-size', type=int, metavar='N', + help='decoder hidden size') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='number of decoder layers') + parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', + help='decoder output embedding dimension') + parser.add_argument('--decoder-attention', type=str, metavar='BOOL', + help='decoder attention') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion') + parser.add_argument('--residuals', default=False, + action='store_true', + help='applying residuals between LSTM layers') + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument('--decoder-dropout-in', type=float, metavar='D', + help='dropout probability for decoder input embedding') + parser.add_argument('--decoder-dropout-out', type=float, metavar='D', + help='dropout probability for decoder output') + parser.add_argument('--share-decoder-input-output-embed', default=False, + action='store_true', + help='share decoder input and output embeddings') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if getattr(args, "max_target_positions", None) is not None: + max_target_positions = args.max_target_positions + else: + max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) + + def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(embed_path) + utils.print_embed_overlap(embed_dict, dictionary) + return utils.load_embedding(embed_dict, dictionary, embed_tokens) + + pretrained_decoder_embed = None + if args.decoder_embed_path: + pretrained_decoder_embed = load_pretrained_embedding_from_file( + args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim + ) + + if args.share_decoder_input_output_embed: + # double check all parameters combinations are valid + if task.source_dictionary != task.target_dictionary: + raise ValueError( + "--share-decoder-input-output-embeddings requires a joint dictionary" + ) + + if args.decoder_embed_dim != args.decoder_out_embed_dim: + raise ValueError( + "--share-decoder-input-output-embeddings requires " + "--decoder-embed-dim to match --decoder-out-embed-dim" + ) + + decoder = LSTMDecoder( + dictionary=task.dictionary, + embed_dim=args.decoder_embed_dim, + hidden_size=args.decoder_hidden_size, + out_embed_dim=args.decoder_out_embed_dim, + num_layers=args.decoder_layers, + dropout_in=args.decoder_dropout_in, + dropout_out=args.decoder_dropout_out, + attention=False, # decoder-only language model doesn't support attention + encoder_output_units=0, + pretrained_embed=pretrained_decoder_embed, + share_input_output_embed=args.share_decoder_input_output_embed, + adaptive_softmax_cutoff=( + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) + if args.criterion == "adaptive_loss" + else None + ), + max_target_positions=max_target_positions, + residuals=args.residuals, + ) + + return cls(decoder) + + +@register_model_architecture("lstm_lm", "lstm_lm") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_hidden_size = getattr( + args, "decoder_hidden_size", args.decoder_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 1) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) + args.decoder_attention = getattr(args, "decoder_attention", "0") + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "10000,50000,200000" + ) + args.residuals = getattr(args, "residuals", False) diff --git a/fairseq/fairseq/models/multilingual_transformer.py b/fairseq/fairseq/models/multilingual_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e722b647edd92c95a3e93489031ae331f90e0463 --- /dev/null +++ b/fairseq/fairseq/models/multilingual_transformer.py @@ -0,0 +1,229 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +from fairseq import utils +from fairseq.models import ( + FairseqMultiModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import ( + Embedding, + TransformerDecoder, + TransformerEncoder, + TransformerModel, + base_architecture, +) +from fairseq.utils import safe_hasattr + + +@register_model("multilingual_transformer") +class MultilingualTransformerModel(FairseqMultiModel): + """Train Transformer models for multiple language pairs simultaneously. + + Requires `--task multilingual_translation`. + + We inherit all arguments from TransformerModel and assume that all language + pairs use a single Transformer architecture. In addition, we provide several + options that are specific to the multilingual setting. + + Args: + --share-encoder-embeddings: share encoder embeddings across all source languages + --share-decoder-embeddings: share decoder embeddings across all target languages + --share-encoders: share all encoder params (incl. embeddings) across all source languages + --share-decoders: share all decoder params (incl. embeddings) across all target languages + """ + + def __init__(self, encoders, decoders): + super().__init__(encoders, decoders) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + TransformerModel.add_args(parser) + parser.add_argument( + "--share-encoder-embeddings", + action="store_true", + help="share encoder embeddings across languages", + ) + parser.add_argument( + "--share-decoder-embeddings", + action="store_true", + help="share decoder embeddings across languages", + ) + parser.add_argument( + "--share-encoders", + action="store_true", + help="share encoders across languages", + ) + parser.add_argument( + "--share-decoders", + action="store_true", + help="share decoders across languages", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + from fairseq.tasks.multilingual_translation import MultilingualTranslationTask + + assert isinstance(task, MultilingualTranslationTask) + + # make sure all arguments are present in older models + base_multilingual_architecture(args) + + if not safe_hasattr(args, "max_source_positions"): + args.max_source_positions = 1024 + if not safe_hasattr(args, "max_target_positions"): + args.max_target_positions = 1024 + + src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs] + tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs] + + if args.share_encoders: + args.share_encoder_embeddings = True + if args.share_decoders: + args.share_decoder_embeddings = True + + def build_embedding(dictionary, embed_dim, path=None): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx) + # if provided, load from preloaded dictionaries + if path: + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + return emb + + # build shared embeddings (if applicable) + shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None + if args.share_all_embeddings: + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( + dicts=task.dicts, + langs=task.langs, + embed_dim=args.encoder_embed_dim, + build_embedding=build_embedding, + pretrained_embed_path=args.encoder_embed_path, + ) + shared_decoder_embed_tokens = shared_encoder_embed_tokens + args.share_decoder_input_output_embed = True + else: + if args.share_encoder_embeddings: + shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( + dicts=task.dicts, + langs=src_langs, + embed_dim=args.encoder_embed_dim, + build_embedding=build_embedding, + pretrained_embed_path=args.encoder_embed_path, + ) + if args.share_decoder_embeddings: + shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( + dicts=task.dicts, + langs=tgt_langs, + embed_dim=args.decoder_embed_dim, + build_embedding=build_embedding, + pretrained_embed_path=args.decoder_embed_path, + ) + + # encoders/decoders for each language + lang_encoders, lang_decoders = {}, {} + + def get_encoder(lang): + if lang not in lang_encoders: + if shared_encoder_embed_tokens is not None: + encoder_embed_tokens = shared_encoder_embed_tokens + else: + encoder_embed_tokens = build_embedding( + task.dicts[lang], + args.encoder_embed_dim, + args.encoder_embed_path, + ) + lang_encoders[lang] = cls._get_module_class( + True, args, task.dicts[lang], encoder_embed_tokens, src_langs + ) + return lang_encoders[lang] + + def get_decoder(lang): + if lang not in lang_decoders: + if shared_decoder_embed_tokens is not None: + decoder_embed_tokens = shared_decoder_embed_tokens + else: + decoder_embed_tokens = build_embedding( + task.dicts[lang], + args.decoder_embed_dim, + args.decoder_embed_path, + ) + lang_decoders[lang] = cls._get_module_class( + False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs + ) + return lang_decoders[lang] + + # shared encoders/decoders (if applicable) + shared_encoder, shared_decoder = None, None + if args.share_encoders: + shared_encoder = get_encoder(src_langs[0]) + if args.share_decoders: + shared_decoder = get_decoder(tgt_langs[0]) + + encoders, decoders = OrderedDict(), OrderedDict() + for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs): + encoders[lang_pair] = ( + shared_encoder if shared_encoder is not None else get_encoder(src) + ) + decoders[lang_pair] = ( + shared_decoder if shared_decoder is not None else get_decoder(tgt) + ) + + return MultilingualTransformerModel(encoders, decoders) + + @classmethod + def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): + module_class = TransformerEncoder if is_encoder else TransformerDecoder + return module_class(args, lang_dict, embed_tokens) + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + state_dict_subset = state_dict.copy() + for k, _ in state_dict.items(): + assert k.startswith("models.") + lang_pair = k.split(".")[1] + if lang_pair not in self.models: + del state_dict_subset[k] + super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg) + + +@register_model_architecture("multilingual_transformer", "multilingual_transformer") +def base_multilingual_architecture(args): + base_architecture(args) + args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False) + args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False) + args.share_encoders = getattr(args, "share_encoders", False) + args.share_decoders = getattr(args, "share_decoders", False) + + +@register_model_architecture( + "multilingual_transformer", "multilingual_transformer_iwslt_de_en" +) +def multilingual_transformer_iwslt_de_en(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_layers = getattr(args, "decoder_layers", 6) + base_multilingual_architecture(args)