Audio-to-Audio
BeauKang01 commited on
Commit
d972bc8
·
1 Parent(s): 09ce0e3

add checkpoint

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ckpt/best.pt.tar +3 -0
  2. ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/config.json +81 -0
  3. ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/preprocessor_config.json +11 -0
  4. ckpt/codec_ckpt/hub/version.txt +1 -0
  5. ckpt/download.sh +18 -0
  6. ckpt/download_ckpt.py +58 -0
  7. config/test.yml +21 -0
  8. inference.py +474 -0
  9. inference.sh +6 -0
  10. loader/__pycache__/datareader.cpython-310.pyc +0 -0
  11. loader/__pycache__/datareader_aec.cpython-310.pyc +0 -0
  12. loader/__pycache__/datareader_fe.cpython-310.pyc +0 -0
  13. loader/__pycache__/datareader_tse.cpython-310.pyc +0 -0
  14. loader/datareader.py +65 -0
  15. loader/datareader_aec.py +86 -0
  16. loader/datareader_tse.py +85 -0
  17. nnet/WavLM.py +793 -0
  18. nnet/__pycache__/WavLM.cpython-310.pyc +0 -0
  19. nnet/__pycache__/embedding.cpython-310.pyc +0 -0
  20. nnet/__pycache__/llase.cpython-310.pyc +0 -0
  21. nnet/__pycache__/modules.cpython-310.pyc +0 -0
  22. nnet/llase.py +104 -0
  23. nnet/modules.py +825 -0
  24. vq/__init__.py +4 -0
  25. vq/__pycache__/__init__.cpython-310.pyc +0 -0
  26. vq/__pycache__/__init__.cpython-311.pyc +0 -0
  27. vq/__pycache__/__init__.cpython-312.pyc +0 -0
  28. vq/__pycache__/__init__.cpython-37.pyc +0 -0
  29. vq/__pycache__/__init__.cpython-38.pyc +0 -0
  30. vq/__pycache__/__init__.cpython-39.pyc +0 -0
  31. vq/__pycache__/activations.cpython-310.pyc +0 -0
  32. vq/__pycache__/activations.cpython-311.pyc +0 -0
  33. vq/__pycache__/activations.cpython-312.pyc +0 -0
  34. vq/__pycache__/activations.cpython-37.pyc +0 -0
  35. vq/__pycache__/activations.cpython-38.pyc +0 -0
  36. vq/__pycache__/activations.cpython-39.pyc +0 -0
  37. vq/__pycache__/blocks.cpython-310.pyc +0 -0
  38. vq/__pycache__/blocks.cpython-39.pyc +0 -0
  39. vq/__pycache__/bs_roformer5.cpython-310.pyc +0 -0
  40. vq/__pycache__/bs_roformer5.cpython-37.pyc +0 -0
  41. vq/__pycache__/bs_roformer5.cpython-38.pyc +0 -0
  42. vq/__pycache__/bs_roformer5.cpython-39.pyc +0 -0
  43. vq/__pycache__/codec_decoder.cpython-310.pyc +0 -0
  44. vq/__pycache__/codec_decoder.cpython-311.pyc +0 -0
  45. vq/__pycache__/codec_decoder.cpython-312.pyc +0 -0
  46. vq/__pycache__/codec_decoder.cpython-39.pyc +0 -0
  47. vq/__pycache__/codec_decoder_vocos.cpython-310.pyc +0 -0
  48. vq/__pycache__/codec_decoder_vocos.cpython-311.pyc +0 -0
  49. vq/__pycache__/codec_decoder_vocos.cpython-312.pyc +0 -0
  50. vq/__pycache__/codec_decoder_vocos.cpython-39.pyc +0 -0
