diff --git a/ckpt/best.pt.tar b/ckpt/best.pt.tar new file mode 100644 index 0000000000000000000000000000000000000000..ae76da2174eb8b582b50d0116347e96d82899cb2 --- /dev/null +++ b/ckpt/best.pt.tar @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:298e4fccf3baacd8623574e7a767d22198f3ddb0c35cdb17e71e03dd4edf0fc5 +size 11826355083 diff --git a/ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/config.json b/ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/config.json new file mode 100644 index 0000000000000000000000000000000000000000..a383a594dac18459628cd2837168cd276342a31a --- /dev/null +++ b/ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/config.json @@ -0,0 +1,81 @@ +{ + "activation_dropout": 0.0, + "adapter_act": "relu", + "adapter_kernel_size": 3, + "adapter_stride": 2, + "add_adapter": false, + "apply_spec_augment": false, + "architectures": [ + "Wav2Vec2BertModel" + ], + "attention_dropout": 0.0, + "bos_token_id": 1, + "classifier_proj_size": 768, + "codevector_dim": 768, + "conformer_conv_dropout": 0.1, + "contrastive_logits_temperature": 0.1, + "conv_depthwise_kernel_size": 31, + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "diversity_loss_weight": 0.1, + "eos_token_id": 2, + "feat_proj_dropout": 0.0, + "feat_quantizer_dropout": 0.0, + "feature_projection_input_dim": 160, + "final_dropout": 0.1, + "hidden_act": "swish", + "hidden_dropout": 0.0, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "left_max_position_embeddings": 64, + "mask_feature_length": 10, + "mask_feature_min_masks": 0, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_masks": 2, + "mask_time_prob": 0.05, + "max_source_positions": 5000, + "model_type": "wav2vec2-bert", + "num_adapter_layers": 1, + "num_attention_heads": 16, + "num_codevector_groups": 2, + "num_codevectors_per_group": 320, + "num_hidden_layers": 24, + "num_negatives": 100, + "output_hidden_size": 1024, + "pad_token_id": 0, + "position_embeddings_type": "relative_key", + "proj_codevector_dim": 768, + "right_max_position_embeddings": 8, + "rotary_embedding_base": 10000, + "tdnn_dilation": [ + 1, + 2, + 3, + 1, + 1 + ], + "tdnn_dim": [ + 512, + 512, + 512, + 512, + 1500 + ], + "tdnn_kernel": [ + 5, + 3, + 3, + 1, + 1 + ], + "torch_dtype": "float32", + "transformers_version": "4.37.0.dev0", + "use_intermediate_ffn_before_adapter": false, + "use_weighted_layer_sum": false, + "vocab_size": null, + "xvector_output_dim": 512 +} diff --git a/ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/preprocessor_config.json b/ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/preprocessor_config.json new file mode 100644 index 0000000000000000000000000000000000000000..5db61951cdf5edab6337fd84ee619500c27aaa3d --- /dev/null +++ b/ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/preprocessor_config.json @@ -0,0 +1,11 @@ +{ + "feature_extractor_type": "SeamlessM4TFeatureExtractor", + "feature_size": 80, + "num_mel_bins": 80, + "padding_side": "right", + "padding_value": 1, + "processor_class": "Wav2Vec2BertProcessor", + "return_attention_mask": true, + "sampling_rate": 16000, + "stride": 2 +} diff --git a/ckpt/codec_ckpt/hub/version.txt b/ckpt/codec_ckpt/hub/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..56a6051ca2b02b04ef92d5150c9ef600403cb1de --- /dev/null +++ b/ckpt/codec_ckpt/hub/version.txt @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/ckpt/download.sh b/ckpt/download.sh new file mode 100644 index 0000000000000000000000000000000000000000..ca0b3c175b09561607808724c6a8ca34cd559492 --- /dev/null +++ b/ckpt/download.sh @@ -0,0 +1,18 @@ +python download_script.py \ + --source hf \ + --repo_id microsoft/wavlm-large \ + --filename pytorch_model.bin \ + --save_path ./WavLM-Large.pt + +python download_script.py \ + --source hf \ + --repo_id facebook/w2v-bert-2.0 \ + --filename model.safetensors \ + --save_path \ + ./codec_ckpt/hub/models--facebook--w2v-bert-2.0/model.safetensors + +python download_script.py \ + --source hf \ + --repo_id HKUSTAudio/xcodec2 \ + --filename ckpt/epoch=4-step=1400000.ckpt \ + --save_path ./codec_ckpt/epoch=4-step=1400000.ckpt \ No newline at end of file diff --git a/ckpt/download_ckpt.py b/ckpt/download_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..5c137c2936c7dcc7af017744cc42a2f59e627046 --- /dev/null +++ b/ckpt/download_ckpt.py @@ -0,0 +1,58 @@ +import os +import requests +import argparse +from huggingface_hub import hf_hub_download +from tqdm import tqdm + +def download_from_url(url, save_path): + """Download a file from a given URL and save it locally.""" + response = requests.get(url, stream=True) + total_size = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 KB + progress_bar = tqdm(total=total_size, unit="B", unit_scale=True) + + with open(save_path, "wb") as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + + if total_size != 0 and progress_bar.n != total_size: + print("Download failed!") + else: + print(f"File downloaded to: {save_path}") + +def download_from_hf(repo_id, filename, save_path): + """Download a file from Hugging Face Hub.""" + print(f"Downloading from Hugging Face Hub: {repo_id}/{filename}") + try: + hf_hub_download(repo_id=repo_id, filename=filename, local_dir=os.path.dirname(save_path), local_dir_use_symlinks=False) + print(f"File downloaded to: {save_path}") + except Exception as e: + print(f"Download failed: {e}") + +def main(): + parser = argparse.ArgumentParser(description="Automatically download model checkpoints") + parser.add_argument("--source", type=str, required=True, choices=["hf", "url"], help="Download source: hf (Hugging Face Hub) or url (custom URL)") + parser.add_argument("--repo_id", type=str, help="Hugging Face model repository ID (e.g., google/bert-base-uncased)") + parser.add_argument("--filename", type=str, help="Filename in the Hugging Face repository") + parser.add_argument("--url", type=str, help="Custom download URL") + parser.add_argument("--save_path", type=str, required=True, help="Path to save the file (including filename)") + args = parser.parse_args() + + # Ensure the save directory exists + os.makedirs(os.path.dirname(args.save_path), exist_ok=True) + + if args.source == "hf": + if not args.repo_id or not args.filename: + print("Please provide a Hugging Face repository ID and filename!") + return + download_from_hf(args.repo_id, args.filename, args.save_path) + elif args.source == "url": + if not args.url: + print("Please provide a download URL!") + return + download_from_url(args.url, args.save_path) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/config/test.yml b/config/test.yml new file mode 100644 index 0000000000000000000000000000000000000000..17b0597fd3f4c87af04a5d30cb320fb2d85b90ba --- /dev/null +++ b/config/test.yml @@ -0,0 +1,21 @@ +test: + checkpoint: ./ckpt/best.pt.tar + use_cuda: True + infer_feat_too: True + inference_time: 1 + +save: + feat_dir: ./decode/feat/se + wav_dir: ./decode/wav/se + +task: SE + +# LLaSE config +nnet_conf: + d_model: 1024 + nhead: 16 + num_layers: 16 + +datareader: + sample_rate: 16000 + filename: /home/node57_data2/bykang/work_plus/test_set/interspeech2020/syn_no_reverb.scp # /path/to/your/filelist \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..5019a5fcf602be39c7b40077fd114e3c22e43c53 --- /dev/null +++ b/inference.py @@ -0,0 +1,474 @@ +import os +import sys +import time +import librosa +import yaml +import joblib +import argparse + +import soundfile as sf +import numpy as np + +from pathlib import Path +from collections import defaultdict +from typing import Optional +from tqdm import tqdm + +sys.path.append(os.path.dirname(__file__)) +sys.path.append(os.path.dirname(os.path.dirname(__file__))) + +# Torch +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel + +# WavLM +from nnet.WavLM import WavLM, WavLMConfig + +# Xcodec2 +from vq.codec_encoder import CodecEncoder_Transformer +from vq.codec_decoder_vocos import CodecDecoderVocos +from vq.module import SemanticEncoder +from transformers import AutoFeatureExtractor, Wav2Vec2BertModel +from collections import OrderedDict + +# Dataloader +from loader.datareader import DataReader +from loader.datareader_aec import DataReaderAEC +from loader.datareader_tse import DataReaderTSE + +# LLaSE +from nnet.llase import LLM_AR as model + +class Encodec(): + ''' + Load Xcodec2 + ''' + def __init__(self,device="cpu") -> None: + self.device=device + ckpt = "./ckpt/codec_ckpt/epoch=4-step=1400000.ckpt", + # ckpt = '/home/bykang/codec_ckpt/epoch=4-step=1400000.ckpt' + ckpt = torch.load(ckpt, map_location='cpu') + state_dict = ckpt['state_dict'] + filtered_state_dict_codec = OrderedDict() + filtered_state_dict_semantic_encoder = OrderedDict() + filtered_state_dict_gen = OrderedDict() + filtered_state_dict_fc_post_a = OrderedDict() + filtered_state_dict_fc_prior = OrderedDict() + for key, value in state_dict.items(): + if key.startswith('CodecEnc.'): + new_key = key[len('CodecEnc.'):] + filtered_state_dict_codec[new_key] = value + elif key.startswith('generator.'): + new_key = key[len('generator.'):] + filtered_state_dict_gen[new_key] = value + elif key.startswith('fc_post_a.'): + new_key = key[len('fc_post_a.'):] + filtered_state_dict_fc_post_a[new_key] = value + elif key.startswith('SemanticEncoder_module.'): + new_key = key[len('SemanticEncoder_module.'):] + filtered_state_dict_semantic_encoder[new_key] = value + elif key.startswith('fc_prior.'): + new_key = key[len('fc_prior.'):] + filtered_state_dict_fc_prior[new_key] = value + + self.semantic_model = Wav2Vec2BertModel.from_pretrained( + "./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0", + # "/home/bykang/codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b", + output_hidden_states=True) + self.semantic_model=self.semantic_model.eval().to(self.device) + + self.SemanticEncoder_module = SemanticEncoder(1024,1024,1024) + self.SemanticEncoder_module.load_state_dict(filtered_state_dict_semantic_encoder) + self.SemanticEncoder_module = self.SemanticEncoder_module.eval().to(self.device) + + self.encoder = CodecEncoder_Transformer() + self.encoder.load_state_dict(filtered_state_dict_codec) + self.encoder = self.encoder.eval().to(self.device) + + self.decoder = CodecDecoderVocos() + self.decoder.load_state_dict(filtered_state_dict_gen) + self.decoder = self.decoder.eval().to(self.device) + + self.fc_post_a = nn.Linear( 2048, 1024 ) + self.fc_post_a.load_state_dict(filtered_state_dict_fc_post_a) + self.fc_post_a = self.fc_post_a.eval().to(self.device) + + self.fc_prior = nn.Linear( 2048, 2048 ) + self.fc_prior.load_state_dict(filtered_state_dict_fc_prior) + self.fc_prior = self.fc_prior.eval().to(self.device) + + self.feature_extractor = AutoFeatureExtractor.from_pretrained( + "./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0") + # "/home/bykang/codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b") + + def get_feat(self, wav_batch, pad=None): + + if len(wav_batch.shape) != 2: + return self.feature_extractor(F.pad(wav_batch, pad), sampling_rate=16000, return_tensors="pt") .data['input_features'] + + padded_wavs = torch.stack([F.pad(wav, pad) for wav in wav_batch]) + batch_feats = [] + + for wav in padded_wavs: + feat = self.feature_extractor( + wav, + sampling_rate=16000, + return_tensors="pt" + ).data['input_features'] + + batch_feats.append(feat) + feat_batch = torch.concat(batch_feats, dim=0).to(self.device) + return feat_batch + + def get_embedding(self, wav_cpu): + wav_cpu = wav_cpu.cpu() + feat = self.get_feat(wav_cpu,pad=(160,160)) + feat = feat.to(self.device) + + if(len(wav_cpu.shape)==1): + wav = wav_cpu.unsqueeze(0).to(self.device) + else: + wav = wav_cpu.to(self.device) + + wav = torch.nn.functional.pad(wav, (0, (200 - (wav.shape[1] % 200)))) + with torch.no_grad(): + vq_emb = self.encoder(wav.unsqueeze(1)) + vq_emb = vq_emb.transpose(1, 2) + + if vq_emb.shape[2]!=feat.shape[1]: + feat = self.get_feat(wav_cpu) + feat = feat.to(self.device) + + semantic_target = self.semantic_model(feat[:, :,:]) + semantic_target = semantic_target.hidden_states[16] + semantic_target = semantic_target.transpose(1, 2) + semantic_target = self.SemanticEncoder_module(semantic_target) + + vq_emb = torch.cat([semantic_target, vq_emb], dim=1) + + return vq_emb + + def emb2token(self, emb): + emb.to(self.device) + emb = self.fc_prior(emb.transpose(1, 2)).transpose(1, 2) + _, vq_code, _ = self.decoder(emb, vq=True) + return vq_code + + def token2wav(self, vq_code): + vq_code.to(self.device) + vq_post_emb = self.decoder.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) + vq_post_emb = vq_post_emb.transpose(1, 2) + vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1,2)).transpose(1,2) + recon = self.decoder(vq_post_emb.transpose(1, 2), vq=False)[0].squeeze() + # if write the wav, add .squeeze().detach().cpu().numpy() + # if need gradient use the config right now + return recon + +class WavLM_feat(object): + ''' + Load WavLM + ''' + def __init__(self, device): + self.wavlm = self._reload_wavLM_large(device=device) + + def __call__(self, wav): + T = wav.shape[-1] + wav = wav.reshape(-1, T) + with torch.no_grad(): + feat = self.wavlm.extract_features(wav, output_layer=6, ret_layer_results=False)[0] + B, T, D = feat.shape + feat = torch.reshape(feat, (-1, D)) + + return feat + + def _reload_wavLM_large(self, path="/home/bykang/WavLM-Large.pt", device: Optional[torch.device] = None): + cpt = torch.load(path, map_location="cpu") + cfg = WavLMConfig(cpt['cfg']) + wavLM = WavLM(cfg) + wavLM.load_state_dict(cpt['model']) + wavLM.eval() + if device != None: + wavLM = wavLM.to(device) + for p in wavLM.parameters(): + p.requires_grad = False + print('successful to reload wavLM', path) + return wavLM + +def get_firstchannel_read(path, fs=16000): + ''' + Get first channel of the wav + ''' + wave_data, sr = sf.read(path) + if sr != fs: + if len(wave_data.shape) != 1: + wave_data = wave_data.transpose((1, 0)) + wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs) + if len(wave_data.shape) != 1: + wave_data = wave_data.transpose((1, 0)) + if len(wave_data.shape) > 1: + wave_data = wave_data[:, 0] + return wave_data + +def load_obj(obj, device): + ''' + Offload tensor object in obj to cuda device + ''' + def cuda(obj): + return obj.to(device) if isinstance(obj, torch.Tensor) else obj + + if isinstance(obj, dict): + return {key: load_obj(obj[key], device) for key in obj} + elif isinstance(obj, list): + return [load_obj(val, device) for val in obj] + else: + return cuda(obj) + +def run(args): + LOCAL_RANK = int(os.environ['LOCAL_RANK']) + WORLD_SIZE = int(os.environ['WORLD_SIZE']) + WORLD_RANK = int(os.environ['RANK']) + dist.init_process_group(args.backend, rank=WORLD_RANK, world_size=WORLD_SIZE) + torch.cuda.set_device(LOCAL_RANK) + device = torch.device('cuda', LOCAL_RANK) + print(f"[{os.getpid()}] using device: {device}", torch.cuda.current_device(), "local rank", LOCAL_RANK) + + with open(args.conf, "r") as f: + conf = yaml.load(f, Loader=yaml.FullLoader) + + # Dataloader + if conf["task"]=="AEC": + data_reader = DataReaderAEC(**conf["datareader"]) + elif conf["task"]=="TSE": + data_reader = DataReaderTSE(**conf["datareader"]) + else: + data_reader = DataReader(**conf["datareader"]) + + # Load WavLM and XCodec2 + codec = Encodec(device) + wavlm_feat = WavLM_feat(device) + + # Load LLaSE + nnet = model(**conf["nnet_conf"]) + cpt_fname = Path(conf["test"]["checkpoint"]) + cpt = torch.load(cpt_fname, map_location="cpu") + + nnet = nnet.to(device) + nnet = DistributedDataParallel(nnet, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, find_unused_parameters=True) + nnet.load_state_dict(cpt["model_state_dict"]) + nnet.eval() + + # Make sure the dir exists + if conf["task"]=="AEC": + if not os.path.exists(conf["save"]["feat_dir"]+"/mic"): + os.makedirs(conf["save"]["feat_dir"]+"/mic") + if not os.path.exists(conf["save"]["feat_dir"]+"/ref"): + os.makedirs(conf["save"]["feat_dir"]+"/ref") + elif conf["task"]=="TSE": + if not os.path.exists(conf["save"]["feat_dir"]+"/mic"): + os.makedirs(conf["save"]["feat_dir"]+"/mic") + if not os.path.exists(conf["save"]["feat_dir"]+"/ref"): + os.makedirs(conf["save"]["feat_dir"]+"/ref") + else: + if not os.path.exists(conf["save"]["feat_dir"]): + os.makedirs(conf["save"]["feat_dir"]) + + if not os.path.exists(conf["save"]["wav_dir"]): + os.makedirs(conf["save"]["wav_dir"]) + + # Main of inference + if_feat_too = conf["test"]["infer_feat_too"] + + origin_feat_dir = conf["save"]["feat_dir"] + origin_wav_dir = conf["save"]["wav_dir"] + + last_feat_dir = origin_feat_dir + last_wav_dir = origin_wav_dir + + for inference_time in range(conf["test"]["inference_time"]): + # For multi-inference + if inference_time > 0: + feat_dir = origin_feat_dir + "inference" + str(inference_time) + wav_dir = origin_wav_dir + "inference" + str(inference_time) + else: + feat_dir = origin_feat_dir + wav_dir = origin_wav_dir + + if not os.path.exists(feat_dir): + os.makedirs(feat_dir) + if not os.path.exists(wav_dir): + os.makedirs(wav_dir) + + with torch.no_grad(): + # Extract WavLM features + if if_feat_too ==True or inference_time>0: + for egs in tqdm(data_reader): + egs = load_obj(egs, device) + + if conf["task"]=="AEC" or conf["task"]=="TSE": + if inference_time > 0: + mic_path = last_wav_dir + '/' + egs["mic_name"] + ".wav" + egs["mic"] = torch.from_numpy(get_firstchannel_read(mic_path).astype(np.float32)).unsqueeze(0).to(device) + else: + egs["mic"]=egs["mic"].contiguous() + egs["ref"]=egs["ref"].contiguous() + + feat_mic = wavlm_feat(egs["mic"]) + out_mic = feat_mic.detach().squeeze(0).cpu().numpy() + + if not os.path.exists(os.path.join(feat_dir, "mic")): + os.makedirs(os.path.join(feat_dir, "mic")) + np.save(os.path.join(feat_dir, "mic", egs["mic_name"]), out_mic) + + # For AEC and TSE, reference audio only need to extract feats at first time + if inference_time == 0: + feat_ref = wavlm_feat(egs["ref"]) + out_ref = feat_ref.detach().squeeze(0).cpu().numpy() + np.save(os.path.join(origin_feat_dir, "ref", egs["ref_name"]), out_ref) + + torch.cuda.empty_cache() + + else: + if inference_time > 0: + mix_path = last_wav_dir + '/' + egs["name"] + ".wav" + egs["mix"] = torch.from_numpy(get_firstchannel_read(mix_path).astype(np.float32)).unsqueeze(0).to(device) + else: + egs["mix"]=egs["mix"].contiguous() + + feat = wavlm_feat(egs["mix"]) + out = feat.detach().squeeze(0).cpu().numpy() + np.save(os.path.join(feat_dir, egs["name"]), out) + + # Predict the clean tokens and token2wav + for egs in tqdm(data_reader): + egs = load_obj(egs, device) + sr = 16000 + + if conf["task"] == "AEC": + # Get feat + feat_path_mic = os.path.join(feat_dir, "mic", egs["mic_name"]) + ".npy" + feat_path_ref = os.path.join(origin_feat_dir, "ref", egs["ref_name"]) + ".npy" + + feat_mic = torch.from_numpy(np.load(feat_path_mic)).unsqueeze(0) + feat_ref = torch.from_numpy(np.load(feat_path_ref)).unsqueeze(0) + + # For multi-inference + if inference_time > 0: + est = nnet(feat_mic) + else: + est = nnet(feat_mic, feat_ref) + + # Get tokens and token2wav + max, max_indices_1 = torch.max(est[1], dim=1) + recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() + + # Save the wav + target_path = os.path.join(wav_dir, egs["mic_name"] + ".wav") + print(target_path) + sf.write(target_path , recon_1, sr) + + elif conf["task"] == "TSE" : + # Get feat + feat_path_mic = os.path.join(feat_dir, "mic", egs["mic_name"]) + ".npy" + feat_path_ref = os.path.join(origin_feat_dir, "ref", egs["ref_name"]) + ".npy" + + feat_mic = torch.from_numpy(np.load(feat_path_mic)).unsqueeze(0) + feat_ref = torch.from_numpy(np.load(feat_path_ref)).unsqueeze(0) + + # Choose if keep the enroallment audio while multi-inference + if_keep_ref = True + + if inference_time>0 and if_keep_ref== False: + est = nnet(feat_mic) + else: + est = nnet(feat_mic, feat_ref) + + # Get tokens and token2wav + max, max_indices_1 = torch.max(est[0], dim=1) + recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() + + # Save the wav + target_path = os.path.join(wav_dir, egs["mic_name"] + ".wav") + print(target_path) + sf.write(target_path , recon_1, sr) + + elif conf["task"] == "PLC": + # Get feat + feat_path = os.path.join(feat_dir, egs["name"]) + ".npy" + feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0) + + # Get tokens and token2wav + est = nnet(feat) + max, max_indices_1 = torch.max(est[1], dim=1) + recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() + + # Save the wav + target_path = os.path.join(wav_dir, egs["name"] + ".wav") + print(target_path) + sf.write(target_path , recon_1, sr) + + elif conf["task"] == "SS": + # Get feat + feat_path = os.path.join(feat_dir, egs["name"]) + ".npy" + feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0) + + # Separate the first speaker + est = nnet(feat) + max, max_indices_1 = torch.max(est[1], dim=1) + recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() + + target_path_1 = os.path.join(wav_dir, egs["name"] + ".wav") + sf.write(target_path_1 , recon_1, sr) + + # Separate the second speaker, SS need at least 2 inference time in config + if inference_time > 0: + origin_feat_path = os.path.join(origin_feat_dir, egs["name"]) + ".npy" + origin_feat = torch.from_numpy(np.load(origin_feat_path)).unsqueeze(0) + + est2 = nnet(origin_feat, feat) + max, max_indices_2 = torch.max(est2[1], dim=1) + recon_2 = codec.token2wav(max_indices_2.unsqueeze(0)).squeeze().detach().cpu().numpy() + + if not os.path.exists(last_wav_dir + "s2"): + os.makedirs(last_wav_dir + "s2") + + target_path_2 = os.path.join(last_wav_dir + "s2", egs["name"] + ".wav") + sf.write(target_path_2 , recon_2, sr) + + else: + # Get feat + feat_path = os.path.join(feat_dir, egs["name"]) + ".npy" + feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0) + + # Get tokens and token2wav + est = nnet(feat) + max, max_indices_1 = torch.max(est[1], dim=1) + recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() + + # Save the wav + target_path = os.path.join(wav_dir, egs["name"] + ".wav") + print(target_path) + sf.write(target_path , recon_1, sr) + + # For next inference + last_feat_dir = feat_dir + last_wav_dir = wav_dir + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description = "Command to test separation model in Pytorch", + formatter_class = argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-conf", + type=str, + required=True, + help="Yaml configuration file for training") + parser.add_argument("--backend", + type=str, + default="nccl", + choices=["nccl", "gloo"]) + args = parser.parse_args() + # for nccl debug + os.environ["NCCL_DEBUG"] = "INFO" + run(args) \ No newline at end of file diff --git a/inference.sh b/inference.sh new file mode 100755 index 0000000000000000000000000000000000000000..c6b5ace19d1bdf4809dde0d67326585497c31be4 --- /dev/null +++ b/inference.sh @@ -0,0 +1,6 @@ +CUDA_VISIBLE_DEVICES=1 torchrun \ + --nnodes=1 \ + --nproc_per_node=1 \ + --master_port=21547 \ + inference.py \ + -conf ./config/test.yml \ No newline at end of file diff --git a/loader/__pycache__/datareader.cpython-310.pyc b/loader/__pycache__/datareader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5126eacb5110d9f8fe191ed11471e9bc91e465e Binary files /dev/null and b/loader/__pycache__/datareader.cpython-310.pyc differ diff --git a/loader/__pycache__/datareader_aec.cpython-310.pyc b/loader/__pycache__/datareader_aec.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..912b5ba3723a2e0507e5ccfa37afa18ce235ea5a Binary files /dev/null and b/loader/__pycache__/datareader_aec.cpython-310.pyc differ diff --git a/loader/__pycache__/datareader_fe.cpython-310.pyc b/loader/__pycache__/datareader_fe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae442cf46322afffd6421824c30ab5891a9dd356 Binary files /dev/null and b/loader/__pycache__/datareader_fe.cpython-310.pyc differ diff --git a/loader/__pycache__/datareader_tse.cpython-310.pyc b/loader/__pycache__/datareader_tse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..839cb11a2bc0d77495238b206c852ab7471519f0 Binary files /dev/null and b/loader/__pycache__/datareader_tse.cpython-310.pyc differ diff --git a/loader/datareader.py b/loader/datareader.py new file mode 100755 index 0000000000000000000000000000000000000000..f03e566d448e4d05a995825cbe5cac4a1496bec4 --- /dev/null +++ b/loader/datareader.py @@ -0,0 +1,65 @@ +import numpy as np +import torchaudio +import torch + +def get_firstchannel_read(path, fs=16000): + wave_data, sr = torchaudio.load(path) + if sr != fs: + wave_data = torchaudio.functional.resample(wave_data, sr, fs) + if len(wave_data.shape) > 1: + wave_data = wave_data[0,...] + wave_data = wave_data.cpu().numpy() + return wave_data + +def parse_scp(scp, path_list): + with open(scp) as fid: + for line in fid: + tmp = line.strip().split() + if len(tmp) > 1: + path_list.append({"inputs": tmp[0], "duration": tmp[1]}) + else: + path_list.append({"inputs": tmp[0]}) + +class DataReader(object): + def __init__(self, filename, sample_rate): + self.file_list = [] + self.sample_rate = sample_rate + parse_scp(filename, self.file_list) + + def extract_feature(self, path): + path = path["inputs"] + name = path.split("/")[-1].split(".")[0] + data = get_firstchannel_read(path, fs=self.sample_rate).astype(np.float32) + max_norm = np.max(np.abs(data)) + if max_norm == 0: + max_norm = 1 + data = data / max_norm + inputs = np.reshape(data, [1, data.shape[0]]) + inputs = torch.from_numpy(inputs) + + egs = { + "mix": inputs, + "max_norm": max_norm, + "name": name + } + return egs + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, index): + return self.extract_feature(self.file_list[index]) + + def get_utt2spk(self, path): + lines = open(path, "r").readlines() + for line in lines: + line = line.strip().split() + utt_path, spk_id = line[0], line[1] + self.utt2spk[utt_path] = spk_id + + def get_spk2utt(self, path): + lines = open(path, "r").readlines() + for line in lines: + line = line.strip().split() + utt_path, spk_id = line[0], line[1] + self.spk2aux[spk_id] = utt_path diff --git a/loader/datareader_aec.py b/loader/datareader_aec.py new file mode 100755 index 0000000000000000000000000000000000000000..d4a06774cef37da1ba6f6901e00b826f8cb85ab7 --- /dev/null +++ b/loader/datareader_aec.py @@ -0,0 +1,86 @@ +import librosa +import torch as th +import numpy as np +import soundfile as sf + +import sys, os +sys.path.append(os.path.dirname(__file__)) +# from speex_linear.lp_or_tde import LP_or_TDE + + +def audio(path, fs=16000): + wave_data, sr = sf.read(path) + if sr != fs: + wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs) + return wave_data + +def get_firstchannel_read(path, fs=16000): + wave_data, sr = sf.read(path) + if sr != fs: + if len(wave_data.shape) != 1: + wave_data = wave_data.transpose((1, 0)) + wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs) + if len(wave_data.shape) != 1: + wave_data = wave_data.transpose((1, 0)) + if len(wave_data.shape) > 1: + wave_data = wave_data[:, 0] + return wave_data + +def parse_scp(scp, path_list): + with open(scp) as fid: + for line in fid: + tmp = line.strip().split() + if len(tmp) > 1: + path_list.append({"inputs": tmp[0], "duration": tmp[1]}) + else: + path_list.append({"inputs": tmp[0]}) + +class DataReaderAEC(object): + def __init__(self, filename, sample_rate): #, aux_segment): # filename是不带id的待解码音频,noisy_id是带id的带解码音频,clean是带id的注册音频 + self.file_list = [] + parse_scp(filename, self.file_list) + self.sample_rate = sample_rate + + # self.aux_segment_length = aux_segment * sample_rate + + def extract_feature(self, path): + mic_path = path["inputs"] + utt_id = mic_path.split("/")[-1] + mic_name = mic_path.split("/")[-1].split(".")[0] + + ref_path = mic_path.replace("mic.wav", "lpb.wav") + ref_name = ref_path.split("/")[-1].split(".")[0] + + mic = get_firstchannel_read(mic_path, self.sample_rate).astype(np.float32) + ref = get_firstchannel_read(ref_path, self.sample_rate).astype(np.float32) + + min_len = min(mic.shape[0], ref.shape[0]) + mic = mic[:min_len] + ref = ref[:min_len] + + inputs_mic = np.reshape(mic, [1, mic.shape[0]]) + inputs_ref = np.reshape(ref, [1, ref.shape[0]]).astype(np.float32) + + + inputs_mic = th.from_numpy(inputs_mic) + inputs_ref = th.from_numpy(inputs_ref) + + # print(f'e: {inputs_e.shape}') + # print(f'mic: {inputs_mic.shape}') + # print(f'ref: {inputs_ref.shape}') + + egs = { + "mic": inputs_mic, + "ref": inputs_ref, + "utt_id": utt_id, + "mic_name": mic_name, + "ref_name": ref_name + # "max_norm": max_norm + } + return egs + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, index): + return self.extract_feature(self.file_list[index]) diff --git a/loader/datareader_tse.py b/loader/datareader_tse.py new file mode 100755 index 0000000000000000000000000000000000000000..a9e34f90b4f13ffd3f4941d1d83eeedab542d6c6 --- /dev/null +++ b/loader/datareader_tse.py @@ -0,0 +1,85 @@ +import librosa +import torch as th +import numpy as np +import soundfile as sf + +import sys, os +sys.path.append(os.path.dirname(__file__)) +# from speex_linear.lp_or_tde import LP_or_TDE + +def audio(path, fs=16000): + wave_data, sr = sf.read(path) + if sr != fs: + wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs) + return wave_data + +def get_firstchannel_read(path, fs=16000): + wave_data, sr = sf.read(path) + if sr != fs: + if len(wave_data.shape) != 1: + wave_data = wave_data.transpose((1, 0)) + wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs) + if len(wave_data.shape) != 1: + wave_data = wave_data.transpose((1, 0)) + if len(wave_data.shape) > 1: + wave_data = wave_data[:, 0] + return wave_data + +def parse_scp(scp, path_list): + with open(scp) as fid: + for line in fid: + tmp = line.strip().split() + if len(tmp) > 1: + path_list.append({"inputs": tmp[0], "duration": tmp[1]}) + else: + path_list.append({"inputs": tmp[0]}) + +class DataReaderTSE(object): + def __init__(self, filename, sample_rate): # filename是不带id的待解码音频,noisy_id是带id的带解码音频,clean是带id的注册音频 + self.file_list = [] + parse_scp(filename, self.file_list) + self.sample_rate = sample_rate + + def extract_feature(self, path): + mic_path = path["inputs"] + utt_id = mic_path.split("/")[-1] + mic_name = mic_path.split("/")[-1].split(".")[0] + + ref_path = mic_path.replace("noisy/", "enrol/") + ref_name = ref_path.split("/")[-1].split(".")[0] + + mic = get_firstchannel_read(mic_path, self.sample_rate).astype(np.float32) + ref = get_firstchannel_read(ref_path, self.sample_rate).astype(np.float32) + + if ref.shape[0] > mic.shape[0]: + min_len = mic.shape[0] + ref = ref[:min_len] + + # print(ref.shape[0]) + # print(mic.shape[0]) + + inputs_mic = np.reshape(mic, [1, mic.shape[0]]).astype(np.float32) + inputs_ref = np.reshape(ref, [1, ref.shape[0]]).astype(np.float32) + + inputs_mic = th.from_numpy(inputs_mic) + inputs_ref = th.from_numpy(inputs_ref) + + # print(f'e: {inputs_e.shape}') + # print(f'mic: {inputs_mic.shape}') + # print(f'ref: {inputs_ref.shape}') + + egs = { + "mic": inputs_mic, + "ref": inputs_ref, + "utt_id": utt_id, + "mic_name": mic_name, + "ref_name": ref_name + # "max_norm": max_norm + } + return egs + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, index): + return self.extract_feature(self.file_list[index]) diff --git a/nnet/WavLM.py b/nnet/WavLM.py new file mode 100644 index 0000000000000000000000000000000000000000..861ddb2648a57a2b125e7e472e864022b637fbab --- /dev/null +++ b/nnet/WavLM.py @@ -0,0 +1,793 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import logging +from typing import List, Optional, Tuple + +import sys,os +sys.path.append(os.path.dirname(sys.path[0])) +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm +from nnet.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + MultiheadAttention, + SamePad, + init_bert_params, + get_activation_fn, + TransposeLast, + GLU_Linear, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + + def long_term_modeling( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + + features = source.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default" + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride, padding=1) + ) + self.conv_layers.append( + torch.nn.LayerNorm([dim, idim]) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append( + torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + ) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + y = x.transpose(1, 2).clone() + x_conv = self.pos_conv(y) + x_conv = x_conv.transpose(1, 2) + x += x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, + self_attn_mask=streaming_mask, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias diff --git a/nnet/__pycache__/WavLM.cpython-310.pyc b/nnet/__pycache__/WavLM.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c333101ec60f2423cd4f90122b362a3d8c863568 Binary files /dev/null and b/nnet/__pycache__/WavLM.cpython-310.pyc differ diff --git a/nnet/__pycache__/embedding.cpython-310.pyc b/nnet/__pycache__/embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c28a0294a4afb7afdf979607d61a593e687234b3 Binary files /dev/null and b/nnet/__pycache__/embedding.cpython-310.pyc differ diff --git a/nnet/__pycache__/llase.cpython-310.pyc b/nnet/__pycache__/llase.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5462a7de0771d9b556c8d27d11bf39b2400c7fe4 Binary files /dev/null and b/nnet/__pycache__/llase.cpython-310.pyc differ diff --git a/nnet/__pycache__/modules.cpython-310.pyc b/nnet/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93dd7df59a5996d3784b2a1c048ea8b5b191b8c9 Binary files /dev/null and b/nnet/__pycache__/modules.cpython-310.pyc differ diff --git a/nnet/llase.py b/nnet/llase.py new file mode 100644 index 0000000000000000000000000000000000000000..255c895d008b6da9711dc22094b18bb33429ccde --- /dev/null +++ b/nnet/llase.py @@ -0,0 +1,104 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import sys,os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from typing import Union, Optional +from transformers import LlamaConfig, LlamaForCausalLM + +NUM_AUDIO_TOKENS = 65536 # Codebook size + +class LLM_AR(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + num_layers: int + ): + super().__init__() + self.d_model = d_model + + self.audio_linear_y = nn.Linear(1024, d_model) + self.audio_linear_x = nn.Linear(1024, d_model) + + self.Llama_config = LlamaConfig( + hidden_size=d_model*2, + intermediate_size=d_model * 4, + num_attention_heads=nhead, + num_hidden_layers=num_layers, + dropout_rate=0.1, + attention_dropout=0.1, + is_decoder=True, + use_cache=True + ) + + self.llama= LlamaForCausalLM(config=self.Llama_config) + self.predict_layer_x = nn.Linear(2*d_model, NUM_AUDIO_TOKENS) + self.predict_layer_y = nn.Linear(2*d_model, NUM_AUDIO_TOKENS) + + def forward( + self, + y: torch.Tensor, + x: Union[torch.Tensor, None] = None, + ) -> torch.Tensor: + # y = y.transpose(1,2) # if codec input use this transpose + + if x is None: + x = torch.zeros_like(y) + elif x.dim() == 2: + x = x.unsqueeze(-1) + x = x.expand_as(y) + + + y_emb = self.audio_linear_y(y) # [B, T, D] + x_emb = self.audio_linear_x(x) # [B, T, D] + + if x_emb.shape[1] < y_emb.shape[1]: + pad_length = y_emb.shape[1] - x_emb.shape[1] + x_emb= F.pad(x_emb, (0, 0, 0, pad_length), mode='constant', value=0) + + if y_emb.shape[1] < x_emb.shape[1]: + pad_length = x_emb.shape[1] - y_emb.shape[1] + y_emb= F.pad(y_emb, (0, 0, 0, pad_length), mode='constant', value=0) + + y_emb = torch.concat([x_emb, y_emb], dim = -1) # [B, T_y, D*2] + + outputs = self.llama(inputs_embeds = y_emb, output_hidden_states=True) + + dec = outputs.hidden_states[-1] # [B, T_y, D*2] + + logits_y = self.predict_layer_y(dec) # [B, T, NUM_AUDIO_TOKENS] + logits_x = self.predict_layer_x(dec) + + logits_y = logits_y.transpose(-1, -2) # [B, NUM_AUDIO_TOKENS, T] + logits_x = logits_x.transpose(-1, -2) + + return logits_y, logits_x + +if __name__=="__main__": + # Simple test + model = LLM_AR(d_model=1024, nhead=8, num_layers=16) + ce_loss = nn.CrossEntropyLoss() + + y = torch.randn([1,199,1024]) + x = torch.randn([1,99,1024]) + label = torch.from_numpy(np.random.randint(0, 300, size=[2,1,199])) + + total_params = sum(p.numel() for p in model.parameters()) + + print(f"Total Params: {total_params}") + + logits = model(y) + print(logits[0].shape) + print(logits[1].shape) + + logits = model(y,x) + print(logits[0].shape) + print(logits[1].shape) + + logits = model(y,y) + print(logits[0].shape) + print(logits[1].shape) diff --git a/nnet/modules.py b/nnet/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..936b323a49da8569406e563df05cfbf5a12d907c --- /dev/null +++ b/nnet/modules.py @@ -0,0 +1,825 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +from torch.nn import Parameter +import torch.nn.functional as F + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function + """ + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights \ No newline at end of file diff --git a/vq/__init__.py b/vq/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..82b8462e5c4c1d1e8ac0e6e6ce02589f4e55d959 --- /dev/null +++ b/vq/__init__.py @@ -0,0 +1,4 @@ +from vq.codec_encoder import CodecEncoder +from vq.codec_decoder import CodecDecoder +from vq.codec_decoder_vocos import CodecDecoderVocos +from vq.codec_encoder import CodecEncoder_Transformer,CodecEncoder_only_Transformer \ No newline at end of file diff --git a/vq/__pycache__/__init__.cpython-310.pyc b/vq/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c68796e8af84a54a6defc8fdc7abac8ec9b7fd17 Binary files /dev/null and b/vq/__pycache__/__init__.cpython-310.pyc differ diff --git a/vq/__pycache__/__init__.cpython-311.pyc b/vq/__pycache__/__init__.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..3a622e8da2ba3969a82261d2c1e6a3dbe2a87155 Binary files /dev/null and b/vq/__pycache__/__init__.cpython-311.pyc differ diff --git a/vq/__pycache__/__init__.cpython-312.pyc b/vq/__pycache__/__init__.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a1bd6ded01b7b342483f5e8ab7993e74a8b74054 Binary files /dev/null and b/vq/__pycache__/__init__.cpython-312.pyc differ diff --git a/vq/__pycache__/__init__.cpython-37.pyc b/vq/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a571424b73fb4d5c89be94267d5453b60117ee95 Binary files /dev/null and b/vq/__pycache__/__init__.cpython-37.pyc differ diff --git a/vq/__pycache__/__init__.cpython-38.pyc b/vq/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25a045acaafb7e3b9a2914e2258f7ddb2e67ac3b Binary files /dev/null and b/vq/__pycache__/__init__.cpython-38.pyc differ diff --git a/vq/__pycache__/__init__.cpython-39.pyc b/vq/__pycache__/__init__.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..28496dc7d0b52aa03def26726eda7030966289cf Binary files /dev/null and b/vq/__pycache__/__init__.cpython-39.pyc differ diff --git a/vq/__pycache__/activations.cpython-310.pyc b/vq/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebe30b67180263a6e0dd4c14d5bad1f6330a7c2d Binary files /dev/null and b/vq/__pycache__/activations.cpython-310.pyc differ diff --git a/vq/__pycache__/activations.cpython-311.pyc b/vq/__pycache__/activations.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..b68f81ee8be448f64ef4c8f7bab31d57bdd561a8 Binary files /dev/null and b/vq/__pycache__/activations.cpython-311.pyc differ diff --git a/vq/__pycache__/activations.cpython-312.pyc b/vq/__pycache__/activations.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..7b65c1363e2538d7f7da65eacf3801535d9dfe7e Binary files /dev/null and b/vq/__pycache__/activations.cpython-312.pyc differ diff --git a/vq/__pycache__/activations.cpython-37.pyc b/vq/__pycache__/activations.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b897add29bf0a6d160206ac3835ea52c7fb7a890 Binary files /dev/null and b/vq/__pycache__/activations.cpython-37.pyc differ diff --git a/vq/__pycache__/activations.cpython-38.pyc b/vq/__pycache__/activations.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a51ea75125cf19cf4eb1399cdee8e4491d245e2 Binary files /dev/null and b/vq/__pycache__/activations.cpython-38.pyc differ diff --git a/vq/__pycache__/activations.cpython-39.pyc b/vq/__pycache__/activations.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..fd07520223d095049de727a8ea397de500720406 Binary files /dev/null and b/vq/__pycache__/activations.cpython-39.pyc differ diff --git a/vq/__pycache__/blocks.cpython-310.pyc b/vq/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fae42c2f63b8f0dd029aba584da8656ebc0a218 Binary files /dev/null and b/vq/__pycache__/blocks.cpython-310.pyc differ diff --git a/vq/__pycache__/blocks.cpython-39.pyc b/vq/__pycache__/blocks.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..d205eda8af24088edadcaa771f74e9638a757b24 Binary files /dev/null and b/vq/__pycache__/blocks.cpython-39.pyc differ diff --git a/vq/__pycache__/bs_roformer5.cpython-310.pyc b/vq/__pycache__/bs_roformer5.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52c02700d0bfbfdbbe533139d780563857148d76 Binary files /dev/null and b/vq/__pycache__/bs_roformer5.cpython-310.pyc differ diff --git a/vq/__pycache__/bs_roformer5.cpython-37.pyc b/vq/__pycache__/bs_roformer5.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ece31623799ab8c594554a9fcf6fa3aac5665d64 Binary files /dev/null and b/vq/__pycache__/bs_roformer5.cpython-37.pyc differ diff --git a/vq/__pycache__/bs_roformer5.cpython-38.pyc b/vq/__pycache__/bs_roformer5.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70574c1f525b62cf89dec196202da4fcca6d00f9 Binary files /dev/null and b/vq/__pycache__/bs_roformer5.cpython-38.pyc differ diff --git a/vq/__pycache__/bs_roformer5.cpython-39.pyc b/vq/__pycache__/bs_roformer5.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..49c3fd85f007ce66b708f685487968600a40e91e Binary files /dev/null and b/vq/__pycache__/bs_roformer5.cpython-39.pyc differ diff --git a/vq/__pycache__/codec_decoder.cpython-310.pyc b/vq/__pycache__/codec_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffdb653b296c935d2db80367b64b290d50ef83c0 Binary files /dev/null and b/vq/__pycache__/codec_decoder.cpython-310.pyc differ diff --git a/vq/__pycache__/codec_decoder.cpython-311.pyc b/vq/__pycache__/codec_decoder.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..d3bba0654a05fc2d5df7f0cfa85d4e1e8f9e43a9 Binary files /dev/null and b/vq/__pycache__/codec_decoder.cpython-311.pyc differ diff --git a/vq/__pycache__/codec_decoder.cpython-312.pyc b/vq/__pycache__/codec_decoder.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..9fd232bd7d3e8217ae32122d85158c8d3d07a0ba Binary files /dev/null and b/vq/__pycache__/codec_decoder.cpython-312.pyc differ diff --git a/vq/__pycache__/codec_decoder.cpython-39.pyc b/vq/__pycache__/codec_decoder.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..8240066ea0a78b8cc0a833066f7cd33fce0c1320 Binary files /dev/null and b/vq/__pycache__/codec_decoder.cpython-39.pyc differ diff --git a/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc b/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb6f676179bd3f23d2b0648e47eaf76cb63edf1a Binary files /dev/null and b/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc differ diff --git a/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc b/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..16d7559f3e464f7cc8229cf5cf3864cf6ce5270a Binary files /dev/null and b/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc differ diff --git a/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc b/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4d14cbb3475d4adca86027ce48447562dd8bd872 Binary files /dev/null and b/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc differ diff --git a/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc b/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..1afcbb8027bec2b3ab56a26ffaa19231d775d7e9 Binary files /dev/null and b/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc differ diff --git a/vq/__pycache__/codec_encoder.cpython-310.pyc b/vq/__pycache__/codec_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6d5ca9de924ea453eacca5ab0c0619dd46c40fa Binary files /dev/null and b/vq/__pycache__/codec_encoder.cpython-310.pyc differ diff --git a/vq/__pycache__/codec_encoder.cpython-311.pyc b/vq/__pycache__/codec_encoder.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a95e106c52b918fc9879df154585f48f5b97f084 Binary files /dev/null and b/vq/__pycache__/codec_encoder.cpython-311.pyc differ diff --git a/vq/__pycache__/codec_encoder.cpython-312.pyc b/vq/__pycache__/codec_encoder.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..033919b2465e77531ac1f4340a42e8247dcf927e Binary files /dev/null and b/vq/__pycache__/codec_encoder.cpython-312.pyc differ diff --git a/vq/__pycache__/codec_encoder.cpython-37.pyc b/vq/__pycache__/codec_encoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..105e82671abfd71efea2b8ba642524348ce8fcfa Binary files /dev/null and b/vq/__pycache__/codec_encoder.cpython-37.pyc differ diff --git a/vq/__pycache__/codec_encoder.cpython-38.pyc b/vq/__pycache__/codec_encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..024fc56cf4d11d9d91f707c77b6dff07d4c223a8 Binary files /dev/null and b/vq/__pycache__/codec_encoder.cpython-38.pyc differ diff --git a/vq/__pycache__/codec_encoder.cpython-39.pyc b/vq/__pycache__/codec_encoder.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..7840da79435c2ae0bce6f6c6b7566769499ee537 Binary files /dev/null and b/vq/__pycache__/codec_encoder.cpython-39.pyc differ diff --git a/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc b/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc91616c04b61c019d6fcecf344a503fbb76daa6 Binary files /dev/null and b/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc differ diff --git a/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc b/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..d1018bb516e0ca39612d3046cc66e93d3e9fc857 Binary files /dev/null and b/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc differ diff --git a/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc b/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a79936f97258dfac4e4c1fa00bcba84459f754bf Binary files /dev/null and b/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc differ diff --git a/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc b/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..732de3e9520dcc785bc296be3a7f955d03155ee5 Binary files /dev/null and b/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc differ diff --git a/vq/__pycache__/module.cpython-310.pyc b/vq/__pycache__/module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a69a387c02bcf3924b3d76595bef2944dd9de97 Binary files /dev/null and b/vq/__pycache__/module.cpython-310.pyc differ diff --git a/vq/__pycache__/module.cpython-311.pyc b/vq/__pycache__/module.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..8bae56a74114e5ecb523e39275bf451e406e43a2 Binary files /dev/null and b/vq/__pycache__/module.cpython-311.pyc differ diff --git a/vq/__pycache__/module.cpython-312.pyc b/vq/__pycache__/module.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..67dc7809ed198ecba09efecfa433224dc7b5ea4a Binary files /dev/null and b/vq/__pycache__/module.cpython-312.pyc differ diff --git a/vq/__pycache__/module.cpython-37.pyc b/vq/__pycache__/module.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e5a49b73770a4cf1504dca7818eb8d97f282aa5 Binary files /dev/null and b/vq/__pycache__/module.cpython-37.pyc differ diff --git a/vq/__pycache__/module.cpython-38.pyc b/vq/__pycache__/module.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f2acfaa353b7e80e18f5292c53a7c4f1f334aeb Binary files /dev/null and b/vq/__pycache__/module.cpython-38.pyc differ diff --git a/vq/__pycache__/module.cpython-39.pyc b/vq/__pycache__/module.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..1e4e0bd81ad6519ea28c0c697bb725e04253bcab Binary files /dev/null and b/vq/__pycache__/module.cpython-39.pyc differ diff --git a/vq/__pycache__/residual_vq.cpython-310.pyc b/vq/__pycache__/residual_vq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d6dd3e0d0ea758f571066eaeb2b2af4e98dbfe2 Binary files /dev/null and b/vq/__pycache__/residual_vq.cpython-310.pyc differ diff --git a/vq/__pycache__/residual_vq.cpython-311.pyc b/vq/__pycache__/residual_vq.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..92069858ec62cb51a7cce7cdf1550695c656e557 Binary files /dev/null and b/vq/__pycache__/residual_vq.cpython-311.pyc differ diff --git a/vq/__pycache__/residual_vq.cpython-312.pyc b/vq/__pycache__/residual_vq.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4d81960f8bbec9c309fa08267b44245d8a5b4959 Binary files /dev/null and b/vq/__pycache__/residual_vq.cpython-312.pyc differ diff --git a/vq/__pycache__/residual_vq.cpython-39.pyc b/vq/__pycache__/residual_vq.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..1c85f8cdf25b033d78e683be4425efe1d6f63f5a Binary files /dev/null and b/vq/__pycache__/residual_vq.cpython-39.pyc differ diff --git a/vq/__pycache__/unet.cpython-312.pyc b/vq/__pycache__/unet.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..e431b555864d1029ed9735ed1bad0bae1b75c4f8 Binary files /dev/null and b/vq/__pycache__/unet.cpython-312.pyc differ diff --git a/vq/__pycache__/unet.cpython-39.pyc b/vq/__pycache__/unet.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..9e12a03c0fba3df01e7d87e249fd77a3bab599f1 Binary files /dev/null and b/vq/__pycache__/unet.cpython-39.pyc differ diff --git a/vq/activations.py b/vq/activations.py new file mode 100755 index 0000000000000000000000000000000000000000..61f2808a5466b3cf4d041059700993af5527dd29 --- /dev/null +++ b/vq/activations.py @@ -0,0 +1,120 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x \ No newline at end of file diff --git a/vq/alias_free_torch/__init__.py b/vq/alias_free_torch/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a2318b63198250856809c0cb46210a4147b829bc --- /dev/null +++ b/vq/alias_free_torch/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .filter import * +from .resample import * +from .act import * \ No newline at end of file diff --git a/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc b/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55b73cf35cfea876c394ef287e398c54f091c47c Binary files /dev/null and b/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc differ diff --git a/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc b/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..8c624a5d6d4dc886d41f429249030bde5539887e Binary files /dev/null and b/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc b/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a43e41d400b719ffcbd2edeb74835601d56febf5 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc differ diff --git a/vq/alias_free_torch/__pycache__/__init__.cpython-37.pyc b/vq/alias_free_torch/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8de4524f53ac458650478855e7e8f588243935a7 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/__init__.cpython-37.pyc differ diff --git a/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc b/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..667519d1bbca0db750319a8d119800bec758a8c9 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc differ diff --git a/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc b/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..88b0eba48452d5c08f2ffcca460880a45db0c925 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc differ diff --git a/vq/alias_free_torch/__pycache__/act.cpython-310.pyc b/vq/alias_free_torch/__pycache__/act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d27e6ced61b5af60bc0acf38ed6755d082001805 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/act.cpython-310.pyc differ diff --git a/vq/alias_free_torch/__pycache__/act.cpython-311.pyc b/vq/alias_free_torch/__pycache__/act.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..6db1958197fb0a213b4ff1df75dcfbf89317ac7d Binary files /dev/null and b/vq/alias_free_torch/__pycache__/act.cpython-311.pyc differ diff --git a/vq/alias_free_torch/__pycache__/act.cpython-312.pyc b/vq/alias_free_torch/__pycache__/act.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..8cfb68556b30eb0c90e80734d84ce4630c84f6fe Binary files /dev/null and b/vq/alias_free_torch/__pycache__/act.cpython-312.pyc differ diff --git a/vq/alias_free_torch/__pycache__/act.cpython-37.pyc b/vq/alias_free_torch/__pycache__/act.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..403a8fdca04b919f787340923154694220f21505 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/act.cpython-37.pyc differ diff --git a/vq/alias_free_torch/__pycache__/act.cpython-38.pyc b/vq/alias_free_torch/__pycache__/act.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a017eef192c8c6a87c8721c254629defa879f219 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/act.cpython-38.pyc differ diff --git a/vq/alias_free_torch/__pycache__/act.cpython-39.pyc b/vq/alias_free_torch/__pycache__/act.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..777550951bf21859da3845a422a30a631c3ae5c8 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/act.cpython-39.pyc differ diff --git a/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc b/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f31500966250c153bc78aa7e3abc2b2dead1485c Binary files /dev/null and b/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc differ diff --git a/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc b/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..ce1c1840a49147f4771cb6c43c75fafdff331ebf Binary files /dev/null and b/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc differ diff --git a/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc b/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..e7c8b0886d3813d2f3184509934becbdd6a7cf6f Binary files /dev/null and b/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc differ diff --git a/vq/alias_free_torch/__pycache__/filter.cpython-37.pyc b/vq/alias_free_torch/__pycache__/filter.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a6c6d42fa3cc1505cda48620e0fd6763d43d5e3 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/filter.cpython-37.pyc differ diff --git a/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc b/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1309f6d11298214af247a586bc4b1c662d668737 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc differ diff --git a/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc b/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..6e9b10d722d25e9c8f7dc3c6f6fe95936055e36b Binary files /dev/null and b/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc differ diff --git a/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc b/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8f4021159da37f82e2c0ca6f48d9f825573a191 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc differ diff --git a/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc b/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..51f3f964031846bf45ae58bc751a8d74afe4495f Binary files /dev/null and b/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc differ diff --git a/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc b/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..bd086d178282ace7cda97ede64bd598c43012dab Binary files /dev/null and b/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc differ diff --git a/vq/alias_free_torch/__pycache__/resample.cpython-37.pyc b/vq/alias_free_torch/__pycache__/resample.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad54f2206095699ca0b516b550b07aa95922db1e Binary files /dev/null and b/vq/alias_free_torch/__pycache__/resample.cpython-37.pyc differ diff --git a/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc b/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e002c7e51c834ff56a60d104482fb17d5e606cb Binary files /dev/null and b/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc differ diff --git a/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc b/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..11487a7ab175df19af75772e09aae63ede9731d3 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc differ diff --git a/vq/alias_free_torch/act.py b/vq/alias_free_torch/act.py new file mode 100755 index 0000000000000000000000000000000000000000..028debd697dd60458aae75010057df038bd3518a --- /dev/null +++ b/vq/alias_free_torch/act.py @@ -0,0 +1,28 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x \ No newline at end of file diff --git a/vq/alias_free_torch/filter.py b/vq/alias_free_torch/filter.py new file mode 100755 index 0000000000000000000000000000000000000000..7ad6ea87c1f10ddd94c544037791d7a4634d5ae1 --- /dev/null +++ b/vq/alias_free_torch/filter.py @@ -0,0 +1,95 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + + return out \ No newline at end of file diff --git a/vq/alias_free_torch/resample.py b/vq/alias_free_torch/resample.py new file mode 100755 index 0000000000000000000000000000000000000000..750e6c3402cc5ac939c4b9d075246562e0e1d1a7 --- /dev/null +++ b/vq/alias_free_torch/resample.py @@ -0,0 +1,49 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx \ No newline at end of file diff --git a/vq/blocks.py b/vq/blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..3996fec146cbf4f3caef4f9da3bbbe04f7729bbb --- /dev/null +++ b/vq/blocks.py @@ -0,0 +1,183 @@ +from typing import Callable, Sequence, Type, Union + +import numpy as np +import torch +import torch.nn as nn + +ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]] + + +class FeedForwardModule(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.net = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class Residual(nn.Module): + + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.module(x) + x + + +class DilatedConvolutionalUnit(FeedForwardModule): + + def __init__( + self, + hidden_dim: int, + dilation: int, + kernel_size: int, + activation: ModuleFactory, + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + self.net = nn.Sequential( + activation(), + normalization( + nn.Conv1d( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=kernel_size, + dilation=dilation, + padding=((kernel_size - 1) * dilation) // 2, + )), + activation(), + nn.Conv1d(in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=1), + ) + + +class UpsamplingUnit(FeedForwardModule): + + def __init__( + self, + input_dim: int, + output_dim: int, + stride: int, + activation: ModuleFactory, + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + self.net = nn.Sequential( + activation(), + normalization( + nn.ConvTranspose1d( + in_channels=input_dim, + out_channels=output_dim, + kernel_size=2 * stride, + stride=stride, + padding=stride // 2+ stride % 2, + output_padding=1 if stride % 2 != 0 else 0 + ))) + + +class DownsamplingUnit(FeedForwardModule): + + def __init__( + self, + input_dim: int, + output_dim: int, + stride: int, + activation: ModuleFactory, + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + self.net = nn.Sequential( + activation(), + normalization( + nn.Conv1d( + in_channels=input_dim, + out_channels=output_dim, + kernel_size=2 * stride, + stride=stride, + padding= stride // 2+ stride % 2, + + ))) + + +class DilatedResidualEncoder(FeedForwardModule): + + def __init__( + self, + capacity: int, + dilated_unit: Type[DilatedConvolutionalUnit], + downsampling_unit: Type[DownsamplingUnit], + ratios: Sequence[int], + dilations: Union[Sequence[int], Sequence[Sequence[int]]], + pre_network_conv: Type[nn.Conv1d], + post_network_conv: Type[nn.Conv1d], + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + channels = capacity * 2**np.arange(len(ratios) + 1) + + dilations_list = self.normalize_dilations(dilations, ratios) + + net = [normalization(pre_network_conv(out_channels=channels[0]))] + + for ratio, dilations, input_dim, output_dim in zip( + ratios, dilations_list, channels[:-1], channels[1:]): + for dilation in dilations: + net.append(Residual(dilated_unit(input_dim, dilation))) + net.append(downsampling_unit(input_dim, output_dim, ratio)) + + net.append(post_network_conv(in_channels=output_dim)) + + self.net = nn.Sequential(*net) + + @staticmethod + def normalize_dilations(dilations: Union[Sequence[int], + Sequence[Sequence[int]]], + ratios: Sequence[int]): + if isinstance(dilations[0], int): + dilations = [dilations for _ in ratios] + return dilations + + +class DilatedResidualDecoder(FeedForwardModule): + + def __init__( + self, + capacity: int, + dilated_unit: Type[DilatedConvolutionalUnit], + upsampling_unit: Type[UpsamplingUnit], + ratios: Sequence[int], + dilations: Union[Sequence[int], Sequence[Sequence[int]]], + pre_network_conv: Type[nn.Conv1d], + post_network_conv: Type[nn.Conv1d], + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + channels = capacity * 2**np.arange(len(ratios) + 1) + channels = channels[::-1] + + dilations_list = self.normalize_dilations(dilations, ratios) + dilations_list = dilations_list[::-1] + + net = [pre_network_conv(out_channels=channels[0])] + + for ratio, dilations, input_dim, output_dim in zip( + ratios, dilations_list, channels[:-1], channels[1:]): + net.append(upsampling_unit(input_dim, output_dim, ratio)) + for dilation in dilations: + net.append(Residual(dilated_unit(output_dim, dilation))) + + net.append(normalization(post_network_conv(in_channels=output_dim))) + + self.net = nn.Sequential(*net) + + @staticmethod + def normalize_dilations(dilations: Union[Sequence[int], + Sequence[Sequence[int]]], + ratios: Sequence[int]): + if isinstance(dilations[0], int): + dilations = [dilations for _ in ratios] + return dilations \ No newline at end of file diff --git a/vq/bs_roformer5.py b/vq/bs_roformer5.py new file mode 100755 index 0000000000000000000000000000000000000000..08aa016d731a6a5cae3e4f38514d97187ad7adb4 --- /dev/null +++ b/vq/bs_roformer5.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Module, ModuleList +import torchaudio +from einops import rearrange +import numpy as np +# from rotary_embedding_torch import RotaryEmbedding + +from torchtune.modules import RotaryPositionalEmbeddings + + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + r"""https://github.com/meta-llama/llama/blob/main/llama/model.py""" + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm_x = torch.mean(x ** 2, dim=-1, keepdim=True) + output = x * torch.rsqrt(norm_x + self.eps) * self.weight + return output + + + +class MLP(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + + self.fc1 = nn.Linear(dim, 4 * dim, bias=False) + self.silu = nn.SiLU() + self.fc2 = nn.Linear(4 * dim, dim, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = self.silu(x) + x = self.fc2(x) + return x + + +class Attention(nn.Module): + + def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): + super().__init__() + + assert dim % n_heads == 0 + + self.n_heads = n_heads + self.dim = dim + self.rotary_embed = rotary_embed + + self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + assert self.flash, "Must have flash attention." + + self.c_attn = nn.Linear(dim, 3 * dim, bias=False) + self.c_proj = nn.Linear(dim, dim, bias=False) + + def forward(self, x): + r""" + Args: + x: (b, t, h*d) + + Constants: + b: batch_size + t: time steps + r: 3 + h: heads_num + d: heads_dim + """ + B, T, C = x.size() + + q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads) + # q, k, v: (b, h, t, d) + + q = self.rotary_embed(q) + k = self.rotary_embed(k) + + if self.flash: + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False) + + y = rearrange(y, 'b h t d -> b t (h d)') + + y = self.c_proj(y) + # shape: (b, t, h*d) + + return y + + +class TransformerBlock(nn.Module): + def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): + + super().__init__() + self.dim = dim + self.n_heads = n_heads + + self.att_norm = RMSNorm(dim) + self.ffn_norm = RMSNorm(dim) + self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed) + self.mlp = MLP(dim=dim) + + + def forward( + self, + x: torch.Tensor, + ): + x = x + self.att(self.att_norm(x)) + x = x + self.mlp(self.ffn_norm(x)) + return x + + +if __name__ == '__main__': + rotary_embed_128 = RotaryPositionalEmbeddings(dim=128) + transformer_block = TransformerBlock( + dim=1024, + n_heads=8, + rotary_embed=rotary_embed_128 + ) + x = torch.randn(2, 128, 1024) + y = transformer_block(x) + print(y.shape) + c=1 \ No newline at end of file diff --git a/vq/codec_decoder.py b/vq/codec_decoder.py new file mode 100755 index 0000000000000000000000000000000000000000..e0f11327ac328a9ac3d33cce17d65336a5437d0d --- /dev/null +++ b/vq/codec_decoder.py @@ -0,0 +1,304 @@ +import sys +sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv1d_transformer') +import numpy as np +import torch +import torch.nn as nn +from vq.residual_vq import ResidualVQ +from vq.module import WNConv1d, DecoderBlock, ResLSTM +from vq.alias_free_torch import * +from vq import activations +import vq.blocks as blocks +from torch.nn import utils + +from vq.bs_roformer5 import TransformerBlock + +from torchtune.modules import RotaryPositionalEmbeddings + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + +class CodecDecoder(nn.Module): + def __init__(self, + in_channels=1024, + upsample_initial_channel=1536, + ngf=48, + use_rnn=True, + rnn_bidirectional=False, + rnn_num_layers=2, + up_ratios=(5, 4, 4, 4, 2), + dilations=(1, 3, 9), + vq_num_quantizers=1, + vq_dim=2048, + vq_commit_weight=0.25, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_size=16384, + codebook_dim=32, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf = ngf + self.up_ratios = up_ratios + + self.quantizer = ResidualVQ( + num_quantizers=vq_num_quantizers, + dim=vq_dim, # double the dim for acousitc and semantic + codebook_size=codebook_size, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + ) + channels = upsample_initial_channel + layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)] + + if use_rnn: + layers += [ + ResLSTM(channels, + num_layers=rnn_num_layers, + bidirectional=rnn_bidirectional + ) + ] + + for i, stride in enumerate(up_ratios): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride, dilations)] + + layers += [ + Activation1d(activation=activations.SnakeBeta(output_dim, alpha_logscale=True)), + WNConv1d(output_dim, 1, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + self.reset_parameters() + + def forward(self, x, vq=True): + if vq is True: + x, q, commit_loss = self.quantizer(x) + return x, q, commit_loss + x = self.model(x) + return x + + def vq2emb(self, vq): + self.quantizer = self.quantizer.eval() + x = self.quantizer.vq2emb(vq) + return x + + def get_emb(self): + self.quantizer = self.quantizer.eval() + embs = self.quantizer.get_emb() + return embs + + def inference_vq(self, vq): + x = vq[None,:,:] + x = self.model(x) + return x + + def inference_0(self, x): + x, q, loss, perp = self.quantizer(x) + x = self.model(x) + return x, None + + def inference(self, x): + x = self.model(x) + return x, None + + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + +class CodecDecoder_oobleck_Transformer(nn.Module): + def __init__(self, + ngf=32, + up_ratios=(5, 4, 4, 4, 2), + dilations=(1, 3, 9), + vq_num_quantizers=1, + vq_dim=1024, + vq_commit_weight=0.25, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_size=16384, + codebook_dim=16, + hidden_dim=1024, + depth=12, + heads=16, + pos_meb_dim=64, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.capacity = ngf + self.up_ratios = up_ratios + self.hidden_dim = hidden_dim + self.quantizer = ResidualVQ( + num_quantizers=vq_num_quantizers, + dim=vq_dim, # double the dim for acousitc and semantic + codebook_size=codebook_size, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + ) + + time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + transformer_blocks = [ + TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + for _ in range(depth) + ] + + self.transformers = nn.Sequential(*transformer_blocks) + + self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + + self.conv_blocks = blocks.DilatedResidualDecoder( + capacity=self.capacity, + dilated_unit=self.dilated_unit, + upsampling_unit=self.upsampling_unit, + ratios=up_ratios, # 逆转编码器的下采样比率 + dilations=dilations, + pre_network_conv=self.pre_conv, + post_network_conv=self.post_conv, + ) + + + + self.reset_parameters() + + def forward(self, x, vq=True): + if vq is True: + x, q, commit_loss = self.quantizer(x) + return x, q, commit_loss + x= self.transformers(x) + x = self.final_layer_norm(x) + x = x.permute(0, 2, 1) + x = self.conv_blocks(x) + return x + + def vq2emb(self, vq): + self.quantizer = self.quantizer.eval() + x = self.quantizer.vq2emb(vq) + return x + + def get_emb(self): + self.quantizer = self.quantizer.eval() + embs = self.quantizer.get_emb() + return embs + + def inference_vq(self, vq): + x = vq[None,:,:] + x = self.model(x) + return x + + def inference_0(self, x): + x, q, loss, perp = self.quantizer(x) + x = self.model(x) + return x, None + + def inference(self, x): + x = self.model(x) + return x, None + + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + def pre_conv(self, out_channels): + return nn.Conv1d(in_channels=self.hidden_dim, out_channels=out_channels, kernel_size=1) + + # 定义后处理卷积层,将模型的输出映射到最终的输出通道数 + def post_conv(self,in_channels): + return nn.Conv1d(in_channels=in_channels, out_channels=1, kernel_size=1) + + def dilated_unit(self, hidden_dim, dilation): + return blocks.DilatedConvolutionalUnit( + hidden_dim=hidden_dim, + dilation=dilation, + kernel_size=3, + activation=nn.ReLU , + normalization=utils.weight_norm + ) + + # 定义上采样单元 + def upsampling_unit(self,input_dim, output_dim, stride): + return blocks.UpsamplingUnit( + input_dim=input_dim, + output_dim=output_dim, + stride=stride, + activation=nn.ReLU , + normalization=utils.weight_norm + ) + +def main(): + # 设置设备 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # 初始化模型 + model = CodecDecoder_oobleck_Transformer().to(device) + print("Model initialized.") + + # 创建测试输入: batch_size x in_channels x sequence_length + batch_size = 2 + in_channels = 1024 + sequence_length = 100 # 示例长度,可以根据需要调整 + dummy_input = torch.randn(batch_size, sequence_length, in_channels).to(device) + print(f"Dummy input shape: {dummy_input.shape}") + + # 将模型设为评估模式 + model.eval() + + + + output_no_vq = model(dummy_input, vq=False) + c=1 + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/vq/codec_decoder_vocos.py b/vq/codec_decoder_vocos.py new file mode 100755 index 0000000000000000000000000000000000000000..83aac48f0bbe4a03b82ddcb0c424f31acf1a3b54 --- /dev/null +++ b/vq/codec_decoder_vocos.py @@ -0,0 +1,632 @@ +import sys +sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv_transformer_vocos') +import numpy as np +import torch +import torch.nn as nn +from vq.residual_vq import ResidualVQ +from vq.module import WNConv1d, DecoderBlock, ResLSTM +from vq.alias_free_torch import * +from vq import activations +from typing import Optional +from vq.module import ConvNeXtBlock, AdaLayerNorm +from vq.bs_roformer5 import TransformerBlock +# from rotary_embedding_torch import RotaryEmbedding +from torchtune.modules import RotaryPositionalEmbeddings +from vector_quantize_pytorch import ResidualFSQ +from torch.nn import Module, ModuleList +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + assert (window_envelope > 1e-11).all() + y = y / window_envelope + + return y + + + +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class ISTFTHead(FourierHead): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = torch.nn.Linear(dim, out_dim) + self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x_pred = self.out(x ) + # x_pred = x + x_pred = x_pred.transpose(1, 2) + mag, p = x_pred.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes + # wrapping happens here. These two lines produce real and imaginary value + x = torch.cos(p) + y = torch.sin(p) + # recalculating phase here does not produce anything new + # only costs time + # phase = torch.atan2(y, x) + # S = mag * torch.exp(phase * 1j) + # better directly produce the complex value + S = mag * (x + 1j * y) + audio = self.istft(S) + return audio.unsqueeze(1),x_pred + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv1d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb=None): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h = q.shape + q = q.permute(0, 2, 1) # b,hw,c + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + + h_ = self.proj_out(h_) + + return x + h_ + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + + +class Backbone(nn.Module): + """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + + Returns: + Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, + and H denotes the model dimension. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class VocosBackbone(Backbone): + """ + Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. + num_layers (int): Number of ConvNeXtBlock layers. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional model. Defaults to None. + """ + + def __init__( + self, hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64): + super().__init__() + + self.embed = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3) + + + + self.temb_ch = 0 + block_in = hidden_dim + dropout = 0.1 + + prior_net : tp.List[nn.Module] = [ + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ] + self.prior_net = nn.Sequential(*prior_net) + + depth = depth + time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + + transformer_blocks = [ + TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + for _ in range(depth) + ] + + + self.transformers = nn.Sequential(*transformer_blocks) + self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + post_net : tp.List[nn.Module] = [ + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ] + self.post_net = nn.Sequential(*post_net) + + def forward(self, x: torch.Tensor ) -> torch.Tensor: + x = x.transpose(1, 2) + x = self.embed(x) + x = self.prior_net(x) + x = x.transpose(1, 2) + x= self.transformers(x) + x = x.transpose(1, 2) + x = self.post_net(x) + x = x.transpose(1, 2) + x = self.final_layer_norm(x) + return x + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + +class CodecDecoderVocos(nn.Module): + def __init__(self, + hidden_dim=1024, + depth=12, + heads=16, + pos_meb_dim=64, + hop_length=320, + vq_num_quantizers=1, + vq_dim=2048, #1024 2048 + vq_commit_weight=0.25, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_size=16384, + codebook_dim=16, + ): + super().__init__() + self.hop_length = hop_length + + self.quantizer = ResidualFSQ( + dim = vq_dim, + levels = [4, 4, 4, 4, 4,4,4,4], + num_quantizers = 1 + ) + + # self.quantizer = ResidualVQ( + # num_quantizers=vq_num_quantizers, + # dim=vq_dim, + # codebook_size=codebook_size, + # codebook_dim=codebook_dim, + # threshold_ema_dead_code=2, + # commitment=vq_commit_weight, + # weight_init=vq_weight_init, + # full_commit_loss=vq_full_commit_loss, + # ) + + + self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim) + + self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same") + + self.reset_parameters() + + def forward(self, x, vq=True): + if vq is True: + # x, q, commit_loss = self.quantizer(x) + x = x.permute(0, 2, 1) + x, q = self.quantizer(x) + x = x.permute(0, 2, 1) + q = q.permute(0, 2, 1) + return x, q, None + x = self.backbone(x) + x,_ = self.head(x) + + return x ,_ + + def vq2emb(self, vq): + self.quantizer = self.quantizer.eval() + x = self.quantizer.vq2emb(vq) + return x + + def get_emb(self): + self.quantizer = self.quantizer.eval() + embs = self.quantizer.get_emb() + return embs + + def inference_vq(self, vq): + x = vq[None,:,:] + x = self.model(x) + return x + + def inference_0(self, x): + x, q, loss, perp = self.quantizer(x) + x = self.model(x) + return x, None + + def inference(self, x): + x = self.model(x) + return x, None + + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + + +class CodecDecoderVocos_transpose(nn.Module): + def __init__(self, + hidden_dim=1024, + depth=12, + heads=16, + pos_meb_dim=64, + hop_length=320, + vq_num_quantizers=1, + vq_dim=1024, #1024 2048 + vq_commit_weight=0.25, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_size=16384, + codebook_dim=16, + ): + super().__init__() + self.hop_length = hop_length + + + self.quantizer = ResidualVQ( + num_quantizers=vq_num_quantizers, + dim=vq_dim, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + ) + + + self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim) + + self.inverse_mel_conv = nn.Sequential( + nn.GELU(), + nn.ConvTranspose1d( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=2, + padding=1, + output_padding=1 # 确保输出长度与编码前匹配 + ), + nn.GELU(), + nn.ConvTranspose1d( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + padding=1 + ) + ) + + self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same") + + self.reset_parameters() + + def forward(self, x, vq=True): + if vq is True: + x, q, commit_loss = self.quantizer(x) + return x, q, commit_loss + x = self.backbone(x) + x,_ = self.head(x) + + return x ,_ + + def vq2emb(self, vq): + self.quantizer = self.quantizer.eval() + x = self.quantizer.vq2emb(vq) + return x + + def get_emb(self): + self.quantizer = self.quantizer.eval() + embs = self.quantizer.get_emb() + return embs + + def inference_vq(self, vq): + x = vq[None,:,:] + x = self.model(x) + return x + + def inference_0(self, x): + x, q, loss, perp = self.quantizer(x) + x = self.model(x) + return x, None + + def inference(self, x): + x = self.model(x) + return x, None + + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + + + +def main(): + # 设置设备 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # 初始化模型 + model = CodecDecoderVocos_transpose().to(device) + print("Model initialized.") + + # 创建测试输入: batch_size x in_channels x sequence_length + batch_size = 2 + in_channels = 1024 + sequence_length = 50 # 示例长度,可以根据需要调整 + dummy_input = torch.randn(batch_size, in_channels, sequence_length).to(device) + print(f"Dummy input shape: {dummy_input.shape}") + + # 将模型设为评估模式 + model.eval() + + # 前向传播(使用 VQ) + # with torch.no_grad(): + # try: + # output, q, commit_loss = model(dummy_input, vq=True) + # print("Forward pass with VQ:") + # print(f"Output shape: {output.shape}") + # print(f"Quantized codes shape: {q.shape}") + # print(f"Commitment loss: {commit_loss}") + # except Exception as e: + # print(f"Error during forward pass with VQ: {e}") + + # 前向传播(不使用 VQ) + with torch.no_grad(): + # try: + output_no_vq = model(dummy_input, vq=False) + print("\nForward pass without VQ:") + print(f"Output shape: {output_no_vq.shape}") + c=1 + # except Exception as e: + # print(f"Error during forward pass without VQ: {e}") + + + # model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) + # model_size_mb = model_size_bytes / (1024 ** 2) + # print(f"Model size: {model_size_bytes} bytes ({model_size_mb:.2f} MB)") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/vq/codec_encoder.py b/vq/codec_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..5657697f8497333884ca501dd28a8b0a931e93d9 --- /dev/null +++ b/vq/codec_encoder.py @@ -0,0 +1,335 @@ +import sys +sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv1d_transformer') +import torch +from torch import nn +import numpy as np +from vq.module import WNConv1d, EncoderBlock, ResLSTM +from vq.alias_free_torch import * +from vq import activations +from vq.bs_roformer5 import TransformerBlock +# from rotary_embedding_torch import RotaryEmbedding +from torchtune.modules import RotaryPositionalEmbeddings +import vq.blocks as blocks +from torch.nn import utils +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + +class CodecEncoder(nn.Module): + def __init__(self, + ngf=48, + use_rnn=True, + rnn_bidirectional=False, + rnn_num_layers=2, + up_ratios=(2, 2, 4, 4, 5), + dilations=(1, 3, 9), + out_channels=1024): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf = ngf + self.up_ratios = up_ratios + + # Create first convolution + d_model = ngf + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for i, stride in enumerate(up_ratios): + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride, dilations=dilations)] + # RNN + if use_rnn: + self.block += [ + ResLSTM(d_model, + num_layers=rnn_num_layers, + bidirectional=rnn_bidirectional + ) + ] + # Create last convolution + self.block += [ + Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), + WNConv1d(d_model, out_channels, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + self.reset_parameters() + + def forward(self, x): + out = self.block(x) + return out + + def inference(self, x): + return self.block(x) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + +class Transpose(nn.Module): + def __init__(self, dim1, dim2): + super(Transpose, self).__init__() + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, x): + return x.transpose(self.dim1, self.dim2) + +class CodecEncoder_Transformer(nn.Module): + def __init__(self, + ngf=48, + up_ratios=[2, 2, 4, 4, 5], + dilations=(1, 3, 9), + hidden_dim=1024, + depth=12, + heads=12, + pos_meb_dim=64, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf =ngf + self.up_ratios = up_ratios + + d_model = ngf + self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + + for i, stride in enumerate(up_ratios): + d_model *= 2 + self.conv_blocks += [EncoderBlock(d_model, stride=stride, dilations=dilations)] + + self.conv_blocks = nn.Sequential(*self.conv_blocks) + + + # time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + + # transformer_blocks = [ + # TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + # for _ in range(depth) + # ] + + + # self.transformers = nn.Sequential(*transformer_blocks) + + # self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + + self.conv_final_block = [ + Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), + WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1), + ] + self.conv_final_block = nn.Sequential(*self.conv_final_block) + + self.reset_parameters() + + def forward(self, x): + x = self.conv_blocks(x) + # x = x.permute(0, 2, 1) + # x= self.transformers(x) + # x = self.final_layer_norm(x) + # x = x.permute(0, 2, 1) + x = self.conv_final_block (x) + x = x.permute(0, 2, 1) + return x + + def inference(self, x): + return self.block(x) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + + +class Codec_oobleck_Transformer(nn.Module): + def __init__(self, + ngf=32, + up_ratios=(2, 2,4,4, 5), + dilations=(1, 3, 9), + hidden_dim=1024, + depth=12, + heads=16, + pos_meb_dim=64, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf =ngf + self.up_ratios = up_ratios + self.hidden_dim = hidden_dim + + + self.conv_blocks = blocks.DilatedResidualEncoder( + capacity=ngf, + dilated_unit=self.dilated_unit, + downsampling_unit=self.downsampling_unit, + ratios=up_ratios, + dilations=dilations, + pre_network_conv=self.pre_conv, + post_network_conv=self.post_conv, + ) + + + time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + transformer_blocks = [ + TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + for _ in range(depth) + ] + + self.transformers = nn.Sequential(*transformer_blocks) + + self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + + + self.reset_parameters() + + def forward(self, x): + x = self.conv_blocks(x) + x = x.permute(0, 2, 1) + x= self.transformers(x) + x = self.final_layer_norm(x) + return x + + def inference(self, x): + return self.block(x) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + def dilated_unit(self,hidden_dim, dilation): + return blocks.DilatedConvolutionalUnit(hidden_dim, + dilation, + kernel_size=3, + activation=nn.ReLU, + normalization=utils.weight_norm) + + def downsampling_unit(self, input_dim: int, output_dim: int, stride: int): + return blocks.DownsamplingUnit(input_dim, + output_dim, + stride, + nn.ReLU, + normalization=utils.weight_norm) + + def pre_conv(self,out_channels): + return nn.Conv1d(1, out_channels, 1) + + def post_conv(self,in_channels): + return nn.Conv1d(in_channels, self.hidden_dim, 1) + + + + + +class CodecEncoder_only_Transformer(nn.Module): + def __init__(self,hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64): + super().__init__() + # self.embed = nn.Linear(input_dim, hidden_dim )input_dim=300, + + depth = depth + time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + + transformer_blocks = [ + TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + for _ in range(depth) + ] + + + self.transformers = nn.Sequential(*transformer_blocks) + + self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + + def forward(self, x: torch.Tensor ) -> torch.Tensor: + # x = self.embed(x) + + + x= self.transformers(x) + x = self.final_layer_norm(x) + + return x + + + + + + + +def get_model_size(model): + # 计算总参数数 + total_params = sum(p.numel() for p in model.parameters()) + + # 假设每个参数都是32位浮点数,计算模型大小(以字节为单位) + model_size_bytes = total_params # 每个参数4字节 + + # 转换为更易读的单位(例如,MB) + model_size_mb = model_size_bytes / (1024 ** 2) + + return total_params, model_size_mb + +if __name__ == '__main__': + model = Codec_oobleck_Transformer() + x = torch.randn(1, 1, 16000) # example input tensor + output = model(x) + print("Output shape:", output.shape) diff --git a/vq/factorized_vector_quantize.py b/vq/factorized_vector_quantize.py new file mode 100755 index 0000000000000000000000000000000000000000..35f0c66736112f771a1933ca7e156b8cd5259e66 --- /dev/null +++ b/vq/factorized_vector_quantize.py @@ -0,0 +1,109 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +class FactorizedVectorQuantize(nn.Module): + def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.commitment = commitment + + if dim != self.codebook_dim: + self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim)) + self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim)) + else: + self.in_proj = nn.Identity() + self.out_proj = nn.Identity() + self._codebook = nn.Embedding(codebook_size, self.codebook_dim) + + @property + def codebook(self): + return self._codebook + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + # transpose since we use linear + + z = rearrange(z, "b d t -> b t d") + + # Factorized codes project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x T x D) + z_e = rearrange(z_e, "b t d -> b d t") + z_q, indices = self.decode_latents(z_e) + + + if self.training: + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction='none').mean([1, 2]) * self.commitment + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction='none').mean([1, 2]) + commit_loss = commitment_loss + codebook_loss + else: + commit_loss = torch.zeros(z.shape[0], device = z.device) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = rearrange(z_q, "b d t -> b t d") + z_q = self.out_proj(z_q) + z_q = rearrange(z_q, "b t d -> b d t") + + return z_q, indices, commit_loss + + def vq2emb(self, vq, proj=True): + emb = self.embed_code(vq) + if proj: + emb = self.out_proj(emb) + return emb + + def get_emb(self): + return self.codebook.weight + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices \ No newline at end of file diff --git a/vq/module.py b/vq/module.py new file mode 100755 index 0000000000000000000000000000000000000000..0c4f69b351abbc3906ced487f4609ed784c29975 --- /dev/null +++ b/vq/module.py @@ -0,0 +1,420 @@ +import torch.nn as nn +from einops import rearrange +from . import activations +from .alias_free_torch import * +from torch.nn.utils import weight_norm + +from typing import Optional, Tuple + +from torch.nn.utils import weight_norm, remove_weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + return x + self.block(x) + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1, dilations = (1, 3, 9)): + super().__init__() + runits = [ResidualUnit(dim // 2, dilation=d) for d in dilations] + self.block = nn.Sequential( + *runits, + Activation1d(activation=activations.SnakeBeta(dim//2, alpha_logscale=True)), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + ), + ) + + def forward(self, x): + return self.block(x) + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, dilations = (1, 3, 9)): + super().__init__() + self.block = nn.Sequential( + Activation1d(activation=activations.SnakeBeta(input_dim, alpha_logscale=True)), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding= stride % 2, + ) + ) + self.block.extend([ResidualUnit(output_dim, dilation=d) for d in dilations]) + + def forward(self, x): + return self.block(x) + +class ResLSTM(nn.Module): + def __init__(self, dimension: int, + num_layers: int = 2, + bidirectional: bool = False, + skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension if not bidirectional else dimension // 2, + num_layers, batch_first=True, + bidirectional=bidirectional) + + def forward(self, x): + """ + Args: + x: [B, F, T] + + Returns: + y: [B, F, T] + """ + x = rearrange(x, "b f t -> b t f") + y, _ = self.lstm(x) + if self.skip: + y = y + x + y = rearrange(y, "b t f -> b f t") + return y + + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: float, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + num_embeddings (int): Number of embeddings. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding_id) + shift = self.shift(cond_embedding_id) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale + shift + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: Tuple[int, int, int] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: Optional[float] = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + ] + ) + + self.gamma = nn.ParameterList( + [ + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) + + + +class SemanticEncoder(nn.Module): + def __init__( + self, + input_channels: int, + code_dim: int, + encode_channels: int, + kernel_size: int = 3, + bias: bool = True, + ): + super(SemanticEncoder, self).__init__() + + # 初始卷积,将 input_channels 映射到 encode_channels + self.initial_conv = nn.Conv1d( + in_channels=input_channels, + out_channels=encode_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + + # 残差块 + self.residual_blocks = nn.Sequential( + nn.ReLU(inplace=True), + nn.Conv1d( + encode_channels, + encode_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=bias + ), + nn.ReLU(inplace=True), + nn.Conv1d( + encode_channels, + encode_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=bias + ) + ) + + # 最终卷积,将 encode_channels 映射到 code_dim + self.final_conv = nn.Conv1d( + in_channels=encode_channels, + out_channels=code_dim, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + + def forward(self, x): + """ + 前向传播方法。 + + Args: + x (Tensor): 输入张量,形状为 (Batch, Input_channels, Length) + + Returns: + Tensor: 编码后的张量,形状为 (Batch, Code_dim, Length) + """ + x = self.initial_conv(x) # (Batch, Encode_channels, Length) + x = self.residual_blocks(x) + x # 残差连接 + x = self.final_conv(x) # (Batch, Code_dim, Length) + return x + +class SemanticDecoder(nn.Module): + def __init__( + self, + code_dim: int, + output_channels: int, + decode_channels: int, + kernel_size: int = 3, + bias: bool = True, + ): + super(SemanticDecoder, self).__init__() + + # Initial convolution to map code_dim to decode_channels + self.initial_conv = nn.Conv1d( + in_channels=code_dim, + out_channels=decode_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + + # Residual Blocks + self.residual_blocks = nn.Sequential( + nn.ReLU(inplace=True), + nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias), + nn.ReLU(inplace=True), + nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias) + ) + + # Final convolution to map decode_channels to output_channels + self.final_conv = nn.Conv1d( + in_channels=decode_channels, + out_channels=output_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + + def forward(self, z): + # z: (Batch, Code_dim, Length) + x = self.initial_conv(z) # (Batch, Decode_channels, Length) + x = self.residual_blocks(x) + x # Residual connection + x = self.final_conv(x) # (Batch, Output_channels, Length) + return x \ No newline at end of file diff --git a/vq/residual_vq.py b/vq/residual_vq.py new file mode 100755 index 0000000000000000000000000000000000000000..40d3338fd940aa2e41177827d6e24c5269765b86 --- /dev/null +++ b/vq/residual_vq.py @@ -0,0 +1,53 @@ +import math +import torch +from torch import nn +from .factorized_vector_quantize import FactorizedVectorQuantize + +class ResidualVQ(nn.Module): + def __init__( + self, + *, + num_quantizers, + codebook_size, + **kwargs + ): + super().__init__() + VQ = FactorizedVectorQuantize + if type(codebook_size) == int: + codebook_size = [codebook_size] * num_quantizers + self.layers = nn.ModuleList([VQ(codebook_size=size, **kwargs) for size in codebook_size]) + self.num_quantizers = num_quantizers + + def forward(self, x): + quantized_out = 0. + residual = x + + all_losses = [] + all_indices = [] + + for idx, layer in enumerate(self.layers): + quantized, indices, loss = layer(residual) + + residual = residual - quantized + + quantized_out = quantized_out + quantized + + loss = loss.mean() + + all_indices.append(indices) + all_losses.append(loss) + all_losses, all_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, all_indices, all_losses + + def vq2emb(self, vq, proj=True): + # [B, T, num_quantizers] + quantized_out = 0. + for idx, layer in enumerate(self.layers): + quantized = layer.vq2emb(vq[:, :, idx], proj=proj) + quantized_out = quantized_out + quantized + return quantized_out + def get_emb(self): + embs = [] + for idx, layer in enumerate(self.layers): + embs.append(layer.get_emb()) + return embs diff --git a/vq/unet.py b/vq/unet.py new file mode 100755 index 0000000000000000000000000000000000000000..ca31029d0866b61663f75045c7770cc7208d9482 --- /dev/null +++ b/vq/unet.py @@ -0,0 +1,210 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +import numpy as np + + +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): + super(EncoderBlock, self).__init__() + + self.pool_size = 2 + + self.conv_block = ConvBlock(in_channels, out_channels, kernel_size) + + def forward(self, x): + latent = self.conv_block(x) + output = F.avg_pool2d(latent, kernel_size=self.pool_size) + return output, latent + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): + super(DecoderBlock, self).__init__() + + stride = 2 + + self.upsample = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=stride, + stride=stride, + padding=(0, 0), + bias=False, + ) + + self.conv_block = ConvBlock(in_channels * 2, out_channels, kernel_size) + + def forward(self, x, latent): + x = self.upsample(x) + x = torch.cat((x, latent), dim=1) + output = self.conv_block(x) + return output + + +class UNet(nn.Module): + def __init__(self,freq_dim=1281,out_channel=1024): + super(UNet, self).__init__() + + self.downsample_ratio = 16 + + + in_channels = 1 #self.audio_channels * self.cmplx_num + + self.encoder_block1 = EncoderBlock(in_channels, 16) + self.encoder_block2 = EncoderBlock(16, 64) + self.encoder_block3 = EncoderBlock(64, 256) + self.encoder_block4 = EncoderBlock(256, 1024) + self.middle = EncoderBlock(1024, 1024) + self.decoder_block1 = DecoderBlock(1024, 256) + self.decoder_block2 = DecoderBlock(256, 64) + self.decoder_block3 = DecoderBlock(64, 16) + self.decoder_block4 = DecoderBlock(16, 16) + + self.fc = nn.Linear(freq_dim*16, out_channel) + + def forward(self, x_ori): + """ + Args: + complex_sp: (batch_size, channels_num, time_steps, freq_bins),复数张量 + + Returns: + output: (batch_size, channels_num, time_steps, freq_bins),复数张量 + """ + + + x= self.process_image(x_ori) + x1, latent1 = self.encoder_block1(x) + x2, latent2 = self.encoder_block2(x1) + x3, latent3 = self.encoder_block3(x2) + x4, latent4 = self.encoder_block4(x3) + _, h = self.middle(x4) + x5 = self.decoder_block1(h, latent4) + x6 = self.decoder_block2(x5, latent3) + x7 = self.decoder_block3(x6, latent2) + x8 = self.decoder_block4(x7, latent1) + x= self.unprocess_image(x8,x_ori.shape[2]) + x = x.permute(0, 2, 1, 3).contiguous() # 将形状变为 [6, 256, 16, 1024] + x = x.view(x.size(0), x.size(1), -1) + x= self.fc(x) + + return x + + def process_image(self, x): + """ + 处理频谱以便可以被 downsample_ratio 整除。 + + Args: + x: (B, C, T, F) + + Returns: + output: (B, C, T_padded, F_reduced) + """ + + B, C, T, Freq = x.shape + + pad_len = ( + int(np.ceil(T / self.downsample_ratio)) * self.downsample_ratio + - T + ) + x = F.pad(x, pad=(0, 0, 0, pad_len)) + + output = x[:, :, :, 0 : Freq - 1] + + return output + + def unprocess_image(self, x,time_steps): + """ + 恢复频谱到原始形状。 + + Args: + x: (B, C, T_padded, F_reduced) + + Returns: + output: (B, C, T_original, F_original) + """ + x = F.pad(x, pad=(0, 1)) + + output = x[:, :,0:time_steps, :] + + return output + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): + super(ConvBlock, self).__init__() + + padding = [kernel_size[0] // 2, kernel_size[1] // 2] + + self.bn1 = nn.BatchNorm2d(in_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=False, + ) + + if in_channels != out_channels: + self.shortcut = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + padding=(0, 0), + ) + self.is_shortcut = True + else: + self.is_shortcut = False + + def forward(self, x): + h = self.conv1(F.leaky_relu_(self.bn1(x))) + h = self.conv2(F.leaky_relu_(self.bn2(h))) + + if self.is_shortcut: + return self.shortcut(x) + h + else: + return x + h + + +def test_unet(): + # 定义输入参数 + batch_size = 6 + channels = 1 # 音频通道数 + time_steps = 256 # 时间步数 + freq_bins = 1024 # 频率 bins 数 + + # 创建一个随机的复数张量作为输入 + real_part = torch.randn(batch_size, channels, time_steps, freq_bins) + imag_part = torch.randn(batch_size, channels, time_steps, freq_bins) + complex_sp = real_part #torch.complex(real_part, imag_part) + + # 实例化 UNet 模型 + model = UNet() + + # 前向传播 + output = model(complex_sp) + + # 输出输入和输出的形状 + print("输入形状:", complex_sp.shape) + print("输出形状:", output.shape) + + # 检查输出是否为复数张量 + assert torch.is_complex(output), "输出不是复数张量" + + # 检查输出形状是否与输入形状一致 + assert output.shape == complex_sp.shape, "输出形状与输入形状不一致" + + print("测试通过,模型正常工作。") + +# 运行测试函数 +if __name__ == "__main__": + test_unet() \ No newline at end of file