diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..330615f4245edb126f9e19a7dc7fbb1eddce9b12 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +fairseq/examples/MMPT/vlm.png filter=lfs diff=lfs merge=lfs -text +fairseq/examples/MMPT/videoclip.png filter=lfs diff=lfs merge=lfs -text diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8dc716815d8d4812237c814fc17a90dd7ac3cf1e --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml @@ -0,0 +1,31 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: VideoProcessor + aligner: DiDeMoAligner + bert_name: bert-base-uncased + meta_processor: DiDeMoMetaProcessor + test_path: data/didemo/test_data.json + vfeat_dir: data/feat/feat_didemo_s3d + text_processor: DiDeMoTextProcessor + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 256 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/retri/videoclip/checkpoint_best.pt +model: + model_cls: MMFusionSeparate + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/retri/videoclip/didemo_zs/eval +metric: DiDeMoMetric +predictor: DiDeMoPredictor diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f0566d784ab53e0159a1db6b82222ad11ad12add --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml @@ -0,0 +1,49 @@ +dataset: + video_processor: VideoProcessor + bert_name: bert-base-uncased + meta_processor: MSRVTTMetaProcessor + train_path: data/msrvtt/MSRVTT_train.csv + dup: 20 + val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTTextProcessor + json_path: data/msrvtt/MSRVTT_data.json + aligner: DSAligner + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 128 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 122 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 5 + checkpoint: + restore_file: runs/retri/videoclip/checkpoint_best.pt + reset_optimizer: true + reset_dataloader: true + reset_meters: true + save_dir: runs/retri/videoclip/vttqa +task_type: sweep_small +model: + model_cls: MMFusionSeparate + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +loss: + loss_cls: V2TContraLoss diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c2b13e5519fe0839c6b3bc681f7ee97b522077d8 --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml @@ -0,0 +1,49 @@ +dataset: + video_processor: YoucookVideoProcessor + bert_name: bert-base-uncased + meta_processor: YoucookMetaProcessor + train_path: data/youcook/youcook_train.pkl + val_path: data/youcook/youcook_val.pkl + trainval_annotation: data/youcook/youcookii_annotations_trainval.json + use_annotation_text: true + vfeat_dir: data/feat/feat_youcook_s3d + text_processor: TextProcessor + aligner: DSAligner + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 128 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 122 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 10 + checkpoint: + restore_file: runs/retri/videoclip/checkpoint_best.pt + reset_optimizer: true + reset_dataloader: true + reset_meters: true + save_dir: runs/retri/videoclip/youcook +task_type: sweep_small +model: + model_cls: MMFusionSeparate + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +loss: + loss_cls: T2VContraLoss diff --git a/fairseq/examples/MMPT/projects/task/coin.yaml b/fairseq/examples/MMPT/projects/task/coin.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e7772486e166348deff1f7c9deceb68c7e1bb443 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/coin.yaml @@ -0,0 +1,25 @@ +includes: projects/task/ft.yaml +task_type: sweep_big +dataset: + meta_processor: COINActionSegmentationMetaProcessor + train_path: data/coin/COIN.json + val_path: data/coin/COIN.json + vfeat_dir: data/feat/feat_coin_s3d + video_processor: VideoProcessor + text_processor: COINActionSegmentationTextProcessor + aligner: COINActionSegmentationAligner + num_iso_layer: 12 + sliding_window: 8 + sliding_window_size: 32 +model: + model_cls: MMFusionActionSegmentation + mm_encoder_cls: MMBertForTokenClassification +loss: + loss_cls: CrossEntropy +fairseq: + dataset: + batch_size: 1 + optimization: + max_epoch: 8 + checkpoint: + save_dir: runs/task/coin diff --git a/fairseq/examples/MMPT/projects/task/coin_videoclip.yaml b/fairseq/examples/MMPT/projects/task/coin_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..69988bc18a03a8f9bf594266f3bf06aa54f11b6a --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/coin_videoclip.yaml @@ -0,0 +1,7 @@ +includes: projects/task/coin.yaml +model: + model_cls: MMFusionSeparateActionSegmentation + mm_encoder_cls: + video_encoder_cls: MMBertForTokenClassification + text_encoder_cls: BertModel # dummy, not used. + num_hidden_video_layers: 6 diff --git a/fairseq/examples/MMPT/projects/task/test.yaml b/fairseq/examples/MMPT/projects/task/test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a98445241a17bbfda84a3b3be0687d31e49d12f --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test.yaml @@ -0,0 +1,13 @@ +# this yaml cannot be run alone: implement a test_${dataset}.yaml +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: VideoProcessor + aligner: DSAligner + bert_name: bert-base-uncased +fairseq: + dataset: + batch_size: 256 + valid_subset: test + num_workers: 2 diff --git a/fairseq/examples/MMPT/projects/task/test_vtt.yaml b/fairseq/examples/MMPT/projects/task/test_vtt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2f809b306d5d2b4d50d83ca3f9fa2db3f64530a5 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_vtt.yaml @@ -0,0 +1,19 @@ +includes: projects/task/test.yaml +dataset: + meta_processor: MSRVTTMetaProcessor + test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + video_processor: VideoProcessor + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTTextProcessor + num_iso_layer: 12 +model: + model_cls: MMFusionJoint + mm_encoder_cls: MMBertForJoint +eval: + save_path: runs/task/vtt/eval +fairseq: + # read code and find what is the checkpoint arg. + common_eval: + path: runs/task/vtt/checkpoint_last.pt +metric: RetrievalMetric +predictor: RetrievalPredictor diff --git a/fairseq/examples/MMPT/projects/task/test_youcook.yaml b/fairseq/examples/MMPT/projects/task/test_youcook.yaml new file mode 100644 index 0000000000000000000000000000000000000000..092b680fa6400458c447abe4f5cc5e452ffdb54f --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_youcook.yaml @@ -0,0 +1,22 @@ +includes: projects/task/test.yaml +dataset: + meta_processor: YoucookMetaProcessor + test_path: data/youcook/youcook_val.pkl + trainval_annotation: data/youcook/youcookii_annotations_trainval.json + use_annotation_text: True + video_processor: YoucookVideoProcessor + vfeat_dir: data/feat/feat_youcook_s3d # /checkpoint/huxu/feat/youcook_vmz # /checkpoint/prarora/berniehuang/feat_youcook_vmz + text_processor: TextProcessor + aligner: DSAligner + num_iso_layer: 12 +model: + model_cls: MMFusionJoint + mm_encoder_cls: MMBertForJoint +eval: + save_path: runs/task/youcook/eval +fairseq: + # read code and find what is the checkpoint arg. + common_eval: + path: runs/task/youcook/checkpoint_last.pt +metric: RetrievalMetric +predictor: RetrievalPredictor diff --git a/fairseq/examples/MMPT/projects/task/test_youcookcap.yaml b/fairseq/examples/MMPT/projects/task/test_youcookcap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..24f6518b7b4e210289f5200edafdb436d569adf7 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_youcookcap.yaml @@ -0,0 +1,23 @@ +includes: projects/task/test.yaml +dataset: + meta_processor: YoucookNLGMetaProcessor + test_path: data/youcook/val_list.txt + trainval_annotation: data/youcook/youcookii_annotations_trainval.json + video_processor: YoucookVideoProcessor + vfeat_dir: data/feat/feat_youcook_s3d + text_processor: NLGTextProcessor + aligner: DSNLGAligner +model: + model_cls: MMFusionNLG + mm_encoder_cls: MMBertForNLG + max_decode_length: 24 +eval: + save_path: runs/task/youcookcap/eval +fairseq: + # read code and find what is the checkpoint arg. + common_eval: + path: runs/task/youcookcap/checkpoint_best.pt +metric: NLGMetric +predictor: NLGPredictor +gen_param: + num_beams: 5 diff --git a/fairseq/examples/MMPT/projects/task/vtt.yaml b/fairseq/examples/MMPT/projects/task/vtt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..395e2ee6fe96899f9c47e39ad33b558358431554 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/vtt.yaml @@ -0,0 +1,25 @@ +includes: projects/task/ft.yaml +dataset: + meta_processor: MSRVTTMetaProcessor + train_path: data/msrvtt/MSRVTT_train.csv + jsfusion_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + full_test_path: data/msrvtt/MSRVTT_FULL_test.csv + dup: 20 + val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTTextProcessor + json_path: data/msrvtt/MSRVTT_data.json + aligner: DSAligner + num_iso_layer: 12 +model: + model_cls: MMFusionJoint + mm_encoder_cls: MMBertForJoint +loss: + loss_cls: T2VContraLoss +fairseq: + dataset: + batch_size: 256 + optimization: + max_epoch: 10 + checkpoint: + save_dir: runs/task/vtt diff --git a/fairseq/examples/MMPT/projects/task/vtt_videoclip.yaml b/fairseq/examples/MMPT/projects/task/vtt_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9892cab01b810175dcef2ee91f0e64e96bd36d9 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/vtt_videoclip.yaml @@ -0,0 +1,12 @@ +includes: projects/task/vtt.yaml +model: + model_cls: MMFusionSeparate + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +fairseq: + dataset: + batch_size: 224 +# model_cls: MMFusionShare +# mm_encoder_cls: MMBertForEncoder diff --git a/fairseq/examples/MMPT/projects/task/vttqa_videoclip.yaml b/fairseq/examples/MMPT/projects/task/vttqa_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d484ca8a5de0e96172b49b746f93d052f50dea7 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/vttqa_videoclip.yaml @@ -0,0 +1,10 @@ +includes: projects/task/vttqa.yaml +model: + model_cls: MMFusionSeparate + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 + +# model_cls: MMFusionShare +# mm_encoder_cls: MMBertForEncoder diff --git a/fairseq/examples/MMPT/projects/task/youcook.yaml b/fairseq/examples/MMPT/projects/task/youcook.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0cd841747a5cfa4c77a94d2478bfcb029517678 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/youcook.yaml @@ -0,0 +1,25 @@ +includes: projects/task/ft.yaml +dataset: + meta_processor: YoucookMetaProcessor + train_path: data/youcook/youcook_train.pkl + val_path: data/youcook/youcook_val.pkl + trainval_annotation: data/youcook/youcookii_annotations_trainval.json + use_annotation_text: True + video_processor: YoucookVideoProcessor + vfeat_dir: data/feat/feat_youcook_s3d # /checkpoint/huxu/feat/youcook_vmz # /checkpoint/prarora/berniehuang/feat_youcook_vmz + text_processor: TextProcessor + aligner: DSAligner + num_iso_layer: 12 +model: + model_cls: MMFusionJoint + mm_encoder_cls: MMBertForJoint +loss: + loss_cls: T2VContraLoss +fairseq: + dataset: + batch_size: 128 + optimization: + max_epoch: 10 + checkpoint: + save_dir: runs/task/youcook + diff --git a/fairseq/examples/MMPT/projects/task/youcook_videoclip.yaml b/fairseq/examples/MMPT/projects/task/youcook_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3e901c30ca8ef0e858b9fb0a317aea0b55f6e34 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/youcook_videoclip.yaml @@ -0,0 +1,9 @@ +includes: projects/task/youcook.yaml +model: + model_cls: MMFusionSeparate + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 + # model_cls: MMFusionShare + # mm_encoder_cls: MMBertForEncoder diff --git a/fairseq/examples/MMPT/projects/task/youcookcap.yaml b/fairseq/examples/MMPT/projects/task/youcookcap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..047735f21716777354101ebaa230772244ac79e5 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/youcookcap.yaml @@ -0,0 +1,23 @@ +# finetuning for youcook captioning. +includes: projects/task/ft.yaml +dataset: + meta_processor: YoucookNLGMetaProcessor + train_path: data/youcook/train_list.txt + val_path: data/youcook/val_list.txt + trainval_annotation: data/youcook/youcookii_annotations_trainval.json + video_processor: YoucookVideoProcessor + vfeat_dir: data/feat/feat_youcook_s3d + text_processor: NLGTextProcessor + aligner: DSNLGAligner +model: + model_cls: MMFusionNLG + mm_encoder_cls: MMBertForNLG +loss: + loss_cls: NLGLoss +fairseq: + dataset: + batch_size: 128 + optimization: + max_epoch: 10 + checkpoint: + save_dir: runs/task/youcookcap diff --git a/fairseq/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml b/fairseq/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml new file mode 100644 index 0000000000000000000000000000000000000000..473dd9b45b78d6d4443921fb5f355356db4f5531 --- /dev/null +++ b/fairseq/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml @@ -0,0 +1,5 @@ +dataset: + bert_name: bert-base-uncased + caption_pkl_path: data/how2/raw_caption_dedup.pkl + use_fast: true + target_dir: data/feat/feat_how2_s3d_shard_small diff --git a/fairseq/examples/MMPT/scripts/text_token_extractor/pretokenization.py b/fairseq/examples/MMPT/scripts/text_token_extractor/pretokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..29ae5dc151085cc59de03e4e8bfc6e5341defb8b --- /dev/null +++ b/fairseq/examples/MMPT/scripts/text_token_extractor/pretokenization.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import pickle +import os +import argparse +import numpy as np + +from torch.utils.data import Dataset, DataLoader +from mmpt.processors import PKLJSONStrTextProcessor +from mmpt.utils import ShardedTensor, recursive_config + + +class TokenizerDataset(Dataset): + def __init__(self, config): + self.text_processor = PKLJSONStrTextProcessor(config) + self.video_ids = list(self.text_processor.data.keys()) + + def __getitem__(self, idx): + video_id = self.video_ids[idx] + return video_id, self.text_processor(video_id) + + def __len__(self): + return len(self.video_ids) + + +def numpify(shard_idx, video_ids, captions, target_dir, split, prefix, max_cap_len=32): + startends = [] + caps_ids = [] + for video_id in video_ids: + caption = captions[video_id] + startend = [] + cap_ids = [] + for start, end, cap in zip( + caption["start"], caption["end"], caption["cap"]): + startend.append(np.array([start, end]).astype("float32")) + cap_id = np.full((max_cap_len,), -1, dtype=np.int32) + cap = cap[:max_cap_len] + cap_id[:len(cap)] = cap + cap_ids.append(cap_id) + startends.append(np.stack(startend)) + caps_ids.append(np.stack(cap_ids)) + + startends = ShardedTensor.from_list(startends) + target_path = os.path.join( + target_dir, + prefix + split + "_" + str(shard_idx) + ) + print("save to", target_path) + startends.save(target_path + ".startends") + caps_ids = ShardedTensor.from_list(caps_ids) + caps_ids.save(target_path + ".caps_ids") + + +def sharding(config, out_file): + with open(out_file, "rb") as fr: + captions = pickle.load(fr) + target_dir = config.target_dir + prefix = os.path.basename( + os.path.splitext(config.caption_pkl_path)[0] + ) + "." + config.bert_name + "." + for split in ["train", "val"]: + target_path = os.path.join(target_dir, split + "_meta") + with open(target_path + ".pkl", "rb") as fr: + meta = pickle.load(fr) + print("load meta", target_path, len(meta)) + for shard_id in meta: + numpify( + shard_id, meta[shard_id], captions, + target_dir, split, prefix + ) + + +def tokenize(config, out_file): + def collator(samples): + return samples + dataset = TokenizerDataset(config) + data = {} + for idx, batch in enumerate( + DataLoader(dataset, collate_fn=collator, num_workers=16)): + for video_id, caption in batch: + data[video_id] = caption + if idx % 5000 == 0: + print(idx) + with open(out_file, "wb") as fw: + pickle.dump(data, fw, pickle.HIGHEST_PROTOCOL) + + +def main(args): + config = recursive_config(args.config).dataset + + out_file = os.path.splitext(config.caption_pkl_path)[0] \ + + "." + config.bert_name + ".pkl" + if not os.path.isfile(out_file): + tokenize(config, out_file) + sharding(config, out_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="pretokenize (raw_)caption.json into pkl.") + parser.add_argument('config', type=str) + args = parser.parse_args() + main(args) diff --git a/fairseq/examples/MMPT/scripts/video_feature_extractor/extract.py b/fairseq/examples/MMPT/scripts/video_feature_extractor/extract.py new file mode 100644 index 0000000000000000000000000000000000000000..b5ee7b778890994aa87d463b896f653792c073f2 --- /dev/null +++ b/fairseq/examples/MMPT/scripts/video_feature_extractor/extract.py @@ -0,0 +1,157 @@ +# Copyright Howto100M authors. +# Copyright (c) Facebook, Inc. All Rights Reserved + +import torch as th +import torch.nn.functional as F +import math +import numpy as np +import argparse + +from torch.utils.data import DataLoader +from model import get_model +from preprocessing import Preprocessing +from random_sequence_shuffler import RandomSequenceSampler + +from tqdm import tqdm +from pathbuilder import PathBuilder +from videoreader import VideoLoader + + +parser = argparse.ArgumentParser(description='Easy video feature extractor') + +parser.add_argument('--vdir', type=str) +parser.add_argument('--fdir', type=str) +parser.add_argument('--hflip', type=int, default=0) + +parser.add_argument('--batch_size', type=int, default=64, + help='batch size') +parser.add_argument('--type', type=str, default='2d', + help='CNN type') +parser.add_argument('--half_precision', type=int, default=0, + help='output half precision float') +parser.add_argument('--num_decoding_thread', type=int, default=4, + help='Num parallel thread for video decoding') +parser.add_argument('--l2_normalize', type=int, default=1, + help='l2 normalize feature') +parser.add_argument('--resnext101_model_path', type=str, default='model/resnext101.pth', + help='Resnext model path') +parser.add_argument('--vmz_model_path', type=str, default='model/r2plus1d_34_clip8_ig65m_from_scratch-9bae36ae.pth', + help='vmz model path') + +args = parser.parse_args() + + +# TODO: refactor all args into config. (current code is from different people.) +CONFIGS = { + "2d": { + "fps": 1, + "size": 224, + "centercrop": False, + "shards": 0, + }, + "3d": { + "fps": 24, + "size": 112, + "centercrop": True, + "shards": 0, + }, + "s3d": { + "fps": 30, + "size": 224, + "centercrop": True, + "shards": 0, + }, + "vmz": { + "fps": 24, + "size": 112, + "centercrop": True, + "shards": 0, + }, + "vae": { + "fps": 2, + "size": 256, + "centercrop": True, + "shards": 100, + } +} + +config = CONFIGS[args.type] + + +video_dirs = args.vdir +feature_dir = args.fdir + +video_dict = PathBuilder.build(video_dirs, feature_dir, ".npy", config["shards"]) + +dataset = VideoLoader( + video_dict=video_dict, + framerate=config["fps"], + size=config["size"], + centercrop=config["centercrop"], + hflip=args.hflip +) +n_dataset = len(dataset) +sampler = RandomSequenceSampler(n_dataset, 10) +loader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + num_workers=args.num_decoding_thread, + sampler=sampler if n_dataset > 10 else None, +) +preprocess = Preprocessing(args.type) +model = get_model(args) + +with th.no_grad(): + for k, data in tqdm(enumerate(loader), total=loader.__len__(), ascii=True): + input_file = data['input'][0] + output_file = data['output'][0] + if len(data['video'].shape) > 3: + video = data['video'].squeeze() + if len(video.shape) == 4: + video = preprocess(video) + n_chunk = len(video) + if args.type == 'vmz': + n_chunk = math.ceil(n_chunk/float(3)) + features = th.cuda.FloatTensor(n_chunk, 512).fill_(0) + elif args.type == 's3d': + features = th.cuda.FloatTensor(n_chunk, 512).fill_(0) + elif args.type == "vae": + features = th.cuda.LongTensor(n_chunk, 1024).fill_(0) + else: + features = th.cuda.FloatTensor(n_chunk, 2048).fill_(0) + n_iter = int(math.ceil(n_chunk / float(args.batch_size))) + for i in range(n_iter): + factor = 1 + if args.type == 'vmz': + factor = 3 + min_ind = factor * i * args.batch_size + max_ind = factor * (i + 1) * args.batch_size + video_batch = video[min_ind:max_ind:factor].cuda() + if args.type == '2d': + batch_features = model(video_batch) # (51, 487), (51, 512) + elif args.type == 's3d': + batch_features = model(video_batch) + batch_features = batch_features['video_embedding'] + elif args.type == "vae": + # image_code. + batch_features = model(video_batch) + else: + batch_pred, batch_features = model(video_batch) # (51, 487), (51, 512) + if args.l2_normalize: + batch_features = F.normalize(batch_features, dim=1) + features[i*args.batch_size:(i+1)*args.batch_size] = batch_features + features = features.cpu().numpy() + if args.half_precision: + if args.type == "vae": + features = features.astype(np.int16) + else: + features = features.astype('float16') + else: + if args.type == "vae": + features = features.astype(np.int32) + else: + features = features.astype('float32') + np.save(output_file, features) + else: + print('Video {} error.'.format(input_file)) diff --git a/fairseq/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh b/fairseq/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh new file mode 100644 index 0000000000000000000000000000000000000000..90102c89fbb4fb413951dfd049b30b8dfe7cbe99 --- /dev/null +++ b/fairseq/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh @@ -0,0 +1,8 @@ +#!/bin/bash + + +python scripts/video_feature_extractor/extract.py \ + --vdir \ + --fdir data/feat/feat_how2_s3d \ + --type=s3d --num_decoding_thread=4 \ + --batch_size 32 --half_precision 1 diff --git a/fairseq/examples/MMPT/scripts/video_feature_extractor/model.py b/fairseq/examples/MMPT/scripts/video_feature_extractor/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ac266e844c86246bbfce02b9e6a2999353661df9 --- /dev/null +++ b/fairseq/examples/MMPT/scripts/video_feature_extractor/model.py @@ -0,0 +1,58 @@ +# Copyright (c) Howto100M authors and Facebook, Inc. All Rights Reserved + +import torch as th + +from torch import nn + + +class GlobalAvgPool(nn.Module): + def __init__(self): + super(GlobalAvgPool, self).__init__() + + def forward(self, x): + return th.mean(x, dim=[-2, -1]) + + +def get_model(args): + assert args.type in ['2d', '3d', 'vmz', 's3d', 'vae'] + if args.type == '2d': + print('Loading 2D-ResNet-152 ...') + import torchvision.models as models + model = models.resnet152(pretrained=True) + model = nn.Sequential(*list(model.children())[:-2], GlobalAvgPool()) + model = model.cuda() + elif args.type == 'vmz': + print('Loading VMZ ...') + from vmz34 import r2plus1d_34 + model = r2plus1d_34(pretrained_path=args.vmz_model_path, pretrained_num_classes=487) + model = model.cuda() + elif args.type == 's3d': + # we use one copy of s3d instead of dup another one for feature extraction. + from mmpt.processors.models.s3dg import S3D + model = S3D('pretrained_models/s3d_dict.npy', 512) + model.load_state_dict(th.load('pretrained_models/s3d_howto100m.pth')) + model = model.cuda() + + elif args.type == '3d': + print('Loading 3D-ResneXt-101 ...') + from videocnn.models import resnext + model = resnext.resnet101( + num_classes=400, + shortcut_type='B', + cardinality=32, + sample_size=112, + sample_duration=16, + last_fc=False) + model = model.cuda() + model_data = th.load(args.resnext101_model_path) + model.load_state_dict(model_data) + elif args.type == 'vae': + from openaivae import OpenAIParallelDiscreteVAE + model = OpenAIParallelDiscreteVAE() + model = model.cuda() + else: + raise ValueError("model not supported yet.") + + model.eval() + print('loaded') + return model diff --git a/fairseq/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py b/fairseq/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py new file mode 100644 index 0000000000000000000000000000000000000000..2392d6d63bc293fb24068e33f707ee4ce0e74a6f --- /dev/null +++ b/fairseq/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os +import urllib.parse +import json +import pandas as pd + +from tqdm import tqdm + + +# TODO: extending to other datasets. +supported_formats = {} + + +class PathBuilder(object): + @classmethod + def build(cls, video_dirs, feature_dir, ext, shards=0, split=None): + meta_fn = os.path.join(feature_dir, "meta_plan.json") + os.makedirs(feature_dir, exist_ok=True) + if os.path.isfile(meta_fn): + with open(meta_fn) as fr: + meta = json.load(fr) + return meta + print("searching videos...") + + video_id_to_path = {} + for video_dir in video_dirs.split(","): + # TODO: add supports of recursive listdir. + if video_dir in supported_formats: + supported_formats[video_dir].load(video_dir, video_id_to_path) + else: + for idx, fn in enumerate(tqdm(os.listdir(video_dir))): + video_fn = os.path.join(video_dir, fn) + if os.path.isfile(video_fn): + video_id = os.path.splitext(fn)[0] + video_id_to_path[video_id] = video_fn + elif os.path.isdir(video_fn): + # shards of folders. + shard_dir = video_fn + for idx, fn in enumerate(os.listdir(shard_dir)): + video_fn = os.path.join(shard_dir, fn) + if os.path.isfile(video_fn): + video_id = os.path.splitext(fn)[0] + video_id_to_path[video_id] = video_fn + + video_path, feature_path = [], [] + valid_ext = set() + for idx, video_id in enumerate(video_id_to_path): + video_path.append(video_id_to_path[video_id]) + if ext is None: + # use original file ext for format compatibility. + video_id_to_path[video_id] + path = urllib.parse.urlparse(video_id_to_path[video_id]).path + ext = os.path.splitext(path)[1] + if ext not in valid_ext: + valid_ext.add(ext) + print("adding", ext) + if shards: + shard_id = str(idx % shards) + feature_fn = os.path.join( + feature_dir, shard_id, video_id + ext) + else: + feature_fn = os.path.join( + feature_dir, video_id + ext) + feature_path.append(feature_fn) + + print("targeting", len(feature_path), "videos") + meta = { + "video_path": video_path, "feature_path": feature_path} + with open(meta_fn, "w") as fw: + json.dump(meta, fw) + + if split is not None: + splits = split.split("/") + assert len(splits) == 2 + cur, total = int(splits[0]), int(splits[1]) + assert cur < total + import math + chunk = math.ceil(len(meta["video_path"]) / total) + start = cur * chunk + end = (cur + 1) * chunk + meta = { + "video_path": meta["video_path"][start:end], + "feature_path": meta["feature_path"][start:end] + } + + return meta diff --git a/fairseq/examples/MMPT/scripts/video_feature_extractor/preprocessing.py b/fairseq/examples/MMPT/scripts/video_feature_extractor/preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0cec3a76b4665fae57d1feadd5a52469234226 --- /dev/null +++ b/fairseq/examples/MMPT/scripts/video_feature_extractor/preprocessing.py @@ -0,0 +1,57 @@ +# Copyright Howto100m authors. +# Copyright (c) Facebook, Inc. All Rights Reserved + +import torch as th + +class Normalize(object): + + def __init__(self, mean, std): + self.mean = th.FloatTensor(mean).view(1, 3, 1, 1) + self.std = th.FloatTensor(std).view(1, 3, 1, 1) + + def __call__(self, tensor): + tensor = (tensor - self.mean) / (self.std + 1e-8) + return tensor + +class Preprocessing(object): + + def __init__(self, type): + self.type = type + if type == '2d': + self.norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + elif type == '3d': + self.norm = Normalize(mean=[110.6, 103.2, 96.3], std=[1.0, 1.0, 1.0]) + elif type == 'vmz': + self.norm = Normalize(mean=[110.201, 100.64, 95.997], std=[58.1489, 56.4701, 55.3324]) + + def _zero_pad(self, tensor, size): + n = size - len(tensor) % size + if n == size: + return tensor + else: + z = th.zeros(n, tensor.shape[1], tensor.shape[2], tensor.shape[3]) + return th.cat((tensor, z), 0) + + def __call__(self, tensor): + if self.type == '2d': + tensor = tensor / 255.0 + tensor = self.norm(tensor) + elif self.type == 'vmz': + #tensor = self._zero_pad(tensor, 8) + tensor = self._zero_pad(tensor, 10) + tensor = self.norm(tensor) + #tensor = tensor.view(-1, 8, 3, 112, 112) + tensor = tensor.view(-1, 10, 3, 112, 112) + tensor = tensor.transpose(1, 2) + elif self.type == '3d': + tensor = self._zero_pad(tensor, 16) + tensor = self.norm(tensor) + tensor = tensor.view(-1, 16, 3, 112, 112) + tensor = tensor.transpose(1, 2) + elif self.type == 's3d': + tensor = tensor / 255.0 + tensor = self._zero_pad(tensor, 30) + tensor = tensor.view(-1, 30, 3, 224, 224) # N x 30 x 3 x H x W + tensor = tensor.transpose(1, 2) # N x 3 x 30 x H x W + # for vae do nothing + return tensor diff --git a/fairseq/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py b/fairseq/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3e4aceaa8f930c9eeac6283b15924e7eed5dc3 --- /dev/null +++ b/fairseq/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. All Rights Reserved + +import numpy as np + +from torch.utils.data.sampler import Sampler + + +class RandomSequenceSampler(Sampler): + + def __init__(self, n_sample, seq_len): + self.n_sample = n_sample + self.seq_len = seq_len + + def _pad_ind(self, ind): + zeros = np.zeros(self.seq_len - self.n_sample % self.seq_len) + ind = np.concatenate((ind, zeros)) + return ind + + def __iter__(self): + idx = np.arange(self.n_sample) + if self.n_sample % self.seq_len != 0: + idx = self._pad_ind(idx) + idx = np.reshape(idx, (-1, self.seq_len)) + np.random.shuffle(idx) + idx = np.reshape(idx, (-1)) + return iter(idx.astype(int)) + + def __len__(self): + return self.n_sample + (self.seq_len - self.n_sample % self.seq_len) diff --git a/fairseq/examples/MMPT/scripts/video_feature_extractor/shard_feature.py b/fairseq/examples/MMPT/scripts/video_feature_extractor/shard_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..f75e1dfae558fa366d15fa19e89d353003e728b3 --- /dev/null +++ b/fairseq/examples/MMPT/scripts/video_feature_extractor/shard_feature.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np +import os +import pickle + +from mmpt.utils import ShardedTensor + + +class Shard(object): + def __init__( + self, + vfeat_dir, + tfeat_dir, + target_dir, + file_paths, + shard_size=4096 + ): + self.vfeat_dir = vfeat_dir + self.tfeat_dir = tfeat_dir + self.target_dir = target_dir + self.video_ids = {} + for split, file_path in zip(["train", "val"], file_paths): + with open(file_path) as fr: + self.video_ids[split] = [ + line.strip() for line in fr.readlines()] + self.shard_size = shard_size + + def __call__(self, split="train"): + for split in ["train", "val"]: + meta = {} + for shard_idx, shard_offset in enumerate( + range(0, len(self.video_ids[split]), self.shard_size) + ): + print(shard_idx) + meta_shard = [] + video_shard = [] + for video_id in self.video_ids[split][shard_offset:shard_offset+self.shard_size]: + meta_shard.append(video_id) + npy_file = os.path.join(self.vfeat_dir, video_id + ".npy") + video_shard.append(np.load(npy_file)) + + meta[shard_idx] = meta_shard + video_shard = ShardedTensor.from_list(video_shard) + target_path = os.path.join( + self.target_dir, split + "_" + str(shard_idx)) + video_shard.save(target_path) + + target_path = os.path.join(self.target_dir, split + "_meta") + with open(target_path + ".pkl", "wb") as fw: + pickle.dump(meta, fw, pickle.HIGHEST_PROTOCOL) + + +if __name__ == "__main__": + shard = Shard( + "data/feat/feat_how2_s3d", + "data/how2/raw_caption_dedup.bert-base-uncased", + "data/feat/feat_how2_s3d_shard_small", + ["data/how2/how2_s3d_train.lst", "data/how2/how2_s3d_val.lst"] + ) + + shard() diff --git a/fairseq/examples/MMPT/scripts/video_feature_extractor/videoreader.py b/fairseq/examples/MMPT/scripts/video_feature_extractor/videoreader.py new file mode 100644 index 0000000000000000000000000000000000000000..429e05f8bc8667408b8c2057578b1e8c0e98638c --- /dev/null +++ b/fairseq/examples/MMPT/scripts/video_feature_extractor/videoreader.py @@ -0,0 +1,242 @@ +# Copyright Howto100M authors. +# Copyright (c) Facebook, Inc. All Rights Reserved + +import torch as th +import pandas as pd +import os +import numpy as np +import ffmpeg +import random + +from torch.utils.data import Dataset + + +class VideoLoader(Dataset): + """modified from how2's video_feature_extractor.""" + def __init__( + self, + csv=None, + video_dict=None, + framerate=1, + size=112, + centercrop=False, + hflip=False, + **kwargs + ): + if csv is None and video_dict is None: + raise ValueError("csv and video_dict cannot be both None.") + if csv is not None: + self.csv = pd.read_csv(csv) + if video_dict is not None: + self.csv = pd.DataFrame.from_dict(video_dict) + + self.centercrop = centercrop + self.size = size + self.framerate = framerate + self.hflip = hflip + + def __len__(self): + return len(self.csv) + + def _get_video_dim(self, video_path): + probe = ffmpeg.probe(video_path) + video_stream = next((stream for stream in probe['streams'] + if stream['codec_type'] == 'video'), None) + width = int(video_stream['width']) + height = int(video_stream['height']) + return height, width + + def _get_video_info(self, video_path): + probe = ffmpeg.probe(video_path) + video_stream = next((stream for stream in probe['streams'] + if stream['codec_type'] == 'video'), None) + return video_stream + + def _get_output_dim(self, h, w): + if isinstance(self.size, tuple) and len(self.size) == 2: + return self.size + elif h >= w: + return int(h * self.size / w), self.size + else: + return self.size, int(w * self.size / h) + + def __getitem__(self, idx): + video_path = self.csv['video_path'].values[idx] + output_file = self.csv['feature_path'].values[idx] + return self._decode(output_file, video_path) + + def _decode(self, output_file, video_path): + if not(os.path.isfile(output_file)) and os.path.isfile(video_path): + try: + h, w = self._get_video_dim(video_path) + except Exception: + print('ffprobe failed at: {}'.format(video_path)) + return {'video': th.zeros(1), 'input': video_path, + 'output': output_file} + try: + os.makedirs(os.path.dirname(output_file), exist_ok=True) + height, width = self._get_output_dim(h, w) + + cmd = ( + ffmpeg + .input(video_path) + .filter('fps', fps=self.framerate) + .filter('scale', width, height) + ) + if self.hflip: + cmd = cmd.filter('hflip') + + if self.centercrop: + x = int((width - self.size) / 2.0) + y = int((height - self.size) / 2.0) + cmd = cmd.crop(x, y, self.size, self.size) + video = self._run(cmd, output_file) + except Exception: + video = th.zeros(1) + else: + video = th.zeros(1) + + return {'video': video, 'input': video_path, 'output': output_file} + + def _run(self, cmd, output_file): + out, _ = ( + cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24') + .run(capture_stdout=True, quiet=True) + ) + if self.centercrop and isinstance(self.size, int): + height, width = self.size, self.size + video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3]) + video = th.from_numpy(video.astype('float32')) + return video.permute(0, 3, 1, 2) + + +class VideoVerifier(VideoLoader): + def __getitem__(self, idx): + video_path = self.csv['video_path'].values[idx] + try: + return self._get_video_info(video_path) + except Exception: + # print('ffprobe failed at: {}'.format(video_path)) + return None + + +class VideoCompressor(VideoLoader): + def __init__( + self, + csv=None, + video_dict=None, + framerate=1, + size=112, + centercrop=False, + hflip=False, + crf=32, + **kwargs + ): + super().__init__( + csv, + video_dict, + framerate, + size, + centercrop, + hflip + ) + self.crf = crf + + def _run(self, cmd, output_file): + out, _ = ( + cmd.output(filename=output_file, crf=self.crf) + .run(quiet=True) + ) + video = None + return video + + +class VideoDownloader(VideoCompressor): + """download""" + def __getitem__(self, idx): + video_path = self.csv['video_path'].values[idx] + output_file = self.csv['feature_path'].values[idx] + if not(os.path.isfile(output_file)): + os.makedirs(os.path.dirname(output_file), exist_ok=True) + cmd = "wget -O" + output_file + " " + video_path + # import subprocess + # subprocess.check_output( + # cmd, + # stderr=subprocess.STDOUT, shell=True) + os.system(cmd) + return {'video': None, 'input': video_path, 'output': output_file} + + +class AvKeyframeVideoCompressor(VideoLoader): + """extract keyframes from a video and save it as jpg. + TODO: consider to merge with `CodecProcessor`. + """ + def __init__( + self, + csv=None, + video_dict=None, + framerate=1, + size=112, + centercrop=False, + max_num_frames=5, + **kwargs + ): + super().__init__(csv, video_dict, framerate, size, centercrop) + self.max_num_frames = max_num_frames + + def _get_video_dim(self, video_fn): + """decord cannot probe the size of a video, we use pyav instead.""" + import av + with av.open(video_fn) as container: + height = container.streams.video[0].codec_context.height + width = container.streams.video[0].codec_context.width + return height, width + + def _get_output_dim(self, height, width): + """ + keep the shorter side be `self.size`, strech the other. + """ + if height >= width: + return int(height * self.size / width), self.size + else: + return self.size, int(width * self.size / height) + + def __getitem__(self, idx): + import av + video_path = self.csv['video_path'].values[idx] + output_file = self.csv['feature_path'].values[idx] + if not(os.path.isdir(output_file)) and os.path.isfile(video_path): + try: + h, w = self._get_video_dim(video_path) + except Exception: + print('probe failed at: {}'.format(video_path)) + return {'video': th.zeros(1), 'input': video_path, + 'output': output_file} + + try: + height, width = self._get_output_dim(h, w) + + # new for av. + with av.open(video_path) as container: + container.streams.video[0].thread_type = "AUTO" + container.streams.video[0].codec_context.height = height + container.streams.video[0].codec_context.width = width + if self.framerate == 0: # keyframe. + container.streams.video[0].codec_context.skip_frame = 'NONKEY' + frames = [] + for frame in container.decode(video=0): + frames.append(frame) + frames = random.sample(frames, self.max_num_frames) + + os.makedirs(output_file, exist_ok=True) + for frame in frames: + frame.to_image().save( + os.path.join( + output_file, + "%04d.jpg" % frame.index)) + except Exception: + print('extract failed at: {}'.format(video_path)) + return {'video': th.zeros(1), 'input': video_path, + 'output': output_file} + video = th.zeros(1) + return {'video': video, 'input': video_path, 'output': output_file} diff --git a/fairseq/examples/MMPT/videoclip.png b/fairseq/examples/MMPT/videoclip.png new file mode 100644 index 0000000000000000000000000000000000000000..2b6f8d959fcfe3ce229df16a52c5c8230950b582 --- /dev/null +++ b/fairseq/examples/MMPT/videoclip.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d54fe18d1259ade9332e78fdb74f834fdfbdb0b0486517e6a7cd48956b30663 +size 385871 diff --git a/fairseq/examples/MMPT/vlm.png b/fairseq/examples/MMPT/vlm.png new file mode 100644 index 0000000000000000000000000000000000000000..c702b32b12d7d630c25af96a98ded4f83378bcb4 --- /dev/null +++ b/fairseq/examples/MMPT/vlm.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:722852ed6258ac9f7ffd3e3913fa1a370702c4d989ef6d881847432d59ade4e5 +size 418405 diff --git a/fairseq/examples/adaptive_span/README.md b/fairseq/examples/adaptive_span/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d5224fb2894606a2a8027e01e224be190776ecfe --- /dev/null +++ b/fairseq/examples/adaptive_span/README.md @@ -0,0 +1,90 @@ +# Adaptive Span + +Adaptive Span is a novel self-attention mechanism that can learn its optimal +attention span. This allows us to extend significantly the maximum context size +used in Transformer, while maintaining control over their memory footprint +and computational time. It uses the Truncated BPTT technique for training, +as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md). + +Adaptive Span was introduced by paper: +[Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799), +which achieved state-of-the-art language modeling results at the time of publication. + +We manage to reproduce their result in fairseq and keep most of the +[original implementation](https://github.com/facebookresearch/adaptive-span) untouched. +You can refer to the their sweep file as well if any combination of hyperparameter is not clear. + +##### 0. Setup + +First you need to process the Enwik8 dataset, we use the pre-tokenized dataset +from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh). +You can download the dataset, and then run: +```bash +fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \ + --validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \ + --destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20 +``` + +##### 1. Train a Adaptive Span model on Enwik8 + +We will train a 12-layer Adaptive Span model following the [hyperparameters +used in the original +paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh). + +The following command assumes 4 GPUs, so that the total batch size is 64 +sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs: +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ + --user-dir examples/adaptive_span \ + --data ~/data/enwik8/data-bin/ \ + --fp16 --fp16-no-flatten-grads --max-update 600000 \ + --task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \ + --n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \ + --attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \ + --validate-interval-updates 1000 \ + --lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \ + --lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \ + --seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07 +``` +This should land around 1.05 on validation, 1.03 on test. You can lower the +--aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc +improvement to the transformerXL baseline here. +If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients +and simulate training on 4 GPUs. +You can also reproduce the transformerXL result on enwik8 using this code base. +It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh). +You can try by +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ + --user-dir examples/truncated_bptt \ + ~/data/enwik8/data-bin/ \ + --task truncated_bptt_lm --fp16 --max-update 400000 \ + --tokens-per-sample 512 --arch transformer_xl --n-layer 12 \ + --d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \ + --dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \ + --lr-scheduler cosine --warmup-updates 0 \ + --lr 0.0 --lr 0.00025 --batch-size 15 \ + --update-freq 1 --seed 2 --log-format json --log-interval 25 \ + --fp16 +``` + +##### 2. Evaluate +For Adaptive Span: +```bash +fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \ + --user-dir examples/adaptive_span \ + --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test +``` +For Transformer-XL evaluation: +```bash +fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \ + --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \ + --tokens-per-sample 80 \ + --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \ + --gen-subset valid +``` + +*Note:* During training the model saw 512 tokens of context +(``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation +settings from [the original +paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh). diff --git a/fairseq/examples/adaptive_span/__init__.py b/fairseq/examples/adaptive_span/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a142a769360e1140bf814c532eaf841f1d52d8 --- /dev/null +++ b/fairseq/examples/adaptive_span/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +# automatically import any Python files in the current directory +cur_dir = os.path.dirname(__file__) +for file in os.listdir(cur_dir): + path = os.path.join(cur_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + mod_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module(__name__ + "." + mod_name) diff --git a/fairseq/examples/adaptive_span/adagrad_with_grad_clip.py b/fairseq/examples/adaptive_span/adagrad_with_grad_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..585ce184ab2d6bbde0d2f7fcafd6536fa8f6d8b6 --- /dev/null +++ b/fairseq/examples/adaptive_span/adagrad_with_grad_clip.py @@ -0,0 +1,128 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch.optim import Adagrad + +from fairseq.optim import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adagrad_with_grad_clip") +class FairseqAdagradWithGradClip(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = AdagradWithGradClip(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D', + help='internal grad clip') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "weight_decay": self.args.weight_decay, + "grad_clip": self.args.adagrad_clip, + } + + @property + def supports_flat_params(self): + return False + + +def _clip_grad(clr, grad, group_grad_clip): + if group_grad_clip > 0: + norm = grad.norm(2).item() + if norm > group_grad_clip: + clr *= group_grad_clip / (norm + 1e-10) + return clr + + +class AdagradWithGradClip(Adagrad): + """Adagrad algorithm with custom gradient clipping""" + + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + grad_clip=0, + ): + Adagrad.__init__( + self, + params, + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + ) + self.defaults["grad_clip"] = grad_clip + self.param_groups[0].setdefault("grad_clip", grad_clip) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad.data + state = self.state[p] + + state["step"] += 1 + + if group["weight_decay"] != 0: + if p.grad.data.is_sparse: + raise RuntimeError( + "weight_decay option is " + "not compatible with sparse " + "gradients" + ) + grad = grad.add(group["weight_decay"], p.data) + + clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"]) + + # clip + clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"]) + + if grad.is_sparse: + # the update is non-linear so indices must be unique + grad = grad.coalesce() + grad_indices = grad._indices() + grad_values = grad._values() + size = grad.size() + + def make_sparse(values): + constructor = grad.new + if grad_indices.dim() == 0 or values.dim() == 0: + return constructor().resize_as_(grad) + return constructor(grad_indices, values, size) + + state["sum"].add_(make_sparse(grad_values.pow(2))) + std = state["sum"]._sparse_mask(grad) + std_values = std._values().sqrt_().add_(1e-10) + p.data.add_(-clr, make_sparse(grad_values / std_values)) + else: + state["sum"].addcmul_(1, grad, grad) + std = state["sum"].sqrt().add_(1e-10) + p.data.addcdiv_(-clr, grad, std) + + return loss diff --git a/fairseq/examples/adaptive_span/adaptive_span_attention.py b/fairseq/examples/adaptive_span/adaptive_span_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..07f757bb8e1a8a67b1124175ee338c8735aa8d65 --- /dev/null +++ b/fairseq/examples/adaptive_span/adaptive_span_attention.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AdaptiveMask(nn.Module): + """Soft masking function for adaptive size. + It masks out the last K values of an input. The masking value + goes from 1 to 0 gradually, so K can be learned with + back-propagation. + Args: + max_size: maximum size (i.e. input dimension) + ramp_size: size of the ramp going from 0 to 1 + init_val: initial size proportion not to be masked out + shape: learn multiple sizes independent of each other + """ + + def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)): + nn.Module.__init__(self) + self._max_size = max_size + self._ramp_size = ramp_size + self.current_val = nn.Parameter(torch.zeros(*shape) + init_val) + mask_template = torch.linspace(1 - max_size, 0, steps=max_size) + self.register_buffer("mask_template", mask_template) + + def forward(self, x): + mask = self.mask_template.float() + self.current_val.float() * self._max_size + mask = mask / self._ramp_size + 1 + mask = mask.clamp(0, 1) + if x.size(-1) < self._max_size: + # the input could have been trimmed beforehand to save computation + mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1)) + x = (x * mask).type_as(x) + return x + + def get_current_max_size(self, include_ramp=True): + current_size = math.ceil(self.current_val.max().item() * self._max_size) + if include_ramp: + current_size += self._ramp_size + current_size = max(0, min(self._max_size, current_size)) + return current_size + + def get_current_avg_size(self, include_ramp=True): + current_size = math.ceil( + self.current_val.float().mean().item() * self._max_size + ) + if include_ramp: + current_size += self._ramp_size + current_size = max(0, min(self._max_size, current_size)) + return current_size + + def clamp_param(self): + """this need to be called after each update""" + self.current_val.data.clamp_(0, 1) + + +class AdaptiveSpan(nn.Module): + """Adaptive attention span for Transformerself. + This module learns an attention span length from data for each + self-attention head. + Args: + attn_span: maximum attention span + adapt_span_loss: loss coefficient for the span length + adapt_span_ramp: length of the masking ramp + adapt_span_init: initial size ratio + adapt_span_cache: adapt cache size to reduce memory usage + """ + + def __init__( + self, + attn_span, + adapt_span_ramp, + adapt_span_init, + n_head, + adapt_span_layer, + **kargs + ): + nn.Module.__init__(self) + self._max_span = attn_span + self._n_head = n_head + self._adapt_span_layer = adapt_span_layer + if self._adapt_span_layer: + self._mask = AdaptiveMask( + max_size=self._max_span, + ramp_size=adapt_span_ramp, + init_val=adapt_span_init, + ) + else: + self._mask = AdaptiveMask( + max_size=self._max_span, + ramp_size=adapt_span_ramp, + init_val=adapt_span_init, + shape=(n_head, 1, 1), + ) + + def forward(self, attn, normalize=True): + """mask attention with the right span""" + # batch and head dimensions are merged together, so separate them first + self.clamp_param() + if self._adapt_span_layer: + attn = self._mask(attn) + else: + B = attn.size(0) # batch size + M = attn.size(1) # block size + attn = attn.reshape(B // self._n_head, self._n_head, M, -1) + attn = self._mask(attn) + attn = attn.view(B, M, -1) + return attn + + def get_trim_len(self): + """how much of memory can be trimmed to reduce computation""" + L = self._max_span + trim_len = min(L - 1, L - self._mask.get_current_max_size()) + # too fine granularity might be bad for the memory management + trim_len = math.floor(trim_len / 64) * 64 + return trim_len + + def trim_memory(self, query, key, value, key_pe): + """trim out unnecessary memory beforehand to reduce computation""" + trim_len = self.get_trim_len() + cache_size = key.size(1) - query.size(1) + trim_len_cache = trim_len - (self._max_span - cache_size) + if trim_len_cache > 0: + key = key[:, trim_len_cache:, :] + value = value[:, trim_len_cache:, :] + elif trim_len_cache < 0: + # cache is too short! this happens when validation resumes + # after a lot of updates. + key = F.pad(key, [0, 0, -trim_len_cache, 0]) + value = F.pad(value, [0, 0, -trim_len_cache, 0]) + if trim_len > 0: + if key_pe is not None: + key_pe = key_pe[:, :, trim_len:] + return key, value, key_pe + + def get_cache_size(self): + """determine how long the cache should be""" + trim_len = self.get_trim_len() + # give a buffer of 64 steps since a span might increase + # in future updates + return min(self._max_span, self._max_span - trim_len + 64) + + def get_loss(self): + """a loss term for regularizing the span length""" + return self._max_span * self._mask.current_val.float().mean() + + def get_current_max_span(self): + return self._mask.get_current_max_size() + + def get_current_avg_span(self): + return self._mask.get_current_avg_size() + + def clamp_param(self): + self._mask.clamp_param() diff --git a/fairseq/examples/adaptive_span/adaptive_span_loss.py b/fairseq/examples/adaptive_span/adaptive_span_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fe95b0d949250087eb4712172f7e3e04559de547 --- /dev/null +++ b/fairseq/examples/adaptive_span/adaptive_span_loss.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass + +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import register_criterion +from fairseq.criterions.cross_entropy import CrossEntropyCriterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class AdaptiveSpanCriterionConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + + +@register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig) +class AdaptiveSpanCriterion(CrossEntropyCriterion): + def __init__(self, task, sentence_avg): + super().__init__(task, sentence_avg) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss here is summed, different from the adaptive span code + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, aux_loss, avg_span, max_span = self.compute_loss( + model, net_output, sample, reduce=reduce + ) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + loss /= sample_size + total_loss = loss + aux_loss + sample_size = 1 + + logging_output = { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + "total_loss": total_loss.data, + "avg_span": avg_span * sample_size, + "max_span": max_span * sample_size, + } + return total_loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + loss, _ = super().compute_loss(model, net_output, sample, reduce) + aux_loss = model.get_aux_loss() + avg_span = model.get_current_avg_span() + max_span = model.get_current_max_span() + return loss, aux_loss, avg_span, max_span + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs) + avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs) + max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs) + + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3) + metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3) + # total loss contains the L1 norm on adaptive-span + metrics.log_scalar( + "total_loss", + total_loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/examples/adaptive_span/adaptive_span_model.py b/fairseq/examples/adaptive_span/adaptive_span_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d96c95b85dbcf29e9384cc6d8d9630d2489991b2 --- /dev/null +++ b/fairseq/examples/adaptive_span/adaptive_span_model.py @@ -0,0 +1,263 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq.modules.layer_norm import LayerNorm + +from .adaptive_span_attention import AdaptiveSpan + +# Size notations: +# B = batch_size, H = d_model, M = block_size, L = attn_span + + +def _skew(X, pad_value): + """shift every row 1 step to right""" + # X = B x M x L + B, M, L = X.size() + X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1) + X = X.view(B, -1) # B x ML+MM+M + X = X[:, :-M] # B x ML+MM + X = X.view(B, M, M + L) # B x M x L+M + return X + + +def _unskew(X): + """reverse _skew operation""" + # X = B x M x L+M + B, M, L = X.size() + L -= M + X = X.view(B, -1) # B x ML+MM + X = F.pad(X, (0, M)) # B x ML+MM+M + X = X.view(B, M, M + L + 1) # B x M x L+M+1 + X = X[:, :, :L] # B x M x L + return X + + +class SeqAttention(nn.Module): + """Sequential self-attention layer. + Each token will attend to its previous fixed number of steps. + Note that attention doesn't include the current step itself. + """ + + def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs): + nn.Module.__init__(self) + self.dropout = nn.Dropout(dropout) + self.d_model = d_model # size of a single head + self.attn_span = attn_span + self.adaptive_span = AdaptiveSpan( + attn_span=attn_span, + n_head=n_head, + adapt_span_layer=adapt_span_layer, + **kargs + ) + + def forward(self, query, key, value, key_pe): + # query size = B x M x H + # key, value sizes = B x (M+L) x H + + key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe) + + # compute attention from context + # B x M (dest) x (M+L) (src) + attn_cont = torch.matmul(query, key.transpose(-1, -2)) + attn_cont = _unskew(attn_cont) # B x M x L + + # compute the effect of position embedding + attn_pos = torch.matmul(query, key_pe) # B x M x L_pos + attn = attn_cont + attn_pos + + attn = attn / math.sqrt(self.d_model) # B x M X L_pos + + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + + # trim attention lengths according to the learned span + attn = self.adaptive_span(attn) + + attn = self.dropout(attn) # B x M X L_pos + + attn_cont = _skew(attn, 0) # B x M X (L+M) + out = torch.matmul(attn_cont, value) # B x M x H + return out + + def get_cache_size(self): + return self.adaptive_span.get_cache_size() + + +class MultiHeadSeqAttention(nn.Module): + def __init__(self, d_model, n_head, **kargs): + nn.Module.__init__(self) + assert d_model % n_head == 0 + self.n_head = n_head + self.head_dim = d_model // n_head + self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs) + self.proj_query = nn.Linear(d_model, d_model, bias=False) + nn.init.xavier_normal_(self.proj_query.weight) + self.proj_out = nn.Linear(d_model, d_model, bias=False) + nn.init.xavier_normal_(self.proj_out.weight) + self.proj_val = nn.Linear(d_model, d_model, bias=False) + nn.init.xavier_normal_(self.proj_val.weight) + self.proj_key = nn.Linear(d_model, d_model, bias=False) + nn.init.xavier_normal_(self.proj_key.weight) + + def head_reshape(self, x): + K = self.n_head + D = self.head_dim + x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D + x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D + x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D + return x + + def forward(self, query, key, value, key_pe): + B = query.size(0) + K = self.n_head + D = self.head_dim + M = query.size(1) + + query = self.proj_query(query) + query = self.head_reshape(query) + value = self.proj_val(value) + value = self.head_reshape(value) + key = self.proj_key(key) + key = self.head_reshape(key) + + out = self.attn(query, key, value, key_pe) # B_K x M x D + out = out.view(B, K, M, D) # B x K x M x D + out = out.transpose(1, 2).contiguous() # B x M x K x D + out = out.view(B, M, -1) # B x M x K_D + out = self.proj_out(out) + return out + + +class FeedForwardLayer(nn.Module): + def __init__(self, d_model, d_inner, dropout, **kargs): + nn.Module.__init__(self) + self.fc1 = nn.Linear(d_model, d_inner) + self.fc2 = nn.Linear(d_inner, d_model) + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + self.dropout = nn.Dropout(dropout) + + def forward(self, h): + h1 = F.relu(self.fc1(h)) + h1 = self.dropout(h1) + h2 = self.fc2(h1) + return h2 + + +class TransformerSeqLayer(nn.Module): + def __init__(self, d_model, **kargs): + nn.Module.__init__(self) + self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs) + self.norm1 = LayerNorm(d_model) + self.ff = FeedForwardLayer(d_model=d_model, **kargs) + self.norm2 = LayerNorm(d_model) + + def forward(self, h, h_cache, key_pe): + # h = B x M x H + # h_cache = B x L x H + h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H + attn_out = self.attn(h, h_all, h_all, key_pe) + h = self.norm1(h + attn_out) # B x M x H + if self.ff is not None: + ff_out = self.ff(h) + out = self.norm2(h + ff_out) # B x M x H + else: + out = h + return out + + def get_cache_size(self): + return self.attn.attn.get_cache_size() + + +class TransformerSeq(nn.Module): + def __init__( + self, + vocab_size, + d_model, + n_head, + n_layer, + attn_span, + emb_dropout, + aux_loss_scaler, + adapt_span_layer, + **kargs + ): + nn.Module.__init__(self) + # token embeddings + self.in_emb = nn.Embedding(vocab_size, d_model) + nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5) + self.out_emb = nn.Linear(d_model, vocab_size) + self.aux_loss_scaler = aux_loss_scaler + if emb_dropout > 0: + self.emb_dropout = nn.Dropout(emb_dropout) + else: + self.emb_dropout = None + # position embeddings + self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span)) + + self.layers = nn.ModuleList() + self.layers.extend( + TransformerSeqLayer( + d_model=d_model, + n_head=n_head, + attn_span=attn_span, + adapt_span_layer=adapt_span_layer, + **kargs + ) + for _ in range(n_layer) + ) + + def forward(self, x, h_cache, target=None): + # x size = B x M + block_size = x.size(1) + h = self.in_emb(x) # B x M x H + if self.emb_dropout is not None: + h = self.emb_dropout(h) + + h_cache_next = [] + for l, layer in enumerate(self.layers): + cache_size = layer.attn.attn.get_cache_size() + if cache_size > block_size: + h_cache_next_l = torch.cat( + [h_cache[l][:, -cache_size + block_size :, :], h], dim=1 + ).detach() + else: + h_cache_next_l = h[:, -cache_size:, :].detach() + h_cache_next.append(h_cache_next_l) + h = layer(h, h_cache[l], self.key_pe) # B x M x H + + if self.emb_dropout is not None: + h = self.emb_dropout(h) + + out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h) + dummy_loss = None + + return out, h_cache_next, dummy_loss + + def get_aux_loss(self): + loss = 0.0 + for layer in self.layers: + loss += layer.attn.attn.adaptive_span.get_loss() + return self.aux_loss_scaler * loss + + def get_current_max_span(self): + max_span = 0.0 + for layer in self.layers: + max_span = max( + max_span, layer.attn.attn.adaptive_span.get_current_max_span() + ) + return max_span + + def get_current_avg_span(self): + avg_span = 0.0 + for layer in self.layers: + avg_span += layer.attn.attn.adaptive_span.get_current_avg_span() + return avg_span / len(self.layers) diff --git a/fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py b/fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..5b147fe11f9d730438d036321a2d4a5d776efaa2 --- /dev/null +++ b/fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py @@ -0,0 +1,145 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch +from fairseq.dataclass import FairseqDataclass +from fairseq.models import ( + FairseqIncrementalDecoder, + FairseqLanguageModel, + register_model, +) +from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel + + +logger = logging.getLogger(__name__) + + +@dataclass +class AdaptiveSpanSmallConfig(FairseqDataclass): + # defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh + vocab_size: int = 50 + d_model: int = 256 + n_head: int = 4 + d_inner: int = 1024 + n_layer: int = 8 + attn_span: int = 1024 + dropout: float = 0.0 + emb_dropout: float = 0.0 + adapt_span_ramp: int = 32 + adapt_span_init: float = 0.0 + aux_loss_scaler: float = 0.000002 + adapt_span_layer: bool = False + + +@register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig) +class AdaptiveSpanTransformer(FairseqLanguageModel): + @classmethod + def build_model(cls, cfg: AdaptiveSpanSmallConfig, task): + return cls(AdaptiveSpanDecoder(cfg, task)) + + def get_aux_loss(self): + return self.decoder.get_aux_loss() + + def get_current_max_span(self): + return self.decoder.get_current_max_span() + + def get_current_avg_span(self): + return self.decoder.get_current_avg_span() + + +class AdaptiveSpanDecoder(FairseqIncrementalDecoder): + def __init__(self, cfg, task): + + super().__init__(task.target_dictionary) + + self.config = cfg + config = AdaptiveSpanSmallConfig( + vocab_size=len(task.target_dictionary), + d_model=cfg.d_model, + n_head=cfg.n_head, + d_inner=cfg.d_inner, + n_layer=cfg.n_layer, + attn_span=cfg.attn_span, + dropout=cfg.dropout, + emb_dropout=cfg.emb_dropout, + adapt_span_ramp=cfg.adapt_span_ramp, + adapt_span_init=cfg.adapt_span_init, + aux_loss_scaler=cfg.aux_loss_scaler, + adapt_span_layer=cfg.adapt_span_layer, + ) + logger.info(config) + self.model = AdaptiveSpanTransformerModel(**config.__dict__) + + self._mems = None + + def forward( + self, + src_tokens, + incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None, + encoder_out=None, + ): + bsz = src_tokens.size(0) + if incremental_state is not None: # used during inference + mems = self.get_incremental_state("mems") + src_tokens = src_tokens[:, -1:] # only keep the most recent token + else: + mems = self._mems + + if mems is None: + # first time init + mems = self.init_hid_cache(bsz) + output = self.model(x=src_tokens, h_cache=mems,) + if incremental_state is not None: + self.set_incremental_state(incremental_state, "mems", output[1]) + else: + self._mems = output[1] + return (output[0],) + + def max_positions(self): + return self.config.attn_span + + def init_hid_cache(self, batch_sz): + hid = [] + for layer in self.model.layers: + param = next(self.model.parameters()) + h = torch.zeros( + batch_sz, + layer.get_cache_size(), + self.config.d_model, + dtype=param.dtype, + device=param.device, + ) + hid.append(h) + return hid + + def get_aux_loss(self): + return self.model.get_aux_loss() + + def get_current_max_span(self): + return self.model.get_current_max_span() + + def get_current_avg_span(self): + return self.model.get_current_avg_span() + + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], + new_order: torch.Tensor, + ): + """Reorder incremental state. + + This will be called when the order of the input has changed from the + previous time step. A typical use case is beam search, where the input + order changes between time steps based on the selection of beams. + """ + raise NotImplementedError("This is required for generation/beam search") + # mems = self.get_incremental_state(incremental_state, "mems") + # if mems is not None: + # new_mems = [mems_i.index_select(1, new_order) for mems_i in mems] + # self.set_incremental_state(incremental_state, "mems", new_mems) diff --git a/fairseq/examples/adaptive_span/truncated_bptt_lm_task.py b/fairseq/examples/adaptive_span/truncated_bptt_lm_task.py new file mode 100644 index 0000000000000000000000000000000000000000..9978481b6d95134ab609499586c609913aee35df --- /dev/null +++ b/fairseq/examples/adaptive_span/truncated_bptt_lm_task.py @@ -0,0 +1,285 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +import torch +from fairseq import utils +from fairseq.data import ( + Dictionary, + TokenBlockDataset, + data_utils, + iterators, +) +from fairseq.dataclass import FairseqDataclass +from fairseq.distributed import utils as dist_utils +from fairseq.tasks import FairseqTask, register_task +from omegaconf import II + + +logger = logging.getLogger(__name__) + + +@dataclass +class TruncatedBPTTLMConfig(FairseqDataclass): + data: str = field(default="???", metadata={"help": "path to data directory"}) + tokens_per_sample: int = field( + default=1024, metadata={"help": "max number of tokens per sequence"}, + ) + batch_size: int = II("dataset.batch_size") + # Some models use *max_target_positions* to know how many positional + # embeddings to learn. We use II(...) to make it default to + # *tokens_per_sample*, but in principle there could be more positional + # embeddings than tokens in a single batch. This may also be irrelevant for + # custom model implementations. + max_target_positions: int = II("task.tokens_per_sample") + # these will be populated automatically if not provided + data_parallel_rank: Optional[int] = None + data_parallel_size: Optional[int] = None + + +@register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig) +class TruncatedBPTTLMTask(FairseqTask): + def __init__(self, cfg: TruncatedBPTTLMConfig): + super().__init__(cfg) + + if cfg.data_parallel_rank is None or cfg.data_parallel_size is None: + if torch.distributed.is_initialized(): + cfg.data_parallel_rank = dist_utils.get_data_parallel_rank() + cfg.data_parallel_size = dist_utils.get_data_parallel_world_size() + else: + cfg.data_parallel_rank = 0 + cfg.data_parallel_size = 1 + + # load the dictionary + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(self.dictionary))) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split (e.g., train, valid, test)""" + + # support sharded datasets + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + # each element of *data* will be a tensorized line from the original + # text dataset, similar to ``open(split_path).readlines()`` + data = data_utils.load_indexed_dataset( + split_path, self.dictionary, combine=combine + ) + if data is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + # this is similar to ``data.view(-1).split(tokens_per_sample)`` + data = TokenBlockDataset( + data, + data.sizes, + block_size=self.cfg.tokens_per_sample, + pad=None, # unused + eos=None, # unused + break_mode="none", + ) + + self.datasets[split] = TruncatedBPTTDataset( + data=data, + bsz_per_shard=self.cfg.batch_size, + shard_id=self.cfg.data_parallel_rank, + num_shards=self.cfg.data_parallel_size, + ) + + def dataset(self, split): + return self.datasets[split] + + def get_batch_iterator( + self, + dataset, + num_workers=0, + epoch=1, + data_buffer_size=0, + skip_remainder_batch=False, + **kwargs + ): + return iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=self._collate_fn, + num_workers=num_workers, + epoch=epoch, + buffer_size=data_buffer_size, + # we don't use the batching functionality from EpochBatchIterator; + # instead every item in *dataset* is a whole batch + batch_sampler=[[i] for i in range(len(dataset))], + disable_shuffling=True, + skip_remainder_batch=skip_remainder_batch, + ) + + def _collate_fn(self, items: List[List[torch.Tensor]]): + # we don't use fairseq's batching functionality, so we expect a single + # Tensor of type List[torch.Tensor] + assert len(items) == 1 + + # item will have shape B x T (the last batch may have length < T) + id, item = items[0] + item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad()) + B, T = item.size() + + # shift item one position over and append a padding token for the target + target = torch.nn.functional.pad( + item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad() + ) + + # fairseq expects batches to have the following structure + return { + "id": torch.tensor([id] * item.size(0)), + "net_input": {"src_tokens": item,}, + "target": target, + "nsentences": item.size(0), + "ntokens": item.numel(), + } + + def build_dataset_for_inference( + self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs + ) -> torch.utils.data.Dataset: + eos = self.source_dictionary.eos() + dataset = TokenBlockDataset( + src_tokens, + src_lengths, + block_size=None, # ignored for "eos" break mode + pad=self.source_dictionary.pad(), + eos=eos, + break_mode="eos", + ) + + class Dataset(torch.utils.data.Dataset): + def __getitem__(self, i): + item = dataset[i] + if item[-1] == eos: + # remove eos to support generating with a prefix + item = item[:-1] + return (i, [item]) + + def __len__(self): + return len(dataset) + + return Dataset() + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + if constraints is not None: + raise NotImplementedError + + # SequenceGenerator doesn't use *src_tokens* directly, we need to + # pass the *prefix_tokens* argument instead. + if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement(): + prefix_tokens = sample["net_input"]["src_tokens"] + + # begin generation with the end-of-sentence token + bos_token = self.source_dictionary.eos() + + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token + ) + + def eval_lm_dataloader( + self, + dataset, + max_tokens: Optional[int] = 36000, + batch_size: Optional[int] = None, + max_positions: Optional[int] = None, + num_shards: int = 1, + shard_id: int = 0, + num_workers: int = 1, + data_buffer_size: int = 10, + context_window: int = 0, + ): + if context_window > 0: + raise NotImplementedError( + "Transformer-XL doesn't need --context-window, try " + "--model-overrides '{\"mem_len\":42}' instead " + ) + return self.get_batch_iterator( + dataset=dataset, + max_tokens=max_tokens, + max_sentences=batch_size, + max_positions=max_positions, + ignore_invalid_inputs=True, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + data_buffer_size=data_buffer_size, + ).next_epoch_itr(shuffle=False) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + +class TruncatedBPTTDataset(torch.utils.data.Dataset): + def __init__( + self, + data: List[torch.Tensor], # ordered list of items + bsz_per_shard, # number of items processed per GPUs per forward + shard_id, # current GPU ID + num_shards, # number of GPUs + ): + super().__init__() + self.data = data + + def batchify(data, bsz): + # Work out how cleanly we can divide the dataset into bsz parts. + nbatch = data.size(0) // bsz + # Trim off any extra elements that wouldn't cleanly fit (remainders). + data = data.narrow(0, 0, nbatch * bsz) + # Evenly divide the data across the bsz batches. + data = data.view(bsz, -1).contiguous() + return data + + # total number of sequences processed by all GPUs in each forward pass + global_batch_size = bsz_per_shard * num_shards + + """ + With a 16 item dataset, bsz_per_shard=2 and num_shards=3, + *indices* might look like: + + indices = [[0, 1], + [2, 3], + [4, 5], + [6, 7], + [8, 9], + [10, 11]] + + The size of the TruncatedBPTTDataset instance will be 2, + and shard 1 will see items: + + [(0, [data[4], data[6]]), + (1, [data[5], data[7]])] + """ + indices = batchify(torch.arange(len(data)), global_batch_size) + assert indices.size(0) == global_batch_size + + self.my_indices = indices[ + shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard + ] + assert self.my_indices.size(0) == bsz_per_shard + + def __len__(self): + return self.my_indices.size(1) + + def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]: + return (i, [self.data[idx] for idx in self.my_indices[:, i]]) diff --git a/fairseq/examples/attention_head_selection/README.md b/fairseq/examples/attention_head_selection/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2434f1fb212fa075335f97507dc89508f29402ec --- /dev/null +++ b/fairseq/examples/attention_head_selection/README.md @@ -0,0 +1,161 @@ +# Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling (Gong et al., 2021) + +[https://arxiv.org/pdf/2106.10840.pdf](https://arxiv.org/pdf/2106.10840.pdf) + +## Introduction + +We present attention head selection strategies in multilingual and multi-domain sequence modeling including text translation, speech recognition and speech translation tasks. + +Below is an example of training multilingual/multi-domain speech recognition models. + +## Data Preparation +Prepare mTEDx data as in [mTEDx example](https://github.com/fairinternal/fairseq-py/blob/0d9c5851e6fac40f9e366b3633ccd615c2901788/examples/speech_to_text/docs/mtedx_example.md) and CoVoST data as in [CoVoST example](https://github.com/fairinternal/fairseq-py/blob/0d9c5851e6fac40f9e366b3633ccd615c2901788/examples/speech_to_text/docs/covost_example.md). Similarly prepare EuroParl data. + + +## Training a multilingual ASR model with attention head selection + +```bash +data_dir= +train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx" +valid_subset="valid_ar_ar_tedx,valid_de_de_tedx,valid_el_el_tedx,valid_es_es_tedx,valid_fr_fr_tedx,valid_it_it_tedx,valid_pt_pt_tedx,valid_ru_ru_tedx" +strateg= + +fairseq-train ${data_dir} \ + --user-dir examples/attention_head_selection/src \ + --train-subset "${train_subset}" \ + --valid-subset "${valid_subset}" \ + --config-yaml 'config_asr.yaml' \ + --arch 'head_selection_s2t_transformer_s' \ + --task 'speech_to_text_head_selection' \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \ + --lr 5e-4 \ + --clip-norm 10.0 \ + --seed 1 \ + --max-epoch 400 \ + --max-tokens 32000 \ + --ignore-prefix-size 1 \ + --dropout 0.3 \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --skip-invalid-size-inputs-valid-test \ + --encoder-attn-head-select \ + --total-encoder-attention-heads 8 \ + --decoder-self-attn-head-select \ + --total-decoder-attention-heads 8 \ + --attn-head-select-strategy ${strategy} \ + --task-type lang \ +``` + +## Training a multi-domain ASR model with attention head selection + +```bash +data_dir= +train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep" +valid_subset="dev_es_es_tedx,dev_fr_fr_tedx,dev_pt_pt_tedx,dev_it_it_tedx,dev_ru_ru_tedx,dev_el_el_tedx,dev_ar_ar_tedx,dev_de_de_tedx,dev_ar_ar_cv,dev_de_de_cv,dev_es_es_cv,dev_fr_fr_cv,dev_it_it_cv,dev_pt_pt_cv,dev_ru_ru_cv,dev_de_de_ep,dev_es_es_ep,dev_fr_fr_ep,dev_it_it_ep,dev_pt_pt_ep" +strateg= + +fairseq-train ${data_dir} \ + --user-dir examples/attention_head_selection/src \ + --train-subset "${train_subset}" \ + --valid-subset "${valid_subset}" \ + --config-yaml 'config_asr.yaml' \ + --arch head_selection_s2t_transformer_s \ + --task speech_to_text_head_selection \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \ + --lr 5e-4 \ + --clip-norm 10.0 \ + --seed 1 \ + --max-epoch 400 \ + --max-tokens 32000 \ + --ignore-prefix-size 1 \ + --dropout 0.3 \ + --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ + --skip-invalid-size-inputs-valid-test \ + --encoder-attn-head-select \ + --total-encoder-attention-heads 8 \ + --decoder-self-attn-head-select \ + --total-decoder-attention-heads 8 \ + --attn-head-select-strategy ${strategy} \ + --task-type domain +``` + +## Inference in multilingual setting + +```bash +MODEL_DIR= +data_dir= +gen_subset= +train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx" +last_n=10 +CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt" +CHECKPOINT="_avg" +RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}" +if [ ! -d $RESULTS ]; then + mkdir -p $RESULTS +fi; + +python scripts/average_checkpoints.py \ + --inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \ + --output "${MODEL_DIR}/${CHECKPOINT_FILENAME}" + +fairseq-generate ${data_dir} \ + --user-dir examples/attention_head_selection/src \ + --arch 'head_selection_s2t_transformer_s' \ + --task 'speech_to_text_head_selection' \ + --train-subset ${train_subset} \ + --gen-subset ${gen_subset} \ + --path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \ + --config-yaml 'config_asr.yaml' \ + --prefix-size 1 \ + --max-tokens 40000 --beam 5 \ + --skip-invalid-size-inputs-valid-test \ + --results-path ${RESULTS} \ + --scoring wer --wer-tokenizer 13a \ + --wer-lowercase --wer-remove-punct --remove-bpe +``` + +## Inference in multi-domain setting + +```bash +MODEL_DIR= +data_dir= +gen_subset= +train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep" +last_n=10 +CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt" +CHECKPOINT="_avg" +RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}" +if [ ! -d $RESULTS ]; then + mkdir -p $RESULTS +fi; + +python scripts/average_checkpoints.py \ + --inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \ + --output "${MODEL_DIR}/${CHECKPOINT_FILENAME}" + +fairseq-generate ${data_dir} \ + --user-dir examples/attention_head_selection/src \ + --arch 'head_selection_s2t_transformer_s' \ + --task 'speech_to_text_head_selection' \ + --train-subset ${train_subset} \ + --gen-subset ${gen_subset} \ + --path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \ + --config-yaml 'config_asr.yaml' \ + --prefix-size 1 \ + --max-tokens 40000 --beam 5 \ + --skip-invalid-size-inputs-valid-test \ + --results-path ${RESULTS} \ + --scoring wer --wer-tokenizer 13a \ + --wer-lowercase --wer-remove-punct --remove-bpe +``` + +## Citation +```bibtex +@article{gong2021pay, + title={Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling}, + author={Gong, Hongyu and Tang, Yun and Pino, Juan and Li, Xian}, + journal={arXiv preprint arXiv:2106.10840}, + year={2021} +} +''' diff --git a/fairseq/examples/attention_head_selection/src/__init__.py b/fairseq/examples/attention_head_selection/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/examples/attention_head_selection/src/data/__init__.py b/fairseq/examples/attention_head_selection/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py b/fairseq/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1823a7ac7c92c7a0e93fade59162b247f5db36 --- /dev/null +++ b/fairseq/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py @@ -0,0 +1,242 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from pathlib import Path +from typing import Dict, List, Optional +from dataclasses import dataclass + +import torch +from fairseq.data import ( + ConcatDataset, + Dictionary, + FairseqDataset, + ResamplingDataset +) +from fairseq.data.audio.data_cfg import S2TDataConfig +from fairseq.data.audio.speech_to_text_dataset import ( + SpeechToTextDatasetItem, + SpeechToTextDataset, + SpeechToTextDatasetCreator +) + +logger = logging.getLogger(__name__) + + +@dataclass +class SpeechToTextDatasetItemWithDomain(SpeechToTextDatasetItem): + src_lang_id: Optional[torch.Tensor] = None + tgt_lang_id: Optional[torch.Tensor] = None + domain_id: Optional[torch.Tensor] = None + + +class SpeechToTextDatasetWithDomain(SpeechToTextDataset): + + def __init__( + self, + split: str, + is_train_split: bool, + cfg: S2TDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, + n_frames_per_step=1, + speaker_to_id=None, + src_lang_ids: Optional[List[int]] = None, + tgt_lang_ids: Optional[List[int]] = None, + domain_ids: Optional[List[int]] = None + ): + super().__init__( + split, is_train_split, cfg, audio_paths, n_frames, + src_texts, tgt_texts, speakers, src_langs, tgt_langs, + ids, tgt_dict, pre_tokenizer, bpe_tokenizer, + n_frames_per_step, speaker_to_id + ) + assert src_lang_ids is None or len(src_lang_ids) == self.n_samples + assert tgt_lang_ids is None or len(tgt_lang_ids) == self.n_samples + assert domain_ids is None or len(domain_ids) == self.n_samples + + self.src_lang_ids = src_lang_ids + self.tgt_lang_ids = tgt_lang_ids + self.domain_ids = domain_ids + + def __getitem__(self, index: int) -> SpeechToTextDatasetItemWithDomain: + item = super().__getitem__(index) + src_lang_id = self.src_lang_ids[index] + tgt_lang_id = self.tgt_lang_ids[index] + domain_id = self.domain_ids[index] + return SpeechToTextDatasetItemWithDomain( + index=item.index, source=item.source, + target=item.target, speaker_id=item.speaker_id, + src_lang_id=src_lang_id, + tgt_lang_id=tgt_lang_id, + domain_id=domain_id + ) + + def collater( + self, samples: List[SpeechToTextDatasetItem], return_order: bool = False + ) -> Dict: + if len(samples) == 0: + return {} + out = super().collater(samples, return_order=True) + order = out["order"] + src_lang_ids = torch.tensor([x.src_lang_id for x in samples], dtype=torch.long).index_select(0, order) + tgt_lang_ids = torch.tensor([x.tgt_lang_id for x in samples], dtype=torch.long).index_select(0, order) + domain_ids = torch.tensor([x.domain_id for x in samples], dtype=torch.long).index_select(0, order) + + out["src_lang_ids"] = src_lang_ids + out["tgt_lang_ids"] = tgt_lang_ids + out["domain_ids"] = domain_ids + if not return_order: + del out["order"] + return out + + +class SpeechToTextDatasetCreatorWithDomain(SpeechToTextDatasetCreator): + KEY_SRC_LANG_ID, KEY_TGT_LANG_ID = "src_lang_id", "tgt_lang_id" + KEY_DOMAIN_ID = "domain_id" + # default values + DEFAULT_SRC_LANG_ID, DEFAULT_TGT_LANG_ID, DEFAULT_DOMAIN_ID = 0, 0, 0 + + @classmethod + def _from_list( + cls, + split_name: str, + is_train_split, + samples: List[Dict], + cfg: S2TDataConfig, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id + ) -> SpeechToTextDatasetWithDomain: + audio_root = Path(cfg.audio_root) + ids = [s[cls.KEY_ID] for s in samples] + audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] + n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] + tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] + src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] + speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] + src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] + tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] + src_lang_ids = [s.get(cls.KEY_SRC_LANG_ID, cls.DEFAULT_SRC_LANG_ID) for s in samples] + tgt_lang_ids = [s.get(cls.KEY_TGT_LANG_ID, cls.DEFAULT_TGT_LANG_ID) for s in samples] + domain_ids = [s.get(cls.KEY_DOMAIN_ID, cls.DEFAULT_DOMAIN_ID) for s in samples] + return SpeechToTextDatasetWithDomain( + split_name, + is_train_split, + cfg, + audio_paths, + n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id, + src_lang_ids=src_lang_ids, + tgt_lang_ids=tgt_lang_ids, + domain_ids=domain_ids + ) + + @classmethod + def _load_samples_from_tsv( + cls, + root: str, + split: str, + src_lang_map, + tgt_lang_map, + domain_map + ): + # metadata from split + _, src_lang, tgt_lang, domain = split.split("_") + src_lang_id = src_lang_map[src_lang] + tgt_lang_id = tgt_lang_map[tgt_lang] + domain_id = domain_map[domain] + + samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split) + for s in samples: + s.update({ + cls.KEY_SRC_LANG_ID: src_lang_id, + cls.KEY_TGT_LANG_ID: tgt_lang_id, + cls.KEY_DOMAIN_ID: domain_id + }) + return samples + + @classmethod + def _from_tsv( + cls, + root: str, + cfg: S2TDataConfig, + split: str, + tgt_dict, + is_train_split: bool, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, + src_lang_map: Dict[str, int], + tgt_lang_map: Dict[str, int], + domain_map: Dict[str, int] + ) -> SpeechToTextDatasetItemWithDomain: + samples = cls._load_samples_from_tsv( + root, split, src_lang_map, + tgt_lang_map, domain_map + ) + return cls._from_list( + split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer, + bpe_tokenizer, n_frames_per_step, speaker_to_id + ) + + @classmethod + def from_tsv( + cls, + root: str, + cfg: S2TDataConfig, + splits: str, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split: bool, + epoch: int, + seed: int, + src_lang_map: Dict[str, int], + tgt_lang_map: Dict[str, int], + domain_map: Dict[str, int], + n_frames_per_step: int = 1, + speaker_to_id=None + ) -> SpeechToTextDatasetWithDomain: + datasets = [ + cls._from_tsv( + root, cfg, split, tgt_dict, is_train_split, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, src_lang_map, tgt_lang_map, domain_map + ) + for split in splits.split(",") + ] + + if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: + # temperature-based sampling + size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) + datasets = [ + ResamplingDataset( + d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) + ) + for r, d in zip(size_ratios, datasets) + ] + + return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] diff --git a/fairseq/examples/attention_head_selection/src/loss/__init__.py b/fairseq/examples/attention_head_selection/src/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/examples/attention_head_selection/src/loss/attention_head_selection.py b/fairseq/examples/attention_head_selection/src/loss/attention_head_selection.py new file mode 100644 index 0000000000000000000000000000000000000000..4ba33954d0171572111eca94ef39e0e9a683e0ed --- /dev/null +++ b/fairseq/examples/attention_head_selection/src/loss/attention_head_selection.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +from torch.nn.modules.loss import _Loss + + +class HeadSelectionLoss(_Loss): + + def __init__(self, args): + super().__init__() + self.args = args + self.kl_weight = getattr(args, "kl_weight", 0.0) + + def forward(self, head_samples, sample_sizes, prior=0.5, eps=1e-7): + """ + head_scores: (num_tasks, num_layers, num_heads) + sample_sizes: (num_tasks, ) + """ + kl_loss = (head_samples * (torch.log(head_samples + eps) - math.log(prior))).sum(-1).sum(-1) + kl_loss /= (torch.numel(head_samples) / head_samples.size(0)) + kl_loss = self.kl_weight * torch.matmul(kl_loss, sample_sizes) + return kl_loss diff --git a/fairseq/examples/attention_head_selection/src/models/__init__.py b/fairseq/examples/attention_head_selection/src/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py b/fairseq/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7ed89e89d2a902bd8419a735676941ece125f1 --- /dev/null +++ b/fairseq/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py @@ -0,0 +1,170 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, List, Optional +from pathlib import Path +import torch.nn as nn +from torch import Tensor +from fairseq import checkpoint_utils + +from fairseq.models import register_model, register_model_architecture +from fairseq.utils import safe_hasattr +from fairseq.models.speech_to_text.s2t_transformer import ( + S2TTransformerModel, + S2TTransformerEncoder, + TransformerDecoderScriptable +) +from fairseq.models.speech_to_text.s2t_transformer import base_architecture as s2t_base_architecture + +from ..modules.attn_head_selector import AttnHeadSelector +from ..modules.head_selection_transformer_layer import HeadSelectionTransformerEncoderLayer +from .head_selection_transformer import HeadSelectionTransformerDecoder + + +logger = logging.getLogger(__name__) + + +@register_model("head_selection_s2t_transformer") +class HeadSelectionS2TTransformerModel(S2TTransformerModel): + """ + Head selection implemented in S2TTransformer + """ + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + S2TTransformerModel.add_args(parser) + # encoder head selection + parser.add_argument( + "--encoder-attn-head-select", + action="store_true", + default=False, + help="encoder head selection" + ) + parser.add_argument( + "--total-encoder-attention-heads", + type=int, + help="total number of encoder attention heads" + ) + # decoder self attention selection + parser.add_argument( + "--decoder-self-attn-head-select", + action="store_true", + default=False, + help="decoder self-attention head selection" + ) + # decoder-encoder attention selection + parser.add_argument( + "--dec-enc-attn-head-select", + action="store_true", + default=False, + help="decoder-encoder attention head selection" + ) + parser.add_argument( + "--total-decoder-attention-heads", + type=int, + help="total number of decoder attention heads" + ) + # selection strategy + parser.add_argument( + "--attn-head-select-strategy", + type=str, + help="attention head selection strategy, subset or group" + ) + + @classmethod + def build_encoder(cls, args): + if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select: + encoder = HeadSelectionS2TTransformerEncoder(args) + else: + encoder = S2TTransformerEncoder(args) + pretraining_path = getattr(args, "load_pretrained_encoder_from", None) + if pretraining_path is not None: + if not Path(pretraining_path).exists(): + logger.warning( + f"skipped pretraining because {pretraining_path} does not exist" + ) + else: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=pretraining_path + ) + logger.info(f"loaded pretrained encoder from: {pretraining_path}") + return encoder + + @classmethod + def build_decoder(cls, args, task, embed_tokens): + if (safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select) or (safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select): + return HeadSelectionTransformerDecoderScriptable(args, task.target_dictionary, embed_tokens) + else: + return TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens) + + +class HeadSelectionS2TTransformerEncoder(S2TTransformerEncoder): + + def __init__(self, args): + super().__init__(args) + self.attn_head_selector = AttnHeadSelector( + args.encoder_tasks, + args.encoder_layers, + args.total_encoder_attention_heads, + args.encoder_attention_heads, + args.attn_head_select_strategy, + ) + self.task_ids = None + self.transformer_layers = nn.ModuleList([ + HeadSelectionTransformerEncoderLayer(args, layer_idx, attn_head_selector=self.attn_head_selector) for layer_idx in range(args.encoder_layers) + ]) + + def set_task_ids(self, task_ids): + self.task_ids = task_ids + + def _forward(self, src_tokens, src_lengths, return_all_hiddens=False): + self.attn_head_selector.head_select(self.task_ids) + return super()._forward(src_tokens, src_lengths, return_all_hiddens) + + +class HeadSelectionTransformerDecoderScriptable(HeadSelectionTransformerDecoder): + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + # call scriptable method from parent class + x, _ = self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + return x, None + + +@register_model_architecture(model_name="head_selection_s2t_transformer", arch_name="head_selection_s2t_transformer") +def base_architecture(args): + s2t_base_architecture(args) + args.encoder_attn_head_select = getattr(args, "encoder_attn_head_select", False) + args.decoder_self_attn_head_select = getattr(args, "decoder_self_attn_head_select", False) + args.dec_enc_attn_head_select = getattr(args, "dec_enc_attn_head_select", False) + args.total_encoder_attention_heads = getattr(args, "total_encoder_attention_heads", 8) + args.total_decoder_attention_heads = getattr(args, "total_decoder_attention_heads", 8) + args.attn_head_select_strategy = getattr(args, "attn_head_select_strategy", "group") + + +@register_model_architecture("head_selection_s2t_transformer", "head_selection_s2t_transformer_s") +def head_selection_s2t_transformer_s(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + base_architecture(args) diff --git a/fairseq/examples/attention_head_selection/src/models/head_selection_transformer.py b/fairseq/examples/attention_head_selection/src/models/head_selection_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d595699db62b247ad3dd82edac366f53e324bb --- /dev/null +++ b/fairseq/examples/attention_head_selection/src/models/head_selection_transformer.py @@ -0,0 +1,215 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Dict, Optional +import torch +import torch.nn as nn +from torch import Tensor + +from fairseq.utils import safe_hasattr +from fairseq.models.transformer import ( + TransformerModel, + TransformerEncoder, + TransformerDecoder +) + +from ..modules.attn_head_selector import AttnHeadSelector +from ..modules.head_selection_transformer_layer import ( + HeadSelectionTransformerEncoderLayer, + HeadSelectionTransformerDecoderLayer +) + + +class HeadSelectionTransformerModel(TransformerModel): + def __init__(self, args, encoder, decoder): + super().__init__(args, encoder, decoder) + + @staticmethod + def add_args(parser): + TransformerModel.add_args(parser) + # encoder head selection + parser.add_argument( + "--encoder-attn-head-select", + action="store_true", + default=False, + help="encoder head selection" + ) + parser.add_argument( + "--total-encoder-attention-heads", + type=int, + help="total number of encoder attention heads" + ) + # decoder self attention + parser.add_argument( + "--decoder-self-attn-head-select", + action="store_true", + default=False, + help="decoder self-attention head selection" + ) + # decoder-encoder attention + parser.add_argument( + "--dec-enc-attn-head-select", + action="store_true", + default=False, + help="decoder-encoder attention head selection" + ) + parser.add_argument( + "--total-decoder-attention-heads", + type=int, + help="total number of decoder attention heads" + ) + # selection strategy + parser.add_argument( + "--attn-head-select-strategy", + type=str, + help="attention head selection strategy, subset or group" + ) + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select: + return HeadSelectionTransformerEncoder( + args, src_dict, embed_tokens + ) + else: + return TransformerEncoder(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + if (safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select) or (safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select): + return HeadSelectionTransformerDecoder( + args, tgt_dict, embed_tokens + ) + else: + return TransformerDecoder(args, tgt_dict, embed_tokens) + + +class HeadSelectionTransformerEncoder(TransformerEncoder): + + def __init__(self, args, dictionary, embed_tokens): + self.num_tasks = args.encoder_tasks + self.num_layers = args.encoder_layers + self.total_num_heads = args.total_encoder_attention_heads + self.num_heads = args.encoder_attention_heads + self.select_strategy = args.attn_head_select_strategy + + super().__init__(args, dictionary, embed_tokens) + self.attn_head_selector = AttnHeadSelector( + self.num_tasks, + self.num_layers, + self.total_num_heads, + self.num_heads, + self.select_strategy + ) + self.task_ids = None + self.layers = nn.ModuleList( + [self.build_encoder_layer(args, i) for i in range(args.encoder_layers)] + ) + + def set_task_ids(self, task_ids): + self.task_ids = task_ids + + def build_encoder_layer(self, args, layer_idx=None): + return HeadSelectionTransformerEncoderLayer( + args, + layer_idx, + attn_head_selector=self.attn_head_selector + ) + + def forward( + self, + src_tokens, + src_lengths: Optional[torch.Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + ): + self.attn_head_selector.head_select(self.task_ids) + return super().forward(src_tokens, src_lengths, return_all_hiddens, token_embeddings) + + +class HeadSelectionTransformerDecoder(TransformerDecoder): + + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): + self.num_tasks = args.decoder_tasks + self.num_layers = args.decoder_layers + self.total_num_heads = args.total_decoder_attention_heads + self.num_heads = args.decoder_attention_heads + self.select_strategy = args.attn_head_select_strategy + super().__init__( + args, dictionary, embed_tokens, + no_encoder_attn=no_encoder_attn, + output_projection=output_projection + ) + self.self_attn_head_selector = None + self.enc_attn_head_selector = None + if safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select: + self.self_attn_head_selector = AttnHeadSelector( + self.num_tasks, + self.num_layers, + self.total_num_heads, + self.num_heads, + self.select_strategy + ) + if safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select: + self.enc_attn_head_selector = AttnHeadSelector( + self.num_tasks, + self.num_layers, + self.total_num_heads, + self.num_heads, + self.select_strategy + ) + self.task_ids = None + self.layers = nn.ModuleList( + [ + self.build_head_selection_decoder_layer(args, no_encoder_attn, idx) for idx in range(args.decoder_layers) + ] + ) + + def set_task_ids(self, task_ids): + self.task_ids = task_ids + + def build_head_selection_decoder_layer(self, args, no_encoder_attn=False, layer_idx=None): + return HeadSelectionTransformerDecoderLayer( + args, + layer_idx, + self.self_attn_head_selector, + self.enc_attn_head_selector, + no_encoder_attn=no_encoder_attn + ) + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + if self.self_attn_head_selector is not None: + self.self_attn_head_selector.head_select(self.task_ids) + if self.enc_attn_head_selector is not None: + self.enc_attn_head_selector.head_select(self.task_ids) + return super().forward( + prev_output_tokens=prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + features_only=features_only, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + src_lengths=src_lengths, + return_all_hiddens=return_all_hiddens + ) diff --git a/fairseq/examples/attention_head_selection/src/modules/__init__.py b/fairseq/examples/attention_head_selection/src/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/examples/attention_head_selection/src/modules/attn_head_selector.py b/fairseq/examples/attention_head_selection/src/modules/attn_head_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..346fc623089989e36e707f0ff80d68e4d35c3ed7 --- /dev/null +++ b/fairseq/examples/attention_head_selection/src/modules/attn_head_selector.py @@ -0,0 +1,81 @@ +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import math + + +class AttnHeadSelector(nn.Module): + """ + Latent variable modeling of attention head selection + """ + def __init__( + self, num_tasks, num_layers, + total_num_heads, num_heads, + select_strategy="group", + head_select_temp=5.0 + ): + super(AttnHeadSelector, self).__init__() + self.num_tasks = num_tasks + self.num_layers = num_layers + self.total_num_heads = total_num_heads + self.num_heads = num_heads + self.select_strategy = select_strategy + self.temp = head_select_temp + + self.head_logits = torch.nn.Parameter( + torch.Tensor(self.num_tasks, self.num_layers, total_num_heads), + requires_grad=True + ) + nn.init.uniform_( + self.head_logits, a=math.log(0.01), + b=math.log(1.0) + ) + + def gumbel_sample(self, logits, tau=1.0): + gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + gumbels1 = (logits + gumbels1 - gumbels2) / tau + y_soft = gumbels1.sigmoid() + return y_soft + + def subset_select(self, y_soft, topk, dim=-1): + top_values, top_inds = torch.topk(y_soft, k=topk, dim=dim) + top_ret = 1.0 - top_values.detach() + top_values + return top_inds.detach(), top_ret + + def group_selet(self, y_soft, topk, dim=-1): + # top_values: (num_tasks, num_layers, topk) + top_values, top_inds = torch.max( + y_soft.view(self.num_tasks, self.num_layers, -1, topk), dim=2 + ) + top_inds = top_inds * topk + torch.arange(topk, device=top_inds.device).unsqueeze(0).unsqueeze(1) + top_ret = 1.0 - top_values.detach() + top_values + return top_inds.detach(), top_ret + + def head_select(self, task_ids=None): + # gumbel_sample + self.head_samples = self.gumbel_sample(self.head_logits, tau=self.temp) + # head select + if self.select_strategy == "subset": + self.subset_heads, self.subset_weights = self.subset_select( + self.head_samples, + topk=self.num_heads, + ) + elif self.select_strategy == "group": + self.subset_heads, self.subset_weights = self.group_selet( + self.head_samples, + topk=self.num_heads, + ) + else: + raise ValueError("{} is not supported".format(self.select_strategy)) + + self.batch_subset = self.subset_heads[task_ids, :, :] + self.batch_weights = self.subset_weights[task_ids, :, :] + + def forward(self, layer_idx): + assert layer_idx is not None + batch_subset = self.batch_subset[:, layer_idx, :] + batch_weights = self.batch_weights[:, layer_idx, :] + return batch_subset, batch_weights diff --git a/fairseq/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py b/fairseq/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..c7921435035666a27ba78cb9193cd555cf54a9d8 --- /dev/null +++ b/fairseq/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py @@ -0,0 +1,92 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.utils import safe_getattr +from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer +from ..modules.multihead_attention_selection import MultiheadAttentionSelection + + +class HeadSelectionTransformerEncoderLayer(TransformerEncoderLayer): + + def __init__(self, args, layer_idx, attn_head_selector=None): + super().__init__(args) + self.layer_idx = layer_idx + self.self_attn = self.build_self_attention_selection( + self.embed_dim, args, attn_head_selector + ) + + def build_self_attention_selection(self, embed_dim, args, attn_head_selector=None): + return MultiheadAttentionSelection( + embed_dim, + args.total_encoder_attention_heads, + args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + layer_idx=self.layer_idx, + attn_head_selector=attn_head_selector + ) + + +class HeadSelectionTransformerDecoderLayer(TransformerDecoderLayer): + + def __init__( + self, + args, + layer_idx, + self_attn_head_selector=None, + enc_attn_head_selector=None, + no_encoder_attn=False, + add_bias_kv=False, + add_zero_attn=False, + ): + self.layer_idx = layer_idx + super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn) + if self_attn_head_selector is not None: + self.self_attn = self.build_self_attention_selection( + self.embed_dim, args, + self_attn_head_selector=self_attn_head_selector, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn + ) + if enc_attn_head_selector is not None: + self.encoder_attn = self.build_encoder_attention_selection( + self.embed_dim, args, + enc_attn_head_selector=enc_attn_head_selector + ) + + def build_self_attention_selection( + self, embed_dim, args, self_attn_head_selector=None, + add_bias_kv=False, add_zero_attn=False + ): + return MultiheadAttentionSelection( + embed_dim, + args.total_decoder_attention_heads, + args.decoder_attention_heads, + dropout=args.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=not safe_getattr(args, "cross_self_attention"), + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + layer_idx=self.layer_idx, + attn_head_selector=self_attn_head_selector, + ) + + def build_encoder_attention_selection(self, embed_dim, args, enc_attn_head_selector=None): + return MultiheadAttentionSelection( + embed_dim, + args.total_decoder_attention_heads, + args.decoder_attention_heads, + kdim=args.encoder_embed_dim, + vdim=args.encoder_embed_dim, + dropout=args.attention_dropout, + encoder_decoder_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + layer_idx=self.layer_idx, + attn_head_selector=enc_attn_head_selector, + ) diff --git a/fairseq/examples/attention_head_selection/src/modules/multihead_attention_selection.py b/fairseq/examples/attention_head_selection/src/modules/multihead_attention_selection.py new file mode 100644 index 0000000000000000000000000000000000000000..566ad822ac4b1bdac573f3b419430d33300c076d --- /dev/null +++ b/fairseq/examples/attention_head_selection/src/modules/multihead_attention_selection.py @@ -0,0 +1,355 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Optional, Tuple +import torch +from fairseq import utils +from fairseq.modules.quant_noise import quant_noise +from torch import Tensor, nn +from torch.nn import Parameter + +from fairseq.modules.multihead_attention import MultiheadAttention +from ..modules.multihead_functional import multi_head_attention_forward + + +class MultiheadAttentionSelection(MultiheadAttention): + + def __init__( + self, + embed_dim, + total_num_heads, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + layer_idx=0, + attn_head_selector=None + ): + super().__init__( + embed_dim, + num_heads, + kdim=kdim, + vdim=vdim, + dropout=dropout, + bias=bias, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=self_attention, + encoder_decoder_attention=encoder_decoder_attention, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + self.layer_idx = layer_idx + self.attn_head_selector = attn_head_selector + self.total_num_heads = total_num_heads + self.total_embed_dim = self.head_dim * total_num_heads + self.k_proj = quant_noise( + nn.Linear(self.kdim, self.total_embed_dim, bias=bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, self.total_embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, self.total_embed_dim, bias=bias), q_noise, qn_block_size + ) + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, self.total_embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, self.total_embed_dim)) + else: + self.bias_k = self.bias_v = None + self.reset_parameters() + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + # subset_heads: Optional[Tensor] = None, + # subset_weights: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor]]: + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + subset_heads, subset_weights = self.attn_head_selector(self.layer_idx) + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert list(query.size()) == [tgt_len, bsz, self.embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if ( + not self.onnx_trace + and not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + ): + assert key is not None and value is not None + return multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.total_num_heads, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + subset_heads=subset_heads, + subset_weights=subset_weights + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.total_num_heads, self.head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.total_num_heads, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.total_num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.total_num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.total_num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.total_num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.total_num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.total_num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.total_num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = utils.softmax( + attn_weights, dim=-1, onnx_trace=self.onnx_trace + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + + # evaluation + if subset_heads is not None and subset_heads.numel() == 1: + subset_heads = subset_heads.repeat(bsz) + subset_weights = subset_weights.repeat(bsz) + + if subset_heads is None: + attn = torch.bmm(attn_probs, v) + else: + # training with head selection + mixed_attn = torch.bmm(attn_probs, v).contiguous().view(bsz, self.total_num_heads, tgt_len, self.head_dim) + attn = torch.stack( + [mixed_attn[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1 + ) + attn = attn * subset_weights.unsqueeze(2).unsqueeze(3) + attn = attn.contiguous().view(bsz * self.num_heads, tgt_len, self.head_dim) + + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + if subset_heads is None: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + else: + mixed_attn_weights = attn_weights_float.view( + bsz, self.total_num_heads, tgt_len, src_len + ) + attn_weights = torch.stack( + [mixed_attn_weights[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1 + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights diff --git a/fairseq/examples/attention_head_selection/src/modules/multihead_functional.py b/fairseq/examples/attention_head_selection/src/modules/multihead_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..d5edc777e364f625986f402326aba9c4276bb75a --- /dev/null +++ b/fairseq/examples/attention_head_selection/src/modules/multihead_functional.py @@ -0,0 +1,278 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple +import torch +from torch import Tensor +from torch.nn.functional import ( + linear, softmax, dropout, pad, + has_torch_function, + handle_torch_function, + _in_projection_packed, +) +import math +import warnings + + +def _scaled_dot_product_attention( + q: Tensor, + k: Tensor, + v: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + bsz: int = 1, + subset_heads: Optional[Tensor] = None, + subset_weights: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor]: + B, Nt, E = q.shape + q = q / math.sqrt(E) + # B: bsz * total_num_heads + # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) + attn = torch.bmm(q, k.transpose(-2, -1)) + if attn_mask is not None: + attn += attn_mask + attn = softmax(attn, dim=-1) + if dropout_p > 0.0: + attn = dropout(attn, p=dropout_p) + if subset_heads is None: + # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) + output = torch.bmm(attn, v) + else: + mixed_output = torch.bmm(attn, v).contiguous().view(bsz, -1, Nt, E) + output = torch.stack( + [mixed_output[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], + dim=1 + ) + output = output * subset_weights.unsqueeze(2).unsqueeze(3) + output = output.contiguous().view(-1, Nt, E) + if subset_heads is not None: + _, Nt, Ns = attn.size() + mixed_attn = attn.view(bsz, -1, Nt, Ns) + attn = torch.stack( + [mixed_attn[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1 + ) + return output, attn + + +def _in_projection( + q: Tensor, + k: Tensor, + v: Tensor, + w_q: Tensor, + w_k: Tensor, + w_v: Tensor, + b_q: Optional[Tensor] = None, + b_k: Optional[Tensor] = None, + b_v: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor]: + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + total_num_heads: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Optional[Tensor], + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Optional[Tensor], + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + subset_heads: Optional[Tensor] = None, + subset_weights: Optional[Tensor] = None, +): + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) + if has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + total_num_heads, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + subset_heads=subset_heads, + subset_weights=subset_weights + ) + + # set up shape vars + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + assert embed_dim == embed_dim_to_check, \ + f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + if isinstance(embed_dim, torch.Tensor): + # embed_dim can be a tensor when JIT tracing + head_dim = embed_dim.div(num_heads, rounding_mode='trunc') + else: + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + if use_separate_proj_weight: + # allow MHA to have different embedding dimensions when separate projection weights are used + assert key.shape[:2] == value.shape[:2], \ + f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + else: + assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" + + # + # compute in-projection + # + if not use_separate_proj_weight: + q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) + else: + assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" + assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" + assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.chunk(3) + q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) + + # prep attention mask + if attn_mask is not None: + if attn_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + else: + assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ + f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" + # ensure attn_mask's dim is 3 + if attn_mask.dim() == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") + attn_mask = attn_mask.unsqueeze(0) + elif attn_mask.dim() == 3: + correct_3d_size = (bsz * total_num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") + + # prep key padding mask + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + key_padding_mask = key_padding_mask.to(torch.bool) + + # add bias along batch dimension (currently second) + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None + assert bias_v is None + + # + # reshape q, k, v for multihead attention and make em batch first + # + q = q.contiguous().view(tgt_len, bsz * total_num_heads, head_dim).transpose(0, 1) + if static_k is None: + k = k.contiguous().view(k.shape[0], bsz * total_num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_k.size(0) == bsz * total_num_heads, \ + f"expecting static_k.size(0) of {bsz * total_num_heads}, but got {static_k.size(0)}" + assert static_k.size(2) == head_dim, \ + f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + k = static_k + if static_v is None: + v = v.contiguous().view(v.shape[0], bsz * total_num_heads, head_dim).transpose(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_v.size(0) == bsz * total_num_heads, \ + f"expecting static_v.size(0) of {bsz * total_num_heads}, but got {static_v.size(0)}" + assert static_v.size(2) == head_dim, \ + f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + v = static_v + + # add zero attention along batch dimension (now first) + if add_zero_attn: + zero_attn_shape = (bsz * total_num_heads, 1, head_dim) + k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.size(1) + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == (bsz, src_len), \ + f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ + expand(-1, total_num_heads, -1, -1).reshape(bsz * total_num_heads, 1, src_len) + if attn_mask is None: + attn_mask = key_padding_mask + elif attn_mask.dtype == torch.bool: + attn_mask = attn_mask.logical_or(key_padding_mask) + else: + attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) + + # convert mask to float + if attn_mask is not None and attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + + # adjust dropout probability + if not training: + dropout_p = 0.0 + + # + # (deep breath) calculate attention and out projection + # + attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, bsz, subset_heads, subset_weights) + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None diff --git a/fairseq/examples/attention_head_selection/src/speech_to_text_head_selection.py b/fairseq/examples/attention_head_selection/src/speech_to_text_head_selection.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0ce11d6307493b30da865ba23adf23d0015b7c --- /dev/null +++ b/fairseq/examples/attention_head_selection/src/speech_to_text_head_selection.py @@ -0,0 +1,180 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.optim.amp_optimizer import AMPOptimizer +from fairseq.tasks import register_task +from fairseq.tasks.speech_to_text import SpeechToTextTask + +from .data.speech_to_text_dataset_with_domain import SpeechToTextDatasetCreatorWithDomain +from .loss.attention_head_selection import HeadSelectionLoss + + +@register_task("speech_to_text_head_selection") +class SpeechToTextHeadSelectionTask(SpeechToTextTask): + + @classmethod + def add_args(cls, parser): + SpeechToTextTask.add_args(parser) + parser.add_argument( + "--task-type", + type=str, + default="lang", + help="task type for head selection, lang or domain" + ) + parser.add_argument( + "--kl-weight", + type=float, + default=0.0, + help="the weight of KL loss" + ) + + def __init__(self, args, tgt_dict): + super().__init__(args, tgt_dict) + self.task_type = args.task_type + assert self.task_type in ["lang", "domain"], "invalid task_type: {}, should be either lang or domain".format(self.task_type) + self.map_task_to_id(args.train_subset) + self.encoder_head_prior = float(args.decoder_attention_heads) / args.total_decoder_attention_heads + self.decoder_head_prior = float(args.encoder_attention_heads) / args.total_encoder_attention_heads + self.kl_loss = HeadSelectionLoss(args) + + def map_task_to_id(self, train_subset): + src_lang_set, tgt_lang_set, domain_set = set(), set(), set() + for split in train_subset.split(","): + seq = split.split("_") + assert len(seq) == 4, "subset {} should be in the format of train_src_tgt_domain".format(split) + _, src_lang, tgt_lang, domain = seq + src_lang_set.add(src_lang) + tgt_lang_set.add(tgt_lang) + domain_set.add(domain) + src_langs = sorted(src_lang_set) + tgt_langs = sorted(tgt_lang_set) + domains = sorted(domain_set) + self.src_lang_map = {src_lang: i for (i, src_lang) in enumerate(src_langs)} + self.tgt_lang_map = {tgt_lang: i for (i, tgt_lang) in enumerate(tgt_langs)} + self.domain_map = {domain: i for (i, domain) in enumerate(domains)} + if self.task_type == "lang": + self.encoder_tasks = len(self.src_lang_map) + self.decoder_tasks = len(self.tgt_lang_map) + elif self.task_type == "domain": + self.encoder_tasks = len(self.domain_map) + self.decoder_tasks = len(self.domain_map) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + is_train_split = split.startswith("train") + pre_tokenizer = self.build_tokenizer(self.args) + bpe_tokenizer = self.build_bpe(self.args) + self.datasets[split] = SpeechToTextDatasetCreatorWithDomain.from_tsv( + self.args.data, + self.data_cfg, + split, + self.tgt_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split=is_train_split, + epoch=epoch, + seed=self.args.seed, + src_lang_map=self.src_lang_map, + tgt_lang_map=self.tgt_lang_map, + domain_map=self.domain_map, + speaker_to_id=self.speaker_to_id + ) + + def build_model(self, args): + args.encoder_tasks = self.encoder_tasks + args.decoder_tasks = self.decoder_tasks + return super(SpeechToTextHeadSelectionTask, self).build_model(args) + + def get_sample_sizes(self, sample, task_ids, num_tasks): + """ + task_ids: (bsz,) + get sample sizes for each task + """ + bsz = task_ids.size(0) + mat = torch.zeros((num_tasks, bsz), device=task_ids.device) + mat[task_ids, torch.arange(bsz)] = 1.0 + ntokens = torch.sum(sample['target'] != 1, dim=-1) + sample_sizes = torch.matmul(mat, ntokens.float()) + return sample_sizes + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + model.train() + model.set_num_updates(update_num) + # task ids + if self.task_type == "lang": + encoder_task_ids = sample["src_lang_ids"] + decoder_task_ids = sample["tgt_lang_ids"] + elif self.task_type == "domain": + encoder_task_ids = sample["domain_ids"] + decoder_task_ids = sample["domain_ids"] + model.encoder.set_task_ids(encoder_task_ids) + model.decoder.set_task_ids(decoder_task_ids) + + with torch.autograd.profiler.record_function("forward"): + with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): + loss, sample_size, logging_output = criterion(model, sample) + # KL loss + if self.args.encoder_attn_head_select: + sample_sizes = self.get_sample_sizes(sample, encoder_task_ids, self.encoder_tasks) + loss += self.kl_loss( + model.encoder.attn_head_selector.head_samples, + sample_sizes, + self.encoder_head_prior + ) + if self.args.decoder_self_attn_head_select: + sample_sizes = self.get_sample_sizes(sample, decoder_task_ids, self.decoder_tasks) + loss += self.kl_loss( + model.decoder.self_attn_head_selector.head_samples, + sample_sizes, + self.decoder_head_prior + ) + if self.args.dec_enc_attn_head_select: + sample_sizes = self.get_sample_sizes(sample, decoder_task_ids, self.decoder_tasks) + loss += self.kl_loss( + model.decoder.enc_attn_head_selector.head_sampes, + sample_sizes, + self.decoder_head_prior + ) + + if ignore_grad: + loss *= 0 + with torch.autograd.profiler.record_function("backward"): + optimizer.backward(loss) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + model.eval() + # task ids + if self.task_type == "lang": + encoder_task_ids = sample["src_lang_ids"] + decoder_task_ids = sample["tgt_lang_ids"] + elif self.task_type == "domain": + encoder_task_ids = sample["domain_ids"] + decoder_task_ids = sample["domain_ids"] + model.encoder.set_task_ids(encoder_task_ids) + model.decoder.set_task_ids(decoder_task_ids) + with torch.no_grad(): + loss, sample_size, logging_output = criterion(model, sample) + return loss, sample_size, logging_output + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + # task ids + if self.task_type == "lang": + encoder_task_ids = sample["src_lang_ids"][:1] + decoder_task_ids = sample["tgt_lang_ids"][:1] + elif self.task_type == "domain": + encoder_task_ids = sample["domain_ids"][:1] + decoder_task_ids = sample["domain_ids"][:1] + for model in models: + model.encoder.set_task_ids(encoder_task_ids) + model.decoder.set_task_ids(decoder_task_ids) + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, constraints=constraints + ) diff --git a/fairseq/examples/audio_nlp/nlu/README.md b/fairseq/examples/audio_nlp/nlu/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a11b3f306584c81a74aa7bf768d5fd264f177cd9 --- /dev/null +++ b/fairseq/examples/audio_nlp/nlu/README.md @@ -0,0 +1,53 @@ +# End-to-end NLU + +End-to-end spoken language understanding (SLU) predicts intent directly from audio using a single model. It promises to improve the performance of assistant systems by leveraging acoustic information lost in the intermediate textual representation and preventing cascading errors from Automatic Speech Recognition (ASR). Further, having one unified model has efficiency advantages when deploying assistant systems on-device. + +This page releases the code for reproducing the results in [STOP: A dataset for Spoken Task Oriented Semantic Parsing](https://arxiv.org/abs/2207.10643) + +The dataset can be downloaded here: [download link](https://dl.fbaipublicfiles.com/stop/stop.tar.gz) + +The low-resource splits can be downloaded here: [download link](http://dl.fbaipublicfiles.com/stop/low_resource_splits.tar.gz) + +## Pretrained models end-to-end NLU Models + +| Speech Pretraining | ASR Pretraining | Test EM Accuracy | Tesst EM-Tree Accuracy | Link | +| ----------- | ----------- |----------|----------|----------| +| None | None | 36.54 | 57.01 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-none-none.pt) | +| Wav2Vec | None | 68.05 | 82.53 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-none.pt) | +| HuBERT | None | 68.40 | 82.85 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-none.pt) | +| Wav2Vec | STOP | 68.70 | 82.78 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-stop.pt) | +| HuBERT | STOP | 69.23 | 82.87 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-stop.pt) | +| Wav2Vec | Librispeech | 68.47 | 82.49 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-wav2vec-ls.pt) | +| HuBERT | Librispeech | 68.70 | 82.78 | [link](https://dl.fbaipublicfiles.com/stop/end-to-end-nlu-hubert-ls.pt) | + +## Pretrained models ASR Models +| Speech Pre-training | ASR Dataset | STOP Eval WER | STOP Test WER | dev\_other WER | dev\_clean WER | test\_clean WER | test\_other WER | Link | +| ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | ----------- | +| HuBERT | Librispeech | 8.47 | 2.99 | 3.25 | 8.06 | 25.68 | 26.19 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-ls.pt) | +| Wav2Vec | Librispeech | 9.215 | 3.204 | 3.334 | 9.006 | 27.257 | 27.588 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-ls.pt) | +| HuBERT | STOP | 46.31 | 31.30 | 31.52 | 47.16 | 4.29 | 4.26 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-stop.pt) | +| Wav2Vec | STOP | 43.103 | 27.833 | 28.479 | 28.479 | 4.679 | 4.667 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-stop.pt) | +| HuBERT | Librispeech + STOP | 9.015 | 3.211 | 3.372 | 8.635 | 5.133 | 5.056 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-hubert-ls-stop.pt) | +| Wav2Vec | Librispeech + STOP | 9.549 | 3.537 | 3.625 | 9.514 | 5.59 | 5.562 | [link](https://dl.fbaipublicfiles.com/stop/ctc-asr-wav2vec-ls-stop.pt) | + +## Creating the fairseq datasets from STOP + +First, create the audio file manifests and label files: + +``` +python examples/audio_nlp/nlu/generate_manifests.py --stop_root $STOP_DOWNLOAD_DIR/stop --output $FAIRSEQ_DATASET_OUTPUT/ +``` + + +Run `./examples/audio_nlp/nlu/create_dict_stop.sh $FAIRSEQ_DATASET_OUTPUT` to generate the fairseq dictionaries. + + +## Training an End-to-end NLU Model + + +Download a wav2vec or hubert model from [link](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert) or [link](https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec) + + +``` +python fairseq_cli/hydra-train --config-dir examples/audio_nlp/nlu/configs/ --config-name nlu_finetuning task.data=$FAIRSEQ_DATA_OUTPUT model.w2v_path=$PRETRAINED_MODEL_PATH +``` diff --git a/fairseq/examples/audio_nlp/nlu/configs/nlu_finetuning.yaml b/fairseq/examples/audio_nlp/nlu/configs/nlu_finetuning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bb90f45a307bd36040ba579a012bac8db911ec5c --- /dev/null +++ b/fairseq/examples/audio_nlp/nlu/configs/nlu_finetuning.yaml @@ -0,0 +1,59 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 10 + tensorboard_logdir: tb + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: em_error + save_interval: 10 + +task: + _name: nlu_finetuning + data: ??? + labels: parse + eval_wer_parse: true + autoregressive: true + +dataset: + num_workers: 6 + max_tokens: 1600000 + skip_invalid_size_inputs_valid_test: true + valid_subset: eval,test + train_subset: train + validate_interval: 10 + +criterion: + _name: label_smoothed_cross_entropy + +optimization: + max_update: 320000 + lr: [0.0001] + sentence_avg: true + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_seq2seq + w2v_path: ??? + autoregressive: true + apply_mask: true + mask_prob: 0.5 + mask_channel_prob: 0.5 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 0 diff --git a/fairseq/examples/audio_nlp/nlu/create_dict_stop.sh b/fairseq/examples/audio_nlp/nlu/create_dict_stop.sh new file mode 100644 index 0000000000000000000000000000000000000000..753393284de3703247db31ab224ebf11fab0242b --- /dev/null +++ b/fairseq/examples/audio_nlp/nlu/create_dict_stop.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +### Script handling creation of data binaries +### for model training within fairseq + + +fairseq_root="." + +data_root=$1 +train_prefix="${data_root}/train" +valid_prefix="${data_root}/eval" +test_prefix="${data_root}/test" + +dest_dir="$data_root/" + +#echo "src dict: $src_dict" > "$dest_dir/src_dict.txt" +#echo "trg dict: $tgt_dict" > "$dest_dir/tgt_dict.txt" + + #--tgtdict $tgt_dict \ +PYTHONPATH=$fairseq_root \ + python $fairseq_root/fairseq_cli/preprocess.py \ + --source-lang "parse" \ + --trainpref "$train_prefix" \ + --validpref "$valid_prefix" \ + --destdir "$dest_dir" \ + --only-source \ + --dict-only \ + --workers 60; + +PYTHONPATH=$fairseq_root \ + python $fairseq_root/fairseq_cli/preprocess.py \ + --source-lang "ltr" \ + --trainpref "$train_prefix" \ + --validpref "$valid_prefix" \ + --destdir "$dest_dir" \ + --only-source \ + --dict-only \ + --workers 60; diff --git a/fairseq/examples/audio_nlp/nlu/generate_manifests.py b/fairseq/examples/audio_nlp/nlu/generate_manifests.py new file mode 100644 index 0000000000000000000000000000000000000000..e2176099cbd3993d3488ccc85b60f5a0da45d4df --- /dev/null +++ b/fairseq/examples/audio_nlp/nlu/generate_manifests.py @@ -0,0 +1,83 @@ +import argparse +from pathlib import Path +import soundfile + +def get_insl_frame(parse): + out = [] + def is_ont_token(tok): + return tok[0] in ["[", "]"]; + + res = [] + x = [] + for tok in parse.split(): + if is_ont_token(tok): + res.extend('_'.join(x)) + x = [] + res.append(tok.upper()) + else: + x.append(tok.upper()) + + return " ".join(res) + ' | ' + +def sequencify_utterance(utterance): + utterance = utterance.upper() + utterance = utterance.replace(' ', '|') + '|' + utterance = list(utterance) + utterance = ' '.join(utterance) + return utterance + + +def generate_fairseq_manifests(manifest, output_path, audio_root=None): + + with open(manifest, 'r') as i: + parses = [] + utterances = [] + filepaths = [] + keys = None + for (idx, line) in enumerate(i): + if idx == 0: keys = line.strip().split('\t') + else: + data = { k: v for (k, v) in zip(keys, line.split('\t'))} + parses.append(get_insl_frame(data['decoupled_normalized_seqlogical'])) + utterances.append(sequencify_utterance(data['normalized_utterance'])) + filepaths.append(data['file_id']) + + parses_fp = output_path.with_suffix('.parse') + with open(str(parses_fp), 'w') as o: + for p in parses: + o.write(p + '\n') + + utterances_fp = output_path.with_suffix('.ltr') + with open(str(utterances_fp), 'w') as o: + for u in utterances: + o.write(u + '\n') + + filepaths_fp = output_path.with_suffix('.tsv') + with open(str(filepaths_fp), 'w') as o: + o.write(str(audio_root) + '\n') + for f in filepaths: + fullpath = audio_root / f + assert fullpath.exists(), f'{fullpath}' + frames = soundfile.info(fullpath).frames + o.write(f'{f}\t{frames}\n') + +def main(args): + + splits = ['train', 'eval', 'test'] + root = Path(args.stop_root) + output_root = Path(args.output) + + for split in splits: + stop_manifest_path = root / 'manifests' / (split + '.tsv') + output_path = output_root / (split) + + generate_fairseq_manifests(stop_manifest_path, output_path, root) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Process some integers.') + parser.add_argument('--stop_root', type=str, + help='path to stop root directory') + parser.add_argument('--output', type=str, + help='output directory') + args = parser.parse_args() + main(args) diff --git a/fairseq/examples/backtranslation/README.md b/fairseq/examples/backtranslation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..73675f1125d80f58aa824db67d8970504d4d6b2a --- /dev/null +++ b/fairseq/examples/backtranslation/README.md @@ -0,0 +1,297 @@ +# Understanding Back-Translation at Scale (Edunov et al., 2018) + +This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381). + +## Pre-trained models + +Model | Description | Dataset | Download +---|---|---|--- +`transformer.wmt18.en-de` | Transformer
([Edunov et al., 2018](https://arxiv.org/abs/1808.09381))
WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz)
See NOTE in the archive + +## Example usage (torch.hub) + +We require a few additional Python dependencies for preprocessing: +```bash +pip install subword_nmt sacremoses +``` + +Then to generate translations from the full model ensemble: +```python +import torch + +# List available models +torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt18.en-de', ... ] + +# Load the WMT'18 En-De ensemble +en2de_ensemble = torch.hub.load( + 'pytorch/fairseq', 'transformer.wmt18.en-de', + checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt', + tokenizer='moses', bpe='subword_nmt') + +# The ensemble contains 5 models +len(en2de_ensemble.models) +# 5 + +# Translate +en2de_ensemble.translate('Hello world!') +# 'Hallo Welt!' +``` + +## Training your own model (WMT'18 English-German) + +The following instructions can be adapted to reproduce the models from the paper. + + +#### Step 1. Prepare parallel data and optionally train a baseline (English-German) model + +First download and preprocess the data: +```bash +# Download and prepare the data +cd examples/backtranslation/ +bash prepare-wmt18en2de.sh +cd ../.. + +# Binarize the data +TEXT=examples/backtranslation/wmt18_en_de +fairseq-preprocess \ + --joined-dictionary \ + --source-lang en --target-lang de \ + --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ + --destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \ + --workers 20 + +# Copy the BPE code into the data-bin directory for future use +cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code +``` + +(Optionally) Train a baseline model (English-German) using just the parallel data: +```bash +CHECKPOINT_DIR=checkpoints_en_de_parallel +fairseq-train --fp16 \ + data-bin/wmt18_en_de \ + --source-lang en --target-lang de \ + --arch transformer_wmt_en_de_big --share-all-embeddings \ + --dropout 0.3 --weight-decay 0.0 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ + --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ + --max-tokens 3584 --update-freq 16 \ + --max-update 30000 \ + --save-dir $CHECKPOINT_DIR +# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a +# different number of GPUs. +``` + +Average the last 10 checkpoints: +```bash +python scripts/average_checkpoints.py \ + --inputs $CHECKPOINT_DIR \ + --num-epoch-checkpoints 10 \ + --output $CHECKPOINT_DIR/checkpoint.avg10.pt +``` + +Evaluate BLEU: +```bash +# tokenized BLEU on newstest2017: +bash examples/backtranslation/tokenized_bleu.sh \ + wmt17 \ + en-de \ + data-bin/wmt18_en_de \ + data-bin/wmt18_en_de/code \ + $CHECKPOINT_DIR/checkpoint.avg10.pt +# BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152) +# compare to 29.46 in Table 1, which is also for tokenized BLEU + +# generally it's better to report (detokenized) sacrebleu though: +bash examples/backtranslation/sacrebleu.sh \ + wmt17 \ + en-de \ + data-bin/wmt18_en_de \ + data-bin/wmt18_en_de/code \ + $CHECKPOINT_DIR/checkpoint.avg10.pt +# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287) +``` + + +#### Step 2. Back-translate monolingual German data + +Train a reverse model (German-English) to do the back-translation: +```bash +CHECKPOINT_DIR=checkpoints_de_en_parallel +fairseq-train --fp16 \ + data-bin/wmt18_en_de \ + --source-lang de --target-lang en \ + --arch transformer_wmt_en_de_big --share-all-embeddings \ + --dropout 0.3 --weight-decay 0.0 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ + --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ + --max-tokens 3584 --update-freq 16 \ + --max-update 30000 \ + --save-dir $CHECKPOINT_DIR +# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a +# different number of GPUs. +``` + +Let's evaluate the back-translation (BT) model to make sure it is well trained: +```bash +bash examples/backtranslation/sacrebleu.sh \ + wmt17 \ + de-en \ + data-bin/wmt18_en_de \ + data-bin/wmt18_en_de/code \ + $CHECKPOINT_DIR/checkpoint_best.py +# BLEU+case.mixed+lang.de-en+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399) +# compare to the best system from WMT'17 which scored 35.1: http://matrix.statmt.org/matrix/systems_list/1868 +``` + +Next prepare the monolingual data: +```bash +# Download and prepare the monolingual data +# By default the script samples 25M monolingual sentences, which after +# deduplication should be just over 24M sentences. These are split into 25 +# shards, each with 1M sentences (except for the last shard). +cd examples/backtranslation/ +bash prepare-de-monolingual.sh +cd ../.. + +# Binarize each shard of the monolingual data +TEXT=examples/backtranslation/wmt18_de_mono +for SHARD in $(seq -f "%02g" 0 24); do \ + fairseq-preprocess \ + --only-source \ + --source-lang de --target-lang en \ + --joined-dictionary \ + --srcdict data-bin/wmt18_en_de/dict.de.txt \ + --testpref $TEXT/bpe.monolingual.dedup.${SHARD} \ + --destdir data-bin/wmt18_de_mono/shard${SHARD} \ + --workers 20; \ + cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \ +done +``` + +Now we're ready to perform back-translation over the monolingual data. The +following command generates via sampling, but it's possible to use greedy +decoding (`--beam 1`), beam search (`--beam 5`), +top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.: +```bash +mkdir backtranslation_output +for SHARD in $(seq -f "%02g" 0 24); do \ + fairseq-generate --fp16 \ + data-bin/wmt18_de_mono/shard${SHARD} \ + --path $CHECKPOINT_DIR/checkpoint_best.pt \ + --skip-invalid-size-inputs-valid-test \ + --max-tokens 4096 \ + --sampling --beam 1 \ + > backtranslation_output/sampling.shard${SHARD}.out; \ +done +``` + +After BT, use the `extract_bt_data.py` script to re-combine the shards, extract +the back-translations and apply length ratio filters: +```bash +python examples/backtranslation/extract_bt_data.py \ + --minlen 1 --maxlen 250 --ratio 1.5 \ + --output backtranslation_output/bt_data --srclang en --tgtlang de \ + backtranslation_output/sampling.shard*.out + +# Ensure lengths are the same: +# wc -l backtranslation_output/bt_data.{en,de} +# 21795614 backtranslation_output/bt_data.en +# 21795614 backtranslation_output/bt_data.de +# 43591228 total +``` + +Binarize the filtered BT data and combine it with the parallel data: +```bash +TEXT=backtranslation_output +fairseq-preprocess \ + --source-lang en --target-lang de \ + --joined-dictionary \ + --srcdict data-bin/wmt18_en_de/dict.en.txt \ + --trainpref $TEXT/bt_data \ + --destdir data-bin/wmt18_en_de_bt \ + --workers 20 + +# We want to train on the combined data, so we'll symlink the parallel + BT data +# in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train" +# and the BT data as "train1", so that fairseq will combine them automatically +# and so that we can use the `--upsample-primary` option to upsample the +# parallel data (if desired). +PARA_DATA=$(readlink -f data-bin/wmt18_en_de) +BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt) +COMB_DATA=data-bin/wmt18_en_de_para_plus_bt +mkdir -p $COMB_DATA +for LANG in en de; do \ + ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \ + for EXT in bin idx; do \ + ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \ + ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \ + ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \ + ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \ + done; \ +done +``` + + +#### 3. Train an English-German model over the combined parallel + BT data + +Finally we can train a model over the parallel + BT data: +```bash +CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt +fairseq-train --fp16 \ + data-bin/wmt18_en_de_para_plus_bt \ + --upsample-primary 16 \ + --source-lang en --target-lang de \ + --arch transformer_wmt_en_de_big --share-all-embeddings \ + --dropout 0.3 --weight-decay 0.0 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ + --lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ + --max-tokens 3584 --update-freq 16 \ + --max-update 100000 \ + --save-dir $CHECKPOINT_DIR +# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a +# different number of GPUs. +``` + +Average the last 10 checkpoints: +```bash +python scripts/average_checkpoints.py \ + --inputs $CHECKPOINT_DIR \ + --num-epoch-checkpoints 10 \ + --output $CHECKPOINT_DIR/checkpoint.avg10.pt +``` + +Evaluate BLEU: +```bash +# tokenized BLEU on newstest2017: +bash examples/backtranslation/tokenized_bleu.sh \ + wmt17 \ + en-de \ + data-bin/wmt18_en_de \ + data-bin/wmt18_en_de/code \ + $CHECKPOINT_DIR/checkpoint.avg10.pt +# BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152) +# compare to 32.35 in Table 1, which is also for tokenized BLEU + +# generally it's better to report (detokenized) sacrebleu: +bash examples/backtranslation/sacrebleu.sh \ + wmt17 \ + en-de \ + data-bin/wmt18_en_de \ + data-bin/wmt18_en_de/code \ + $CHECKPOINT_DIR/checkpoint.avg10.pt +# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287) +``` + + +## Citation +```bibtex +@inproceedings{edunov2018backtranslation, + title = {Understanding Back-Translation at Scale}, + author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David}, + booktitle = {Conference of the Association for Computational Linguistics (ACL)}, + year = 2018, +} +``` diff --git a/fairseq/examples/backtranslation/deduplicate_lines.py b/fairseq/examples/backtranslation/deduplicate_lines.py new file mode 100644 index 0000000000000000000000000000000000000000..50e458328c80b71c42a66d473381ca7e98d294da --- /dev/null +++ b/fairseq/examples/backtranslation/deduplicate_lines.py @@ -0,0 +1,41 @@ +#!/usr/bin/python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import fileinput +import hashlib +import sys +from multiprocessing import Pool + + +def get_hashes_and_lines(raw_line): + hash = hashlib.md5(raw_line).hexdigest() + return hash, raw_line + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--workers", type=int, default=10) + parser.add_argument("files", nargs="*", help="input files") + args = parser.parse_args() + + seen = set() + with fileinput.input(args.files, mode="rb") as h: + pool = Pool(args.workers) + results = pool.imap_unordered(get_hashes_and_lines, h, 1000) + for i, (hash, raw_line) in enumerate(results): + if hash not in seen: + seen.add(hash) + sys.stdout.buffer.write(raw_line) + if i % 1000000 == 0: + print(i, file=sys.stderr, end="", flush=True) + elif i % 100000 == 0: + print(".", file=sys.stderr, end="", flush=True) + print(file=sys.stderr, flush=True) + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/backtranslation/extract_bt_data.py b/fairseq/examples/backtranslation/extract_bt_data.py new file mode 100644 index 0000000000000000000000000000000000000000..e766391e873d0d9a9561d67d5864934b2fad0681 --- /dev/null +++ b/fairseq/examples/backtranslation/extract_bt_data.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import fileinput + +from tqdm import tqdm + + +def main(): + parser = argparse.ArgumentParser( + description=( + "Extract back-translations from the stdout of fairseq-generate. " + "If there are multiply hypotheses for a source, we only keep the first one. " + ) + ) + parser.add_argument("--output", required=True, help="output prefix") + parser.add_argument( + "--srclang", required=True, help="source language (extracted from H-* lines)" + ) + parser.add_argument( + "--tgtlang", required=True, help="target language (extracted from S-* lines)" + ) + parser.add_argument("--minlen", type=int, help="min length filter") + parser.add_argument("--maxlen", type=int, help="max length filter") + parser.add_argument("--ratio", type=float, help="ratio filter") + parser.add_argument("files", nargs="*", help="input files") + args = parser.parse_args() + + def validate(src, tgt): + srclen = len(src.split(" ")) if src != "" else 0 + tgtlen = len(tgt.split(" ")) if tgt != "" else 0 + if ( + (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen)) + or ( + args.maxlen is not None + and (srclen > args.maxlen or tgtlen > args.maxlen) + ) + or ( + args.ratio is not None + and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio) + ) + ): + return False + return True + + def safe_index(toks, index, default): + try: + return toks[index] + except IndexError: + return default + + with open(args.output + "." + args.srclang, "w") as src_h, open( + args.output + "." + args.tgtlang, "w" + ) as tgt_h: + for line in tqdm(fileinput.input(args.files)): + if line.startswith("S-"): + tgt = safe_index(line.rstrip().split("\t"), 1, "") + elif line.startswith("H-"): + if tgt is not None: + src = safe_index(line.rstrip().split("\t"), 2, "") + if validate(src, tgt): + print(src, file=src_h) + print(tgt, file=tgt_h) + tgt = None + + +if __name__ == "__main__": + main() diff --git a/fairseq/examples/backtranslation/prepare-de-monolingual.sh b/fairseq/examples/backtranslation/prepare-de-monolingual.sh new file mode 100644 index 0000000000000000000000000000000000000000..5e67b2b3bcf27d3436031453e796e58a0ae79ec4 --- /dev/null +++ b/fairseq/examples/backtranslation/prepare-de-monolingual.sh @@ -0,0 +1,98 @@ +#!/bin/bash + +SCRIPTS=mosesdecoder/scripts +TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl +NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl +REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl +BPEROOT=subword-nmt/subword_nmt + + +BPE_CODE=wmt18_en_de/code +SUBSAMPLE_SIZE=25000000 +LANG=de + + +OUTDIR=wmt18_${LANG}_mono +orig=orig +tmp=$OUTDIR/tmp +mkdir -p $OUTDIR $tmp + + +URLS=( + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.de.shuffled.gz" + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.de.shuffled.gz" + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.de.shuffled.gz" + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.de.shuffled.gz" + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.de.shuffled.gz" + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.de.shuffled.gz" + "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.de.shuffled.gz" + "http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.de.shuffled.v2.gz" + "http://data.statmt.org/wmt16/translation-task/news.2015.de.shuffled.gz" + "http://data.statmt.org/wmt17/translation-task/news.2016.de.shuffled.gz" + "http://data.statmt.org/wmt18/translation-task/news.2017.de.shuffled.deduped.gz" +) +FILES=( + "news.2007.de.shuffled.gz" + "news.2008.de.shuffled.gz" + "news.2009.de.shuffled.gz" + "news.2010.de.shuffled.gz" + "news.2011.de.shuffled.gz" + "news.2012.de.shuffled.gz" + "news.2013.de.shuffled.gz" + "news.2014.de.shuffled.v2.gz" + "news.2015.de.shuffled.gz" + "news.2016.de.shuffled.gz" + "news.2017.de.shuffled.deduped.gz" +) + + +cd $orig +for ((i=0;i<${#URLS[@]};++i)); do + file=${FILES[i]} + if [ -f $file ]; then + echo "$file already exists, skipping download" + else + url=${URLS[i]} + wget "$url" + fi +done +cd .. + + +if [ -f $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then + echo "found monolingual sample, skipping shuffle/sample/tokenize" +else + gzip -c -d -k $(for FILE in "${FILES[@]}"; do echo $orig/$FILE; done) \ + | shuf -n $SUBSAMPLE_SIZE \ + | perl $NORM_PUNC $LANG \ + | perl $REM_NON_PRINT_CHAR \ + | perl $TOKENIZER -threads 8 -a -l $LANG \ + > $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} +fi + + +if [ -f $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then + echo "found BPE monolingual sample, skipping BPE step" +else + python $BPEROOT/apply_bpe.py -c $BPE_CODE \ + < $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} \ + > $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} +fi + + +if [ -f $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} ]; then + echo "found deduplicated monolingual sample, skipping deduplication step" +else + python deduplicate_lines.py $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} \ + > $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} +fi + + +if [ -f $OUTDIR/bpe.monolingual.dedup.00.de ]; then + echo "found sharded data, skipping sharding step" +else + split --lines 1000000 --numeric-suffixes \ + --additional-suffix .${LANG} \ + $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} \ + $OUTDIR/bpe.monolingual.dedup. +fi diff --git a/fairseq/examples/byte_level_bpe/gru_transformer.py b/fairseq/examples/byte_level_bpe/gru_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d4efa93a4d75da71c78e786d7f62101ef3266af4 --- /dev/null +++ b/fairseq/examples/byte_level_bpe/gru_transformer.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch.nn.functional as F +from fairseq.models import register_model, register_model_architecture +from fairseq.models.transformer import TransformerEncoder, TransformerModel + + +@register_model("gru_transformer") +class GRUTransformerModel(TransformerModel): + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return GRUTransformerEncoder(args, src_dict, embed_tokens) + + +class GRUTransformerEncoder(TransformerEncoder): + def __init__(self, args, dictionary, embed_tokens): + super().__init__(args, dictionary, embed_tokens) + self.emb_ctx = nn.GRU( + input_size=embed_tokens.embedding_dim, + hidden_size=embed_tokens.embedding_dim // 2, + num_layers=1, + bidirectional=True, + ) + + def forward_embedding(self, src_tokens): + # embed tokens and positions + x = embed = self.embed_scale * self.embed_tokens(src_tokens) + if self.embed_positions is not None: + x = embed + self.embed_positions(src_tokens) + + # contextualize embeddings + x = x.transpose(0, 1) + x = self.dropout_module(x) + x, _ = self.emb_ctx.forward(x) + x = x.transpose(0, 1) + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + x = self.dropout_module(x) + return x, embed + + +@register_model_architecture("gru_transformer", "gru_transformer") +def gru_transformer_base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.no_cross_attention = getattr(args, "no_cross_attention", False) + args.cross_self_attention = getattr(args, "cross_self_attention", False) + args.layer_wise_attention = getattr(args, "layer_wise_attention", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + + +@register_model_architecture("gru_transformer", "gru_transformer_big") +def gru_transformer_big(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.3) + gru_transformer_base_architecture(args) diff --git a/fairseq/examples/camembert/README.md b/fairseq/examples/camembert/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5ef4fe3f151bb468712f3be935ea5bb1b1360bf7 --- /dev/null +++ b/fairseq/examples/camembert/README.md @@ -0,0 +1,75 @@ +# CamemBERT: a Tasty French Language Model + +## Introduction + +[CamemBERT](https://arxiv.org/abs/1911.03894) is a pretrained language model trained on 138GB of French text based on RoBERTa. + +Also available in [github.com/huggingface/transformers](https://github.com/huggingface/transformers/). + +## Pre-trained models + +| Model | #params | Download | Arch. | Training data | +|--------------------------------|---------|--------------------------------------------------------------------------------------------------------------------------|-------|-----------------------------------| +| `camembert` / `camembert-base` | 110M | [camembert-base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz) | Base | OSCAR (138 GB of text) | +| `camembert-large` | 335M | [camembert-large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz) | Large | CCNet (135 GB of text) | +| `camembert-base-ccnet` | 110M | [camembert-base-ccnet.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz) | Base | CCNet (135 GB of text) | +| `camembert-base-wikipedia-4gb` | 110M | [camembert-base-wikipedia-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz) | Base | Wikipedia (4 GB of text) | +| `camembert-base-oscar-4gb` | 110M | [camembert-base-oscar-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz) | Base | Subsample of OSCAR (4 GB of text) | +| `camembert-base-ccnet-4gb` | 110M | [camembert-base-ccnet-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz) | Base | Subsample of CCNet (4 GB of text) | + +## Example usage + +### fairseq +##### Load CamemBERT from torch.hub (PyTorch >= 1.1): +```python +import torch +camembert = torch.hub.load('pytorch/fairseq', 'camembert') +camembert.eval() # disable dropout (or leave in train mode to finetune) +``` + +##### Load CamemBERT (for PyTorch 1.0 or custom models): +```python +# Download camembert model +wget https://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz +tar -xzvf camembert.tar.gz + +# Load the model in fairseq +from fairseq.models.roberta import CamembertModel +camembert = CamembertModel.from_pretrained('/path/to/camembert') +camembert.eval() # disable dropout (or leave in train mode to finetune) +``` + +##### Filling masks: +```python +masked_line = 'Le camembert est :)' +camembert.fill_mask(masked_line, topk=3) +# [('Le camembert est délicieux :)', 0.4909118115901947, ' délicieux'), +# ('Le camembert est excellent :)', 0.10556942224502563, ' excellent'), +# ('Le camembert est succulent :)', 0.03453322499990463, ' succulent')] +``` + +##### Extract features from Camembert: +```python +# Extract the last layer's features +line = "J'aime le camembert !" +tokens = camembert.encode(line) +last_layer_features = camembert.extract_features(tokens) +assert last_layer_features.size() == torch.Size([1, 10, 768]) + +# Extract all layer's features (layer 0 is the embedding layer) +all_layers = camembert.extract_features(tokens, return_all_hiddens=True) +assert len(all_layers) == 13 +assert torch.all(all_layers[-1] == last_layer_features) +``` + +## Citation +If you use our work, please cite: + +```bibtex +@inproceedings{martin2020camembert, + title={CamemBERT: a Tasty French Language Model}, + author={Martin, Louis and Muller, Benjamin and Su{\'a}rez, Pedro Javier Ortiz and Dupont, Yoann and Romary, Laurent and de la Clergerie, {\'E}ric Villemonte and Seddah, Djam{\'e} and Sagot, Beno{\^\i}t}, + booktitle={Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics}, + year={2020} +} +``` diff --git a/fairseq/examples/constrained_decoding/README.md b/fairseq/examples/constrained_decoding/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e04b8b6a018214c8233fa87fd91d46a6dd1519d4 --- /dev/null +++ b/fairseq/examples/constrained_decoding/README.md @@ -0,0 +1,123 @@ +# (Vectorized) Lexically constrained decoding with dynamic beam allocation + +This page provides instructions for how to use lexically constrained decoding in Fairseq. +Fairseq implements the code described in the following papers: + +* [Fast Lexically Constrained Decoding With Dynamic Beam Allocation](https://www.aclweb.org/anthology/N18-1119/) (Post & Vilar, 2018) +* [Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting](https://www.aclweb.org/anthology/N19-1090/) (Hu et al., 2019) + +## Quick start + +Constrained search is enabled by adding the command-line argument `--constraints` to `fairseq-interactive`. +Constraints are appended to each line of input, separated by tabs. Each constraint (one or more tokens) +is a separate field. + +The following command, using [Fairseq's WMT19 German--English model](https://github.com/pytorch/fairseq/blob/main/examples/wmt19/README.md), +translates the sentence *Die maschinelle Übersetzung ist schwer zu kontrollieren.* with the constraints +"hard" and "to influence". + + echo -e "Die maschinelle Übersetzung ist schwer zu kontrollieren.\thard\ttoinfluence" \ + | normalize.py | tok.py \ + | fairseq-interactive /path/to/model \ + --path /path/to/model/model1.pt \ + --bpe fastbpe \ + --bpe-codes /path/to/model/bpecodes \ + --constraints \ + -s de -t en \ + --beam 10 + +(tok.py and normalize.py can be found in the same directory as this README; they are just shortcuts around Fairseq's WMT19 preprocessing). +This will generate the following output: + + [snip] + S-0 Die masch@@ in@@ elle Über@@ setzung ist schwer zu kontrollieren . + W-0 1.844 seconds + C-0 hard + C-0 influence + H-0 -1.5333266258239746 Mach@@ ine trans@@ lation is hard to influence . + D-0 -1.5333266258239746 Machine translation is hard to influence . + P-0 -0.5434 -0.1423 -0.1930 -0.1415 -0.2346 -1.8031 -0.1701 -11.7727 -0.1815 -0.1511 + +By default, constraints are generated in the order supplied, with any number (zero or more) of tokens generated +between constraints. If you wish for the decoder to order the constraints, then use `--constraints unordered`. +Note that you may want to use a larger beam. + +## Implementation details + +The heart of the implementation is in `fairseq/search.py`, which adds a `LexicallyConstrainedBeamSearch` instance. +This instance of beam search tracks the progress of each hypothesis in the beam through the set of constraints +provided for each input sentence. It does this using one of two classes, both found in `fairseq/token_generation_contstraints.py`: + +* OrderedConstraintState: assumes the `C` input constraints will be generated in the provided order +* UnorderedConstraintState: tries to apply `C` (phrasal) constraints in all `C!` orders + +## Differences from Sockeye + +There are a number of [differences from Sockeye's implementation](https://awslabs.github.io/sockeye/inference.html#lexical-constraints). + +* Generating constraints in the order supplied (the default option here) is not available in Sockeye. +* Due to an improved beam allocation method, there is no need to prune the beam. +* Again due to better allocation, beam sizes as low as 10 or even 5 are often sufficient. +* [The vector extensions described in Hu et al.](https://github.com/edwardjhu/sockeye/tree/trie_constraints) (NAACL 2019) were never merged + into the main Sockeye branch. + +## Citation + +The paper first describing lexical constraints for seq2seq decoding is: + +```bibtex +@inproceedings{hokamp-liu-2017-lexically, + title = "Lexically Constrained Decoding for Sequence Generation Using Grid Beam Search", + author = "Hokamp, Chris and + Liu, Qun", + booktitle = "Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", + month = jul, + year = "2017", + address = "Vancouver, Canada", + publisher = "Association for Computational Linguistics", + url = "https://www.aclweb.org/anthology/P17-1141", + doi = "10.18653/v1/P17-1141", + pages = "1535--1546", +} +``` + +The fairseq implementation uses the extensions described in + +```bibtex +@inproceedings{post-vilar-2018-fast, + title = "Fast Lexically Constrained Decoding with Dynamic Beam Allocation for Neural Machine Translation", + author = "Post, Matt and + Vilar, David", + booktitle = "Proceedings of the 2018 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers)", + month = jun, + year = "2018", + address = "New Orleans, Louisiana", + publisher = "Association for Computational Linguistics", + url = "https://www.aclweb.org/anthology/N18-1119", + doi = "10.18653/v1/N18-1119", + pages = "1314--1324", +} +``` + +and + +```bibtex +@inproceedings{hu-etal-2019-improved, + title = "Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting", + author = "Hu, J. Edward and + Khayrallah, Huda and + Culkin, Ryan and + Xia, Patrick and + Chen, Tongfei and + Post, Matt and + Van Durme, Benjamin", + booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)", + month = jun, + year = "2019", + address = "Minneapolis, Minnesota", + publisher = "Association for Computational Linguistics", + url = "https://www.aclweb.org/anthology/N19-1090", + doi = "10.18653/v1/N19-1090", + pages = "839--850", +} +``` diff --git a/fairseq/examples/constrained_decoding/normalize.py b/fairseq/examples/constrained_decoding/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..4ae2b5111ba025acb9e1613865c92fdc339a58d5 --- /dev/null +++ b/fairseq/examples/constrained_decoding/normalize.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +from sacremoses.normalize import MosesPunctNormalizer + + +def main(args): + normalizer = MosesPunctNormalizer(lang=args.lang, penn=args.penn) + for line in sys.stdin: + print(normalizer.normalize(line.rstrip()), flush=True) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--lang", "-l", default="en") + parser.add_argument("--penn", "-p", action="store_true") + args = parser.parse_args() + + main(args) diff --git a/fairseq/examples/constrained_decoding/tok.py b/fairseq/examples/constrained_decoding/tok.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f888a8c0d1b8ec7174859476cc3222456e0d2c --- /dev/null +++ b/fairseq/examples/constrained_decoding/tok.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +import sacremoses + + +def main(args): + """Tokenizes, preserving tabs""" + mt = sacremoses.MosesTokenizer(lang=args.lang) + + def tok(s): + return mt.tokenize(s, return_str=True) + + for line in sys.stdin: + parts = list(map(tok, line.split("\t"))) + print(*parts, sep="\t", flush=True) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--lang", "-l", default="en") + parser.add_argument("--penn", "-p", action="store_true") + parser.add_argument("--fields", "-f", help="fields to tokenize") + args = parser.parse_args() + + main(args) diff --git a/fairseq/examples/conv_seq2seq/README.md b/fairseq/examples/conv_seq2seq/README.md new file mode 100644 index 0000000000000000000000000000000000000000..95fe7e7909a77ee0e50fe31d4b8be38daa8f3be7 --- /dev/null +++ b/fairseq/examples/conv_seq2seq/README.md @@ -0,0 +1,25 @@ +# Convolutional Sequence to Sequence Learning (Gehring et al., 2017) + +## Pre-trained models + +Description | Dataset | Model | Test set(s) +---|---|---|--- +Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2)
newstest2012/2013:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2) +Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2) +Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2) | newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2) + +## Example usage + +See the [translation README](../translation/README.md) for instructions on reproducing results for WMT'14 En-De and +WMT'14 En-Fr using the `fconv_wmt_en_de` and `fconv_wmt_en_fr` model architectures. + +## Citation + +```bibtex +@inproceedings{gehring2017convs2s, + title = {Convolutional Sequence to Sequence Learning}, + author = {Gehring, Jonas, and Auli, Michael and Grangier, David and Yarats, Denis and Dauphin, Yann N}, + booktitle = {Proc. of ICML}, + year = 2017, +} +``` diff --git a/fairseq/examples/criss/README.md b/fairseq/examples/criss/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4689ed7c10497a5100b28fe6d6801a7c089da569 --- /dev/null +++ b/fairseq/examples/criss/README.md @@ -0,0 +1,61 @@ +# Cross-lingual Retrieval for Iterative Self-Supervised Training + +https://arxiv.org/pdf/2006.09526.pdf + +## Introduction + +CRISS is a multilingual sequence-to-sequnce pretraining method where mining and training processes are applied iteratively, improving cross-lingual alignment and translation ability at the same time. + +## Requirements: + +* faiss: https://github.com/facebookresearch/faiss +* mosesdecoder: https://github.com/moses-smt/mosesdecoder +* flores: https://github.com/facebookresearch/flores +* LASER: https://github.com/facebookresearch/LASER + +## Unsupervised Machine Translation +##### 1. Download and decompress CRISS checkpoints +``` +cd examples/criss +wget https://dl.fbaipublicfiles.com/criss/criss_3rd_checkpoints.tar.gz +tar -xf criss_checkpoints.tar.gz +``` +##### 2. Download and preprocess Flores test dataset +Make sure to run all scripts from examples/criss directory +``` +bash download_and_preprocess_flores_test.sh +``` + +##### 3. Run Evaluation on Sinhala-English +``` +bash unsupervised_mt/eval.sh +``` + +## Sentence Retrieval +##### 1. Download and preprocess Tatoeba dataset +``` +bash download_and_preprocess_tatoeba.sh +``` + +##### 2. Run Sentence Retrieval on Tatoeba Kazakh-English +``` +bash sentence_retrieval/sentence_retrieval_tatoeba.sh +``` + +## Mining +##### 1. Install faiss +Follow instructions on https://github.com/facebookresearch/faiss/blob/master/INSTALL.md +##### 2. Mine pseudo-parallel data between Kazakh and English +``` +bash mining/mine_example.sh +``` + +## Citation +```bibtex +@article{tran2020cross, + title={Cross-lingual retrieval for iterative self-supervised training}, + author={Tran, Chau and Tang, Yuqing and Li, Xian and Gu, Jiatao}, + journal={arXiv preprint arXiv:2006.09526}, + year={2020} +} +``` diff --git a/fairseq/examples/criss/download_and_preprocess_flores_test.sh b/fairseq/examples/criss/download_and_preprocess_flores_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..ed4b390fbdee3991efeb298050e12065d7fe605b --- /dev/null +++ b/fairseq/examples/criss/download_and_preprocess_flores_test.sh @@ -0,0 +1,64 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +SPM_ENCODE=flores/scripts/spm_encode.py +DATA=data_tmp +SPM_MODEL=criss_checkpoints/sentence.bpe.model +DICT=criss_checkpoints/dict.txt + +download_data() { + CORPORA=$1 + URL=$2 + + if [ -f $CORPORA ]; then + echo "$CORPORA already exists, skipping download" + else + echo "Downloading $URL" + wget $URL -O $CORPORA --no-check-certificate || rm -f $CORPORA + if [ -f $CORPORA ]; then + echo "$URL successfully downloaded." + else + echo "$URL not successfully downloaded." + rm -f $CORPORA + fi + fi +} + +if [[ -f flores ]]; then + echo "flores already cloned" +else + git clone https://github.com/facebookresearch/flores +fi + +mkdir -p $DATA +download_data $DATA/wikipedia_en_ne_si_test_sets.tgz "https://github.com/facebookresearch/flores/raw/master/data/wikipedia_en_ne_si_test_sets.tgz" +pushd $DATA +pwd +tar -vxf wikipedia_en_ne_si_test_sets.tgz +popd + + +for lang in ne_NP si_LK; do + datadir=$DATA/${lang}-en_XX-flores + rm -rf $datadir + mkdir -p $datadir + TEST_PREFIX=$DATA/wikipedia_en_ne_si_test_sets/wikipedia.test + python $SPM_ENCODE \ + --model ${SPM_MODEL} \ + --output_format=piece \ + --inputs ${TEST_PREFIX}.${lang:0:2}-en.${lang:0:2} ${TEST_PREFIX}.${lang:0:2}-en.en \ + --outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX + + # binarize data + fairseq-preprocess \ + --source-lang ${lang} --target-lang en_XX \ + --testpref $datadir/test.bpe.${lang}-en_XX \ + --destdir $datadir \ + --srcdict ${DICT} \ + --joined-dictionary \ + --workers 4 +done diff --git a/fairseq/examples/criss/download_and_preprocess_tatoeba.sh b/fairseq/examples/criss/download_and_preprocess_tatoeba.sh new file mode 100644 index 0000000000000000000000000000000000000000..7ed64f017d5e62695ba73745c840507b994abc0f --- /dev/null +++ b/fairseq/examples/criss/download_and_preprocess_tatoeba.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +SPM_ENCODE=flores/scripts/spm_encode.py +DATA=data_tmp +SPM_MODEL=criss_checkpoints/sentence.bpe.model +DICT=criss_checkpoints/dict.txt + +if [[ -f flores ]]; then + echo "flores already cloned" +else + git clone https://github.com/facebookresearch/flores +fi +if [[ -f LASER ]]; then + echo "LASER already cloned" +else + git clone https://github.com/facebookresearch/LASER +fi +mkdir -p data_tmp +declare -A lang_tatoeba_map=( ["ar_AR"]="ara" ["de_DE"]="deu" ["es_XX"]="spa" ["et_EE"]="est" ["fi_FI"]="fin" ["fr_XX"]="fra" ["hi_IN"]="hin" ["it_IT"]="ita" ["ja_XX"]="jpn" ["ko_KR"]="kor" ["kk_KZ"]="kaz" ["nl_XX"]="nld" ["ru_RU"]="rus" ["tr_TR"]="tur" ["vi_VN"]="vie" ["zh_CN"]="cmn") +for lang in ar_AR de_DE es_XX et_EE fi_FI fr_XX hi_IN it_IT ja_XX kk_KZ ko_KR nl_XX ru_RU tr_TR vi_VN zh_CN; do + lang_tatoeba=${lang_tatoeba_map[$lang]} + echo $lang_tatoeba + datadir=$DATA/${lang}-en_XX-tatoeba + rm -rf $datadir + mkdir -p $datadir + TEST_PREFIX=LASER/data/tatoeba/v1/tatoeba + python $SPM_ENCODE \ + --model ${SPM_MODEL} \ + --output_format=piece \ + --inputs ${TEST_PREFIX}.${lang_tatoeba}-eng.${lang_tatoeba} ${TEST_PREFIX}.${lang_tatoeba}-eng.eng \ + --outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX + + # binarize data + fairseq-preprocess \ + --source-lang ${lang} --target-lang en_XX \ + --testpref $datadir/test.bpe.${lang}-en_XX \ + --destdir $datadir \ + --srcdict ${DICT} \ + --joined-dictionary \ + --workers 4 +done diff --git a/fairseq/examples/criss/mining/mine.py b/fairseq/examples/criss/mining/mine.py new file mode 100644 index 0000000000000000000000000000000000000000..c872da196fe0df776622365748ad7963fee1f0a0 --- /dev/null +++ b/fairseq/examples/criss/mining/mine.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import glob +from subprocess import check_call + +try: + import faiss + + has_faiss = True +except ImportError: + has_faiss = False +import numpy as np + + +GB = 1024 * 1024 * 1024 + + +def call(cmd): + print(cmd) + check_call(cmd, shell=True) + + +def get_batches(directory, lang, prefix="all_avg_pool"): + print(f"Finding in {directory}/{prefix}.{lang}*") + files = glob.glob(f"{directory}/{prefix}.{lang}*") + emb_files = [] + txt_files = [] + for emb_fi in files: + emb_files.append(emb_fi) + txt_fi = emb_fi.replace(prefix, "sentences") + txt_files.append(txt_fi) + return emb_files, txt_files + + +def load_batch(emb_file, dim): + embeddings = np.fromfile(emb_file, dtype=np.float32) + num_rows = int(embeddings.shape[0] / dim) + embeddings = embeddings.reshape((num_rows, dim)) + faiss.normalize_L2(embeddings) + return embeddings + + +def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"): + if not has_faiss: + raise ImportError("Please install Faiss") + sims = [] + inds = [] + xfrom = 0 + xto = 0 + for x_batch_f in x_batches_f: + yfrom = 0 + yto = 0 + x_batch = load_batch(x_batch_f, dim) + xto = xfrom + x_batch.shape[0] + bsims, binds = [], [] + for y_batch_f in y_batches_f: + y_batch = load_batch(y_batch_f, dim) + neighbor_size = min(k, y_batch.shape[0]) + yto = yfrom + y_batch.shape[0] + print("{}-{} -> {}-{}".format(xfrom, xto, yfrom, yto)) + idx = faiss.IndexFlatIP(dim) + idx = faiss.index_cpu_to_all_gpus(idx) + idx.add(y_batch) + bsim, bind = idx.search(x_batch, neighbor_size) + + bsims.append(bsim) + binds.append(bind + yfrom) + yfrom += y_batch.shape[0] + del idx + del y_batch + bsims = np.concatenate(bsims, axis=1) + binds = np.concatenate(binds, axis=1) + aux = np.argsort(-bsims, axis=1) + sim_batch = np.zeros((x_batch.shape[0], k), dtype=np.float32) + ind_batch = np.zeros((x_batch.shape[0], k), dtype=np.int64) + for i in range(x_batch.shape[0]): + for j in range(k): + sim_batch[i, j] = bsims[i, aux[i, j]] + ind_batch[i, j] = binds[i, aux[i, j]] + sims.append(sim_batch) + inds.append(ind_batch) + xfrom += x_batch.shape[0] + del x_batch + sim = np.concatenate(sims, axis=0) + ind = np.concatenate(inds, axis=0) + return sim, ind + + +def score(sim, fwd_mean, bwd_mean, margin): + return margin(sim, (fwd_mean + bwd_mean) / 2) + + +def score_candidates( + sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False +): + print(" - scoring {:d} candidates".format(sim_mat.shape[0])) + scores = np.zeros(candidate_inds.shape) + for i in range(scores.shape[0]): + for j in range(scores.shape[1]): + k = int(candidate_inds[i, j]) + scores[i, j] = score(sim_mat[i, j], fwd_mean[i], bwd_mean[k], margin) + return scores + + +def load_text(files): + all_sentences = [] + for fi in files: + with open(fi) as sentence_fi: + for line in sentence_fi: + all_sentences.append(line.strip()) + print(f"Read {len(all_sentences)} sentences") + return all_sentences + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Mine bitext") + parser.add_argument("--src-lang", help="Source language") + parser.add_argument("--tgt-lang", help="Target language") + parser.add_argument( + "--dict-path", help="Path to dictionary file", default="dict.txt" + ) + parser.add_argument( + "--spm-path", help="Path to SPM model file", default="sentence.bpe.model" + ) + parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension") + parser.add_argument("--mem", type=int, default=5, help="Memory in GB") + parser.add_argument("--src-dir", help="Source directory") + parser.add_argument("--tgt-dir", help="Target directory") + parser.add_argument("--output", help="Output path") + parser.add_argument( + "--neighborhood", type=int, default=4, help="Embedding dimension" + ) + parser.add_argument( + "--threshold", type=float, default=1.06, help="Threshold on mined bitext" + ) + parser.add_argument( + "--valid-size", + type=int, + default=2000, + help="Number of sentences used for validation set", + ) + parser.add_argument( + "--min-count", + type=int, + default=50000, + help="Min num sentences used for each language", + ) + args = parser.parse_args() + + x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang) + y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang) + margin = lambda a, b: a / b + y2x_sim, y2x_ind = knnGPU_sharded( + y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x" + ) + x2y_sim, x2y_ind = knnGPU_sharded( + x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y" + ) + + x2y_mean = x2y_sim.mean(axis=1) + y2x_mean = y2x_sim.mean(axis=1) + fwd_scores = score_candidates(x2y_sim, x2y_ind, x2y_mean, y2x_mean, margin) + bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin) + fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)] + bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)] + indices = np.stack( + ( + np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)), + np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))), + ), + axis=1, + ) + scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1))) + + x_sentences = load_text(x_sents_f) + y_sentences = load_text(y_sents_f) + + threshold = args.threshold + min_count = args.min_count + seen_src, seen_trg = set(), set() + directory = args.output + call(f"mkdir -p {directory}") + src_out = open( + f"{directory}/all.{args.src_lang}", + mode="w", + encoding="utf-8", + errors="surrogateescape", + ) + tgt_out = open( + f"{directory}/all.{args.tgt_lang}", + mode="w", + encoding="utf-8", + errors="surrogateescape", + ) + scores_out = open( + f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape" + ) + count = 0 + for i in np.argsort(-scores): + src_ind, trg_ind = indices[i] + if src_ind not in seen_src and trg_ind not in seen_trg: + seen_src.add(src_ind) + seen_trg.add(trg_ind) + if scores[i] > threshold or count < min_count: + if x_sentences[src_ind]: + print(scores[i], file=scores_out) + print(x_sentences[src_ind], file=src_out) + print(y_sentences[trg_ind], file=tgt_out) + count += 1 + else: + print(f"Ignoring sentence: {x_sentences[src_ind]}") + src_out.close() + tgt_out.close() + scores_out.close() + + print(f"Found {count} pairs for threshold={threshold}") + with open(f"{directory}/all.{args.src_lang}") as all_s, open( + f"{directory}/all.{args.tgt_lang}" + ) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open( + f"{directory}/valid.{args.tgt_lang}", "w" + ) as valid_t, open( + f"{directory}/train.{args.src_lang}", "w" + ) as train_s, open( + f"{directory}/train.{args.tgt_lang}", "w" + ) as train_t: + count = 0 + for s_line, t_line in zip(all_s, all_t): + s_line = s_line.split("\t")[1] + t_line = t_line.split("\t")[1] + if count >= args.valid_size: + train_s.write(s_line) + train_t.write(t_line) + else: + valid_s.write(s_line) + valid_t.write(t_line) + count += 1 diff --git a/fairseq/examples/criss/mining/mine_example.sh b/fairseq/examples/criss/mining/mine_example.sh new file mode 100644 index 0000000000000000000000000000000000000000..ace995ac44665f99d904b6a89d7fbbce24103afe --- /dev/null +++ b/fairseq/examples/criss/mining/mine_example.sh @@ -0,0 +1,103 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +source_lang=kk_KZ +target_lang=en_XX +MODEL=criss_checkpoints/criss.3rd.pt +SPM=criss_checkpoints/sentence.bpe.model +SPLIT=test +LANG_DICT=criss_checkpoints/lang_dict.txt +SPM_ENCODE=flores/scripts/spm_encode.py +SAVE_ENCODER=save_encoder.py +ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL +DICT=criss_checkpoints/dict.txt +THRESHOLD=1.02 +MIN_COUNT=500 + +DATA_DIR=data_tmp +SAVE_DIR=mining/${source_lang}_${target_lang}_mined +ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang} +INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba + +mkdir -p $ENCODER_SAVE_DIR/${target_lang} +mkdir -p $ENCODER_SAVE_DIR/${source_lang} +mkdir -p $SAVE_DIR + +## Save encoder outputs + +# Save encoder outputs for source sentences +python $SAVE_ENCODER \ + ${INPUT_DIR} \ + --path ${MODEL} \ + --task translation_multi_simple_epoch \ + --lang-pairs ${source_lang}-${target_lang} \ + --lang-dict ${LANG_DICT} \ + --gen-subset ${SPLIT} \ + --bpe 'sentencepiece' \ + -s ${source_lang} -t ${target_lang} \ + --sentencepiece-model ${SPM} \ + --remove-bpe 'sentencepiece' \ + --beam 1 \ + --lang-tok-style mbart \ + --encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang} + +## Save encoder outputs for target sentences +python $SAVE_ENCODER \ + ${INPUT_DIR} \ + --path ${MODEL} \ + --lang-pairs ${source_lang}-${target_lang} \ + --lang-dict ${LANG_DICT} \ + --task translation_multi_simple_epoch \ + --gen-subset ${SPLIT} \ + --bpe 'sentencepiece' \ + -t ${source_lang} -s ${target_lang} \ + --sentencepiece-model ${SPM} \ + --remove-bpe 'sentencepiece' \ + --beam 1 \ + --lang-tok-style mbart \ + --encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang} + +## Mining +python mining/mine.py \ + --src-lang ${source_lang} \ + --tgt-lang ${target_lang} \ + --dim 1024 \ + --mem 10 \ + --neighborhood 4 \ + --src-dir ${ENCODER_SAVE_DIR}/${source_lang} \ + --tgt-dir ${ENCODER_SAVE_DIR}/${target_lang} \ + --output $SAVE_DIR \ + --threshold ${THRESHOLD} \ + --min-count ${MIN_COUNT} \ + --valid-size 100 \ + --dict-path ${DICT} \ + --spm-path ${SPM} \ + + +## Process and binarize mined data +python $SPM_ENCODE \ + --model ${SPM} \ + --output_format=piece \ + --inputs mining/${source_lang}_${target_lang}_mined/train.${source_lang} mining/${source_lang}_${target_lang}_mined/train.${target_lang} \ + --outputs mining/${source_lang}_${target_lang}_mined/train.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/train.bpe.${target_lang} + +python $SPM_ENCODE \ + --model ${SPM} \ + --output_format=piece \ + --inputs mining/${source_lang}_${target_lang}_mined/valid.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.${target_lang} \ + --outputs mining/${source_lang}_${target_lang}_mined/valid.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.bpe.${target_lang} + + +fairseq-preprocess \ + --source-lang ${source_lang} \ + --target-lang ${target_lang} \ + --trainpref mining/${source_lang}_${target_lang}_mined/train.bpe \ + --validpref mining/${source_lang}_${target_lang}_mined/valid.bpe \ + --destdir mining/${source_lang}_${target_lang}_mined \ + --srcdict ${DICT} \ + --joined-dictionary \ + --workers 8 diff --git a/fairseq/examples/criss/save_encoder.py b/fairseq/examples/criss/save_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..24a842e4092663c79c92a299fa85747b7c0bed64 --- /dev/null +++ b/fairseq/examples/criss/save_encoder.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Translate pre-processed data with a trained model. +""" + +import numpy as np +import torch +from fairseq import checkpoint_utils, options, progress_bar, tasks, utils +from fairseq.sequence_generator import EnsembleModel +from fairseq.utils import safe_hasattr + + +def get_avg_pool( + models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False +): + model = EnsembleModel(models) + + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" + } + + # compute the encoder output for each beam + encoder_outs = model.forward_encoder(encoder_input) + np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32) + encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype( + np.float32 + ) + encoder_mask = np.expand_dims(encoder_mask.T, axis=2) + if has_langtok: + encoder_mask = encoder_mask[1:, :, :] + np_encoder_outs = np_encoder_outs[1, :, :] + masked_encoder_outs = encoder_mask * np_encoder_outs + avg_pool = (masked_encoder_outs / encoder_mask.sum(axis=0)).sum(axis=0) + return avg_pool + + +def main(args): + assert args.path is not None, "--path required for generation!" + assert ( + not args.sampling or args.nbest == args.beam + ), "--sampling requires --nbest to be equal to --beam" + assert ( + args.replace_unk is None or args.raw_text + ), "--replace-unk requires a raw text dataset (--raw-text)" + + args.beam = 1 + utils.import_user_module(args) + + if args.max_tokens is None: + args.max_tokens = 12000 + print(args) + use_cuda = torch.cuda.is_available() and not args.cpu + + # Load dataset splits + task = tasks.setup_task(args) + task.load_dataset(args.gen_subset) + + # Set dictionaries + try: + src_dict = getattr(task, "source_dictionary", None) + except NotImplementedError: + src_dict = None + tgt_dict = task.target_dictionary + + # Load ensemble + print("| loading model(s) from {}".format(args.path)) + models, _model_args = checkpoint_utils.load_model_ensemble( + args.path.split(":"), + arg_overrides=eval(args.model_overrides), + task=task, + ) + + # Optimize ensemble for generation + for model in models: + model.make_generation_fast_( + beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, + need_attn=args.print_alignment, + ) + if args.fp16: + model.half() + if use_cuda: + model.cuda() + + # Load alignment dictionary for unknown word replacement + # (None if no unknown word replacement, empty if no path to align dictionary) + align_dict = utils.load_align_dict(args.replace_unk) + + # Load dataset (possibly sharded) + itr = task.get_batch_iterator( + dataset=task.dataset(args.gen_subset), + max_tokens=args.max_tokens, + max_positions=utils.resolve_max_positions( + task.max_positions(), + ), + ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=args.required_batch_size_multiple, + num_shards=args.num_shards, + shard_id=args.shard_id, + num_workers=args.num_workers, + ).next_epoch_itr(shuffle=False) + + num_sentences = 0 + source_sentences = [] + shard_id = 0 + all_avg_pool = None + encoder_has_langtok = ( + safe_hasattr(task.args, "encoder_langtok") + and task.args.encoder_langtok is not None + and safe_hasattr(task.args, "lang_tok_replacing_bos_eos") + and not task.args.lang_tok_replacing_bos_eos + ) + with progress_bar.build_progress_bar(args, itr) as t: + for sample in t: + if sample is None: + print("Skipping None") + continue + sample = utils.move_to_cuda(sample) if use_cuda else sample + if "net_input" not in sample: + continue + + prefix_tokens = None + if args.prefix_size > 0: + prefix_tokens = sample["target"][:, : args.prefix_size] + + with torch.no_grad(): + avg_pool = get_avg_pool( + models, + sample, + prefix_tokens, + src_dict, + args.post_process, + has_langtok=encoder_has_langtok, + ) + if all_avg_pool is not None: + all_avg_pool = np.concatenate((all_avg_pool, avg_pool)) + else: + all_avg_pool = avg_pool + + if not isinstance(sample["id"], list): + sample_ids = sample["id"].tolist() + else: + sample_ids = sample["id"] + for i, sample_id in enumerate(sample_ids): + # Remove padding + src_tokens = utils.strip_pad( + sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() + ) + + # Either retrieve the original sentences or regenerate them from tokens. + if align_dict is not None: + src_str = task.dataset(args.gen_subset).src.get_original_text( + sample_id + ) + else: + if src_dict is not None: + src_str = src_dict.string(src_tokens, args.post_process) + else: + src_str = "" + + if not args.quiet: + if src_dict is not None: + print("S-{}\t{}".format(sample_id, src_str)) + + source_sentences.append(f"{sample_id}\t{src_str}") + + num_sentences += sample["nsentences"] + if all_avg_pool.shape[0] >= 1000000: + with open( + f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", + "w", + ) as avg_pool_file: + all_avg_pool.tofile(avg_pool_file) + with open( + f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", + "w", + ) as sentence_file: + sentence_file.writelines(f"{line}\n" for line in source_sentences) + all_avg_pool = None + source_sentences = [] + shard_id += 1 + + if all_avg_pool is not None: + with open( + f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w" + ) as avg_pool_file: + all_avg_pool.tofile(avg_pool_file) + with open( + f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w" + ) as sentence_file: + sentence_file.writelines(f"{line}\n" for line in source_sentences) + return None + + +def cli_main(): + parser = options.get_generation_parser() + parser.add_argument( + "--encoder-save-dir", + default="", + type=str, + metavar="N", + help="directory to save encoder outputs", + ) + args = options.parse_args_and_arch(parser) + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/examples/criss/sentence_retrieval/encoder_analysis.py b/fairseq/examples/criss/sentence_retrieval/encoder_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..b41bfbe38789ba14e6a5ea938c75d761424c00ab --- /dev/null +++ b/fairseq/examples/criss/sentence_retrieval/encoder_analysis.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import glob + +import numpy as np + + +DIM = 1024 + + +def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False): + target_ids = [tid for tid in target_embs] + source_mat = np.stack(source_embs.values(), axis=0) + normalized_source_mat = source_mat / np.linalg.norm( + source_mat, axis=1, keepdims=True + ) + target_mat = np.stack(target_embs.values(), axis=0) + normalized_target_mat = target_mat / np.linalg.norm( + target_mat, axis=1, keepdims=True + ) + sim_mat = normalized_source_mat.dot(normalized_target_mat.T) + if return_sim_mat: + return sim_mat + neighbors_map = {} + for i, sentence_id in enumerate(source_embs): + idx = np.argsort(sim_mat[i, :])[::-1][:k] + neighbors_map[sentence_id] = [target_ids[tid] for tid in idx] + return neighbors_map + + +def load_embeddings(directory, LANGS): + sentence_embeddings = {} + sentence_texts = {} + for lang in LANGS: + sentence_embeddings[lang] = {} + sentence_texts[lang] = {} + lang_dir = f"{directory}/{lang}" + embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*") + for embed_file in embedding_files: + shard_id = embed_file.split(".")[-1] + embeddings = np.fromfile(embed_file, dtype=np.float32) + num_rows = embeddings.shape[0] // DIM + embeddings = embeddings.reshape((num_rows, DIM)) + + with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file: + for idx, line in enumerate(sentence_file): + sentence_id, sentence = line.strip().split("\t") + sentence_texts[lang][sentence_id] = sentence + sentence_embeddings[lang][sentence_id] = embeddings[idx, :] + + return sentence_embeddings, sentence_texts + + +def compute_accuracy(directory, LANGS): + sentence_embeddings, sentence_texts = load_embeddings(directory, LANGS) + + top_1_accuracy = {} + + top1_str = " ".join(LANGS) + "\n" + for source_lang in LANGS: + top_1_accuracy[source_lang] = {} + top1_str += f"{source_lang} " + for target_lang in LANGS: + top1 = 0 + top5 = 0 + neighbors_map = compute_dist( + sentence_embeddings[source_lang], sentence_embeddings[target_lang] + ) + for sentence_id, neighbors in neighbors_map.items(): + if sentence_id == neighbors[0]: + top1 += 1 + if sentence_id in neighbors[:5]: + top5 += 1 + n = len(sentence_embeddings[target_lang]) + top1_str += f"{top1/n} " + top1_str += "\n" + + print(top1_str) + print(top1_str, file=open(f"{directory}/accuracy", "w")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Analyze encoder outputs") + parser.add_argument("directory", help="Source language corpus") + parser.add_argument("--langs", help="List of langs") + args = parser.parse_args() + langs = args.langs.split(",") + compute_accuracy(args.directory, langs) diff --git a/fairseq/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh b/fairseq/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh new file mode 100644 index 0000000000000000000000000000000000000000..0428d8bef9d426ac3e664cd281ce0b688f5f580f --- /dev/null +++ b/fairseq/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +source_lang=kk_KZ +target_lang=en_XX +MODEL=criss_checkpoints/criss.3rd.pt +SPM=criss_checkpoints/sentence.bpe.model +SPLIT=test +LANG_DICT=criss_checkpoints/lang_dict.txt +ENCODER_ANALYSIS=sentence_retrieval/encoder_analysis.py +SAVE_ENCODER=save_encoder.py +ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL + + + +DATA_DIR=data_tmp +INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba +ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang} +mkdir -p $ENCODER_SAVE_DIR/${target_lang} +mkdir -p $ENCODER_SAVE_DIR/${source_lang} + +# Save encoder outputs for source sentences +python $SAVE_ENCODER \ + ${INPUT_DIR} \ + --path ${MODEL} \ + --task translation_multi_simple_epoch \ + --lang-dict ${LANG_DICT} \ + --gen-subset ${SPLIT} \ + --bpe 'sentencepiece' \ + --lang-pairs ${source_lang}-${target_lang} \ + -s ${source_lang} -t ${target_lang} \ + --sentencepiece-model ${SPM} \ + --remove-bpe 'sentencepiece' \ + --beam 1 \ + --lang-tok-style mbart \ + --encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang} + +# Save encoder outputs for target sentences +python $SAVE_ENCODER \ + ${INPUT_DIR} \ + --path ${MODEL} \ + --lang-dict ${LANG_DICT} \ + --task translation_multi_simple_epoch \ + --gen-subset ${SPLIT} \ + --bpe 'sentencepiece' \ + --lang-pairs ${target_lang}-${source_lang} \ + -t ${source_lang} -s ${target_lang} \ + --sentencepiece-model ${SPM} \ + --remove-bpe 'sentencepiece' \ + --beam 1 \ + --lang-tok-style mbart \ + --encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang} + +# Analyze sentence retrieval accuracy +python $ENCODER_ANALYSIS --langs "${source_lang},${target_lang}" ${ENCODER_SAVE_DIR} diff --git a/fairseq/examples/criss/unsupervised_mt/eval.sh b/fairseq/examples/criss/unsupervised_mt/eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..03b773ed5a522eb82186fea8ffbb6c557e14b6d3 --- /dev/null +++ b/fairseq/examples/criss/unsupervised_mt/eval.sh @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +SRC=si_LK +TGT=en_XX +MODEL=criss_checkpoints/criss.3rd.pt + +MULTIBLEU=mosesdecoder/scripts/generic/multi-bleu.perl +MOSES=mosesdecoder +REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl +NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl +REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl +TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl +GEN_TMP_DIR=gen_tmp +LANG_DICT=criss_checkpoints/lang_dict.txt + +if [ ! -d "mosesdecoder" ]; then + git clone https://github.com/moses-smt/mosesdecoder +fi +mkdir -p $GEN_TMP_DIR +fairseq-generate data_tmp/${SRC}-${TGT}-flores \ + --task translation_multi_simple_epoch \ + --max-tokens 2000 \ + --path ${MODEL} \ + --skip-invalid-size-inputs-valid-test \ + --beam 5 --lenpen 1.0 --gen-subset test \ + --remove-bpe=sentencepiece \ + --source-lang ${SRC} --target-lang ${TGT} \ + --decoder-langtok --lang-pairs 'en_XX-ar_AR,en_XX-de_DE,en_XX-es_XX,en_XX-fr_XX,en_XX-hi_IN,en_XX-it_IT,en_XX-ja_XX,en_XX-ko_KR,en_XX-nl_XX,en_XX-ru_RU,en_XX-zh_CN,en_XX-tr_TR,en_XX-vi_VN,en_XX-ro_RO,en_XX-my_MM,en_XX-ne_NP,en_XX-si_LK,en_XX-cs_CZ,en_XX-lt_LT,en_XX-kk_KZ,en_XX-gu_IN,en_XX-fi_FI,en_XX-et_EE,en_XX-lv_LV,ar_AR-en_XX,cs_CZ-en_XX,de_DE-en_XX,es_XX-en_XX,et_EE-en_XX,fi_FI-en_XX,fr_XX-en_XX,gu_IN-en_XX,hi_IN-en_XX,it_IT-en_XX,ja_XX-en_XX,kk_KZ-en_XX,ko_KR-en_XX,lt_LT-en_XX,lv_LV-en_XX,my_MM-en_XX,ne_NP-en_XX,nl_XX-en_XX,ro_RO-en_XX,ru_RU-en_XX,si_LK-en_XX,tr_TR-en_XX,vi_VN-en_XX,zh_CN-en_XX,ar_AR-es_XX,es_XX-ar_AR,ar_AR-hi_IN,hi_IN-ar_AR,ar_AR-zh_CN,zh_CN-ar_AR,cs_CZ-es_XX,es_XX-cs_CZ,cs_CZ-hi_IN,hi_IN-cs_CZ,cs_CZ-zh_CN,zh_CN-cs_CZ,de_DE-es_XX,es_XX-de_DE,de_DE-hi_IN,hi_IN-de_DE,de_DE-zh_CN,zh_CN-de_DE,es_XX-hi_IN,hi_IN-es_XX,es_XX-zh_CN,zh_CN-es_XX,et_EE-es_XX,es_XX-et_EE,et_EE-hi_IN,hi_IN-et_EE,et_EE-zh_CN,zh_CN-et_EE,fi_FI-es_XX,es_XX-fi_FI,fi_FI-hi_IN,hi_IN-fi_FI,fi_FI-zh_CN,zh_CN-fi_FI,fr_XX-es_XX,es_XX-fr_XX,fr_XX-hi_IN,hi_IN-fr_XX,fr_XX-zh_CN,zh_CN-fr_XX,gu_IN-es_XX,es_XX-gu_IN,gu_IN-hi_IN,hi_IN-gu_IN,gu_IN-zh_CN,zh_CN-gu_IN,hi_IN-zh_CN,zh_CN-hi_IN,it_IT-es_XX,es_XX-it_IT,it_IT-hi_IN,hi_IN-it_IT,it_IT-zh_CN,zh_CN-it_IT,ja_XX-es_XX,es_XX-ja_XX,ja_XX-hi_IN,hi_IN-ja_XX,ja_XX-zh_CN,zh_CN-ja_XX,kk_KZ-es_XX,es_XX-kk_KZ,kk_KZ-hi_IN,hi_IN-kk_KZ,kk_KZ-zh_CN,zh_CN-kk_KZ,ko_KR-es_XX,es_XX-ko_KR,ko_KR-hi_IN,hi_IN-ko_KR,ko_KR-zh_CN,zh_CN-ko_KR,lt_LT-es_XX,es_XX-lt_LT,lt_LT-hi_IN,hi_IN-lt_LT,lt_LT-zh_CN,zh_CN-lt_LT,lv_LV-es_XX,es_XX-lv_LV,lv_LV-hi_IN,hi_IN-lv_LV,lv_LV-zh_CN,zh_CN-lv_LV,my_MM-es_XX,es_XX-my_MM,my_MM-hi_IN,hi_IN-my_MM,my_MM-zh_CN,zh_CN-my_MM,ne_NP-es_XX,es_XX-ne_NP,ne_NP-hi_IN,hi_IN-ne_NP,ne_NP-zh_CN,zh_CN-ne_NP,nl_XX-es_XX,es_XX-nl_XX,nl_XX-hi_IN,hi_IN-nl_XX,nl_XX-zh_CN,zh_CN-nl_XX,ro_RO-es_XX,es_XX-ro_RO,ro_RO-hi_IN,hi_IN-ro_RO,ro_RO-zh_CN,zh_CN-ro_RO,ru_RU-es_XX,es_XX-ru_RU,ru_RU-hi_IN,hi_IN-ru_RU,ru_RU-zh_CN,zh_CN-ru_RU,si_LK-es_XX,es_XX-si_LK,si_LK-hi_IN,hi_IN-si_LK,si_LK-zh_CN,zh_CN-si_LK,tr_TR-es_XX,es_XX-tr_TR,tr_TR-hi_IN,hi_IN-tr_TR,tr_TR-zh_CN,zh_CN-tr_TR,vi_VN-es_XX,es_XX-vi_VN,vi_VN-hi_IN,hi_IN-vi_VN,vi_VN-zh_CN,zh_CN-vi_VN' \ + --lang-dict ${LANG_DICT} --lang-tok-style 'mbart' --sampling-method 'temperature' --sampling-temperature '1.0' > $GEN_TMP_DIR/${SRC}_${TGT}.gen +cat $GEN_TMP_DIR/${SRC}_${TGT}.gen | grep -P "^T-" | cut -f2 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l ${TGT:0:2} | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape ${TGT:0:2} > $GEN_TMP_DIR/${SRC}_${TGT}.hyp +cat $GEN_TMP_DIR/${SRC}_${TGT}.gen | grep -P "^H-" | cut -f3 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l ${TGT:0:2} | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape ${TGT:0:2} > $GEN_TMP_DIR/${SRC}_${TGT}.ref +${MULTIBLEU} $GEN_TMP_DIR/${SRC}_${TGT}.ref < $GEN_TMP_DIR/${SRC}_${TGT}.hyp diff --git a/fairseq/examples/cross_lingual_language_model/README.md b/fairseq/examples/cross_lingual_language_model/README.md new file mode 100644 index 0000000000000000000000000000000000000000..af9128e39e5925e9411d162c2f24a19e4532d618 --- /dev/null +++ b/fairseq/examples/cross_lingual_language_model/README.md @@ -0,0 +1,77 @@ +# Cross-Lingual Language Model Pre-training + +Below are some details for training Cross-Lingual Language Models (XLM) - similar to the ones presented in [Lample & Conneau, 2019](https://arxiv.org/pdf/1901.07291.pdf) - in Fairseq. The current implementation only supports the Masked Language Model (MLM) from the paper above. + +## Downloading and Tokenizing Monolingual Data + +Pointers to the monolingual data from wikipedia, used for training the XLM-style MLM model as well as details on processing (tokenization and BPE) it can be found in the [XLM Github Repository](https://github.com/facebookresearch/XLM#download--preprocess-monolingual-data). + +Let's assume the following for the code snippets in later sections to work +- Processed data is in the folder: monolingual_data/processed +- Each language has 3 files for train, test and validation. For example we have the following files for English: + train.en, valid.en +- We are training a model for 5 languages: Arabic (ar), German (de), English (en), Hindi (hi) and French (fr) +- The vocabulary file is monolingual_data/processed/vocab_mlm + + +## Fairseq Pre-processing and Binarization + +Pre-process and binarize the data with the MaskedLMDictionary and cross_lingual_lm task + +```bash +# Ensure the output directory exists +DATA_DIR=monolingual_data/fairseq_processed +mkdir -p "$DATA_DIR" + +for lg in ar de en hi fr +do + + fairseq-preprocess \ + --task cross_lingual_lm \ + --srcdict monolingual_data/processed/vocab_mlm \ + --only-source \ + --trainpref monolingual_data/processed/train \ + --validpref monolingual_data/processed/valid \ + --testpref monolingual_data/processed/test \ + --destdir monolingual_data/fairseq_processed \ + --workers 20 \ + --source-lang $lg + + # Since we only have a source language, the output file has a None for the + # target language. Remove this + + for stage in train test valid + + sudo mv "$DATA_DIR/$stage.$lg-None.$lg.bin" "$stage.$lg.bin" + sudo mv "$DATA_DIR/$stage.$lg-None.$lg.idx" "$stage.$lg.idx" + + done + +done +``` + +## Train a Cross-lingual Language Model similar to the XLM MLM model + +Use the following command to train the model on 5 languages. + +``` +fairseq-train \ +--task cross_lingual_lm monolingual_data/fairseq_processed \ +--save-dir checkpoints/mlm \ +--max-update 2400000 --save-interval 1 --no-epoch-checkpoints \ +--arch xlm_base \ +--optimizer adam --lr-scheduler reduce_lr_on_plateau \ +--lr-shrink 0.5 --lr 0.0001 --stop-min-lr 1e-09 \ +--dropout 0.1 \ +--criterion legacy_masked_lm_loss \ +--max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \ +--dataset-impl lazy --seed 0 \ +--masked-lm-only \ +--monolingual-langs 'ar,de,en,hi,fr' --num-segment 5 \ +--ddp-backend=legacy_ddp +``` + +Some Notes: +- Using tokens_per_sample greater than 256 can cause OOM (out-of-memory) issues. Usually since MLM packs in streams of text, this parameter doesn't need much tuning. +- The Evaluation workflow for computing MLM Perplexity on test data is in progress. +- Finetuning this model on a downstream task is something which is not currently available. diff --git a/fairseq/examples/data2vec/README.md b/fairseq/examples/data2vec/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a0ff21b82ac420259f0abdc1d80479d7a6edd47c --- /dev/null +++ b/fairseq/examples/data2vec/README.md @@ -0,0 +1,261 @@ +# data2vec 2.0 + +data2vec 2.0 improves the training efficiency of the original data2vec algorithm. We make the following improvements for efficiency considerations - we forward only the unmasked timesteps through the encoder, we use convolutional decoder and we use multimasking to amortize the compute overhead of the teacher model. You can find details in the paper [Efficient Self-supervised Learning with Contextualized Target Representations for Vision, Speech and Language](https://arxiv.org/abs/2212.07525) and our [blog post](https://ai.facebook.com/blog/ai-self-supervised-learning-data2vec/). + +## Pretrained and finetuned models +### Vision +| Model | Finetuning split | Link +|---|---|--- +data2vec ViT-B | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet.pt) +data2vec ViT-B | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_imagenet_ft.pt) +data2vec ViT-L | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet.pt) +data2vec ViT-L | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_imagenet_ft.pt) +data2vec ViT-H | No fine-tuning | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/huge_imagenet.pt) +data2vec ViT-H | Imagenet-1K | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/huge_imagenet_ft.pt) + +Vision models only are license under CC-BY-NC. +### Speech + +| Model | Finetuning split | Dataset | Link +|---|---|---|--- +data2vec Base | No fine-tuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri.pt) +data2vec Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/base_libri_960h.pt) +data2vec Large | No fine-tuning | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_vox.pt) +data2vec Large | 960 hours | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/large_vox_960h.pt) + +### NLP + +| Model | Fine-tuning data | Dataset | Link | Dict | BPE +|---|---|---|---|---|--- +data2vec Base | No fine-tuning | Books + Wiki | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec2/nlp_base.pt) | [dict](https://dl.fbaipublicfiles.com/fairseq/data2vec2/dict.txt) | [encoder](https://dl.fbaipublicfiles.com/fairseq/data2vec2/encoder.json) / [vocab](https://dl.fbaipublicfiles.com/fairseq/data2vec2/vocab.bpe) + +[//]: # (## Data Preparation) + +[//]: # () +[//]: # (### Vision) + +[//]: # (add details) + +[//]: # (### Speech) + +[//]: # (add details) + +[//]: # () +[//]: # (### NLP) + +[//]: # (add details) + + +## Commands to train different models using data2vec 2.0 + +### Vision + +Commands to pretrain different model configurations +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \ +--config-name base_images_only_task task.data=/path/to/dir +``` + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \ +--config-name large_images_only_task task.data=/path/to/dir +``` + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \ +--config-name huge_images14_only_task task.data=/path/to/dir +``` + +Commands to finetune different model configurations + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \ +--config-name mae_imagenet_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model +``` + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \ +--config-name mae_imagenet_large_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model +``` + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/vision/finetuning \ +--config-name mae_imagenet_huge_clean task.data=/path/to/dir model.model_path=/path/to/pretrained/model +``` + +### Speech + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \ +--config-name base_audio_only_task task.data=/path/to/manifests +``` + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \ +--config-name large_audio_only_task task.data=/path/to/manifests +``` + +Finetuning: + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/wav2vec/config/finetuning --config-name vox_10h \ +task.data=/path/to/manifests model.w2v_path=/path/to/pretrained/model common.user_dir=examples/data2vec +``` + +Replace vox_10h with the right config depending on your model and fine-tuning split. +See examples/wav2vec/config/finetuning for all available configs. + +### NLP + +Commands to pretrain +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2 \ +--config-name base_text_only_task task.data=/path/to/file +``` + +Commands to fine-tune all GLUE tasks +```shell script +$ task=cola # choose from [cola|qnli|mrpc|rte|sst_2|mnli|qqp|sts_b] +$ lr=1e-5 # sweep [1e-5|2e-5|4e-5|6e-5] for each task +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/v2/text_finetuning \ +--config-name $task task.data=/path/to/file model.model_path=/path/to/pretrained/model "optimization.lr=[${lr}]" +``` + +# data2vec + +data2vec is a framework for self-supervised representation learning for images, speech, and text as described in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Language (Baevski et al., 2022)](https://ai.facebook.com/research/data2vec-a-general-framework-for-self-supervised-learning-in-speech-vision-and-language). The algorithm uses the same learning mechanism for different modalities. + + +## Pre-trained models + +### Vision + +Code and pre-trained models for data2vec visions can be found [here](https://github.com/facebookresearch/data2vec_vision/tree/main/beit). + +### Speech + +| Model | Finetuning split | Dataset | Link +|---|---|---|--- +data2vec Base | No fine-tuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/audio_base_ls.pt) +data2vec Base | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/audio_base_ls_10m.pt) +data2vec Base | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/audio_base_ls_100h.pt) +data2vec Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/audio_base_ls_960h.pt) +data2vec Large | No fine-tuning | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/vox_pretrained.pt) +data2vec Large | 10 minutes | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/vox_10m.pt) +data2vec Large | 100 hours | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/vox_100h.pt) +data2vec Large | 960 hours | [Libri-light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/vox_960h.pt) +--- + +### NLP + +Model | Fine-tuning data | Dataset | Link +|---|---|---|---| +data2vec Base | No fine-tuning | Books + Wiki | [download](https://dl.fbaipublicfiles.com/fairseq/data2vec/nlp_base.pt) + +## Training a new speech model with the CLI tools + +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) + +### Prepare training data manifest: + +First, install the `soundfile` library: +```shell script +pip install soundfile +``` + +Next, run: + +```shell script +$ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid +``` + +$ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read. + +$valid should be set to some reasonable percentage (like 0.01) of training data to use for validation. +To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a +separately pre-processed manifest file. + +### Train a data2vec Base model: + +This configuration was used for the base model trained on the Librispeech dataset in the data2vec paper + +Note that the input is expected to be single channel, sampled at 16 kHz + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/audio/pretraining \ +--config-name base_librispeech task.data=/path/to/manifests common.user_dir=examples/data2vec +``` + +Note: you can simulate 16 GPUs by using k GPUs and adding command line parameters +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 16/k + +### Fine-tune a pre-trained model with CTC: + +Fine-tuning a model requires parallel audio and labels file, as well as a vocabulary file in fairseq format. +A letter vocabulary can be downloaded [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). +An example [script](../wav2vec/libri_labels.py) that generates labels for the Librispeech dataset from the tsv file produced by wav2vec_manifest.py can be used as follows: + +```shell script +split=train +$ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $split +``` + +Fine-tuning on 100h of Librispeech with letter targets: +```shell script +$ fairseq-hydra-train \ + distributed_training.distributed_port=$PORT \ + task.data=/path/to/data \ + model.w2v_path=/path/to/model.pt \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/finetuning \ + --config-name base_100h common.user_dir=examples/data2vec +``` + +There are other config files in the config/finetuning directory that can be used to fine-tune on other splits. +You can specify the right config via the `--config-name` parameter. + +Decoding with a language model during training requires flashlight [python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter). +If you want to use a language model, add `+criterion.wer_args='[/path/to/kenlm, /path/to/lexicon, 2, -1]'` to the command line. + +### Evaluating a CTC model: + +Evaluating a CTC model with a language model requires [flashlight python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter) to be installed. + +Fairseq transformer language model used in the wav2vec 2.0 paper can be obtained from the [wav2letter model repository](https://github.com/facebookresearch/wav2letter/tree/master/recipes/sota/2019). +Be sure to upper-case the language model vocab after downloading it. + +Letter dictionary for pre-trained models can be found [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). + +Next, run the evaluation command: + +```shell script +python examples/speech_recognition/new/infer.py --config-dir examples/speech_recognition/new/conf \ +--config-name infer task=audio_finetuning task.data=/path/to/manifests common.user_dir=examples/data2vec \ +task.labels=ltr decoding.type=kenlm \ +decoding.lmweight=${lmweight} decoding.wordscore=${wordscore} decoding.silweight=${silscore} \ +decoding.lexicon=/path/to/lexicon \ +decoding.lmpath=/path/to/lm decoding.unique_wer_file=True \ +dataset.gen_subset=dev_clean,dev_other,test_clean,test_other \ +common_eval.path=/path/to/checkpoint.pt decoding.beam=1500 distributed_training.distributed_world_size=${num_gpus} +``` + +To get raw numbers, use decoding.type=viterbi and omit the lexicon. To use the transformer language model, use decoding.type=fairseqlm. + +## Training a new NLP model with the CLI tools + +Please follow the [RoBERTa](../roberta/README.md) instructions to preprocess your data. To train a data2vec model on run: + +```shell script +$ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/text/pretraining \ +--config-name base task.data=/path/to/data common.user_dir=examples/data2vec +``` + +As for speech models, you can simulate 16 gpus by using the update_freq parameter. + +### Finetuning data2vec-text on GLUE + +Please use a command similar to this: + +```shell +$ python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \ + --config-name $task task.data=$data_path checkpoint.restore_file="${/path/to/pretrained/model.pt}" +``` diff --git a/fairseq/examples/data2vec/__init__.py b/fairseq/examples/data2vec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/examples/data2vec/config/audio/classification/base_classification.yaml b/fairseq/examples/data2vec/config/audio/classification/base_classification.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fdb9c8d3d7c11b15c50d9a661603438f3ad0ffec --- /dev/null +++ b/fairseq/examples/data2vec/config/audio/classification/base_classification.yaml @@ -0,0 +1,70 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + all_gather_list_size: 70000 + tensorboard_logdir: tb + min_loss_scale: 1e-6 + +checkpoint: + save_interval: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: mAP + maximize_best_checkpoint_metric: true + +task: + _name: audio_classification + data: ??? + normalize: true + labels: lbl + +dataset: + num_workers: 6 + max_tokens: 2560000 + skip_invalid_size_inputs_valid_test: true + valid_subset: eval + validate_interval: 5 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: model + can_sum: false + log_keys: + - _predictions + - _targets + +optimization: + max_update: 30000 + lr: [0.00006] # scratch 53-5 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: cosine + warmup_updates: 5000 + +model: + _name: audio_classification + model_path: ??? + apply_mask: true + mask_prob: 0.6 + mask_length: 5 # scratch 1 + mask_channel_prob: 0 + mask_channel_length: 64 + layerdrop: 0.1 + dropout: 0.1 + activation_dropout: 0.1 + attention_dropout: 0.2 + feature_grad_mult: 0 # scratch 1 + label_mixup: true + source_mixup: 0.5 + prediction_mode: lin_softmax # scratch average_sigmoid +