ckpt/best.pt.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:298e4fccf3baacd8623574e7a767d22198f3ddb0c35cdb17e71e03dd4edf0fc5
3
+ size 11826355083
ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/config.json ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "adapter_act": "relu",
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": false,
8
+ "architectures": [
9
+ "Wav2Vec2BertModel"
10
+ ],
11
+ "attention_dropout": 0.0,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 768,
14
+ "codevector_dim": 768,
15
+ "conformer_conv_dropout": 0.1,
16
+ "contrastive_logits_temperature": 0.1,
17
+ "conv_depthwise_kernel_size": 31,
18
+ "ctc_loss_reduction": "sum",
19
+ "ctc_zero_infinity": false,
20
+ "diversity_loss_weight": 0.1,
21
+ "eos_token_id": 2,
22
+ "feat_proj_dropout": 0.0,
23
+ "feat_quantizer_dropout": 0.0,
24
+ "feature_projection_input_dim": 160,
25
+ "final_dropout": 0.1,
26
+ "hidden_act": "swish",
27
+ "hidden_dropout": 0.0,
28
+ "hidden_size": 1024,
29
+ "initializer_range": 0.02,
30
+ "intermediate_size": 4096,
31
+ "layer_norm_eps": 1e-05,
32
+ "layerdrop": 0.1,
33
+ "left_max_position_embeddings": 64,
34
+ "mask_feature_length": 10,
35
+ "mask_feature_min_masks": 0,
36
+ "mask_feature_prob": 0.0,
37
+ "mask_time_length": 10,
38
+ "mask_time_min_masks": 2,
39
+ "mask_time_prob": 0.05,
40
+ "max_source_positions": 5000,
41
+ "model_type": "wav2vec2-bert",
42
+ "num_adapter_layers": 1,
43
+ "num_attention_heads": 16,
44
+ "num_codevector_groups": 2,
45
+ "num_codevectors_per_group": 320,
46
+ "num_hidden_layers": 24,
47
+ "num_negatives": 100,
48
+ "output_hidden_size": 1024,
49
+ "pad_token_id": 0,
50
+ "position_embeddings_type": "relative_key",
51
+ "proj_codevector_dim": 768,
52
+ "right_max_position_embeddings": 8,
53
+ "rotary_embedding_base": 10000,
54
+ "tdnn_dilation": [
55
+ 1,
56
+ 2,
57
+ 3,
58
+ 1,
59
+ 1
60
+ ],
61
+ "tdnn_dim": [
62
+ 512,
63
+ 512,
64
+ 512,
65
+ 512,
66
+ 1500
67
+ ],
68
+ "tdnn_kernel": [
69
+ 5,
70
+ 3,
71
+ 3,
72
+ 1,
73
+ 1
74
+ ],
75
+ "torch_dtype": "float32",
76
+ "transformers_version": "4.37.0.dev0",
77
+ "use_intermediate_ffn_before_adapter": false,
78
+ "use_weighted_layer_sum": false,
79
+ "vocab_size": null,
80
+ "xvector_output_dim": 512
81
+ }
ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/preprocessor_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "feature_extractor_type": "SeamlessM4TFeatureExtractor",
3
+ "feature_size": 80,
4
+ "num_mel_bins": 80,
5
+ "padding_side": "right",
6
+ "padding_value": 1,
7
+ "processor_class": "Wav2Vec2BertProcessor",
8
+ "return_attention_mask": true,
9
+ "sampling_rate": 16000,
10
+ "stride": 2
11
+ }
ckpt/codec_ckpt/hub/version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
ckpt/download.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python download_script.py \
2
+ --source hf \
3
+ --repo_id microsoft/wavlm-large \
4
+ --filename pytorch_model.bin \
5
+ --save_path ./WavLM-Large.pt
6
+
7
+ python download_script.py \
8
+ --source hf \
9
+ --repo_id facebook/w2v-bert-2.0 \
10
+ --filename model.safetensors \
11
+ --save_path \
12
+ ./codec_ckpt/hub/models--facebook--w2v-bert-2.0/model.safetensors
13
+
14
+ python download_script.py \
15
+ --source hf \
16
+ --repo_id HKUSTAudio/xcodec2 \
17
+ --filename ckpt/epoch=4-step=1400000.ckpt \
18
+ --save_path ./codec_ckpt/epoch=4-step=1400000.ckpt
ckpt/download_ckpt.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import argparse
4
+ from huggingface_hub import hf_hub_download
5
+ from tqdm import tqdm
6
+
7
+ def download_from_url(url, save_path):
8
+ """Download a file from a given URL and save it locally."""
9
+ response = requests.get(url, stream=True)
10
+ total_size = int(response.headers.get("content-length", 0))
11
+ block_size = 1024 # 1 KB
12
+ progress_bar = tqdm(total=total_size, unit="B", unit_scale=True)
13
+
14
+ with open(save_path, "wb") as file:
15
+ for data in response.iter_content(block_size):
16
+ progress_bar.update(len(data))
17
+ file.write(data)
18
+ progress_bar.close()
19
+
20
+ if total_size != 0 and progress_bar.n != total_size:
21
+ print("Download failed!")
22
+ else:
23
+ print(f"File downloaded to: {save_path}")
24
+
25
+ def download_from_hf(repo_id, filename, save_path):
26
+ """Download a file from Hugging Face Hub."""
27
+ print(f"Downloading from Hugging Face Hub: {repo_id}/{filename}")
28
+ try:
29
+ hf_hub_download(repo_id=repo_id, filename=filename, local_dir=os.path.dirname(save_path), local_dir_use_symlinks=False)
30
+ print(f"File downloaded to: {save_path}")
31
+ except Exception as e:
32
+ print(f"Download failed: {e}")
33
+
34
+ def main():
35
+ parser = argparse.ArgumentParser(description="Automatically download model checkpoints")
36
+ parser.add_argument("--source", type=str, required=True, choices=["hf", "url"], help="Download source: hf (Hugging Face Hub) or url (custom URL)")
37
+ parser.add_argument("--repo_id", type=str, help="Hugging Face model repository ID (e.g., google/bert-base-uncased)")
38
+ parser.add_argument("--filename", type=str, help="Filename in the Hugging Face repository")
39
+ parser.add_argument("--url", type=str, help="Custom download URL")
40
+ parser.add_argument("--save_path", type=str, required=True, help="Path to save the file (including filename)")
41
+ args = parser.parse_args()
42
+
43
+ # Ensure the save directory exists
44
+ os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
45
+
46
+ if args.source == "hf":
47
+ if not args.repo_id or not args.filename:
48
+ print("Please provide a Hugging Face repository ID and filename!")
49
+ return
50
+ download_from_hf(args.repo_id, args.filename, args.save_path)
51
+ elif args.source == "url":
52
+ if not args.url:
53
+ print("Please provide a download URL!")
54
+ return
55
+ download_from_url(args.url, args.save_path)
56
+
57
+ if __name__ == "__main__":
58
+ main()
config/test.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ test:
2
+ checkpoint: ./ckpt/best.pt.tar
3
+ use_cuda: True
4
+ infer_feat_too: True
5
+ inference_time: 1
6
+
7
+ save:
8
+ feat_dir: ./decode/feat/se
9
+ wav_dir: ./decode/wav/se
10
+
11
+ task: SE
12
+
13
+ # LLaSE config
14
+ nnet_conf:
15
+ d_model: 1024
16
+ nhead: 16
17
+ num_layers: 16
18
+
19
+ datareader:
20
+ sample_rate: 16000
21
+ filename: /home/node57_data2/bykang/work_plus/test_set/interspeech2020/syn_no_reverb.scp # /path/to/your/filelist
inference.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import librosa
5
+ import yaml
6
+ import joblib
7
+ import argparse
8
+
9
+ import soundfile as sf
10
+ import numpy as np
11
+
12
+ from pathlib import Path
13
+ from collections import defaultdict
14
+ from typing import Optional
15
+ from tqdm import tqdm
16
+
17
+ sys.path.append(os.path.dirname(__file__))
18
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
19
+
20
+ # Torch
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import torch.distributed as dist
25
+ from torch.nn.parallel import DistributedDataParallel
26
+
27
+ # WavLM
28
+ from nnet.WavLM import WavLM, WavLMConfig
29
+
30
+ # Xcodec2
31
+ from vq.codec_encoder import CodecEncoder_Transformer
32
+ from vq.codec_decoder_vocos import CodecDecoderVocos
33
+ from vq.module import SemanticEncoder
34
+ from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
35
+ from collections import OrderedDict
36
+
37
+ # Dataloader
38
+ from loader.datareader import DataReader
39
+ from loader.datareader_aec import DataReaderAEC
40
+ from loader.datareader_tse import DataReaderTSE
41
+
42
+ # LLaSE
43
+ from nnet.llase import LLM_AR as model
44
+
45
+ class Encodec():
46
+ '''
47
+ Load Xcodec2
48
+ '''
49
+ def __init__(self,device="cpu") -> None:
50
+ self.device=device
51
+ ckpt = "./ckpt/codec_ckpt/epoch=4-step=1400000.ckpt",
52
+ # ckpt = '/home/bykang/codec_ckpt/epoch=4-step=1400000.ckpt'
53
+ ckpt = torch.load(ckpt, map_location='cpu')
54
+ state_dict = ckpt['state_dict']
55
+ filtered_state_dict_codec = OrderedDict()
56
+ filtered_state_dict_semantic_encoder = OrderedDict()
57
+ filtered_state_dict_gen = OrderedDict()
58
+ filtered_state_dict_fc_post_a = OrderedDict()
59
+ filtered_state_dict_fc_prior = OrderedDict()
60
+ for key, value in state_dict.items():
61
+ if key.startswith('CodecEnc.'):
62
+ new_key = key[len('CodecEnc.'):]
63
+ filtered_state_dict_codec[new_key] = value
64
+ elif key.startswith('generator.'):
65
+ new_key = key[len('generator.'):]
66
+ filtered_state_dict_gen[new_key] = value
67
+ elif key.startswith('fc_post_a.'):
68
+ new_key = key[len('fc_post_a.'):]
69
+ filtered_state_dict_fc_post_a[new_key] = value
70
+ elif key.startswith('SemanticEncoder_module.'):
71
+ new_key = key[len('SemanticEncoder_module.'):]
72
+ filtered_state_dict_semantic_encoder[new_key] = value
73
+ elif key.startswith('fc_prior.'):
74
+ new_key = key[len('fc_prior.'):]
75
+ filtered_state_dict_fc_prior[new_key] = value
76
+
77
+ self.semantic_model = Wav2Vec2BertModel.from_pretrained(
78
+ "./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0",
79
+ # "/home/bykang/codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b",
80
+ output_hidden_states=True)
81
+ self.semantic_model=self.semantic_model.eval().to(self.device)
82
+
83
+ self.SemanticEncoder_module = SemanticEncoder(1024,1024,1024)
84
+ self.SemanticEncoder_module.load_state_dict(filtered_state_dict_semantic_encoder)
85
+ self.SemanticEncoder_module = self.SemanticEncoder_module.eval().to(self.device)
86
+
87
+ self.encoder = CodecEncoder_Transformer()
88
+ self.encoder.load_state_dict(filtered_state_dict_codec)
89
+ self.encoder = self.encoder.eval().to(self.device)
90
+
91
+ self.decoder = CodecDecoderVocos()
92
+ self.decoder.load_state_dict(filtered_state_dict_gen)
93
+ self.decoder = self.decoder.eval().to(self.device)
94
+
95
+ self.fc_post_a = nn.Linear( 2048, 1024 )
96
+ self.fc_post_a.load_state_dict(filtered_state_dict_fc_post_a)
97
+ self.fc_post_a = self.fc_post_a.eval().to(self.device)
98
+
99
+ self.fc_prior = nn.Linear( 2048, 2048 )
100
+ self.fc_prior.load_state_dict(filtered_state_dict_fc_prior)
101
+ self.fc_prior = self.fc_prior.eval().to(self.device)
102
+
103
+ self.feature_extractor = AutoFeatureExtractor.from_pretrained(
104
+ "./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0")
105
+ # "/home/bykang/codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b")
106
+
107
+ def get_feat(self, wav_batch, pad=None):
108
+
109
+ if len(wav_batch.shape) != 2:
110
+ return self.feature_extractor(F.pad(wav_batch, pad), sampling_rate=16000, return_tensors="pt") .data['input_features']
111
+
112
+ padded_wavs = torch.stack([F.pad(wav, pad) for wav in wav_batch])
113
+ batch_feats = []
114
+
115
+ for wav in padded_wavs:
116
+ feat = self.feature_extractor(
117
+ wav,
118
+ sampling_rate=16000,
119
+ return_tensors="pt"
120
+ ).data['input_features']
121
+
122
+ batch_feats.append(feat)
123
+ feat_batch = torch.concat(batch_feats, dim=0).to(self.device)
124
+ return feat_batch
125
+
126
+ def get_embedding(self, wav_cpu):
127
+ wav_cpu = wav_cpu.cpu()
128
+ feat = self.get_feat(wav_cpu,pad=(160,160))
129
+ feat = feat.to(self.device)
130
+
131
+ if(len(wav_cpu.shape)==1):
132
+ wav = wav_cpu.unsqueeze(0).to(self.device)
133
+ else:
134
+ wav = wav_cpu.to(self.device)
135
+
136
+ wav = torch.nn.functional.pad(wav, (0, (200 - (wav.shape[1] % 200))))
137
+ with torch.no_grad():
138
+ vq_emb = self.encoder(wav.unsqueeze(1))
139
+ vq_emb = vq_emb.transpose(1, 2)
140
+
141
+ if vq_emb.shape[2]!=feat.shape[1]:
142
+ feat = self.get_feat(wav_cpu)
143
+ feat = feat.to(self.device)
144
+
145
+ semantic_target = self.semantic_model(feat[:, :,:])
146
+ semantic_target = semantic_target.hidden_states[16]
147
+ semantic_target = semantic_target.transpose(1, 2)
148
+ semantic_target = self.SemanticEncoder_module(semantic_target)
149
+
150
+ vq_emb = torch.cat([semantic_target, vq_emb], dim=1)
151
+
152
+ return vq_emb
153
+
154
+ def emb2token(self, emb):
155
+ emb.to(self.device)
156
+ emb = self.fc_prior(emb.transpose(1, 2)).transpose(1, 2)
157
+ _, vq_code, _ = self.decoder(emb, vq=True)
158
+ return vq_code
159
+
160
+ def token2wav(self, vq_code):
161
+ vq_code.to(self.device)
162
+ vq_post_emb = self.decoder.quantizer.get_output_from_indices(vq_code.transpose(1, 2))
163
+ vq_post_emb = vq_post_emb.transpose(1, 2)
164
+ vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1,2)).transpose(1,2)
165
+ recon = self.decoder(vq_post_emb.transpose(1, 2), vq=False)[0].squeeze()
166
+ # if write the wav, add .squeeze().detach().cpu().numpy()
167
+ # if need gradient use the config right now
168
+ return recon
169
+
170
+ class WavLM_feat(object):
171
+ '''
172
+ Load WavLM
173
+ '''
174
+ def __init__(self, device):
175
+ self.wavlm = self._reload_wavLM_large(device=device)
176
+
177
+ def __call__(self, wav):
178
+ T = wav.shape[-1]
179
+ wav = wav.reshape(-1, T)
180
+ with torch.no_grad():
181
+ feat = self.wavlm.extract_features(wav, output_layer=6, ret_layer_results=False)[0]
182
+ B, T, D = feat.shape
183
+ feat = torch.reshape(feat, (-1, D))
184
+
185
+ return feat
186
+
187
+ def _reload_wavLM_large(self, path="/home/bykang/WavLM-Large.pt", device: Optional[torch.device] = None):
188
+ cpt = torch.load(path, map_location="cpu")
189
+ cfg = WavLMConfig(cpt['cfg'])
190
+ wavLM = WavLM(cfg)
191
+ wavLM.load_state_dict(cpt['model'])
192
+ wavLM.eval()
193
+ if device != None:
194
+ wavLM = wavLM.to(device)
195
+ for p in wavLM.parameters():
196
+ p.requires_grad = False
197
+ print('successful to reload wavLM', path)
198
+ return wavLM
199
+
200
+ def get_firstchannel_read(path, fs=16000):
201
+ '''
202
+ Get first channel of the wav
203
+ '''
204
+ wave_data, sr = sf.read(path)
205
+ if sr != fs:
206
+ if len(wave_data.shape) != 1:
207
+ wave_data = wave_data.transpose((1, 0))
208
+ wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs)
209
+ if len(wave_data.shape) != 1:
210
+ wave_data = wave_data.transpose((1, 0))
211
+ if len(wave_data.shape) > 1:
212
+ wave_data = wave_data[:, 0]
213
+ return wave_data
214
+
215
+ def load_obj(obj, device):
216
+ '''
217
+ Offload tensor object in obj to cuda device
218
+ '''
219
+ def cuda(obj):
220
+ return obj.to(device) if isinstance(obj, torch.Tensor) else obj
221
+
222
+ if isinstance(obj, dict):
223
+ return {key: load_obj(obj[key], device) for key in obj}
224
+ elif isinstance(obj, list):
225
+ return [load_obj(val, device) for val in obj]
226
+ else:
227
+ return cuda(obj)
228
+
229
+ def run(args):
230
+ LOCAL_RANK = int(os.environ['LOCAL_RANK'])
231
+ WORLD_SIZE = int(os.environ['WORLD_SIZE'])
232
+ WORLD_RANK = int(os.environ['RANK'])
233
+ dist.init_process_group(args.backend, rank=WORLD_RANK, world_size=WORLD_SIZE)
234
+ torch.cuda.set_device(LOCAL_RANK)
235
+ device = torch.device('cuda', LOCAL_RANK)
236
+ print(f"[{os.getpid()}] using device: {device}", torch.cuda.current_device(), "local rank", LOCAL_RANK)
237
+
238
+ with open(args.conf, "r") as f:
239
+ conf = yaml.load(f, Loader=yaml.FullLoader)
240
+
241
+ # Dataloader
242
+ if conf["task"]=="AEC":
243
+ data_reader = DataReaderAEC(**conf["datareader"])
244
+ elif conf["task"]=="TSE":
245
+ data_reader = DataReaderTSE(**conf["datareader"])
246
+ else:
247
+ data_reader = DataReader(**conf["datareader"])
248
+
249
+ # Load WavLM and XCodec2
250
+ codec = Encodec(device)
251
+ wavlm_feat = WavLM_feat(device)
252
+
253
+ # Load LLaSE
254
+ nnet = model(**conf["nnet_conf"])
255
+ cpt_fname = Path(conf["test"]["checkpoint"])
256
+ cpt = torch.load(cpt_fname, map_location="cpu")
257
+
258
+ nnet = nnet.to(device)
259
+ nnet = DistributedDataParallel(nnet, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, find_unused_parameters=True)
260
+ nnet.load_state_dict(cpt["model_state_dict"])
261
+ nnet.eval()
262
+
263
+ # Make sure the dir exists
264
+ if conf["task"]=="AEC":
265
+ if not os.path.exists(conf["save"]["feat_dir"]+"/mic"):
266
+ os.makedirs(conf["save"]["feat_dir"]+"/mic")
267
+ if not os.path.exists(conf["save"]["feat_dir"]+"/ref"):
268
+ os.makedirs(conf["save"]["feat_dir"]+"/ref")
269
+ elif conf["task"]=="TSE":
270
+ if not os.path.exists(conf["save"]["feat_dir"]+"/mic"):
271
+ os.makedirs(conf["save"]["feat_dir"]+"/mic")
272
+ if not os.path.exists(conf["save"]["feat_dir"]+"/ref"):
273
+ os.makedirs(conf["save"]["feat_dir"]+"/ref")
274
+ else:
275
+ if not os.path.exists(conf["save"]["feat_dir"]):
276
+ os.makedirs(conf["save"]["feat_dir"])
277
+
278
+ if not os.path.exists(conf["save"]["wav_dir"]):
279
+ os.makedirs(conf["save"]["wav_dir"])
280
+
281
+ # Main of inference
282
+ if_feat_too = conf["test"]["infer_feat_too"]
283
+
284
+ origin_feat_dir = conf["save"]["feat_dir"]
285
+ origin_wav_dir = conf["save"]["wav_dir"]
286
+
287
+ last_feat_dir = origin_feat_dir
288
+ last_wav_dir = origin_wav_dir
289
+
290
+ for inference_time in range(conf["test"]["inference_time"]):
291
+ # For multi-inference
292
+ if inference_time > 0:
293
+ feat_dir = origin_feat_dir + "inference" + str(inference_time)
294
+ wav_dir = origin_wav_dir + "inference" + str(inference_time)
295
+ else:
296
+ feat_dir = origin_feat_dir
297
+ wav_dir = origin_wav_dir
298
+
299
+ if not os.path.exists(feat_dir):
300
+ os.makedirs(feat_dir)
301
+ if not os.path.exists(wav_dir):
302
+ os.makedirs(wav_dir)
303
+
304
+ with torch.no_grad():
305
+ # Extract WavLM features
306
+ if if_feat_too ==True or inference_time>0:
307
+ for egs in tqdm(data_reader):
308
+ egs = load_obj(egs, device)
309
+
310
+ if conf["task"]=="AEC" or conf["task"]=="TSE":
311
+ if inference_time > 0:
312
+ mic_path = last_wav_dir + '/' + egs["mic_name"] + ".wav"
313
+ egs["mic"] = torch.from_numpy(get_firstchannel_read(mic_path).astype(np.float32)).unsqueeze(0).to(device)
314
+ else:
315
+ egs["mic"]=egs["mic"].contiguous()
316
+ egs["ref"]=egs["ref"].contiguous()
317
+
318
+ feat_mic = wavlm_feat(egs["mic"])
319
+ out_mic = feat_mic.detach().squeeze(0).cpu().numpy()
320
+
321
+ if not os.path.exists(os.path.join(feat_dir, "mic")):
322
+ os.makedirs(os.path.join(feat_dir, "mic"))
323
+ np.save(os.path.join(feat_dir, "mic", egs["mic_name"]), out_mic)
324
+
325
+ # For AEC and TSE, reference audio only need to extract feats at first time
326
+ if inference_time == 0:
327
+ feat_ref = wavlm_feat(egs["ref"])
328
+ out_ref = feat_ref.detach().squeeze(0).cpu().numpy()
329
+ np.save(os.path.join(origin_feat_dir, "ref", egs["ref_name"]), out_ref)
330
+
331
+ torch.cuda.empty_cache()
332
+
333
+ else:
334
+ if inference_time > 0:
335
+ mix_path = last_wav_dir + '/' + egs["name"] + ".wav"
336
+ egs["mix"] = torch.from_numpy(get_firstchannel_read(mix_path).astype(np.float32)).unsqueeze(0).to(device)
337
+ else:
338
+ egs["mix"]=egs["mix"].contiguous()
339
+
340
+ feat = wavlm_feat(egs["mix"])
341
+ out = feat.detach().squeeze(0).cpu().numpy()
342
+ np.save(os.path.join(feat_dir, egs["name"]), out)
343
+
344
+ # Predict the clean tokens and token2wav
345
+ for egs in tqdm(data_reader):
346
+ egs = load_obj(egs, device)
347
+ sr = 16000
348
+
349
+ if conf["task"] == "AEC":
350
+ # Get feat
351
+ feat_path_mic = os.path.join(feat_dir, "mic", egs["mic_name"]) + ".npy"
352
+ feat_path_ref = os.path.join(origin_feat_dir, "ref", egs["ref_name"]) + ".npy"
353
+
354
+ feat_mic = torch.from_numpy(np.load(feat_path_mic)).unsqueeze(0)
355
+ feat_ref = torch.from_numpy(np.load(feat_path_ref)).unsqueeze(0)
356
+
357
+ # For multi-inference
358
+ if inference_time > 0:
359
+ est = nnet(feat_mic)
360
+ else:
361
+ est = nnet(feat_mic, feat_ref)
362
+
363
+ # Get tokens and token2wav
364
+ max, max_indices_1 = torch.max(est[1], dim=1)
365
+ recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy()
366
+
367
+ # Save the wav
368
+ target_path = os.path.join(wav_dir, egs["mic_name"] + ".wav")
369
+ print(target_path)
370
+ sf.write(target_path , recon_1, sr)
371
+
372
+ elif conf["task"] == "TSE" :
373
+ # Get feat
374
+ feat_path_mic = os.path.join(feat_dir, "mic", egs["mic_name"]) + ".npy"
375
+ feat_path_ref = os.path.join(origin_feat_dir, "ref", egs["ref_name"]) + ".npy"
376
+
377
+ feat_mic = torch.from_numpy(np.load(feat_path_mic)).unsqueeze(0)
378
+ feat_ref = torch.from_numpy(np.load(feat_path_ref)).unsqueeze(0)
379
+
380
+ # Choose if keep the enroallment audio while multi-inference
381
+ if_keep_ref = True
382
+
383
+ if inference_time>0 and if_keep_ref== False:
384
+ est = nnet(feat_mic)
385
+ else:
386
+ est = nnet(feat_mic, feat_ref)
387
+
388
+ # Get tokens and token2wav
389
+ max, max_indices_1 = torch.max(est[0], dim=1)
390
+ recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy()
391
+
392
+ # Save the wav
393
+ target_path = os.path.join(wav_dir, egs["mic_name"] + ".wav")
394
+ print(target_path)
395
+ sf.write(target_path , recon_1, sr)
396
+
397
+ elif conf["task"] == "PLC":
398
+ # Get feat
399
+ feat_path = os.path.join(feat_dir, egs["name"]) + ".npy"
400
+ feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0)
401
+
402
+ # Get tokens and token2wav
403
+ est = nnet(feat)
404
+ max, max_indices_1 = torch.max(est[1], dim=1)
405
+ recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy()
406
+
407
+ # Save the wav
408
+ target_path = os.path.join(wav_dir, egs["name"] + ".wav")
409
+ print(target_path)
410
+ sf.write(target_path , recon_1, sr)
411
+
412
+ elif conf["task"] == "SS":
413
+ # Get feat
414
+ feat_path = os.path.join(feat_dir, egs["name"]) + ".npy"
415
+ feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0)
416
+
417
+ # Separate the first speaker
418
+ est = nnet(feat)
419
+ max, max_indices_1 = torch.max(est[1], dim=1)
420
+ recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy()
421
+
422
+ target_path_1 = os.path.join(wav_dir, egs["name"] + ".wav")
423
+ sf.write(target_path_1 , recon_1, sr)
424
+
425
+ # Separate the second speaker, SS need at least 2 inference time in config
426
+ if inference_time > 0:
427
+ origin_feat_path = os.path.join(origin_feat_dir, egs["name"]) + ".npy"
428
+ origin_feat = torch.from_numpy(np.load(origin_feat_path)).unsqueeze(0)
429
+
430
+ est2 = nnet(origin_feat, feat)
431
+ max, max_indices_2 = torch.max(est2[1], dim=1)
432
+ recon_2 = codec.token2wav(max_indices_2.unsqueeze(0)).squeeze().detach().cpu().numpy()
433
+
434
+ if not os.path.exists(last_wav_dir + "s2"):
435
+ os.makedirs(last_wav_dir + "s2")
436
+
437
+ target_path_2 = os.path.join(last_wav_dir + "s2", egs["name"] + ".wav")
438
+ sf.write(target_path_2 , recon_2, sr)
439
+
440
+ else:
441
+ # Get feat
442
+ feat_path = os.path.join(feat_dir, egs["name"]) + ".npy"
443
+ feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0)
444
+
445
+ # Get tokens and token2wav
446
+ est = nnet(feat)
447
+ max, max_indices_1 = torch.max(est[1], dim=1)
448
+ recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy()
449
+
450
+ # Save the wav
451
+ target_path = os.path.join(wav_dir, egs["name"] + ".wav")
452
+ print(target_path)
453
+ sf.write(target_path , recon_1, sr)
454
+
455
+ # For next inference
456
+ last_feat_dir = feat_dir
457
+ last_wav_dir = wav_dir
458
+
459
+ if __name__ == "__main__":
460
+ parser = argparse.ArgumentParser(
461
+ description = "Command to test separation model in Pytorch",
462
+ formatter_class = argparse.ArgumentDefaultsHelpFormatter)
463
+ parser.add_argument("-conf",
464
+ type=str,
465
+ required=True,
466
+ help="Yaml configuration file for training")
467
+ parser.add_argument("--backend",
468
+ type=str,
469
+ default="nccl",
470
+ choices=["nccl", "gloo"])
471
+ args = parser.parse_args()
472
+ # for nccl debug
473
+ os.environ["NCCL_DEBUG"] = "INFO"
474
+ run(args)
inference.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=1 torchrun \
2
+ --nnodes=1 \
3
+ --nproc_per_node=1 \
4
+ --master_port=21547 \
5
+ inference.py \
6
+ -conf ./config/test.yml
loader/__pycache__/datareader.cpython-310.pyc ADDED
Binary file (2.44 kB). View file
 
