import sys sys.path.append('../') from typing import Optional from copy import deepcopy from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor, WhisperFeatureExtractor, WhisperModel # from .modeling_whisper import WhisperModel from my_laion_clap.CLAP.src.laion_clap.clap_module.htsat import create_htsat_model import torch import torchaudio import torchaudio.transforms as T import numpy as np from torch import nn import torchvision.transforms from contextlib import suppress try: from .flamingo import Flamingo from .flamingo_lm import FlamingoLMMixin from .utils import extend_instance except: from flamingo import Flamingo from flamingo_lm import FlamingoLMMixin from utils import extend_instance def int16_to_float32(x): return (x / 32767.0).astype(np.float32) def float32_to_int16(x): x = np.clip(x, a_min=-1., a_max=1.) return (x * 32767.).astype(np.int16) def int16_to_float32_torch(x): return (x / 32767.0).type(torch.float32) def float32_to_int16_torch(x): x = torch.clamp(x, min=-1., max=1.) return (x * 32767.).type(torch.int16) class CLAPAudioCfp: model_type: str = "HTSAT" model_name: str = "large" sample_rate: int = 16000 audio_length: int = 1024 window_size: int = 1024 hop_size: int = 160 fmin: int = 50 fmax: int = 14000 class_num: int = 527 mel_bins: int = 64 clip_samples: int = 160000 class CLAP(nn.Module): def __init__(self, clap_config): super(CLAP, self).__init__() self.clap_config = clap_config self.method = clap_config["method"] device_id = f'cuda:{torch.cuda.current_device()}' if ('finetune' in clap_config) and clap_config['finetune']: self.finetune = True print('Finetuning CLAP encoder as well!') else: self.finetune = False audio_cfg = CLAPAudioCfp() enable_fusion = True fusion_type = "aff_2d" self.nvclap = create_htsat_model(audio_cfg, enable_fusion, fusion_type) clap_state_dict = torch.load(clap_config["checkpoint"], map_location = 'cpu') clap_state_dict_copy = clap_state_dict['state_dict'].copy() for key in list(clap_state_dict['state_dict'].keys()): if 'audio' in key: clap_state_dict_copy[key.replace('module.audio_branch.','')] = clap_state_dict_copy[key] del clap_state_dict_copy[key] else: del clap_state_dict_copy[key] self.nvclap.load_state_dict(clap_state_dict_copy, strict = False) self.nvclap = self.nvclap.to(device_id) for param in self.nvclap.parameters(): param.requires_grad = self.finetune if self.finetune: self.nvclap.train() else: self.nvclap.eval() print('loaded NVCLAP model: {}'.format(clap_config["checkpoint"])) def get_mel(self, audio_data): # mel shape: (n_mels, T) mel_tf = torchaudio.transforms.MelSpectrogram( sample_rate=16000, n_fft=1024, win_length=1024, hop_length=160, center=True, pad_mode="reflect", power=2.0, norm=None, onesided=True, n_mels=64, f_min=50, f_max=14000 ).to(audio_data.device) mel = mel_tf(audio_data) # we use log mel spectrogram as input mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) return mel.T # (T, n_mels) def get_audio_features(self, sample, audio_data, max_len, data_truncating, data_filling, require_grad=False): grad_fn = suppress if require_grad else torch.no_grad with grad_fn(): if len(audio_data) > max_len: if data_truncating == "rand_trunc": longer = torch.tensor([True]) elif data_truncating == "fusion": # fusion mel = self.get_mel(audio_data) # split to three parts chunk_frames = max_len // 160 + 1 # the +1 related to how the spectrogram is computed total_frames = mel.shape[0] if chunk_frames == total_frames: # there is a corner case where the audio length is # larger than max_len but smaller than max_len+hop_size. # In this case, we just use the whole audio. mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) sample["mel_fusion"] = mel_fusion longer = torch.tensor([False]) else: ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) if len(ranges[1]) == 0: # if the audio is too short, we just use the first chunk ranges[1] = [0] if len(ranges[2]) == 0: # if the audio is too short, we just use the first chunk ranges[2] = [0] # randomly choose index for each part idx_front = np.random.choice(ranges[0]) idx_middle = np.random.choice(ranges[1]) idx_back = np.random.choice(ranges[2]) # select mel mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :] mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :] mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :] # shrink the mel mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])(mel[None])[0] # logging.info(f"mel_shrink.shape: {mel_shrink.shape}") # stack mel_fusion = torch.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0) sample["mel_fusion"] = mel_fusion longer = torch.tensor([True]) else: raise NotImplementedError( f"data_truncating {data_truncating} not implemented" ) # random crop to max_len (for compatibility) overflow = len(audio_data) - max_len idx = np.random.randint(0, overflow + 1) audio_data = audio_data[idx: idx + max_len] else: # padding if too short if len(audio_data) < max_len: # do nothing if equal if data_filling == "repeatpad": n_repeat = int(max_len / len(audio_data)) audio_data = audio_data.repeat(n_repeat) # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0) # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0] audio_data = F.pad( audio_data, (0, max_len - len(audio_data)), mode="constant", value=0, ) elif data_filling == "pad": audio_data = F.pad( audio_data, (0, max_len - len(audio_data)), mode="constant", value=0, ) elif data_filling == "repeat": n_repeat = int(max_len / len(audio_data)) audio_data = audio_data.repeat(n_repeat + 1)[:max_len] else: raise NotImplementedError( f"data_filling {data_filling} not implemented" ) if data_truncating == 'fusion': mel = self.get_mel(audio_data) mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) sample["mel_fusion"] = mel_fusion longer = torch.tensor([False]) sample["longer"] = longer sample["waveform"] = audio_data return sample def load_audio(self, clips): # waveform, sr = torchaudio.load(filename) # waveform = torchaudio.functional.resample(waveform, orig_freq=self.clap_config['sampling_rate'], new_freq=16000) processed_clips = [] for clip in clips: audio_data = int16_to_float32_torch(float32_to_int16_torch(clip)) sample = self.get_audio_features({}, audio_data, 160000, "fusion", "repeatpad") processed_clips.append(sample) waveforms = {} waveforms["mel_fusion"] = torch.stack([item["mel_fusion"] for item in processed_clips], dim=0) waveforms["longer"] = torch.stack([item["longer"] for item in processed_clips], dim=0) waveforms["waveform"] = torch.stack([item["waveform"] for item in processed_clips], dim=0) return waveforms def forward(self, audio_clips): # It will handle various segments, 1 audio will have various segments [B X n_segments X time] # expand batch dimension during inference if len(audio_clips.shape) == 2: audio_clips = audio_clips.unsqueeze(0) assert len(audio_clips.shape) == 3 audio_embeds = [] for audio_clip in audio_clips: audio = self.load_audio(audio_clip) audio_embed = self.nvclap(audio) #.reshape(-1, self.clap_config["audio_embed_dim"]) audio_embeds.append(audio_embed) audio_embeds = torch.stack(audio_embeds, dim=0) # audio_embeds.requires_grad = self.finetune return audio_embeds class Whisper(nn.Module): def __init__(self, whisper_config): super(Whisper, self).__init__() self.whisper_config = whisper_config self.method = self.whisper_config["method"] device_id = f'cuda:{torch.cuda.current_device()}' if ('finetune' in self.whisper_config) and self.whisper_config['finetune']: self.finetune = True print('Finetuning Whisper encoder as well!') else: self.finetune = False self.whisper = WhisperModel.from_pretrained(self.whisper_config['path']).encoder self.whisper = self.whisper.to(device_id) self.wav_processor = WhisperFeatureExtractor.from_pretrained(self.whisper_config['path']) for param in self.whisper.parameters(): param.requires_grad = self.finetune if self.finetune: self.whisper.train() else: self.whisper.eval() print('loaded Whisper model: {}'.format(self.whisper_config['path'])) def load_audio(self, clips): device_id = f'cuda:{torch.cuda.current_device()}' sample = self.wav_processor(clips.cpu().numpy(), sampling_rate=self.whisper_config['sampling_rate'], return_tensors="pt")["input_features"].to(device_id) return sample def forward(self, audio_clips): # It will handle various segments, 1 audio will have various segments [batch X n_segments X time] if len(audio_clips.shape) == 2: audio_clips = audio_clips.unsqueeze(0) assert len(audio_clips.shape) == 3 audio_embeds = [] for audio_clip in audio_clips: audio = self.load_audio(audio_clip) audio_embed = self.whisper(audio).last_hidden_state #.reshape(-1, self.whisper_config["audio_embed_dim"]) audio_embeds.append(audio_embed) audio_embeds = torch.stack(audio_embeds, dim=0) # audio_embeds.requires_grad = self.finetune return audio_embeds class MERT(nn.Module): def __init__(self, mert_config): super(MERT, self).__init__() self.mert_config = mert_config self.method = mert_config["method"] device_id = f'cuda:{torch.cuda.current_device()}' if ('finetune' in mert_config) and mert_config['finetune']: self.finetune = True print('Finetuning MERT encoder as well!') else: self.finetune = False self.mert = AutoModel.from_pretrained(mert_config['path'], trust_remote_code=True) self.mert = self.mert.to(device_id) self.resampler = T.Resample(16000, mert_config['sampling_rate']).to(device_id) self.wav_processor = Wav2Vec2FeatureExtractor.from_pretrained(mert_config['path'],trust_remote_code=True) for param in self.mert.parameters(): param.requires_grad = self.finetune if self.finetune: self.mert.train() else: self.mert.eval() print('loaded MERT model: {}'.format(mert_config['path'])) def load_audio(self, clips): device_id = f'cuda:{torch.cuda.current_device()}' clips = self.resampler(clips.float()).float() sample = self.wav_processor(clips, sampling_rate=self.mert_config['sampling_rate'], return_tensors="pt")["input_values"] if len(sample.shape) == 1: sample = sample.unsqueeze(0) return sample.to(device_id) def forward(self, audio_clips): # It will handle various segments, 1 audio will have various segments [batch X n_segments X time] if len(audio_clips.shape) == 2: audio_clips = audio_clips.unsqueeze(0) assert len(audio_clips.shape) == 3 audio_embeds = [] for audio_clip in audio_clips: audio = self.load_audio(audio_clip).to(torch.bfloat16) # all processing happens in float if len(audio.shape) > 2: audio = audio.squeeze(0) audio_embed = self.mert(audio, output_hidden_states=True).last_hidden_state #.reshape(-1, self.mert_config["audio_embed_dim"]) audio_embeds.append(audio_embed) audio_embeds = torch.stack(audio_embeds, dim=0) audio_embeds.requires_grad = self.finetune return audio_embeds def create_model_and_transforms( clap_config: dict, lang_encoder_path: str, tokenizer_path: str, audio_transformer_kwargs: dict, cross_attn_every_n_layers: int = 1, use_local_files: bool = False, decoder_layers_attr_name: str = None, freeze_lm_embeddings: bool = False, unfreeze_full_lm: bool = False, cache_dir: Optional[str] = None, **flamingo_kwargs, ): clap = CLAP(clap_config) text_tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, local_files_only=use_local_files, trust_remote_code=True, cache_dir=cache_dir, ) text_tokenizer.add_special_tokens( {"additional_special_tokens": ["