loader/__pycache__/datareader_aec.cpython-310.pyc ADDED
Binary file (2.54 kB). View file
 
loader/__pycache__/datareader_fe.cpython-310.pyc ADDED
Binary file (2.44 kB). View file
 
loader/__pycache__/datareader_tse.cpython-310.pyc ADDED
Binary file (2.53 kB). View file
 
loader/datareader.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torchaudio
3
+ import torch
4
+
5
+ def get_firstchannel_read(path, fs=16000):
6
+ wave_data, sr = torchaudio.load(path)
7
+ if sr != fs:
8
+ wave_data = torchaudio.functional.resample(wave_data, sr, fs)
9
+ if len(wave_data.shape) > 1:
10
+ wave_data = wave_data[0,...]
11
+ wave_data = wave_data.cpu().numpy()
12
+ return wave_data
13
+
14
+ def parse_scp(scp, path_list):
15
+ with open(scp) as fid:
16
+ for line in fid:
17
+ tmp = line.strip().split()
18
+ if len(tmp) > 1:
19
+ path_list.append({"inputs": tmp[0], "duration": tmp[1]})
20
+ else:
21
+ path_list.append({"inputs": tmp[0]})
22
+
23
+ class DataReader(object):
24
+ def __init__(self, filename, sample_rate):
25
+ self.file_list = []
26
+ self.sample_rate = sample_rate
27
+ parse_scp(filename, self.file_list)
28
+
29
+ def extract_feature(self, path):
30
+ path = path["inputs"]
31
+ name = path.split("/")[-1].split(".")[0]
32
+ data = get_firstchannel_read(path, fs=self.sample_rate).astype(np.float32)
33
+ max_norm = np.max(np.abs(data))
34
+ if max_norm == 0:
35
+ max_norm = 1
36
+ data = data / max_norm
37
+ inputs = np.reshape(data, [1, data.shape[0]])
38
+ inputs = torch.from_numpy(inputs)
39
+
40
+ egs = {
41
+ "mix": inputs,
42
+ "max_norm": max_norm,
43
+ "name": name
44
+ }
45
+ return egs
46
+
47
+ def __len__(self):
48
+ return len(self.file_list)
49
+
50
+ def __getitem__(self, index):
51
+ return self.extract_feature(self.file_list[index])
52
+
53
+ def get_utt2spk(self, path):
54
+ lines = open(path, "r").readlines()
55
+ for line in lines:
56
+ line = line.strip().split()
57
+ utt_path, spk_id = line[0], line[1]
58
+ self.utt2spk[utt_path] = spk_id
59
+
60
+ def get_spk2utt(self, path):
61
+ lines = open(path, "r").readlines()
62
+ for line in lines:
63
+ line = line.strip().split()
64
+ utt_path, spk_id = line[0], line[1]
65
+ self.spk2aux[spk_id] = utt_path
loader/datareader_aec.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch as th
3
+ import numpy as np
4
+ import soundfile as sf
5
+
6
+ import sys, os
7
+ sys.path.append(os.path.dirname(__file__))
8
+ # from speex_linear.lp_or_tde import LP_or_TDE
9
+
10
+
11
+ def audio(path, fs=16000):
12
+ wave_data, sr = sf.read(path)
13
+ if sr != fs:
14
+ wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs)
15
+ return wave_data
16
+
17
+ def get_firstchannel_read(path, fs=16000):
18
+ wave_data, sr = sf.read(path)
19
+ if sr != fs:
20
+ if len(wave_data.shape) != 1:
21
+ wave_data = wave_data.transpose((1, 0))
22
+ wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs)
23
+ if len(wave_data.shape) != 1:
24
+ wave_data = wave_data.transpose((1, 0))
25
+ if len(wave_data.shape) > 1:
26
+ wave_data = wave_data[:, 0]
27
+ return wave_data
28
+
29
+ def parse_scp(scp, path_list):
30
+ with open(scp) as fid:
31
+ for line in fid:
32
+ tmp = line.strip().split()
33
+ if len(tmp) > 1:
34
+ path_list.append({"inputs": tmp[0], "duration": tmp[1]})
35
+ else:
36
+ path_list.append({"inputs": tmp[0]})
37
+
38
+ class DataReaderAEC(object):
39
+ def __init__(self, filename, sample_rate): #, aux_segment): # filename是不带id的待解码音频,noisy_id是带id的带解码音频,clean是带id的注册音频
40
+ self.file_list = []
41
+ parse_scp(filename, self.file_list)
42
+ self.sample_rate = sample_rate
43
+
44
+ # self.aux_segment_length = aux_segment * sample_rate
45
+
46
+ def extract_feature(self, path):
47
+ mic_path = path["inputs"]
48
+ utt_id = mic_path.split("/")[-1]
49
+ mic_name = mic_path.split("/")[-1].split(".")[0]
50
+
51
+ ref_path = mic_path.replace("mic.wav", "lpb.wav")
52
+ ref_name = ref_path.split("/")[-1].split(".")[0]
53
+
54
+ mic = get_firstchannel_read(mic_path, self.sample_rate).astype(np.float32)
55
+ ref = get_firstchannel_read(ref_path, self.sample_rate).astype(np.float32)
56
+
57
+ min_len = min(mic.shape[0], ref.shape[0])
58
+ mic = mic[:min_len]
59
+ ref = ref[:min_len]
60
+
61
+ inputs_mic = np.reshape(mic, [1, mic.shape[0]])
62
+ inputs_ref = np.reshape(ref, [1, ref.shape[0]]).astype(np.float32)
63
+
64
+
65
+ inputs_mic = th.from_numpy(inputs_mic)
66
+ inputs_ref = th.from_numpy(inputs_ref)
67
+
68
+ # print(f'e: {inputs_e.shape}')
69
+ # print(f'mic: {inputs_mic.shape}')
70
+ # print(f'ref: {inputs_ref.shape}')
71
+
72
+ egs = {
73
+ "mic": inputs_mic,
74
+ "ref": inputs_ref,
75
+ "utt_id": utt_id,
76
+ "mic_name": mic_name,
77
+ "ref_name": ref_name
78
+ # "max_norm": max_norm
79
+ }
80
+ return egs
81
+
82
+ def __len__(self):
83
+ return len(self.file_list)
84
+
85
+ def __getitem__(self, index):
86
+ return self.extract_feature(self.file_list[index])
loader/datareader_tse.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch as th
3
+ import numpy as np
4
+ import soundfile as sf
5
+
6
+ import sys, os
7
+ sys.path.append(os.path.dirname(__file__))
8
+ # from speex_linear.lp_or_tde import LP_or_TDE
9
+
10
+ def audio(path, fs=16000):
11
+ wave_data, sr = sf.read(path)
12
+ if sr != fs:
13
+ wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs)
14
+ return wave_data
15
+
16
+ def get_firstchannel_read(path, fs=16000):
17
+ wave_data, sr = sf.read(path)
18
+ if sr != fs:
19
+ if len(wave_data.shape) != 1:
20
+ wave_data = wave_data.transpose((1, 0))
21
+ wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs)
22
+ if len(wave_data.shape) != 1:
23
+ wave_data = wave_data.transpose((1, 0))
24
+ if len(wave_data.shape) > 1:
25
+ wave_data = wave_data[:, 0]
26
+ return wave_data
27
+
28
+ def parse_scp(scp, path_list):
29
+ with open(scp) as fid:
30
+ for line in fid:
31
+ tmp = line.strip().split()
32
+ if len(tmp) > 1:
33
+ path_list.append({"inputs": tmp[0], "duration": tmp[1]})
34
+ else:
35
+ path_list.append({"inputs": tmp[0]})
36
+
37
+ class DataReaderTSE(object):
38
+ def __init__(self, filename, sample_rate): # filename是不带id的待解码音频,noisy_id是带id的带解码音频,clean是带id的注册音频
39
+ self.file_list = []
40
+ parse_scp(filename, self.file_list)
41
+ self.sample_rate = sample_rate
42
+
43
+ def extract_feature(self, path):
44
+ mic_path = path["inputs"]
45
+ utt_id = mic_path.split("/")[-1]
46
+ mic_name = mic_path.split("/")[-1].split(".")[0]
47
+
48
+ ref_path = mic_path.replace("noisy/", "enrol/")
49
+ ref_name = ref_path.split("/")[-1].split(".")[0]
50
+
51
+ mic = get_firstchannel_read(mic_path, self.sample_rate).astype(np.float32)
52
+ ref = get_firstchannel_read(ref_path, self.sample_rate).astype(np.float32)
53
+
54
+ if ref.shape[0] > mic.shape[0]:
55
+ min_len = mic.shape[0]
56
+ ref = ref[:min_len]
57
+
58
+ # print(ref.shape[0])
59
+ # print(mic.shape[0])
60
+
61
+ inputs_mic = np.reshape(mic, [1, mic.shape[0]]).astype(np.float32)
62
+ inputs_ref = np.reshape(ref, [1, ref.shape[0]]).astype(np.float32)
63
+
64
+ inputs_mic = th.from_numpy(inputs_mic)
65
+ inputs_ref = th.from_numpy(inputs_ref)
66
+
67
+ # print(f'e: {inputs_e.shape}')
68
+ # print(f'mic: {inputs_mic.shape}')
69
+ # print(f'ref: {inputs_ref.shape}')
70
+
71
+ egs = {
72
+ "mic": inputs_mic,
73
+ "ref": inputs_ref,
74
+ "utt_id": utt_id,
75
+ "mic_name": mic_name,
76
+ "ref_name": ref_name
77
+ # "max_norm": max_norm
78
+ }
79
+ return egs
80
+
81
+ def __len__(self):
82
+ return len(self.file_list)
83
+
84
+ def __getitem__(self, index):
85
+ return self.extract_feature(self.file_list[index])
nnet/WavLM.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import logging
12
+ from typing import List, Optional, Tuple
13
+
14
+ import sys,os
15
+ sys.path.append(os.path.dirname(sys.path[0]))
16
+ import numpy as np
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.nn import LayerNorm
22
+ from nnet.modules import (
23
+ Fp32GroupNorm,
24
+ Fp32LayerNorm,
25
+ GradMultiply,
26
+ MultiheadAttention,
27
+ SamePad,
28
+ init_bert_params,
29
+ get_activation_fn,
30
+ TransposeLast,
31
+ GLU_Linear,
32
+ )
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ def compute_mask_indices(
38
+ shape: Tuple[int, int],
39
+ padding_mask: Optional[torch.Tensor],
40
+ mask_prob: float,
41
+ mask_length: int,
42
+ mask_type: str = "static",
43
+ mask_other: float = 0.0,
44
+ min_masks: int = 0,
45
+ no_overlap: bool = False,
46
+ min_space: int = 0,
47
+ ) -> np.ndarray:
48
+ """
49
+ Computes random mask spans for a given shape
50
+
51
+ Args:
52
+ shape: the the shape for which to compute masks.
53
+ should be of size 2 where first element is batch size and 2nd is timesteps
54
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
55
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
56
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
57
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
58
+ mask_type: how to compute mask lengths
59
+ static = fixed size
60
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
61
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
62
+ poisson = sample from possion distribution with lambda = mask length
63
+ min_masks: minimum number of masked spans
64
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
65
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
66
+ """
67
+
68
+ bsz, all_sz = shape
69
+ mask = np.full((bsz, all_sz), False)
70
+
71
+ all_num_mask = int(
72
+ # add a random number for probabilistic rounding
73
+ mask_prob * all_sz / float(mask_length)
74
+ + np.random.rand()
75
+ )
76
+
77
+ all_num_mask = max(min_masks, all_num_mask)
78
+
79
+ mask_idcs = []
80
+ for i in range(bsz):
81
+ if padding_mask is not None:
82
+ sz = all_sz - padding_mask[i].long().sum().item()
83
+ num_mask = int(
84
+ # add a random number for probabilistic rounding
85
+ mask_prob * sz / float(mask_length)
86
+ + np.random.rand()
87
+ )
88
+ num_mask = max(min_masks, num_mask)
89
+ else:
90
+ sz = all_sz
91
+ num_mask = all_num_mask
92
+
93
+ if mask_type == "static":
94
+ lengths = np.full(num_mask, mask_length)
95
+ elif mask_type == "uniform":
96
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
97
+ elif mask_type == "normal":
98
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
99
+ lengths = [max(1, int(round(x))) for x in lengths]
100
+ elif mask_type == "poisson":
101
+ lengths = np.random.poisson(mask_length, size=num_mask)
102
+ lengths = [int(round(x)) for x in lengths]
103
+ else:
104
+ raise Exception("unknown mask selection " + mask_type)
105
+
106
+ if sum(lengths) == 0:
107
+ lengths[0] = min(mask_length, sz - 1)
108
+
109
+ if no_overlap:
110
+ mask_idc = []
111
+
112
+ def arrange(s, e, length, keep_length):
113
+ span_start = np.random.randint(s, e - length)
114
+ mask_idc.extend(span_start + i for i in range(length))
115
+
116
+ new_parts = []
117
+ if span_start - s - min_space >= keep_length:
118
+ new_parts.append((s, span_start - min_space + 1))
119
+ if e - span_start - keep_length - min_space > keep_length:
120
+ new_parts.append((span_start + length + min_space, e))
121
+ return new_parts
122
+
123
+ parts = [(0, sz)]
124
+ min_length = min(lengths)
125
+ for length in sorted(lengths, reverse=True):
126
+ lens = np.fromiter(
127
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
128
+ np.int,
129
+ )
130
+ l_sum = np.sum(lens)
131
+ if l_sum == 0:
132
+ break
133
+ probs = lens / np.sum(lens)
134
+ c = np.random.choice(len(parts), p=probs)
135
+ s, e = parts.pop(c)
136
+ parts.extend(arrange(s, e, length, min_length))
137
+ mask_idc = np.asarray(mask_idc)
138
+ else:
139
+ min_len = min(lengths)
140
+ if sz - min_len <= num_mask:
141
+ min_len = sz - num_mask - 1
142
+
143
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
144
+
145
+ mask_idc = np.asarray(
146
+ [
147
+ mask_idc[j] + offset
148
+ for j in range(len(mask_idc))
149
+ for offset in range(lengths[j])
150
+ ]
151
+ )
152
+
153
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
154
+
155
+ min_len = min([len(m) for m in mask_idcs])
156
+ for i, mask_idc in enumerate(mask_idcs):
157
+ if len(mask_idc) > min_len:
158
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
159
+ mask[i, mask_idc] = True
160
+
161
+ return mask
162
+
163
+
164
+ class WavLMConfig:
165
+ def __init__(self, cfg=None):
166
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
167
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
168
+
169
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
170
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
171
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
172
+ self.activation_fn: str = "gelu" # activation function to use
173
+
174
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
175
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
176
+ self.conv_bias: bool = False # include bias in conv encoder
177
+ self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
178
+
179
+ self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
180
+
181
+ # dropouts
182
+ self.dropout: float = 0.1 # dropout probability for the transformer
183
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
184
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
185
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
186
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
187
+ self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
188
+
189
+ # masking
190
+ self.mask_length: int = 10 # mask length
191
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
192
+ self.mask_selection: str = "static" # how to choose mask length
193
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
194
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
195
+ self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
196
+
197
+ # channel masking
198
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
199
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
200
+ self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
201
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
202
+ self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
203
+ self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
204
+
205
+ # positional embeddings
206
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
207
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
208
+
209
+ # relative position embedding
210
+ self.relative_position_embedding: bool = False # apply relative position embedding
211
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
212
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
213
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
214
+
215
+ if cfg is not None:
216
+ self.update(cfg)
217
+
218
+ def update(self, cfg: dict):
219
+ self.__dict__.update(cfg)
220
+
221
+
222
+ class WavLM(nn.Module):
223
+ def __init__(
224
+ self,
225
+ cfg: WavLMConfig,
226
+ ) -> None:
227
+ super().__init__()
228
+ logger.info(f"WavLM Config: {cfg.__dict__}")
229
+
230
+ self.cfg = cfg
231
+ feature_enc_layers = eval(cfg.conv_feature_layers)
232
+ self.embed = feature_enc_layers[-1][0]
233
+
234
+ self.feature_extractor = ConvFeatureExtractionModel(
235
+ conv_layers=feature_enc_layers,
236
+ dropout=0.0,
237
+ mode=cfg.extractor_mode,
238
+ conv_bias=cfg.conv_bias,
239
+ )
240
+
241
+ self.post_extract_proj = (
242
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
243
+ if self.embed != cfg.encoder_embed_dim
244
+ else None
245
+ )
246
+
247
+ self.mask_prob = cfg.mask_prob
248
+ self.mask_selection = cfg.mask_selection
249
+ self.mask_other = cfg.mask_other
250
+ self.mask_length = cfg.mask_length
251
+ self.no_mask_overlap = cfg.no_mask_overlap
252
+ self.mask_min_space = cfg.mask_min_space
253
+
254
+ self.mask_channel_prob = cfg.mask_channel_prob
255
+ self.mask_channel_selection = cfg.mask_channel_selection
256
+ self.mask_channel_other = cfg.mask_channel_other
257
+ self.mask_channel_length = cfg.mask_channel_length
258
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
259
+ self.mask_channel_min_space = cfg.mask_channel_min_space
260
+
261
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
262
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
263
+
264
+ self.feature_grad_mult = cfg.feature_grad_mult
265
+
266
+ self.mask_emb = nn.Parameter(
267
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
268
+ )
269
+
270
+ self.encoder = TransformerEncoder(cfg)
271
+ self.layer_norm = LayerNorm(self.embed)
272
+
273
+ def apply_mask(self, x, padding_mask):
274
+ B, T, C = x.shape
275
+ if self.mask_prob > 0:
276
+ mask_indices = compute_mask_indices(
277
+ (B, T),
278
+ padding_mask,
279
+ self.mask_prob,
280
+ self.mask_length,
281
+ self.mask_selection,
282
+ self.mask_other,
283
+ min_masks=2,
284
+ no_overlap=self.no_mask_overlap,
285
+ min_space=self.mask_min_space,
286
+ )
287
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
288
+ x[mask_indices] = self.mask_emb
289
+ else:
290
+ mask_indices = None
291
+
292
+ if self.mask_channel_prob > 0:
293
+ mask_channel_indices = compute_mask_indices(
294
+ (B, C),
295
+ None,
296
+ self.mask_channel_prob,
297
+ self.mask_channel_length,
298
+ self.mask_channel_selection,
299
+ self.mask_channel_other,
300
+ no_overlap=self.no_mask_channel_overlap,
301
+ min_space=self.mask_channel_min_space,
302
+ )
303
+ mask_channel_indices = (
304
+ torch.from_numpy(mask_channel_indices)
305
+ .to(x.device)
306
+ .unsqueeze(1)
307
+ .expand(-1, T, -1)
308
+ )
309
+ x[mask_channel_indices] = 0
310
+
311
+ return x, mask_indices
312
+
313
+ def forward_padding_mask(
314
+ self, features: torch.Tensor, padding_mask: torch.Tensor,
315
+ ) -> torch.Tensor:
316
+ extra = padding_mask.size(1) % features.size(1)
317
+ if extra > 0:
318
+ padding_mask = padding_mask[:, :-extra]
319
+ padding_mask = padding_mask.view(
320
+ padding_mask.size(0), features.size(1), -1
321
+ )
322
+ padding_mask = padding_mask.all(-1)
323
+ return padding_mask
324
+
325
+ def extract_features(
326
+ self,
327
+ source: torch.Tensor,
328
+ padding_mask: Optional[torch.Tensor] = None,
329
+ mask: bool = False,
330
+ ret_conv: bool = False,
331
+ output_layer: Optional[int] = None,
332
+ ret_layer_results: bool = False,
333
+ ):
334
+
335
+ if self.feature_grad_mult > 0:
336
+ features = self.feature_extractor(source)
337
+ if self.feature_grad_mult != 1.0:
338
+ features = GradMultiply.apply(features, self.feature_grad_mult)
339
+ else:
340
+ with torch.no_grad():
341
+ features = self.feature_extractor(source)
342
+
343
+ features = features.transpose(1, 2)
344
+ features = self.layer_norm(features)
345
+
346
+ if padding_mask is not None:
347
+ padding_mask = self.forward_padding_mask(features, padding_mask)
348
+
349
+ if self.post_extract_proj is not None:
350
+ features = self.post_extract_proj(features)
351
+
352
+ features = self.dropout_input(features)
353
+
354
+ if mask:
355
+ x, mask_indices = self.apply_mask(
356
+ features, padding_mask
357
+ )
358
+ else:
359
+ x = features
360
+
361
+ # feature: (B, T, D), float
362
+ # target: (B, T), long
363
+ # x: (B, T, D), float
364
+ # padding_mask: (B, T), bool
365
+ # mask_indices: (B, T), bool
366
+ x, layer_results = self.encoder(
367
+ x,
368
+ padding_mask=padding_mask,
369
+ layer=None if output_layer is None else output_layer - 1
370
+ )
371
+
372
+ res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
373
+
374
+ feature = res["features"] if ret_conv else res["x"]
375
+ if ret_layer_results:
376
+ feature = (feature, res["layer_results"])
377
+ return feature, res["padding_mask"]
378
+
379
+
380
+ def long_term_modeling(
381
+ self,
382
+ source: torch.Tensor,
383
+ padding_mask: Optional[torch.Tensor] = None,
384
+ mask: bool = False,
385
+ ret_conv: bool = False,
386
+ output_layer: Optional[int] = None,
387
+ ret_layer_results: bool = False,
388
+ ):
389
+
390
+ features = source.transpose(1, 2)
391
+ features = self.layer_norm(features)
392
+
393
+ if padding_mask is not None:
394
+ padding_mask = self.forward_padding_mask(features, padding_mask)
395
+
396
+ if self.post_extract_proj is not None:
397
+ features = self.post_extract_proj(features)
398
+
399
+ features = self.dropout_input(features)
400
+
401
+ if mask:
402
+ x, mask_indices = self.apply_mask(
403
+ features, padding_mask
404
+ )
405
+ else:
406
+ x = features
407
+
408
+ # feature: (B, T, D), float
409
+ # target: (B, T), long
410
+ # x: (B, T, D), float
411
+ # padding_mask: (B, T), bool
412
+ # mask_indices: (B, T), bool
413
+ x, layer_results = self.encoder(
414
+ x,
415
+ padding_mask=padding_mask,
416
+ layer=None if output_layer is None else output_layer - 1
417
+ )
418
+
419
+ res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
420
+
421
+ feature = res["features"] if ret_conv else res["x"]
422
+ if ret_layer_results:
423
+ feature = (feature, res["layer_results"])
424
+ return feature, res["padding_mask"]
425
+
426
+
427
+
428
+ class ConvFeatureExtractionModel(nn.Module):
429
+ def __init__(
430
+ self,
431
+ conv_layers: List[Tuple[int, int, int]],
432
+ dropout: float = 0.0,
433
+ mode: str = "default",
434
+ conv_bias: bool = False,
435
+ conv_type: str = "default"
436
+ ):
437
+ super().__init__()
438
+
439
+ assert mode in {"default", "layer_norm"}
440
+
441
+ def block(
442
+ n_in,
443
+ n_out,
444
+ k,
445
+ stride,
446
+ is_layer_norm=False,
447
+ is_group_norm=False,
448
+ conv_bias=False,
449
+ ):
450
+ def make_conv():
451
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
452
+ nn.init.kaiming_normal_(conv.weight)
453
+ return conv
454
+
455
+ assert (
456
+ is_layer_norm and is_group_norm
457
+ ) == False, "layer norm and group norm are exclusive"
458
+
459
+ if is_layer_norm:
460
+ return nn.Sequential(
461
+ make_conv(),
462
+ nn.Dropout(p=dropout),
463
+ nn.Sequential(
464
+ TransposeLast(),
465
+ Fp32LayerNorm(dim, elementwise_affine=True),
466
+ TransposeLast(),
467
+ ),
468
+ nn.GELU(),
469
+ )
470
+ elif is_group_norm:
471
+ return nn.Sequential(
472
+ make_conv(),
473
+ nn.Dropout(p=dropout),
474
+ Fp32GroupNorm(dim, dim, affine=True),
475
+ nn.GELU(),
476
+ )
477
+ else:
478
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
479
+
480
+ self.conv_type = conv_type
481
+ if self.conv_type == "default":
482
+ in_d = 1
483
+ self.conv_layers = nn.ModuleList()
484
+ for i, cl in enumerate(conv_layers):
485
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
486
+ (dim, k, stride) = cl
487
+
488
+ self.conv_layers.append(
489
+ block(
490
+ in_d,
491
+ dim,
492
+ k,
493
+ stride,
494
+ is_layer_norm=mode == "layer_norm",
495
+ is_group_norm=mode == "default" and i == 0,
496
+ conv_bias=conv_bias,
497
+ )
498
+ )
499
+ in_d = dim
500
+ elif self.conv_type == "conv2d":
501
+ in_d = 1
502
+ self.conv_layers = nn.ModuleList()
503
+ for i, cl in enumerate(conv_layers):
504
+ assert len(cl) == 3
505
+ (dim, k, stride) = cl
506
+
507
+ self.conv_layers.append(
508
+ torch.nn.Conv2d(in_d, dim, k, stride)
509
+ )
510
+ self.conv_layers.append(torch.nn.ReLU())
511
+ in_d = dim
512
+ elif self.conv_type == "custom":
513
+ in_d = 1
514
+ idim = 80
515
+ self.conv_layers = nn.ModuleList()
516
+ for i, cl in enumerate(conv_layers):
517
+ assert len(cl) == 3
518
+ (dim, k, stride) = cl
519
+ self.conv_layers.append(
520
+ torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
521
+ )
522
+ self.conv_layers.append(
523
+ torch.nn.LayerNorm([dim, idim])
524
+ )
525
+ self.conv_layers.append(torch.nn.ReLU())
526
+ in_d = dim
527
+ if (i + 1) % 2 == 0:
528
+ self.conv_layers.append(
529
+ torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
530
+ )
531
+ idim = int(math.ceil(idim / 2))
532
+ else:
533
+ pass
534
+
535
+ def forward(self, x, mask=None):
536
+
537
+ # BxT -> BxCxT
538
+ x = x.unsqueeze(1)
539
+ if self.conv_type == "custom":
540
+ for conv in self.conv_layers:
541
+ if isinstance(conv, nn.LayerNorm):
542
+ x = x.transpose(1, 2)
543
+ x = conv(x).transpose(1, 2)
544
+ else:
545
+ x = conv(x)
546
+ x = x.transpose(2, 3).contiguous()
547
+ x = x.view(x.size(0), -1, x.size(-1))
548
+ else:
549
+ for conv in self.conv_layers:
550
+ x = conv(x)
551
+ if self.conv_type == "conv2d":
552
+ b, c, t, f = x.size()
553
+ x = x.transpose(2, 3).contiguous().view(b, c * f, t)
554
+ return x
555
+
556
+
557
+ class TransformerEncoder(nn.Module):
558
+ def __init__(self, args):
559
+ super().__init__()
560
+
561
+ self.dropout = args.dropout
562
+ self.embedding_dim = args.encoder_embed_dim
563
+
564
+ self.pos_conv = nn.Conv1d(
565
+ self.embedding_dim,
566
+ self.embedding_dim,
567
+ kernel_size=args.conv_pos,
568
+ padding=args.conv_pos // 2,
569
+ groups=args.conv_pos_groups,
570
+ )
571
+ dropout = 0
572
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
573
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
574
+ nn.init.constant_(self.pos_conv.bias, 0)
575
+
576
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
577
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
578
+
579
+ if hasattr(args, "relative_position_embedding"):
580
+ self.relative_position_embedding = args.relative_position_embedding
581
+ self.num_buckets = args.num_buckets
582
+ self.max_distance = args.max_distance
583
+ else:
584
+ self.relative_position_embedding = False
585
+ self.num_buckets = 0
586
+ self.max_distance = 0
587
+
588
+ self.layers = nn.ModuleList(
589
+ [
590
+ TransformerSentenceEncoderLayer(
591
+ embedding_dim=self.embedding_dim,
592
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
593
+ num_attention_heads=args.encoder_attention_heads,
594
+ dropout=self.dropout,
595
+ attention_dropout=args.attention_dropout,
596
+ activation_dropout=args.activation_dropout,
597
+ activation_fn=args.activation_fn,
598
+ layer_norm_first=args.layer_norm_first,
599
+ has_relative_attention_bias=(self.relative_position_embedding and i == 0),
600
+ num_buckets=self.num_buckets,
601
+ max_distance=self.max_distance,
602
+ gru_rel_pos=args.gru_rel_pos,
603
+ )
604
+ for i in range(args.encoder_layers)
605
+ ]
606
+ )
607
+
608
+ self.layer_norm_first = args.layer_norm_first
609
+ self.layer_norm = LayerNorm(self.embedding_dim)
610
+ self.layerdrop = args.encoder_layerdrop
611
+
612
+ self.apply(init_bert_params)
613
+
614
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
615
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
616
+
617
+ if self.layer_norm_first and layer is None:
618
+ x = self.layer_norm(x)
619
+
620
+ return x, layer_results
621
+
622
+ def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
623
+
624
+ if padding_mask is not None:
625
+ x[padding_mask] = 0
626
+
627
+ y = x.transpose(1, 2).clone()
628
+ x_conv = self.pos_conv(y)
629
+ x_conv = x_conv.transpose(1, 2)
630
+ x += x_conv
631
+
632
+ if not self.layer_norm_first:
633
+ x = self.layer_norm(x)
634
+
635
+ x = F.dropout(x, p=self.dropout, training=self.training)
636
+
637
+ # B x T x C -> T x B x C
638
+ x = x.transpose(0, 1)
639
+
640
+ layer_results = []
641
+ z = None
642
+ if tgt_layer is not None:
643
+ layer_results.append((x, z))
644
+ r = None
645
+ pos_bias = None
646
+ for i, layer in enumerate(self.layers):
647
+ dropout_probability = np.random.random()
648
+ if not self.training or (dropout_probability > self.layerdrop):
649
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
650
+ self_attn_mask=streaming_mask, pos_bias=pos_bias)
651
+ if tgt_layer is not None:
652
+ layer_results.append((x, z))
653
+ if i == tgt_layer:
654
+ r = x
655
+ break
656
+
657
+ if r is not None:
658
+ x = r
659
+
660
+ # T x B x C -> B x T x C
661
+ x = x.transpose(0, 1)
662
+
663
+ return x, layer_results
664
+
665
+
666
+ class TransformerSentenceEncoderLayer(nn.Module):
667
+ """
668
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
669
+ models.
670
+ """
671
+
672
+ def __init__(
673
+ self,
674
+ embedding_dim: float = 768,
675
+ ffn_embedding_dim: float = 3072,
676
+ num_attention_heads: float = 8,
677
+ dropout: float = 0.1,
678
+ attention_dropout: float = 0.1,
679
+ activation_dropout: float = 0.1,
680
+ activation_fn: str = "relu",
681
+ layer_norm_first: bool = False,
682
+ has_relative_attention_bias: bool = False,
683
+ num_buckets: int = 0,
684
+ max_distance: int = 0,
685
+ rescale_init: bool = False,
686
+ gru_rel_pos: bool = False,
687
+ ) -> None:
688
+
689
+ super().__init__()
690
+ # Initialize parameters
691
+ self.embedding_dim = embedding_dim
692
+ self.dropout = dropout
693
+ self.activation_dropout = activation_dropout
694
+
695
+ # Initialize blocks
696
+ self.activation_name = activation_fn
697
+ self.activation_fn = get_activation_fn(activation_fn)
698
+ self.self_attn = MultiheadAttention(
699
+ self.embedding_dim,
700
+ num_attention_heads,
701
+ dropout=attention_dropout,
702
+ self_attention=True,
703
+ has_relative_attention_bias=has_relative_attention_bias,
704
+ num_buckets=num_buckets,
705
+ max_distance=max_distance,
706
+ rescale_init=rescale_init,
707
+ gru_rel_pos=gru_rel_pos,
708
+ )
709
+
710
+ self.dropout1 = nn.Dropout(dropout)
711
+ self.dropout2 = nn.Dropout(self.activation_dropout)
712
+ self.dropout3 = nn.Dropout(dropout)
713
+
714
+ self.layer_norm_first = layer_norm_first
715
+
716
+ # layer norm associated with the self attention layer
717
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
718
+
719
+ if self.activation_name == "glu":
720
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
721
+ else:
722
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
723
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
724
+
725
+ # layer norm associated with the position wise feed-forward NN
726
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
727
+
728
+ def forward(
729
+ self,
730
+ x: torch.Tensor,
731
+ self_attn_mask: torch.Tensor = None,
732
+ self_attn_padding_mask: torch.Tensor = None,
733
+ need_weights: bool = False,
734
+ pos_bias=None
735
+ ):
736
+ """
737
+ LayerNorm is applied either before or after the self-attention/ffn
738
+ modules similar to the original Transformer imlementation.
739
+ """
740
+ residual = x
741
+
742
+ if self.layer_norm_first:
743
+ x = self.self_attn_layer_norm(x)
744
+ x, attn, pos_bias = self.self_attn(
745
+ query=x,
746
+ key=x,
747
+ value=x,
748
+ key_padding_mask=self_attn_padding_mask,
749
+ need_weights=False,
750
+ attn_mask=self_attn_mask,
751
+ position_bias=pos_bias
752
+ )
753
+ x = self.dropout1(x)
754
+ x = residual + x
755
+
756
+ residual = x
757
+ x = self.final_layer_norm(x)
758
+ if self.activation_name == "glu":
759
+ x = self.fc1(x)
760
+ else:
761
+ x = self.activation_fn(self.fc1(x))
762
+ x = self.dropout2(x)
763
+ x = self.fc2(x)
764
+ x = self.dropout3(x)
765
+ x = residual + x
766
+ else:
767
+ x, attn, pos_bias = self.self_attn(
768
+ query=x,
769
+ key=x,
770
+ value=x,
771
+ key_padding_mask=self_attn_padding_mask,
772
+ need_weights=need_weights,
773
+ attn_mask=self_attn_mask,
774
+ position_bias=pos_bias
775
+ )
776
+
777
+ x = self.dropout1(x)
778
+ x = residual + x
779
+
780
+ x = self.self_attn_layer_norm(x)
781
+
782
+ residual = x
783
+ if self.activation_name == "glu":
784
+ x = self.fc1(x)
785
+ else:
786
+ x = self.activation_fn(self.fc1(x))
787
+ x = self.dropout2(x)
788
+ x = self.fc2(x)
789
+ x = self.dropout3(x)
790
+ x = residual + x
791
+ x = self.final_layer_norm(x)
792
+
793
+ return x, attn, pos_bias
nnet/__pycache__/WavLM.cpython-310.pyc ADDED
Binary file (17.3 kB). View file
 
nnet/__pycache__/embedding.cpython-310.pyc ADDED
Binary file (4.67 kB). View file
 
nnet/__pycache__/llase.cpython-310.pyc ADDED
Binary file (2.84 kB). View file
 
nnet/__pycache__/modules.cpython-310.pyc ADDED
Binary file (19.2 kB). View file
 
nnet/llase.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import sys,os
7
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
8
+
9
+ from typing import Union, Optional
10
+ from transformers import LlamaConfig, LlamaForCausalLM
11
+
12
+ NUM_AUDIO_TOKENS = 65536 # Codebook size
13
+
14
+ class LLM_AR(nn.Module):
15
+ def __init__(
16
+ self,
17
+ d_model: int,
18
+ nhead: int,
19
+ num_layers: int
20
+ ):
21
+ super().__init__()
22
+ self.d_model = d_model
23
+
24
+ self.audio_linear_y = nn.Linear(1024, d_model)
25
+ self.audio_linear_x = nn.Linear(1024, d_model)
26
+
27
+ self.Llama_config = LlamaConfig(
28
+ hidden_size=d_model*2,
29
+ intermediate_size=d_model * 4,
30
+ num_attention_heads=nhead,
31
+ num_hidden_layers=num_layers,
32
+ dropout_rate=0.1,
33
+ attention_dropout=0.1,
34
+ is_decoder=True,
35
+ use_cache=True
36
+ )
37
+
38
+ self.llama= LlamaForCausalLM(config=self.Llama_config)
39
+ self.predict_layer_x = nn.Linear(2*d_model, NUM_AUDIO_TOKENS)
40
+ self.predict_layer_y = nn.Linear(2*d_model, NUM_AUDIO_TOKENS)
41
+
42
+ def forward(
43
+ self,
44
+ y: torch.Tensor,
45
+ x: Union[torch.Tensor, None] = None,
46
+ ) -> torch.Tensor:
47
+ # y = y.transpose(1,2) # if codec input use this transpose
48
+
49
+ if x is None:
50
+ x = torch.zeros_like(y)
51
+ elif x.dim() == 2:
52
+ x = x.unsqueeze(-1)
53
+ x = x.expand_as(y)
54
+
55
+
56
+ y_emb = self.audio_linear_y(y) # [B, T, D]
57
+ x_emb = self.audio_linear_x(x) # [B, T, D]
58
+
59
+ if x_emb.shape[1] < y_emb.shape[1]:
60
+ pad_length = y_emb.shape[1] - x_emb.shape[1]
61
+ x_emb= F.pad(x_emb, (0, 0, 0, pad_length), mode='constant', value=0)
62
+
63
+ if y_emb.shape[1] < x_emb.shape[1]:
64
+ pad_length = x_emb.shape[1] - y_emb.shape[1]
65
+ y_emb= F.pad(y_emb, (0, 0, 0, pad_length), mode='constant', value=0)
66
+
67
+ y_emb = torch.concat([x_emb, y_emb], dim = -1) # [B, T_y, D*2]
68
+
69
+ outputs = self.llama(inputs_embeds = y_emb, output_hidden_states=True)
70
+
71
+ dec = outputs.hidden_states[-1] # [B, T_y, D*2]
72
+
73
+ logits_y = self.predict_layer_y(dec) # [B, T, NUM_AUDIO_TOKENS]
74
+ logits_x = self.predict_layer_x(dec)
75
+
76
+ logits_y = logits_y.transpose(-1, -2) # [B, NUM_AUDIO_TOKENS, T]
77
+ logits_x = logits_x.transpose(-1, -2)
78
+
79
+ return logits_y, logits_x
80
+
81
+ if __name__=="__main__":
82
+ # Simple test
83
+ model = LLM_AR(d_model=1024, nhead=8, num_layers=16)
84
+ ce_loss = nn.CrossEntropyLoss()
85
+
86
+ y = torch.randn([1,199,1024])
87
+ x = torch.randn([1,99,1024])
88
+ label = torch.from_numpy(np.random.randint(0, 300, size=[2,1,199]))
89
+
90
+ total_params = sum(p.numel() for p in model.parameters())
91
+
92
+ print(f"Total Params: {total_params}")
93
+
94
+ logits = model(y)
95
+ print(logits[0].shape)
96
+ print(logits[1].shape)
97
+
98
+ logits = model(y,x)
99
+ print(logits[0].shape)
100
+ print(logits[1].shape)
101
+
102
+ logits = model(y,y)
103
+ print(logits[0].shape)
104
+ print(logits[1].shape)
nnet/modules.py ADDED
@@ -0,0 +1,825 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ from torch.nn import Parameter
16
+ import torch.nn.functional as F
17
+
18
+ class TransposeLast(nn.Module):
19
+ def __init__(self, deconstruct_idx=None):
20
+ super().__init__()
21
+ self.deconstruct_idx = deconstruct_idx
22
+
23
+ def forward(self, x):
24
+ if self.deconstruct_idx is not None:
25
+ x = x[self.deconstruct_idx]
26
+ return x.transpose(-2, -1)
27
+
28
+
29
+ class Fp32LayerNorm(nn.LayerNorm):
30
+ def __init__(self, *args, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+
33
+ def forward(self, input):
34
+ output = F.layer_norm(
35
+ input.float(),
36
+ self.normalized_shape,
37
+ self.weight.float() if self.weight is not None else None,
38
+ self.bias.float() if self.bias is not None else None,
39
+ self.eps,
40
+ )
41
+ return output.type_as(input)
42
+
43
+
44
+ class Fp32GroupNorm(nn.GroupNorm):
45
+ def __init__(self, *args, **kwargs):
46
+ super().__init__(*args, **kwargs)
47
+
48
+ def forward(self, input):
49
+ output = F.group_norm(
50
+ input.float(),
51
+ self.num_groups,
52
+ self.weight.float() if self.weight is not None else None,
53
+ self.bias.float() if self.bias is not None else None,
54
+ self.eps,
55
+ )
56
+ return output.type_as(input)
57
+
58
+
59
+ class GradMultiply(torch.autograd.Function):
60
+ @staticmethod
61
+ def forward(ctx, x, scale):
62
+ ctx.scale = scale
63
+ res = x.new(x)
64
+ return res
65
+
66
+ @staticmethod
67
+ def backward(ctx, grad):
68
+ return grad * ctx.scale, None
69
+
70
+
71
+ class SamePad(nn.Module):
72
+ def __init__(self, kernel_size, causal=False):
73
+ super().__init__()
74
+ if causal:
75
+ self.remove = kernel_size - 1
76
+ else:
77
+ self.remove = 1 if kernel_size % 2 == 0 else 0
78
+
79
+ def forward(self, x):
80
+ if self.remove > 0:
81
+ x = x[:, :, : -self.remove]
82
+ return x
83
+
84
+
85
+ class Swish(nn.Module):
86
+ """Swish function
87
+ """
88
+
89
+ def __init__(self):
90
+ """Construct an MultiHeadedAttention object."""
91
+ super(Swish, self).__init__()
92
+ self.act = torch.nn.Sigmoid()
93
+
94
+ def forward(self, x):
95
+ return x * self.act(x)
96
+
97
+
98
+ class GLU_Linear(nn.Module):
99
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
100
+ super(GLU_Linear, self).__init__()
101
+
102
+ self.glu_type = glu_type
103
+ self.output_dim = output_dim
104
+
105
+ if glu_type == "sigmoid":
106
+ self.glu_act = torch.nn.Sigmoid()
107
+ elif glu_type == "swish":
108
+ self.glu_act = Swish()
109
+ elif glu_type == "relu":
110
+ self.glu_act = torch.nn.ReLU()
111
+ elif glu_type == "gelu":
112
+ self.glu_act = torch.nn.GELU()
113
+
114
+ if bias_in_glu:
115
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
116
+ else:
117
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
118
+
119
+ def forward(self, x):
120
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
121
+ x = self.linear(x)
122
+
123
+ if self.glu_type == "bilinear":
124
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
125
+ else:
126
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
127
+
128
+ return x
129
+
130
+ def gelu_accurate(x):
131
+ if not hasattr(gelu_accurate, "_a"):
132
+ gelu_accurate._a = math.sqrt(2 / math.pi)
133
+ return (
134
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
135
+ )
136
+
137
+
138
+ def gelu(x: torch.Tensor) -> torch.Tensor:
139
+ return torch.nn.functional.gelu(x.float()).type_as(x)
140
+
141
+
142
+ def get_activation_fn(activation: str):
143
+ """Returns the activation function corresponding to `activation`"""
144
+
145
+ if activation == "relu":
146
+ return F.relu
147
+ elif activation == "gelu":
148
+ return gelu
149
+ elif activation == "gelu_fast":
150
+ warnings.warn(
151
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
152
+ )
153
+ return gelu_accurate
154
+ elif activation == "gelu_accurate":
155
+ return gelu_accurate
156
+ elif activation == "tanh":
157
+ return torch.tanh
158
+ elif activation == "linear":
159
+ return lambda x: x
160
+ elif activation == "glu":
161
+ return lambda x: x
162
+ else:
163
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
164
+
165
+
166
+ def init_bert_params(module):
167
+ """
168
+ Initialize the weights specific to the BERT Model.
169
+ This overrides the default initializations depending on the specified arguments.
170
+ 1. If normal_init_linear_weights is set then weights of linear
171
+ layer will be initialized using the normal distribution and
172
+ bais will be set to the specified value.
173
+ 2. If normal_init_embed_weights is set then weights of embedding
174
+ layer will be initialized using the normal distribution.
175
+ 3. If normal_init_proj_weights is set then weights of
176
+ in_project_weight for MultiHeadAttention initialized using
177
+ the normal distribution (to be validated).
178
+ """
179
+
180
+ def normal_(data):
181
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
182
+ # so that the RNG is consistent with and without FSDP
183
+ data.copy_(
184
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
185
+ )
186
+
187
+ if isinstance(module, nn.Linear):
188
+ normal_(module.weight.data)
189
+ if module.bias is not None:
190
+ module.bias.data.zero_()
191
+ if isinstance(module, nn.Embedding):
192
+ normal_(module.weight.data)
193
+ if module.padding_idx is not None:
194
+ module.weight.data[module.padding_idx].zero_()
195
+ if isinstance(module, MultiheadAttention):
196
+ normal_(module.q_proj.weight.data)
197
+ normal_(module.k_proj.weight.data)
198
+ normal_(module.v_proj.weight.data)
199
+
200
+
201
+ def quant_noise(module, p, block_size):
202
+ """
203
+ Wraps modules and applies quantization noise to the weights for
204
+ subsequent quantization with Iterative Product Quantization as
205
+ described in "Training with Quantization Noise for Extreme Model Compression"
206
+
207
+ Args:
208
+ - module: nn.Module
209
+ - p: amount of Quantization Noise
210
+ - block_size: size of the blocks for subsequent quantization with iPQ
211
+
212
+ Remarks:
213
+ - Module weights must have the right sizes wrt the block size
214
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
215
+ - For more detail on how to quantize by blocks with convolutional weights,
216
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
217
+ - We implement the simplest form of noise here as stated in the paper
218
+ which consists in randomly dropping blocks
219
+ """
220
+
221
+ # if no quantization noise, don't register hook
222
+ if p <= 0:
223
+ return module
224
+
225
+ # supported modules
226
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
227
+
228
+ # test whether module.weight has the right sizes wrt block_size
229
+ is_conv = module.weight.ndim == 4
230
+
231
+ # 2D matrix
232
+ if not is_conv:
233
+ assert (
234
+ module.weight.size(1) % block_size == 0
235
+ ), "Input features must be a multiple of block sizes"
236
+
237
+ # 4D matrix
238
+ else:
239
+ # 1x1 convolutions
240
+ if module.kernel_size == (1, 1):
241
+ assert (
242
+ module.in_channels % block_size == 0
243
+ ), "Input channels must be a multiple of block sizes"
244
+ # regular convolutions
245
+ else:
246
+ k = module.kernel_size[0] * module.kernel_size[1]
247
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
248
+
249
+ def _forward_pre_hook(mod, input):
250
+ # no noise for evaluation
251
+ if mod.training:
252
+ if not is_conv:
253
+ # gather weight and sizes
254
+ weight = mod.weight
255
+ in_features = weight.size(1)
256
+ out_features = weight.size(0)
257
+
258
+ # split weight matrix into blocks and randomly drop selected blocks
259
+ mask = torch.zeros(
260
+ in_features // block_size * out_features, device=weight.device
261
+ )
262
+ mask.bernoulli_(p)
263
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
264
+
265
+ else:
266
+ # gather weight and sizes
267
+ weight = mod.weight
268
+ in_channels = mod.in_channels
269
+ out_channels = mod.out_channels
270
+
271
+ # split weight matrix into blocks and randomly drop selected blocks
272
+ if mod.kernel_size == (1, 1):
273
+ mask = torch.zeros(
274
+ int(in_channels // block_size * out_channels),
275
+ device=weight.device,
276
+ )
277
+ mask.bernoulli_(p)
278
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
279
+ else:
280
+ mask = torch.zeros(
281
+ weight.size(0), weight.size(1), device=weight.device
282
+ )
283
+ mask.bernoulli_(p)
284
+ mask = (
285
+ mask.unsqueeze(2)
286
+ .unsqueeze(3)
287
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
288
+ )
289
+
290
+ # scale weights and apply mask
291
+ mask = mask.to(
292
+ torch.bool
293
+ ) # x.bool() is not currently supported in TorchScript
294
+ s = 1 / (1 - p)
295
+ mod.weight.data = s * weight.masked_fill(mask, 0)
296
+
297
+ module.register_forward_pre_hook(_forward_pre_hook)
298
+ return module
299
+
300
+
301
+ class MultiheadAttention(nn.Module):
302
+ """Multi-headed attention.
303
+
304
+ See "Attention Is All You Need" for more details.
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ embed_dim,
310
+ num_heads,
311
+ kdim=None,
312
+ vdim=None,
313
+ dropout=0.0,
314
+ bias=True,
315
+ add_bias_kv=False,
316
+ add_zero_attn=False,
317
+ self_attention=False,
318
+ encoder_decoder_attention=False,
319
+ q_noise=0.0,
320
+ qn_block_size=8,
321
+ has_relative_attention_bias=False,
322
+ num_buckets=32,
323
+ max_distance=128,
324
+ gru_rel_pos=False,
325
+ rescale_init=False,
326
+ ):
327
+ super().__init__()
328
+ self.embed_dim = embed_dim
329
+ self.kdim = kdim if kdim is not None else embed_dim
330
+ self.vdim = vdim if vdim is not None else embed_dim
331
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
332
+
333
+ self.num_heads = num_heads
334
+ self.dropout_module = nn.Dropout(dropout)
335
+
336
+ self.has_relative_attention_bias = has_relative_attention_bias
337
+ self.num_buckets = num_buckets
338
+ self.max_distance = max_distance
339
+ if self.has_relative_attention_bias:
340
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
341
+
342
+ self.head_dim = embed_dim // num_heads
343
+ self.q_head_dim = self.head_dim
344
+ self.k_head_dim = self.head_dim
345
+ assert (
346
+ self.head_dim * num_heads == self.embed_dim
347
+ ), "embed_dim must be divisible by num_heads"
348
+ self.scaling = self.head_dim ** -0.5
349
+
350
+ self.self_attention = self_attention
351
+ self.encoder_decoder_attention = encoder_decoder_attention
352
+
353
+ assert not self.self_attention or self.qkv_same_dim, (
354
+ "Self-attention requires query, key and " "value to be of the same size"
355
+ )
356
+
357
+ k_bias = True
358
+ if rescale_init:
359
+ k_bias = False
360
+
361
+ k_embed_dim = embed_dim
362
+ q_embed_dim = embed_dim
363
+
364
+ self.k_proj = quant_noise(
365
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
366
+ )
367
+ self.v_proj = quant_noise(
368
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
369
+ )
370
+ self.q_proj = quant_noise(
371
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
372
+ )
373
+
374
+ self.out_proj = quant_noise(
375
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
376
+ )
377
+
378
+ if add_bias_kv:
379
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
380
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
381
+ else:
382
+ self.bias_k = self.bias_v = None
383
+
384
+ self.add_zero_attn = add_zero_attn
385
+
386
+ self.gru_rel_pos = gru_rel_pos
387
+ if self.gru_rel_pos:
388
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
389
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
390
+
391
+ self.reset_parameters()
392
+
393
+ def reset_parameters(self):
394
+ if self.qkv_same_dim:
395
+ # Empirically observed the convergence to be much better with
396
+ # the scaled initialization
397
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
398
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
399
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
400
+ else:
401
+ nn.init.xavier_uniform_(self.k_proj.weight)
402
+ nn.init.xavier_uniform_(self.v_proj.weight)
403
+ nn.init.xavier_uniform_(self.q_proj.weight)
404
+
405
+ nn.init.xavier_uniform_(self.out_proj.weight)
406
+ if self.out_proj.bias is not None:
407
+ nn.init.constant_(self.out_proj.bias, 0.0)
408
+ if self.bias_k is not None:
409
+ nn.init.xavier_normal_(self.bias_k)
410
+ if self.bias_v is not None:
411
+ nn.init.xavier_normal_(self.bias_v)
412
+ if self.has_relative_attention_bias:
413
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
414
+
415
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
416
+ num_buckets = self.num_buckets
417
+ max_distance = self.max_distance
418
+ relative_buckets = 0
419
+
420
+ if bidirectional:
421
+ num_buckets = num_buckets // 2
422
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
423
+ relative_positions = torch.abs(relative_positions)
424
+ else:
425
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
426
+
427
+ max_exact = num_buckets // 2
428
+ is_small = relative_positions < max_exact
429
+
430
+ relative_postion_if_large = max_exact + (
431
+ torch.log(relative_positions.float() / max_exact)
432
+ / math.log(max_distance / max_exact)
433
+ * (num_buckets - max_exact)
434
+ ).to(torch.long)
435
+ relative_postion_if_large = torch.min(
436
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
437
+ )
438
+
439
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
440
+ return relative_buckets
441
+
442
+ def compute_bias(self, query_length, key_length):
443
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
444
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
445
+ relative_position = memory_position - context_position
446
+ relative_position_bucket = self._relative_positions_bucket(
447
+ relative_position,
448
+ bidirectional=True
449
+ )
450
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
451
+ values = self.relative_attention_bias(relative_position_bucket)
452
+ values = values.permute([2, 0, 1])
453
+ return values
454
+
455
+ def forward(
456
+ self,
457
+ query,
458
+ key: Optional[Tensor],
459
+ value: Optional[Tensor],
460
+ key_padding_mask: Optional[Tensor] = None,
461
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
462
+ need_weights: bool = True,
463
+ static_kv: bool = False,
464
+ attn_mask: Optional[Tensor] = None,
465
+ before_softmax: bool = False,
466
+ need_head_weights: bool = False,
467
+ position_bias: Optional[Tensor] = None
468
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
469
+ """Input shape: Time x Batch x Channel
470
+
471
+ Args:
472
+ key_padding_mask (ByteTensor, optional): mask to exclude
473
+ keys that are pads, of shape `(batch, src_len)`, where
474
+ padding elements are indicated by 1s.
475
+ need_weights (bool, optional): return the attention weights,
476
+ averaged over heads (default: False).
477
+ attn_mask (ByteTensor, optional): typically used to
478
+ implement causal attention, where the mask prevents the
479
+ attention from looking forward in time (default: None).
480
+ before_softmax (bool, optional): return the raw attention
481
+ weights and values before the attention softmax.
482
+ need_head_weights (bool, optional): return the attention
483
+ weights for each head. Implies *need_weights*. Default:
484
+ return the average attention weights over all heads.
485
+ """
486
+ if need_head_weights:
487
+ need_weights = True
488
+
489
+ is_tpu = query.device.type == "xla"
490
+
491
+ tgt_len, bsz, embed_dim = query.size()
492
+ src_len = tgt_len
493
+ assert embed_dim == self.embed_dim
494
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
495
+ if key is not None:
496
+ src_len, key_bsz, _ = key.size()
497
+ if not torch.jit.is_scripting():
498
+ assert key_bsz == bsz
499
+ assert value is not None
500
+ assert src_len, bsz == value.shape[:2]
501
+
502
+ if self.has_relative_attention_bias and position_bias is None:
503
+ position_bias = self.compute_bias(tgt_len, src_len)
504
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
505
+
506
+ if (
507
+ not is_tpu # don't use PyTorch version on TPUs
508
+ and incremental_state is None
509
+ and not static_kv
510
+ # A workaround for quantization to work. Otherwise JIT compilation
511
+ # treats bias in linear module as method.
512
+ and not torch.jit.is_scripting()
513
+ and self.q_head_dim == self.head_dim
514
+ ):
515
+ assert key is not None and value is not None
516
+ assert attn_mask is None
517
+
518
+ attn_mask_rel_pos = None
519
+ if position_bias is not None:
520
+ attn_mask_rel_pos = position_bias
521
+ if self.gru_rel_pos:
522
+ query_layer = query.transpose(0, 1)
523
+ new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
524
+ query_layer = query_layer.view(*new_x_shape)
525
+ query_layer = query_layer.permute(0, 2, 1, 3)
526
+ _B, _H, _L, __ = query_layer.size()
527
+
528
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
529
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
530
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
531
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
532
+
533
+ attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
534
+ k_proj_bias = self.k_proj.bias
535
+ if k_proj_bias is None:
536
+ k_proj_bias = torch.zeros_like(self.q_proj.bias)
537
+
538
+ x, attn = F.multi_head_attention_forward(
539
+ query,
540
+ key,
541
+ value,
542
+ self.embed_dim,
543
+ self.num_heads,
544
+ torch.empty([0]),
545
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
546
+ self.bias_k,
547
+ self.bias_v,
548
+ self.add_zero_attn,
549
+ self.dropout_module.p,
550
+ self.out_proj.weight,
551
+ self.out_proj.bias,
552
+ self.training,
553
+ # self.training or self.dropout_module.apply_during_inference,
554
+ key_padding_mask,
555
+ need_weights,
556
+ attn_mask_rel_pos,
557
+ use_separate_proj_weight=True,
558
+ q_proj_weight=self.q_proj.weight,
559
+ k_proj_weight=self.k_proj.weight,
560
+ v_proj_weight=self.v_proj.weight,
561
+ )
562
+ return x, attn, position_bias
563
+
564
+ if incremental_state is not None:
565
+ saved_state = self._get_input_buffer(incremental_state)
566
+ if saved_state is not None and "prev_key" in saved_state:
567
+ # previous time steps are cached - no need to recompute
568
+ # key and value if they are static
569
+ if static_kv:
570
+ assert self.encoder_decoder_attention and not self.self_attention
571
+ key = value = None
572
+ else:
573
+ saved_state = None
574
+
575
+ if self.self_attention:
576
+ q = self.q_proj(query)
577
+ k = self.k_proj(query)
578
+ v = self.v_proj(query)
579
+ elif self.encoder_decoder_attention:
580
+ # encoder-decoder attention
581
+ q = self.q_proj(query)
582
+ if key is None:
583
+ assert value is None
584
+ k = v = None
585
+ else:
586
+ k = self.k_proj(key)
587
+ v = self.v_proj(key)
588
+
589
+ else:
590
+ assert key is not None and value is not None
591
+ q = self.q_proj(query)
592
+ k = self.k_proj(key)
593
+ v = self.v_proj(value)
594
+ q *= self.scaling
595
+
596
+ if self.bias_k is not None:
597
+ assert self.bias_v is not None
598
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
599
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
600
+ if attn_mask is not None:
601
+ attn_mask = torch.cat(
602
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
603
+ )
604
+ if key_padding_mask is not None:
605
+ key_padding_mask = torch.cat(
606
+ [
607
+ key_padding_mask,
608
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
609
+ ],
610
+ dim=1,
611
+ )
612
+
613
+ q = (
614
+ q.contiguous()
615
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
616
+ .transpose(0, 1)
617
+ )
618
+ if k is not None:
619
+ k = (
620
+ k.contiguous()
621
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
622
+ .transpose(0, 1)
623
+ )
624
+ if v is not None:
625
+ v = (
626
+ v.contiguous()
627
+ .view(-1, bsz * self.num_heads, self.head_dim)
628
+ .transpose(0, 1)
629
+ )
630
+
631
+ if saved_state is not None:
632
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
633
+ if "prev_key" in saved_state:
634
+ _prev_key = saved_state["prev_key"]
635
+ assert _prev_key is not None
636
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
637
+ if static_kv:
638
+ k = prev_key
639
+ else:
640
+ assert k is not None
641
+ k = torch.cat([prev_key, k], dim=1)
642
+ src_len = k.size(1)
643
+ if "prev_value" in saved_state:
644
+ _prev_value = saved_state["prev_value"]
645
+ assert _prev_value is not None
646
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
647
+ if static_kv:
648
+ v = prev_value
649
+ else:
650
+ assert v is not None
651
+ v = torch.cat([prev_value, v], dim=1)
652
+ prev_key_padding_mask: Optional[Tensor] = None
653
+ if "prev_key_padding_mask" in saved_state:
654
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
655
+ assert k is not None and v is not None
656
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
657
+ key_padding_mask=key_padding_mask,
658
+ prev_key_padding_mask=prev_key_padding_mask,
659
+ batch_size=bsz,
660
+ src_len=k.size(1),
661
+ static_kv=static_kv,
662
+ )
663
+
664
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
665
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
666
+ saved_state["prev_key_padding_mask"] = key_padding_mask
667
+ # In this branch incremental_state is never None
668
+ assert incremental_state is not None
669
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
670
+ assert k is not None
671
+ assert k.size(1) == src_len
672
+
673
+ # This is part of a workaround to get around fork/join parallelism
674
+ # not supporting Optional types.
675
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
676
+ key_padding_mask = None
677
+
678
+ if key_padding_mask is not None:
679
+ assert key_padding_mask.size(0) == bsz
680
+ assert key_padding_mask.size(1) == src_len
681
+
682
+ if self.add_zero_attn:
683
+ assert v is not None
684
+ src_len += 1
685
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
686
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
687
+ if attn_mask is not None:
688
+ attn_mask = torch.cat(
689
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
690
+ )
691
+ if key_padding_mask is not None:
692
+ key_padding_mask = torch.cat(
693
+ [
694
+ key_padding_mask,
695
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
696
+ key_padding_mask
697
+ ),
698
+ ],
699
+ dim=1,
700
+ )
701
+
702
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
703
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
704
+
705
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
706
+
707
+ if attn_mask is not None:
708
+ attn_mask = attn_mask.unsqueeze(0)
709
+ attn_weights += attn_mask
710
+
711
+ if key_padding_mask is not None:
712
+ # don't attend to padding symbols
713
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
714
+ if not is_tpu:
715
+ attn_weights = attn_weights.masked_fill(
716
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
717
+ float("-inf"),
718
+ )
719
+ else:
720
+ attn_weights = attn_weights.transpose(0, 2)
721
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
722
+ attn_weights = attn_weights.transpose(0, 2)
723
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
724
+
725
+ if before_softmax:
726
+ return attn_weights, v, position_bias
727
+
728
+ if position_bias is not None:
729
+ if self.gru_rel_pos == 1:
730
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
731
+ _B, _H, _L, __ = query_layer.size()
732
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
733
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
734
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
735
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
736
+
737
+ position_bias = position_bias.view(attn_weights.size())
738
+
739
+ attn_weights = attn_weights + position_bias
740
+
741
+ attn_weights_float = F.softmax(
742
+ attn_weights, dim=-1
743
+ )
744
+ attn_weights = attn_weights_float.type_as(attn_weights)
745
+ attn_probs = self.dropout_module(attn_weights)
746
+
747
+ assert v is not None
748
+ attn = torch.bmm(attn_probs, v)
749
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
750
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
751
+ attn = self.out_proj(attn)
752
+ attn_weights: Optional[Tensor] = None
753
+ if need_weights:
754
+ attn_weights = attn_weights_float.view(
755
+ bsz, self.num_heads, tgt_len, src_len
756
+ ).transpose(1, 0)
757
+ if not need_head_weights:
758
+ # average attention weights over heads
759
+ attn_weights = attn_weights.mean(dim=0)
760
+
761
+ return attn, attn_weights, position_bias
762
+
763
+ @staticmethod
764
+ def _append_prev_key_padding_mask(
765
+ key_padding_mask: Optional[Tensor],
766
+ prev_key_padding_mask: Optional[Tensor],
767
+ batch_size: int,
768
+ src_len: int,
769
+ static_kv: bool,
770
+ ) -> Optional[Tensor]:
771
+ # saved key padding masks have shape (bsz, seq_len)
772
+ if prev_key_padding_mask is not None and static_kv:
773
+ new_key_padding_mask = prev_key_padding_mask
774
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
775
+ new_key_padding_mask = torch.cat(
776
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
777
+ )
778
+ # During incremental decoding, as the padding token enters and
779
+ # leaves the frame, there will be a time when prev or current
780
+ # is None
781
+ elif prev_key_padding_mask is not None:
782
+ if src_len > prev_key_padding_mask.size(1):
783
+ filler = torch.zeros(
784
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
785
+ device=prev_key_padding_mask.device,
786
+ )
787
+ new_key_padding_mask = torch.cat(
788
+ [prev_key_padding_mask.float(), filler.float()], dim=1
789
+ )
790
+ else:
791
+ new_key_padding_mask = prev_key_padding_mask.float()
792
+ elif key_padding_mask is not None:
793
+ if src_len > key_padding_mask.size(1):
794
+ filler = torch.zeros(
795
+ (batch_size, src_len - key_padding_mask.size(1)),
796
+ device=key_padding_mask.device,
797
+ )
798
+ new_key_padding_mask = torch.cat(
799
+ [filler.float(), key_padding_mask.float()], dim=1
800
+ )
801
+ else:
802
+ new_key_padding_mask = key_padding_mask.float()
803
+ else:
804
+ new_key_padding_mask = prev_key_padding_mask
805
+ return new_key_padding_mask
806
+
807
+ def _get_input_buffer(
808
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
809
+ ) -> Dict[str, Optional[Tensor]]:
810
+ result = self.get_incremental_state(incremental_state, "attn_state")
811
+ if result is not None:
812
+ return result
813
+ else:
814
+ empty_result: Dict[str, Optional[Tensor]] = {}
815
+ return empty_result
816
+
817
+ def _set_input_buffer(
818
+ self,
819
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
820
+ buffer: Dict[str, Optional[Tensor]],
821
+ ):
822
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
823
+
824
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
825
+ return attn_weights
vq/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from vq.codec_encoder import CodecEncoder
2
+ from vq.codec_decoder import CodecDecoder
3
+ from vq.codec_decoder_vocos import CodecDecoderVocos
4
+ from vq.codec_encoder import CodecEncoder_Transformer,CodecEncoder_only_Transformer
vq/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (399 Bytes). View file
 
vq/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (371 Bytes). View file
 
vq/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (318 Bytes). View file
 
vq/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (407 Bytes). View file
 
vq/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (400 Bytes). View file
 
vq/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (449 Bytes). View file
 
vq/__pycache__/activations.cpython-310.pyc ADDED
Binary file (4.01 kB). View file
 
vq/__pycache__/activations.cpython-311.pyc ADDED
Binary file (6.07 kB). View file
 
vq/__pycache__/activations.cpython-312.pyc ADDED
Binary file (5.65 kB). View file
 
vq/__pycache__/activations.cpython-37.pyc ADDED
Binary file (4.11 kB). View file
 
vq/__pycache__/activations.cpython-38.pyc ADDED
Binary file (4.05 kB). View file
 
vq/__pycache__/activations.cpython-39.pyc ADDED
Binary file (4.1 kB). View file
 
vq/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (5.98 kB). View file
 
vq/__pycache__/blocks.cpython-39.pyc ADDED
Binary file (6.35 kB). View file
 
vq/__pycache__/bs_roformer5.cpython-310.pyc ADDED
Binary file (3.92 kB). View file
 
vq/__pycache__/bs_roformer5.cpython-37.pyc ADDED
Binary file (3.93 kB). View file
 
vq/__pycache__/bs_roformer5.cpython-38.pyc ADDED
Binary file (3.9 kB). View file
 
vq/__pycache__/bs_roformer5.cpython-39.pyc ADDED
Binary file (3.93 kB). View file
 
vq/__pycache__/codec_decoder.cpython-310.pyc ADDED
Binary file (9.04 kB). View file
 
vq/__pycache__/codec_decoder.cpython-311.pyc ADDED
Binary file (8.78 kB). View file
 
vq/__pycache__/codec_decoder.cpython-312.pyc ADDED
Binary file (7.76 kB). View file
 
vq/__pycache__/codec_decoder.cpython-39.pyc ADDED
Binary file (9.38 kB). View file
 
vq/__pycache__/codec_decoder_vocos.cpython-310.pyc ADDED
Binary file (18.1 kB). View file
 
vq/__pycache__/codec_decoder_vocos.cpython-311.pyc ADDED
Binary file (27.7 kB). View file
 
vq/__pycache__/codec_decoder_vocos.cpython-312.pyc ADDED
Binary file (25.2 kB). View file
 
vq/__pycache__/codec_decoder_vocos.cpython-39.pyc ADDED
Binary file (18.5 kB). View file