diff --git a/model/slam_model_s2s.py b/model/slam_model_s2s.py new file mode 100644 index 0000000000000000000000000000000000000000..65ab83aa911f642aacb837d713fbe9f43801fcf2 --- /dev/null +++ b/model/slam_model_s2s.py @@ -0,0 +1,444 @@ +import torch +import os +import logging +import torch.nn.functional as F +from slam_llm.models.slam_model import ( + slam_model, + setup_tokenizer, + setup_encoder, + setup_encoder_projector, + setup_llm, +) +from slam_llm.utils.train_utils import print_model_size +from typing import List, Optional +from slam_llm.utils.metric import compute_accuracy +from transformers import T5ForConditionalGeneration +from tqdm import tqdm +from utils.tts_adapter_utils import setup_tts_adapter +from utils.codec_utils import setup_codec +from utils.trick_utils import partial_freeze_weights, train_embedding_layer_only +from utils.snac_utils import layershift + +logger = logging.getLogger(__name__) + + +def model_factory(train_config, model_config, ckpt_path, **kwargs): + # return necessary components for training + tokenizer = setup_tokenizer(train_config, model_config, **kwargs) + + if train_config.task_type == "s2s" or train_config.task_type == "asr": + encoder = setup_encoder(train_config, model_config, **kwargs) + elif train_config.task_type == "tts": + encoder = None + else: + raise NotImplementedError + + # llm + llm = setup_llm(train_config, model_config, **kwargs) + + # projector + if encoder is not None: + encoder_projector = setup_encoder_projector( + train_config, model_config, **kwargs + ) + else: + encoder_projector = None + + codec_decoder = None + if model_config.codec_decode: + codec_decoder = setup_codec(train_config, model_config, **kwargs) + + tts_adapter = None + if model_config.tts_adapter: + adapter_config = model_config.tts_adapter_config + tts_adapter = setup_tts_adapter(adapter_config, model_config, **kwargs) + + model = slam_model_s2s( + encoder, + llm, + encoder_projector, + tokenizer, + tts_adapter, + codec_decoder, + train_config, + model_config, + **kwargs, + ) + + if ckpt_path is not None: + logger.info("loading other parts from: {}".format(ckpt_path)) + ckpt_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt_dict, strict=False) + + if train_config.train_audio_embed_only: + partial_freeze_weights(model, model_config.vocab_config.padded_text_vocabsize, model_config.vocab_config.total_vocabsize) + + if train_config.train_embed_only: + train_embedding_layer_only(model) + + print_model_size( + model, + train_config, + ( + int(os.environ["RANK"]) + if train_config.enable_fsdp or train_config.enable_ddp + else 0 + ), + ) + return model, tokenizer + + +class slam_model_s2s(slam_model): + def __init__( + self, + encoder, + llm, + encoder_projector, + tokenizer, + tts_adapter, + codec_decoder, + train_config, + model_config, + **kwargs, + ): + super().__init__( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + # resize llm embedding layer + self.original_vocabsize = self.llm.lm_head.weight.size(0) + if self.model_config.vocab_config.total_vocabsize != self.original_vocabsize: + self.llm.resize_token_embeddings(self.model_config.vocab_config.total_vocabsize) + + if int(os.environ.get("RANK", "0")) == 0: + logger.info("Resize llm embedding layer's vocab size to {}".format(self.model_config.vocab_config.total_vocabsize)) + + self.codec_decoder = codec_decoder + self.tts_adapter = tts_adapter + self.code_layer = self.model_config.vocab_config.code_layer + + + def forward(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ): + audio_mel = kwargs.get("audio_mel", None) + audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper + + audio = kwargs.get("audio", None) + audio_mask = kwargs.get("audio_mask", None) + + modality_mask = kwargs.get("modality_mask", None) + + encoder_outs = None + if audio_mel is not None or audio is not None: + if self.train_config.freeze_encoder: # freeze encoder + self.encoder.eval() + + if self.model_config.encoder_name == "whisper": + encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim + if self.model_config.encoder_name == "wavlm": + encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask + if self.model_config.encoder_name == "hubert": + results = self.encoder(source = audio, padding_mask = 1-audio_mask) + if self.model_config.encoder_type == "pretrain": + encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"] + if self.model_config.encoder_type == "finetune": + encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] + encoder_outs = encoder_outs.transpose(0, 1) + if self.encoder is None: + encoder_outs = audio_mel if audio_mel is not None else audio + + if self.model_config.encoder_projector == "q-former": + encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) + if self.model_config.encoder_projector == "cov1d-linear": + encoder_outs = self.encoder_projector(encoder_outs) + + if input_ids is not None: + input_ids[input_ids == -1] = 0 # [btz, 8, seq_length] + + if isinstance(self.llm, T5ForConditionalGeneration): + inputs_embeds = self.llm.shared(input_ids) + else: + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(input_ids) # [btz, 8, seq_length, emb_dim] + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(input_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) + + if modality_mask is not None and encoder_outs is not None: + modality_mask = modality_mask.unsqueeze(1).repeat(1, self.code_layer, 1) # [btz, 8, seq_length] + modality_mask_start_indices = (modality_mask == True).float().argmax(dim=2) + modality_lengths = torch.clamp(modality_mask.sum(dim=2), max=encoder_outs.shape[1]).tolist() + + encoder_outs_pad = torch.zeros_like(inputs_embeds) + for i in range(encoder_outs.shape[0]): + for j in range(self.code_layer): + start_idx = modality_mask_start_indices[i, j].item() + length = modality_lengths[i][j] + encoder_outs_pad[i, j, start_idx:start_idx+length] = encoder_outs[i, :length] + + inputs_embeds[:, :self.code_layer, :, :] = encoder_outs_pad[:, :self.code_layer, :, :] + inputs_embeds[:, :self.code_layer, :, :] * (~modality_mask[:, :, :, None]) + + inputs_embeds = torch.mean(inputs_embeds, dim=1) # [btz, seq_length, emb_dim], average over the 8 layers + + if kwargs.get("inference_mode", False): + return inputs_embeds, attention_mask + + text_labels = labels[:,self.code_layer] if labels is not None else None + audio_labels = labels[:, :self.code_layer] if labels is not None else None + model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=text_labels) # here we use the text token layer as the target label + + # parrallel generation + # TODO: add tts adapter forward + x_ori = model_outputs.logits + text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize + audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize + xt = x_ori[..., :text_vocab_size] + xa = [] + for i in range(self.code_layer): + xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)]) + + loss_recorder = [] + total_loss, loss_recorder = self.compute_parallel_loss(xt, text_labels, xa, audio_labels) + model_outputs.loss = total_loss + + text_acc = -1 + audio_acc = [-1 for _ in range(self.code_layer)] + if self.metric: + with torch.no_grad(): + preds = torch.argmax(xt, -1) + text_acc = compute_accuracy(preds.detach()[:, :-1], text_labels.detach()[:, 1:], ignore_label=-100) + + preds_audio = [torch.argmax(xa[i], -1) for i in range(self.code_layer)] + audio_acc = [compute_accuracy(preds_audio[i].detach()[:, :-1], audio_labels[:, i, 1:], ignore_label=-100) for i in range(self.code_layer)] + + # metrics = {"text_acc": text_acc, "audio_acc": audio_acc, "layer_loss": loss_recorder} + return model_outputs, text_acc, audio_acc, loss_recorder + + + + def compute_parallel_loss(self, xt, text_labels, xa, audio_labels): + """ + Compute the parallel loss for text and audio layers. + """ + text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize + audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize + layer_loss = [0 for _ in range(self.code_layer+1) ] + + if text_labels is not None: + # text_loss = F.cross_entropy(xt.reshape(-1, text_vocab_size), text_labels.reshape(-1), ignore_index=-100) + text_loss = F.cross_entropy(xt[:, :-1, :].reshape(-1, text_vocab_size), text_labels[:, 1:].reshape(-1), ignore_index=-100) + layer_loss[self.code_layer] = text_loss + else: + text_loss = 0 + + total_audio_loss = 0 + single_audio_loss = 0 + for i in range(self.code_layer): + if audio_labels[:,i] is not None: + # audio_loss += F.cross_entropy(xa[i].reshape(-1, audio_vocab_size), audio_labels[:,i].reshape(-1), ignore_index=-100) + single_audio_loss = F.cross_entropy(xa[i][:, :-1, :].reshape(-1, audio_vocab_size), audio_labels[:, i, 1:].reshape(-1), ignore_index=-100) + layer_loss[i] = single_audio_loss + total_audio_loss += single_audio_loss + + total_loss = (text_loss + total_audio_loss) / (self.code_layer+1) + return total_loss, layer_loss + + + @torch.no_grad() + def generate(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ): + kwargs["inference_mode"] = True + + inputs_embeds, attention_mask = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + generated_ids = [[] for _ in range((self.code_layer+1))] + current_input_text = None + current_audio_tokens = [None for _ in range(self.code_layer)] + # input_pos = torch.arange(input_ids.size(-1), device=input_ids.device).unsqueeze(0) + past_key_values = None + + text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize + audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize + + max_new_tokens = kwargs.get("max_new_tokens", 360) + repetition_penalty = kwargs.get("repetition_penalty", 1.0) + decode_text_only = kwargs.get("decode_text_only", False) + + pad_t = self.model_config.vocab_config.pad_t + pad_a = self.model_config.vocab_config.pad_a + eot = self.model_config.vocab_config.eot + eoa = self.model_config.vocab_config.eoa + + text_end = False # Track whether text generation has ended + audio_end = False # Track whether audio generation has ended + + # NOTE: currently, we only support greedy decoding and sampling for parallel generation, no beam search + for step in tqdm(range(max_new_tokens), desc="Generating"): + if current_input_text is not None: + audio_tokens = torch.cat([layershift(current_audio_tokens[i], i).unsqueeze(1) for i in range(self.code_layer)], dim=1) + combined_input_ids = torch.cat([audio_tokens, current_input_text.unsqueeze(1)], dim=1) + inputs_embeds = self.llm.model.embed_tokens(combined_input_ids) + inputs_embeds = torch.mean(inputs_embeds, dim=1).unsqueeze(1) + + outputs = self.llm( + inputs_embeds=inputs_embeds, # [btz, seq_len / 1, emb_dim] + attention_mask=attention_mask, # single sample, no need for attention mask + past_key_values=past_key_values, + # position_ids=input_pos, + use_cache=True, + ) + + logits = outputs.logits + past_key_values = outputs.past_key_values # Update past_key_values for the next step + + # Split logits into text and audio layers based on vocab size + xt_logits = logits[..., :text_vocab_size] + xa_logits = [logits[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)] for i in range(self.code_layer)] + + # Apply repetition penalty to the logits + if repetition_penalty != 1.0: + xt_logits = self.repetition_penalty(xt_logits, generated_ids[self.code_layer], repetition_penalty) + for i in range(self.code_layer): + xa_logits[i] = self.repetition_penalty(xa_logits[i], generated_ids[i], repetition_penalty) + + if not text_end: + next_token_text = self.sample_next_token(xt_logits[:, -1, :], **kwargs) + else: + next_token_text = torch.tensor([pad_t], device=input_ids.device) + + next_tokens_audio = [] + for i in range(self.code_layer): + if not audio_end and not decode_text_only: + next_token_audio = self.sample_next_token(xa_logits[i][:, -1, :], **kwargs) + else: + next_token_audio = torch.full((input_ids.size(0),), pad_a, device=input_ids.device) + next_tokens_audio.append(next_token_audio) + + if next_tokens_audio[-1] == eoa or decode_text_only: + audio_end = True + if next_token_text == eot: + text_end = True + + # Update input_ids for the next step + current_input_text = next_token_text + for i in range(self.code_layer): + current_audio_tokens[i] = next_tokens_audio[i] + + # if input_pos.size(-1) > 1: + # input_pos = torch.tensor(input_pos.size(-1), device=input_ids.device).unsqueeze(0) + # else: + # input_pos = input_pos.add_(1) + attention_mask = torch.cat([attention_mask, torch.ones((input_ids.size(0), 1), device=input_ids.device)], dim=1) + + if audio_end and text_end: + break + + # Append generated tokens to the list + for i in range(self.code_layer): + generated_ids[i].append(next_tokens_audio[i].clone().tolist()[0]) # Audio layers + generated_ids[self.code_layer].append(next_token_text.clone().tolist()[0]) # Text layer + + # Concatenate the generated tokens to form the complete sequence + text_tokens = generated_ids[-1] + generated_ids[-1] = text_tokens[: text_tokens.index(eot)] if eot in text_tokens else text_tokens + generated_ids = [torch.tensor(layer) for layer in generated_ids] + return generated_ids + + + @torch.no_grad() + def sample_next_token(self, logits, **kwargs): + """ + Generate the next token based on the model output logits. + Supports both greedy decoding, top-k sampling, and top-p (nucleus) sampling. + """ + do_sample = kwargs.get("do_sample", False) + temperature = kwargs.get("temperature", 1.0) + top_k = kwargs.get("top_k", 50) + top_p = kwargs.get("top_p", 1.0) + num_samples = kwargs.get("num_samples", 1) + + # Adjust logits with temperature + logits = logits.squeeze(0) + logits = logits / temperature + + # Top-k filtering + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # Make sure top_k is within the vocab size + values, indices = torch.topk(logits, top_k) + logits[logits < values[..., [-1]]] = -float('Inf') # Filter tokens not in top_k + + # Top-p filtering (nucleus sampling) + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = -float('Inf') + + if do_sample: + # Perform sampling + return torch.multinomial(F.softmax(logits, dim=-1), num_samples=num_samples) + else: + # Greedy decoding (argmax) + return torch.argmax(logits, dim=-1, keepdim=True) + + + def repetition_penalty(self, logits, generated_ids, repetition_penalty): + """ + Apply repetition penalty to the logits. + """ + for token_id in set(generated_ids): + if logits[0, -1, token_id] < 0: + logits[0, -1, token_id] *= repetition_penalty + else: + logits[0, -1, token_id] /= repetition_penalty + + return logits \ No newline at end of file diff --git a/s2s.py b/s2s.py new file mode 100644 index 0000000000000000000000000000000000000000..23018d5b50b0d0b5e42376b7485dd0d0b5c61c84 --- /dev/null +++ b/s2s.py @@ -0,0 +1,178 @@ +import random +import torch +from slam_llm.utils.model_utils import get_custom_model_factory +from utils.snac_utils import reconscruct_snac, reconstruct_tensors, layershift +import whisper +import numpy as np +from s2s_config import InferenceConfig, CKPT_PATH, CKPT_REPO, CKPT_LOCAL_DIR, CKPT_NAME +import os +from omegaconf import OmegaConf +from huggingface_hub import hf_hub_download +from typing import Callable + + +def update_progress(progress_callback: Callable[[str], None] | None, message: str): + if progress_callback: + progress_callback(message) + + +def pull_model_ckpt(): + if not os.path.exists(CKPT_LOCAL_DIR): + os.makedirs(CKPT_LOCAL_DIR) + if os.path.exists(CKPT_PATH): + return + hf_hub_download( + repo_id=CKPT_REPO, + filename=CKPT_NAME, + local_dir=CKPT_LOCAL_DIR, + token=os.getenv("HF_TOKEN"), + ) + + +pull_model_ckpt() + + +def extract_audio_feature(audio_path, mel_size): + print("Extracting audio features from", audio_path) + audio_raw = whisper.load_audio(audio_path) + audio_raw = whisper.pad_or_trim(audio_raw) + audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size).permute(1, 0) + audio_length = (audio_mel.shape[0] + 1) // 2 + audio_length = audio_length // 5 + audio_res = audio_mel + + return audio_res, audio_length + + +def get_input_ids(length, special_token_a, special_token_t, vocab_config): + input_ids = [] + for i in range(vocab_config.code_layer): + input_ids_item = [] + input_ids_item.append(layershift(vocab_config.input_a, i)) + input_ids_item += [layershift(vocab_config.pad_a, i)] * length + input_ids_item += [ + (layershift(vocab_config.eoa, i)), + layershift(special_token_a, i), + ] + input_ids.append(torch.tensor(input_ids_item).unsqueeze(0)) + input_id_T = torch.tensor( + [vocab_config.input_t] + + [vocab_config.pad_t] * length + + [vocab_config.eot, special_token_t] + ) + input_ids.append(input_id_T.unsqueeze(0)) + return input_ids + + +def generate_from_wav( + wav_path, model, codec_decoder, dataset_config, decode_config, device +): + mel_size = dataset_config.mel_size + prompt = dataset_config.prompt + prompt_template = "USER: {}\n ASSISTANT: " + vocab_config = dataset_config.vocab_config + special_token_a = vocab_config.answer_a + special_token_t = vocab_config.answer_t + code_layer = vocab_config.code_layer + task_type = dataset_config.task_type + + audio_mel, audio_length = extract_audio_feature(wav_path, mel_size) + + prompt = prompt_template.format(prompt) + prompt_ids = model.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) + + example_ids = get_input_ids( + audio_length + prompt_length, special_token_a, special_token_t, vocab_config + ) + text_layer = example_ids[code_layer] + text_layer = torch.cat( + ( + text_layer[:, : audio_length + 1], + prompt_ids.unsqueeze(0), + text_layer[:, -2:], + ), + dim=1, + ) # <bos> <audio> <prompt> <eos> <task> + example_ids[code_layer] = text_layer + + input_length = audio_length + example_mask = example_ids[0][0].ge(-1) + example_ids = torch.stack(example_ids).squeeze() + + input_ids = example_ids.unsqueeze(0).to(device) + attention_mask = example_mask.unsqueeze(0).to(device) + audio_mel = audio_mel.unsqueeze(0).to(device) + input_length = torch.tensor([input_length]).to(device) + audio_length = torch.tensor([audio_length]).to(device) + task_type = [task_type] + + modality_mask = torch.zeros_like(attention_mask) + padding_left = 1 # +1 for <bos> + modality_mask[0, padding_left : padding_left + audio_length] = True + + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "audio_mel": audio_mel, + "input_length": input_length, + "audio_length": audio_length, + "modality_mask": modality_mask, + "task_types": task_type, + } + + model_outputs = model.generate(**batch, **decode_config) + text_outputs = model_outputs[7] + audio_outputs = model_outputs[:7] + output_text = model.tokenizer.decode( + text_outputs, add_special_tokens=False, skip_special_tokens=True + ) + + if decode_config.decode_text_only: + return None, output_text + + audio_tokens = [audio_outputs[layer] for layer in range(7)] + audiolist = reconscruct_snac(audio_tokens) + audio = reconstruct_tensors(audiolist) + with torch.inference_mode(): + audio_hat = codec_decoder.decode(audio) + + return audio_hat, output_text + + +def generate( + wav_path: str, progress_callback: Callable[[str], None] | None = None +) -> tuple[np.ndarray, int | float]: + config = OmegaConf.structured(InferenceConfig()) + train_config, model_config, dataset_config, decode_config = ( + config.train_config, + config.model_config, + config.dataset_config, + config.decode_config, + ) + + torch.cuda.manual_seed(train_config.seed) + torch.manual_seed(train_config.seed) + random.seed(train_config.seed) + + update_progress(progress_callback, "Loading model") + + model_factory = get_custom_model_factory(model_config) + model, _ = model_factory(train_config, model_config, CKPT_PATH) + codec_decoder = model.codec_decoder + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + + update_progress(progress_callback, "Generating") + output_wav, output_text = generate_from_wav( + wav_path, model, codec_decoder, dataset_config, decode_config, device + ) + + return output_wav.squeeze().cpu().numpy(), 24000 + + +if __name__ == "__main__": + wav_path = "sample.wav" + generate(wav_path) diff --git a/s2s_config.py b/s2s_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ea029518de41c02cdde5bf91f96099c529e224fd --- /dev/null +++ b/s2s_config.py @@ -0,0 +1,272 @@ +from dataclasses import dataclass, field +from typing import Optional, List +import os + +CKPT_NAME = "model.pt" +CKPT_LOCAL_DIR = "model_ckpts" +CKPT_PATH = os.path.join(CKPT_LOCAL_DIR, CKPT_NAME) +CKPT_REPO = "xcczach/mini-omni" + + +@dataclass +class VocabConfig: + text_vocabsize: int = 151936 + text_specialtokens: int = 64 + audio_vocabsize: int = 4096 + audio_specialtokens: int = 64 + total_vocabsize: int = 181120 + code_layer: int = 7 + + padded_text_vocabsize: int = field(init=False) + padded_audio_vocabsize: int = field(init=False) + total_audio_vocabsize: int = field(init=False) + + eot: int = field(init=False) # end of text token + pad_t: int = field(init=False) # padding text token + input_t: int = field(init=False) # input text token + answer_t: int = field(init=False) # answer text token + asr: int = field(init=False) # ASR token + + eoa: int = field(init=False) # end of audio token + pad_a: int = field(init=False) # padding audio token + input_a: int = field(init=False) # input audio token + answer_a: int = field(init=False) # answer audio token + split: int = field(init=False) # split token + + def __post_init__(self): + self.padded_text_vocabsize = self.text_vocabsize + self.text_specialtokens + self.padded_audio_vocabsize = self.audio_vocabsize + self.audio_specialtokens + self.total_audio_vocabsize = self.padded_audio_vocabsize * self.code_layer + + self.eot = self.text_vocabsize + self.pad_t = self.text_vocabsize + 1 + self.input_t = self.text_vocabsize + 2 + self.answer_t = self.text_vocabsize + 3 + self.asr = self.text_vocabsize + 4 + + self.eoa = self.audio_vocabsize + self.pad_a = self.audio_vocabsize + 1 + self.input_a = self.audio_vocabsize + 2 + self.answer_a = self.audio_vocabsize + 3 + self.split = self.audio_vocabsize + 4 + + +@dataclass +class TTSAdapterConfig: + add_qkv_bias: Optional[bool] = True + bias: bool = False + gelu_approximate: Optional[str] = None + head_size: Optional[int] = 64 + intermediate_size: Optional[int] = 4864 + lm_head_bias: bool = False + mlp_class_name: str = "GptNeoxMLP" + n_layer: int = 6 + n_head: int = 14 + n_embd: int = 896 + n_query_groups: Optional[int] = 2 + norm_class_name: str = "RMSNorm" + norm_eps: float = 1e-6 + parallel_residual: bool = False + rotary_percentage: float = 1 + shared_attention_norm: bool = False + + def __post_init__(self): + self.rope_n_elem = int(self.rotary_percentage * self.head_size) + + +@dataclass +class ModelConfig: + file: str = "model/slam_model_s2s.py:model_factory" + llm_name: str = "qwen2-0.5b" + llm_path: str = "Qwen/Qwen2-0.5B" + llm_type: str = "decoder_only" + llm_dim: int = 896 + encoder_name: Optional[str] = "whisper" + encoder_ds_rate: int = 2 + encoder_path: Optional[str] = "small" + encoder_dim: int = 768 + encoder_projector: str = "linear" + encoder_projector_ds_rate: int = 5 + modal: str = "audio" + normalize: Optional[bool] = field( + default=False, + metadata={"help": "whether input is normalized, used for models such as wavlm"}, + ) + encoder_type: str = field( + default="finetune", + metadata={ + "help": "whether model is only pretrained or finetuned, used for models such as hubert" + }, + ) + vocab_config: VocabConfig = field(default_factory=VocabConfig) + codec_decode: bool = True + codec_decoder_type: str = "SNAC" + codec_decoder_path: Optional[str] = "hubertsiuzdak/snac_24khz" + tts_adapter: bool = False + tts_adapter_config: TTSAdapterConfig = field(default_factory=TTSAdapterConfig) + + +@dataclass +class PeftConfig: + peft_method: str = "lora" # None , llama_adapter, prefix + r: int = 8 + lora_alpha: int = 32 + target_modules: List = field(default_factory=lambda: ["q_proj", "v_proj"]) + bias: str = "none" + task_type: str = "CAUSAL_LM" + lora_dropout: float = 0.05 + inference_mode: bool = False + + +@dataclass +class TrainConfig: + model_name: str = "s2s" + enable_ddp: bool = False + enable_deepspeed: bool = False + enable_fsdp: bool = False + low_cpu_fsdp: bool = False + run_validation: bool = True + batch_size_training: int = 4 + batching_strategy: str = field( + default="custom", metadata={"help": "alternative: padding"} + ) # + context_length: int = 4096 + gradient_accumulation_steps: int = 1 + num_epochs: int = 1 + num_workers_dataloader: int = 2 + warmup_steps: int = 1000 + total_steps: int = 100000 + validation_interval: int = 1000 + lr: float = 1e-4 + weight_decay: float = 0.0 + gamma: float = 0.85 + seed: int = 42 + use_fp16: bool = False + mixed_precision: bool = True + val_batch_size: int = 1 + + use_peft: bool = False + peft_config: PeftConfig = field(default_factory=PeftConfig) + output_dir: str = "PATH/to/save/PEFT/model" + freeze_layers: bool = False + num_freeze_layers: int = 1 + quantization: bool = False + one_gpu: bool = False + save_model: bool = True + dist_checkpoint_root_folder: str = ( + "PATH/to/save/FSDP/model" # will be used if using FSDP + ) + dist_checkpoint_folder: str = "fine-tuned" # will be used if using FSDP + save_optimizer: bool = False # will be used if using FSDP + use_fast_kernels: bool = ( + False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + ) + run_test_during_validation: bool = False + run_test_during_validation_file: str = "test.wav" + run_test_during_validation_prompt: str = "<|S2S|>" + freeze_llm: bool = field( + default=True, + metadata={ + "help": "whether to freeze llm when finetuning, should be true when use peft finetuning" + }, + ) + freeze_encoder: bool = True + train_embed_only: bool = False + train_audio_embed_only: bool = False + task_type: str = "s2s" + + +@dataclass +class DataConfig: + dataset: str = "speech_dataset_s2s" + file: str = "examples/s2s/speech_dataset_s2s.py:get_speech_dataset" + train_data_path: Optional[str] = None + val_data_path: Optional[str] = None + train_split: str = "train" + test_split: str = "validation" + prompt: Optional[str] = None + data_path: Optional[str] = None + max_words: Optional[int] = None + max_mel: Optional[float] = None + fix_length_audio: int = -1 + inference_mode: bool = True + input_type: str = field( + default="mel", + metadata={"help": "Use raw when input is wav, mel when for whisper"}, + ) + mel_size: int = field( + default=80, metadata={"help": "80 for whisper large v1 and v2, 128 for v3"} + ) + normalize: Optional[bool] = field( + default=False, + metadata={"help": "whether input is normalized, used for models such as wavlm"}, + ) + seed: int = 42 + manifest_format: str = field( + default="datasets", metadata={"help": "alternative: jsonl"} + ) + split_size: float = 0.1 + + vocab_config: VocabConfig = field(default_factory=VocabConfig) + load_from_cache_file: bool = False + task_type: str = "s2s" + + +@dataclass +class DecodeConfig: + do_sample: bool = False + max_new_tokens: int = 300 + min_length: int = 10 + temperature: float = 1.0 + top_k: int = 50 + top_p: float = 0.9 + num_beams: int = 1 + num_return_sequences: int = 1 + num_samples: int = 1 + max_time: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + early_stopping: bool = False + no_repeat_ngram_size: int = 0 + bad_words_ids: List = field(default_factory=list) + num_beam_groups: int = 1 + diversity_penalty: float = 0.0 + task_type: str = "s2s" + decode_text_only: bool = False + + +@dataclass +class FSDPConfig: + mixed_precision: bool = True + use_fp16: bool = False + # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD + sharding_strategy: str = ( + "NO_SHARD" # ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + ) + checkpoint_type: str = ( + "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. + ) + fsdp_activation_checkpointing: bool = True + fsdp_cpu_offload: bool = False + pure_bf16: bool = False + optimizer: str = "AdamW" + + +@dataclass +class LogConfig: + use_wandb: bool = False + wandb_dir: str = "/valleblob/v-wenxichen/exp/wandb_log" + wandb_entity_name: str = "project_name" + wandb_project_name: str = "project_name" + wandb_exp_name: str = "exp_name" + log_file: str = "/valleblob/v-wenxichen/exp/log/test.log" + log_interval: int = 10 + online_output_dir: Optional[str] = None + + +@dataclass +class InferenceConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + decode_config: DecodeConfig = field(default_factory=DecodeConfig) diff --git a/slam_llm/__init__.py b/slam_llm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/slam_llm/data/__init__.py b/slam_llm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..47bc944e530156d72412524d05a2804983a80c8d --- /dev/null +++ b/slam_llm/data/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. \ No newline at end of file diff --git a/slam_llm/data/concatenator.py b/slam_llm/data/concatenator.py new file mode 100644 index 0000000000000000000000000000000000000000..29f49052ff84a960c367ebd16ea9b53db7fbbf2d --- /dev/null +++ b/slam_llm/data/concatenator.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from tqdm import tqdm +from itertools import chain + +from torch.utils.data import Dataset + + +class ConcatDataset(Dataset): + def __init__(self, dataset, chunk_size=4096): + self.dataset = dataset + self.chunk_size = chunk_size + + self.samples = [] + + buffer = { + "input_ids": [], + "attention_mask": [], + "labels": [], + } + + for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): + buffer = {k: v + sample[k] for k,v in buffer.items()} + + while len(next(iter(buffer.values()))) > self.chunk_size: + self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) + buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} + + def __getitem__(self, idx): + return self.samples[idx] + + def __len__(self): + return len(self.samples) diff --git a/slam_llm/data/sampler.py b/slam_llm/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9c91af5b946b2216e2a3003fe37a1f32d072fa52 --- /dev/null +++ b/slam_llm/data/sampler.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import random +from itertools import islice + +import numpy as np +import torch + + +class LengthBasedBatchSampler(torch.utils.data.BatchSampler): + def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None: + if isinstance(next(iter(data_source)), dict): + first_key = next(iter(next(iter(data_source)).keys())) + self.lengths = [len(d[first_key]) for d in data_source] + else: + self.lengths = [len(d) for d in data_source] + self.batch_size = batch_size + self.drop_last = drop_last + self.shuffle = shuffle + + def __iter__(self): + ids = np.argsort(self.lengths) + if self.drop_last: + ids = ids[:len(ids) // self.batch_size * self.batch_size] + + batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)] + + if self.shuffle: + random.shuffle(batches) + + for b in batches: + yield b + + def __len__(self): + if self.drop_last: + return len(self.lengths) // self.batch_size + else: + return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0) + + +class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler): + def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None: + random.seed(seed) + self.batch_sampler = LengthBasedBatchSampler( + data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle + ) + self.num_replicas = num_replicas + self.rank = rank + + def __iter__(self): + max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas + return islice(self.batch_sampler, self.rank, max_length, self.num_replicas) + + def __len__(self): + return len(self.batch_sampler) // self.num_replicas + \ No newline at end of file diff --git a/slam_llm/models/BEATs/BEATs.py b/slam_llm/models/BEATs/BEATs.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a2bbcb336bc89e2d8a0298d8165e41c87a1790 --- /dev/null +++ b/slam_llm/models/BEATs/BEATs.py @@ -0,0 +1,181 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +from torch.nn import LayerNorm +import torchaudio.compliance.kaldi as ta_kaldi + +from .backbone import ( + TransformerEncoder, +) + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class BEATsConfig: + def __init__(self, cfg=None): + self.input_patch_size: int = -1 # path size of patch embedding + self.embed_dim: int = 512 # patch embedding dimension + self.conv_bias: bool = False # include bias in conv encoder + + self.encoder_layers: int = 12 # num encoder layers in the transformer + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.deep_norm: bool = False # apply deep_norm first in the transformer + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + # label predictor + self.finetuned_model: bool = False # whether the model is a fine-tuned model. + self.predictor_dropout: float = 0.1 # dropout probability for the predictor + self.predictor_class: int = 527 # target class number for the predictor + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class BEATs(nn.Module): + def __init__( + self, + cfg: BEATsConfig, + ) -> None: + super().__init__() + logger.info(f"BEATs Config: {cfg.__dict__}") + + self.cfg = cfg + + self.embed = cfg.embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.input_patch_size = cfg.input_patch_size + self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, + bias=cfg.conv_bias) + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + assert not cfg.deep_norm or not cfg.layer_norm_first + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + if cfg.finetuned_model: + self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) + self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class) + else: + self.predictor = None + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + @classmethod + def preprocess( + cls, + source: torch.Tensor, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ) -> torch.Tensor: + if len(source.shape) > 1: # batch + fbanks = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2 ** 15 + fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) + fbanks.append(fbank) + fbank = torch.stack(fbanks, dim=0) + else: # single + waveform = source.unsqueeze(0) * 2 ** 15 + fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) + + fbank = (fbank - fbank_mean) / (2 * fbank_std) + return fbank + + def extract_features( + self, + fbank: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ): + if padding_mask is not None: + padding_mask = self.forward_padding_mask(fbank, padding_mask) + + fbank = fbank.unsqueeze(1) + features = self.patch_embedding(fbank) + features = features.reshape(features.shape[0], features.shape[1], -1) + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + x = self.dropout_input(features) + + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + ) + + if self.predictor is not None: + x = self.predictor_dropout(x) + logits = self.predictor(x) + + if padding_mask is not None and padding_mask.any(): + logits[padding_mask] = 0 + logits = logits.sum(dim=1) + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) + else: + logits = logits.mean(dim=1) + + lprobs = torch.sigmoid(logits) + + return lprobs, padding_mask + else: + return x, padding_mask diff --git a/slam_llm/models/BEATs/Tokenizers.py b/slam_llm/models/BEATs/Tokenizers.py new file mode 100644 index 0000000000000000000000000000000000000000..33e0f6ae711360bfa49777fd3c427eb73f2b6762 --- /dev/null +++ b/slam_llm/models/BEATs/Tokenizers.py @@ -0,0 +1,173 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +from torch.nn import LayerNorm +import torchaudio.compliance.kaldi as ta_kaldi + +from .backbone import ( + TransformerEncoder, +) +from .quantizer import ( + NormEMAVectorQuantizer, +) + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +class TokenizersConfig: + def __init__(self, cfg=None): + self.input_patch_size: int = -1 # path size of patch embedding + self.embed_dim: int = 512 # patch embedding dimension + self.conv_bias: bool = False # include bias in conv encoder + + self.encoder_layers: int = 12 # num encoder layers in the transformer + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.deep_norm: bool = False # apply deep_norm first in the transformer + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + # quantizer + self.quant_n: int = 1024 # codebook number in quantizer + self.quant_dim: int = 256 # codebook dimension in quantizer + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class Tokenizers(nn.Module): + def __init__( + self, + cfg: TokenizersConfig, + ) -> None: + super().__init__() + logger.info(f"Tokenizers Config: {cfg.__dict__}") + + self.cfg = cfg + + self.embed = cfg.embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.input_patch_size = cfg.input_patch_size + self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, + bias=cfg.conv_bias) + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + assert not cfg.deep_norm or not cfg.layer_norm_first + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.quantize = NormEMAVectorQuantizer( + n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99, + ) + self.quant_n = cfg.quant_n + self.quantize_layer = nn.Sequential( + nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim), + nn.Tanh(), + nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize + ) + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def preprocess( + self, + source: torch.Tensor, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ) -> torch.Tensor: + fbanks = [] + for waveform in source: + waveform = waveform.unsqueeze(0) * 2 ** 15 + fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) + fbanks.append(fbank) + fbank = torch.stack(fbanks, dim=0) + fbank = (fbank - fbank_mean) / (2 * fbank_std) + return fbank + + def extract_labels( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + fbank_mean: float = 15.41663, + fbank_std: float = 6.55582, + ): + fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(fbank, padding_mask) + + fbank = fbank.unsqueeze(1) + features = self.patch_embedding(fbank) + features = features.reshape(features.shape[0], features.shape[1], -1) + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + x = self.dropout_input(features) + + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + ) + + quantize_input = self.quantize_layer(x) + quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input) + + return embed_ind + diff --git a/slam_llm/models/BEATs/backbone.py b/slam_llm/models/BEATs/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..97caa0673f677f10a8ea1bd756bca6e208c439b7 --- /dev/null +++ b/slam_llm/models/BEATs/backbone.py @@ -0,0 +1,783 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import numpy as np +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +import torch.nn.functional as F +from torch.nn import LayerNorm, Parameter +from .modules import ( + GradMultiply, + SamePad, + get_activation_fn, + GLU_Linear, + quant_noise, +) + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + deep_norm=args.deep_norm, + has_relative_attention_bias=self.relative_position_embedding, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + encoder_layers=args.encoder_layers, + ) + for i in range(args.encoder_layers) + ] + ) + if self.relative_position_embedding: + for i in range(1, args.encoder_layers): + del self.layers[i].self_attn.relative_attention_bias + self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + if args.deep_norm: + deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4) + for i in range(args.encoder_layers): + nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1) + nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1) + nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta) + nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta) + + self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1) + + def forward(self, x, padding_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + if self.layer_wise_gradient_decay_ratio != 1.0: + x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio) + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + deep_norm: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + encoder_layers: int = 0, + ) -> None: + + super().__init__() + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + self.final_layer_norm = LayerNorm(self.embedding_dim) + + self.deep_norm = deep_norm + if self.deep_norm: + self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4) + else: + self.deep_norm_alpha = 1 + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual * self.deep_norm_alpha + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual * self.deep_norm_alpha + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + alpha = 32 + q *= 1 / alpha + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size()) + + attn_weights = attn_weights + attn_mask_rel_pos + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) diff --git a/slam_llm/models/BEATs/modules.py b/slam_llm/models/BEATs/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..f32ef58ad00c809f0b134cd556adea7ff5fc0503 --- /dev/null +++ b/slam_llm/models/BEATs/modules.py @@ -0,0 +1,219 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +import torch +from torch import Tensor, nn +import torch.nn.functional as F + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + def __init__(self): + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + diff --git a/slam_llm/models/BEATs/quantizer.py b/slam_llm/models/BEATs/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6686444ebb9f58f6bb7c8c12c4580d26cecdf2ec --- /dev/null +++ b/slam_llm/models/BEATs/quantizer.py @@ -0,0 +1,215 @@ +# -------------------------------------------------------- +# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) +# Github source: https://github.com/microsoft/unilm/tree/master/beats +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on VQGAN code bases +# https://github.com/CompVis/taming-transformers +# --------------------------------------------------------' + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as distributed + +try: + from einops import rearrange, repeat +except ImportError: + pass + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): + dim, dtype, device = samples.shape[-1], samples.dtype, samples.device + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + if use_cosine_sim: + dists = samples @ means.t() + else: + diffs = rearrange(samples, 'n d -> n () d') \ + - rearrange(means, 'c d -> () c d') + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + if use_cosine_sim: + new_means = l2norm(new_means) + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''): + super().__init__() + self.num_tokens = num_tokens + self.codebook_dim = codebook_dim + self.decay = decay + self.eps = eps + if codebook_init_path == '': + if not kmeans_init: + weight = torch.randn(num_tokens, codebook_dim) + weight = l2norm(weight) + else: + weight = torch.zeros(num_tokens, codebook_dim) + self.register_buffer('initted', torch.Tensor([not kmeans_init])) + else: + print(f"load init codebook weight from {codebook_init_path}") + codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu') + weight = codebook_ckpt_weight.clone() + self.register_buffer('initted', torch.Tensor([True])) + + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + # self.register_buffer('initted', torch.Tensor([not kmeans_init])) + self.update = True + + @torch.jit.ignore + def init_embed_(self, data): + if self.initted: + return + print("Performing Kemans init for codebook") + embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) + self.weight.data.copy_(embed) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + # normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1)) + self.weight.data.copy_(embed_normalized) + + +def norm_ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + moving_avg.data.copy_(l2norm(moving_avg.data)) + + +class NormEMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + statistic_code_usage=True, kmeans_init=False, codebook_init_path=''): + super().__init__() + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.decay = decay + + # learnable = True if orthogonal_reg_weight > 0 else False + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path) + + self.statistic_code_usage = statistic_code_usage + if statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(n_embed)) + if distributed.is_available() and distributed.is_initialized(): + print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!") + self.all_reduce_fn = distributed.all_reduce + else: + self.all_reduce_fn = nn.Identity() + + def reset_cluster_size(self, device): + if self.statistic_code_usage: + self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) + self.cluster_size = self.cluster_size.to(device) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + # z, 'b c h w -> b h w c' + # z = rearrange(z, 'b c h w -> b h w c') + # z = z.transpose(1, 2) + z = l2norm(z) + z_flattened = z.reshape(-1, self.codebook_dim) + + self.embedding.init_embed_(z_flattened) + + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + + if not self.training: + with torch.no_grad(): + cluster_size = encodings.sum(0) + self.all_reduce_fn(cluster_size) + ema_inplace(self.cluster_size, cluster_size, self.decay) + + if self.training and self.embedding.update: + # EMA cluster size + + bins = encodings.sum(0) + self.all_reduce_fn(bins) + + # self.embedding.cluster_size_ema_update(bins) + ema_inplace(self.cluster_size, bins, self.decay) + + zero_mask = (bins == 0) + bins = bins.masked_fill(zero_mask, 1.) + + embed_sum = z_flattened.t() @ encodings + self.all_reduce_fn(embed_sum) + + embed_normalized = (embed_sum / bins.unsqueeze(0)).t() + embed_normalized = l2norm(embed_normalized) + + embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight, + embed_normalized) + norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q, 'b h w c -> b c h w' + # z_q = rearrange(z_q, 'b h w c -> b c h w') + # z_q = z_q.transpose(1, 2) + return z_q, loss, encoding_indices diff --git a/slam_llm/models/EAT/EAT.py b/slam_llm/models/EAT/EAT.py new file mode 100644 index 0000000000000000000000000000000000000000..8a42ba5ffa45d5fdd148174c8cb79824dd397434 --- /dev/null +++ b/slam_llm/models/EAT/EAT.py @@ -0,0 +1,32 @@ +import torch +import torchaudio +import random + +def EAT_preprocess(source, norm_mean = -4.268, norm_std = 4.569, target_length = 1024, fixed_length = False, random_crop = False): + source = source - source.mean() + source = source.unsqueeze(dim=0) + + source = torchaudio.compliance.kaldi.fbank(source, htk_compat=True, sample_frequency=16000, use_energy=False, + window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10).unsqueeze(dim=0) + + n_frames = source.shape[1] + if not fixed_length: + target_length = n_frames + if target_length % 16 != 0: + target_length = n_frames + (16 - n_frames % 16) + diff = target_length - n_frames + if diff > 0: + m = torch.nn.ZeroPad2d((0, 0, 0, diff)) + source = m(source) + elif diff < 0: + if random_crop: + start_index = random.randint(0, n_frames - target_length) + source = source[:,start_index: start_index+target_length, :] + else: + source = source[:,0:target_length, :] + + # Normalize the mel spectrogram + source = (source - norm_mean) / (norm_std * 2) + source = source.squeeze() + + return source \ No newline at end of file diff --git a/slam_llm/models/SpatialAST/SpatialAST.py b/slam_llm/models/SpatialAST/SpatialAST.py new file mode 100644 index 0000000000000000000000000000000000000000..d08dceb147bcf8e487dee36b5fd08056175b5259 --- /dev/null +++ b/slam_llm/models/SpatialAST/SpatialAST.py @@ -0,0 +1,122 @@ +import torch +import torch.nn as nn + +from torchlibrosa.stft import STFT, LogmelFilterBank +from timm.models.layers import to_2tuple + +from .vision_transformer import VisionTransformer as _VisionTransformer + +def conv3x3(in_channels, out_channels, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) + +class PatchEmbed_new(nn.Module): + """ Flexible Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + + self.img_size = img_size + self.patch_size = patch_size + self.in_chans = in_chans + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches + _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w + self.patch_hw = (h, w) + self.num_patches = h*w + + def get_output_shape(self, img_size): + return self.proj(torch.randn(1, self.in_chans, img_size[0], img_size[1])).shape + + def forward(self, x): + B, C, H, W = x.shape + + x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12 + x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212 + x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768 + return x + +class BinauralEncoder(_VisionTransformer): + """ Spatial Audio Spectrogram Transformer designed for Sound Event Localization and Detection + -------------------------------------------------------- + References: + Spatial-AST from BAT: https://github.com/zszheng147/Spatial-AST and https://arxiv.org/abs/2402.01591 + -------------------------------------------------------- + """ + def __init__(self, num_cls_tokens=3, **kwargs): + super(BinauralEncoder, self).__init__(**kwargs) + img_size = (1024, 128) # 1024, 128 + in_chans = 1 + emb_dim = 768 + + del self.cls_token + self.num_cls_tokens = num_cls_tokens + self.cls_tokens = nn.Parameter(torch.zeros(1, num_cls_tokens, emb_dim)) + + self.patch_embed = PatchEmbed_new( + img_size=img_size, patch_size=(16, 16), + in_chans=in_chans, embed_dim=emb_dim, stride=16 + ) # no overlap. stride=img_size=16 + + num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False) # fixed sin-cos embedding + + self.spectrogram_extractor = STFT( + n_fft=1024, hop_length=320, win_length=1024, window='hann', + center=True, pad_mode='reflect', freeze_parameters=True + ) + self.logmel_extractor = LogmelFilterBank( + sr=32000, n_fft=1024, n_mels=128, fmin=50, + fmax=14000, ref=1.0, amin=1e-10, top_db=None, freeze_parameters=True + ) + + self.conv_downsample = nn.Sequential( + conv3x3(4, 1), + nn.BatchNorm2d(1), + nn.GELU(), + ) + + self.bn = nn.BatchNorm2d(2, affine=False) + del self.norm # remove the original norm + + self.target_frame = 1024 + + def forward_features_mask(self, x): + B = x.shape[0] #bsz, 512, 768 (unmasked) + + x = x + self.pos_embed[:, 1:, :] + + cls_tokens = self.cls_tokens + cls_tokens = cls_tokens.expand(B, -1, -1) + x = torch.cat([cls_tokens, x], dim=1) # bsz, 512 + 2 + 10, 768 + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + return x + + @torch.no_grad() + def forward(self, waveforms): + B, C, T = waveforms.shape + + waveforms = waveforms.reshape(B * C, T) + real, imag = self.spectrogram_extractor(waveforms) + + log_mel = self.logmel_extractor(torch.sqrt(real**2 + imag**2)).reshape(B, C, -1, 128) + log_mel = self.bn(log_mel) + + IPD = torch.atan2(imag[1::2], real[1::2]) - torch.atan2(imag[::2], real[::2]) + x = torch.cat([log_mel, torch.matmul(torch.cat([torch.cos(IPD), torch.sin(IPD)], dim=1), self.logmel_extractor.melW)], dim=1) + + if x.shape[2] < self.target_frame: + x = nn.functional.interpolate(x, (self.target_frame, x.shape[3]), mode="bicubic", align_corners=True) + + x = self.conv_downsample(x) + x = self.patch_embed(x) + x = self.forward_features_mask(x) + + return x \ No newline at end of file diff --git a/slam_llm/models/SpatialAST/vision_transformer.py b/slam_llm/models/SpatialAST/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ade970ab00bd506ff4c5078d9a4361afeaf7294e --- /dev/null +++ b/slam_llm/models/SpatialAST/vision_transformer.py @@ -0,0 +1,239 @@ +import torch +from torch import nn + +from timm.models.layers import to_2tuple, DropPath, trunc_normal_ + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed_new(nn.Module): + """ Flexible Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride = to_2tuple(stride) + + self.img_size = img_size + self.patch_size = patch_size + self.in_chans = in_chans + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches + _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w + self.patch_hw = (h, w) + self.num_patches = h*w + + def get_output_shape(self, img_size): + return self.proj(torch.randn(1, self.in_chans, img_size[0], img_size[1])).shape + + def forward(self, x): + B, C, H, W = x.shape + + x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12 + x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212 + x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768 + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + + # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here + #self.repr = nn.Linear(embed_dim, representation_size) + #self.repr_act = nn.Tanh() + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x \ No newline at end of file diff --git a/slam_llm/models/avhubert/__init__.py b/slam_llm/models/avhubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..994cc7f52f5b785ea61720b40f698dc48d77af34 --- /dev/null +++ b/slam_llm/models/avhubert/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .hubert import * # noqa +from .hubert_asr import * # noqa +from .hubert_dataset import * +from .hubert_pretraining import * +from .hubert_criterion import * diff --git a/slam_llm/models/avhubert/decoder.py b/slam_llm/models/avhubert/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e40868502d37cc4dddd9d7cfe784873930ef19dc --- /dev/null +++ b/slam_llm/models/avhubert/decoder.py @@ -0,0 +1,243 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace +import contextlib +import copy +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass, field +from omegaconf import MISSING, II, open_dict +from typing import Any, Optional + +from fairseq import checkpoint_utils, tasks, utils +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.tasks import FairseqTask +from fairseq.models import ( + BaseFairseqModel, + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, +) +# from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES +from fairseq.modules import ( + LayerNorm, + PositionalEmbedding, + TransformerDecoderLayer, +) + + +class TransformerDecoder(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + cfg, + dictionary, + embed_tokens, + no_encoder_attn=False, + ): + super().__init__(dictionary) + + self.dropout = cfg.decoder_dropout + self.share_input_output_embed = cfg.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = cfg.decoder_embed_dim + self.output_embed_dim = cfg.decoder_embed_dim + + self.layerdrop = cfg.decoder_layerdrop + + padding_idx = embed_tokens.padding_idx + self.max_target_positions = cfg.max_target_positions + + self.embed_tokens = embed_tokens + # self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + cfg.max_target_positions, + embed_dim, + padding_idx, + learned=cfg.decoder_learned_pos, + ) + if not cfg.no_token_positional_embeddings + else None + ) + + # TODO: update this when transformer gets converted to dataclass configs + transformer_cfg = copy.deepcopy(cfg) + # with open_dict(transformer_cfg): + transformer_cfg.dropout = transformer_cfg.decoder_dropout + transformer_cfg.attention_dropout = ( + transformer_cfg.decoder_attention_dropout + ) + transformer_cfg.activation_dropout = ( + transformer_cfg.decoder_activation_dropout + ) + + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + TransformerDecoderLayer(transformer_cfg, no_encoder_attn) + for _ in range(transformer_cfg.decoder_layers) + ] + ) + + if not self.share_input_output_embed: + self.embed_out = nn.Parameter( + torch.Tensor(len(dictionary), self.output_embed_dim) + ) + nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5) + + if transformer_cfg.decoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + prev_output_tokens = prev_output_tokens.long() + x, extra = self.extract_features( + prev_output_tokens, encoder_out, incremental_state + ) + x = self.output_layer(x) + return x, extra + + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + + # embed positions + positions = ( + self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + + inner_states = [x] + + # decoder layers + for layer in self.layers: + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, attn, _ = layer( + x, + encoder_out["encoder_out"] if encoder_out is not None else None, + encoder_out["padding_mask"] if encoder_out is not None else None, + incremental_state, + self_attn_mask=self.buffered_future_mask(x) + if incremental_state is None + else None, + ) + inner_states.append(x) + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, {"attn": attn, "inner_states": inner_states} + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + # project back to size of vocabulary + emb_mat = self.embed_tokens.weight if self.share_input_output_embed else self.embed_out + return torch.matmul(features, emb_mat.transpose(0, 1)) + # if self.share_input_output_embed: + # return F.linear(features, self.embed_tokens.weight) + # else: + # return F.linear(features, self.embed_out) + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + diff --git a/slam_llm/models/avhubert/hubert.py b/slam_llm/models/avhubert/hubert.py new file mode 100644 index 0000000000000000000000000000000000000000..96273879528ad13600c928dbb2cce5cb9cc2b962 --- /dev/null +++ b/slam_llm/models/avhubert/hubert.py @@ -0,0 +1,792 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os,sys +import logging +from typing import Dict, List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn as nn +from dataclasses import dataclass, field +from fairseq import utils +from fairseq.data.data_utils import compute_mask_indices +from fairseq.data.dictionary import Dictionary +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.models.wav2vec.wav2vec2 import ( + ConvFeatureExtractionModel, + TransformerEncoder, +) +from fairseq.modules import GradMultiply, LayerNorm +from copy import deepcopy + +DBG=True if len(sys.argv) == 1 else False + +if DBG: + from hubert_pretraining import ( + AVHubertPretrainingConfig, + AVHubertPretrainingTask, + ) + from resnet import ResEncoder + logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, + ) + from utils import compute_mask_indices + from decoder import TransformerDecoder + +else: + from .hubert_pretraining import ( + AVHubertPretrainingConfig, + AVHubertPretrainingTask, + ) + from .resnet import ResEncoder + from .utils import compute_mask_indices + from .decoder import TransformerDecoder + +from omegaconf import II + +logger = logging.getLogger(__name__) + +EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) +MASKING_DISTRIBUTION_CHOICES = ChoiceEnum( + ["static", "uniform", "normal", "poisson"] +) +# LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer", "trf_adp"]) + + +@dataclass +class AVHubertConfig(FairseqDataclass): + label_rate: int = II("task.label_rate") + input_modality: str = II("task.input_modality") + extractor_mode: EXTRACTOR_MODE_CHOICES = field( + default="default", + metadata={ + "help": "mode for feature extractor. default has a single group " + "norm with d groups in the first conv block, whereas layer_norm " + "has layer norms in every block (meant to use with normalize=True)" + }, + ) + encoder_layers: int = field( + default=12, metadata={"help": "num encoder layers in the transformer"} + ) + encoder_embed_dim: int = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + encoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "encoder embedding dimension for FFN"} + ) + encoder_attention_heads: int = field( + default=12, metadata={"help": "num encoder attention heads"} + ) + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + + # dropouts + dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for the transformer"}, + ) + attention_dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for attention weights"}, + ) + activation_dropout: float = field( + default=0.0, + metadata={"help": "dropout probability after activation in FFN"}, + ) + encoder_layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a tarnsformer layer"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + dropout_features: float = field( + default=0.0, + metadata={ + "help": "dropout to apply to the features (after feat extr)" + }, + ) + + final_dim: int = field( + default=0, + metadata={ + "help": "project final representations and targets to this many " + "dimensions. set to encoder_embed_dim is <= 0" + }, + ) + untie_final_proj: bool = field( + default=False, + metadata={"help": "use separate projection for each target"}, + ) + layer_norm_first: bool = field( + default=False, + metadata={"help": "apply layernorm first in the transformer"}, + ) + conv_feature_layers: str = field( + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + metadata={ + "help": "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + conv_bias: bool = field( + default=False, metadata={"help": "include bias in conv encoder"} + ) + logit_temp: float = field( + default=0.1, metadata={"help": "temperature to divide logits by"} + ) + target_glu: bool = field( + default=False, metadata={"help": "adds projection + glu to targets"} + ) + feature_grad_mult: float = field( + default=1.0, + metadata={"help": "multiply feature extractor var grads by this"}, + ) + + # masking + mask_length_audio: int = field(default=10, metadata={"help": "mask length"}) + mask_prob_audio: float = field( + default=0.65, + metadata={"help": "probability of replacing a token with mask"}, + ) + mask_length_image: int = field(default=10, metadata={"help": "mask length"}) + mask_prob_image: float = field( + default=0.65, + metadata={"help": "probability of replacing a token with mask"}, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose mask length"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: int = field( + default=1, + metadata={ + "help": "min space between spans (if no overlap is enabled)" + }, + ) + + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + mask_channel_min_space: int = field( + default=1, + metadata={ + "help": "min space between spans (if no overlap is enabled)" + }, + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={ + "help": "number of filters for convolutional positional embeddings" + }, + ) + conv_pos_groups: int = field( + default=16, + metadata={ + "help": "number of groups for convolutional positional embedding" + }, + ) + + latent_temp: Tuple[float, float, float] = field( + default=(2, 0.5, 0.999995), + metadata={"help": "legacy (to be removed)"}, + ) + + # loss computation + skip_masked: bool = field( + default=False, + metadata={"help": "skip computing losses over masked frames"}, + ) + skip_nomask: bool = field( + default=False, + metadata={"help": "skip computing losses over unmasked frames"}, + ) + resnet_relu_type: str = field(default='prelu', metadata={"help": 'relu type for resnet'}) + resnet_weights: Optional[str] = field(default=None, metadata={"help": 'resnet weights'}) + sim_type: str = field(default='cosine', metadata={"help": 'similarity type'}) + + sub_encoder_layers: int = field(default=0, metadata={'help': 'number of transformer layers for single modality'}) + audio_feat_dim: int = field(default=-1, metadata={'help': 'audio feature dimension'}) + modality_dropout: float = field(default=0, metadata={'help': 'drop one modality'}) + audio_dropout: float = field(default=0, metadata={'help': 'drop audio feature'}) + modality_fuse: str = field(default='concat', metadata={'help': 'fusing two modalities: add,concat'}) + selection_type : str = field(default='same_other_seq', metadata={'help': 'type of selectig images, same_other_seq: replace masked span with span from another sequence, same_seq: repace masked span with span of the same sequence'}) + masking_type : str = field(default='input', metadata={'help': 'input or feature masking'}) + + decoder_embed_dim: int = field( + default=768, metadata={"help": "decoder embedding dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field( + default=6, metadata={"help": "num of decoder layers"} + ) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "decoder layerdrop chance"} + ) + decoder_attention_heads: int = field( + default=4, metadata={"help": "num decoder attention heads"} + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_normalize_before: bool = field( + default=False, + metadata={"help": "apply layernorm before each decoder block"}, + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings " + "(outside self attention)" + }, + ) + decoder_dropout: float = field( + default=0.1, metadata={"help": "dropout probability in the decoder"} + ) + decoder_attention_dropout: float = field( + default=0.1, + metadata={ + "help": "dropout probability for attention weights " + "inside the decoder" + }, + ) + decoder_activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN " + "inside the decoder" + }, + ) + max_target_positions: int = field( + default=2048, metadata={"help": "max target positions"} + ) + share_decoder_input_output_embed: bool = field( + default=False, + metadata={"help": "share decoder input and output embeddings"}, + ) + no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'}) + + # # new fairseq + # required_seq_len_multiple: int = field( + # default=1, + # metadata={ + # "help": "pad the input to encoder such that the sequence length is divisible by multiple" + # }, + # ) + + # layer_type: LAYER_TYPE_CHOICES = field( + # default="transformer", metadata={"help": "layer type in encoder"} + # ) + +class SubModel(nn.Module): + def __init__(self, resnet=None, input_dim=None, cfg=None): + super().__init__() + self.resnet = resnet + self.proj = nn.Linear(input_dim, cfg.encoder_embed_dim) + self.encoder = TransformerEncoder(cfg) if cfg.encoder_layers > 0 else None + + def forward(self, x): #torch.Size([1, 1, 106, 112, 112]) + if self.resnet is not None: + x = self.resnet(x) #torch.Size([1, 512, 106]) #torch.Size([12, 26, 314]) + x = self.proj(x.transpose(1, 2)) #audio是 Linear(in_features=104, out_features=1024, bias=True) 太他妈扯了吧 + if self.encoder is not None: + x = self.encoder(x)[0].transpose(1, 2) + else: # + x = x.transpose(1, 2) + return x #torch.Size([1, 1024, 106]) + +@register_model("av_hubert", dataclass=AVHubertConfig) +class AVHubertModel(BaseFairseqModel): + def __init__( + self, + cfg: AVHubertConfig, + task_cfg: AVHubertPretrainingConfig, + dictionaries: List[Dictionary], + **kwargs + ) -> None: + super().__init__() + logger.info(f"HubertModel Config: {cfg}") + + feature_ds_rate = 1 + self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate + sub_cfg = deepcopy(cfg) + sub_cfg.encoder_layers = sub_cfg.sub_encoder_layers + resnet = ResEncoder(relu_type=cfg.resnet_relu_type, weights=cfg.resnet_weights) + self.feature_extractor_audio = SubModel(resnet=None, input_dim=cfg.audio_feat_dim, cfg=sub_cfg) + self.feature_extractor_video = SubModel(resnet=resnet, input_dim=resnet.backend_out, cfg=sub_cfg) + self.modality_dropout, self.audio_dropout = cfg.modality_dropout, cfg.audio_dropout + self.modality_fuse = cfg.modality_fuse + self.encoder_embed_dim = cfg.encoder_embed_dim + if self.modality_fuse == 'concat': + self.embed = cfg.encoder_embed_dim * 2 + elif self.modality_fuse == 'add': + self.embed = cfg.encoder_embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob_image, self.mask_prob_audio = cfg.mask_prob_image, cfg.mask_prob_audio + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length_image, self.mask_length_audio = cfg.mask_length_image, cfg.mask_length_audio + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + self.logit_temp = cfg.logit_temp + self.skip_masked = cfg.skip_masked + self.skip_nomask = cfg.skip_nomask + self.sim_type = cfg.sim_type + self.selection_type = cfg.selection_type + self.masking_type = cfg.masking_type + + final_dim = ( + cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + ) + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.audio_feat_dim).uniform_() if self.masking_type == 'input' else torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), nn.GLU() + ) + + self.untie_final_proj = cfg.untie_final_proj + if self.untie_final_proj: + self.final_proj = nn.Linear( + cfg.encoder_embed_dim, final_dim * len(dictionaries) + ) + else: + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + + # modules below are not needed during fine-tuning + if any([d is None for d in dictionaries]): + logger.info( + "cannot find dictionary. assume will be used for fine-tuning" + ) + else: + self.num_classes = [len(d) for d in dictionaries] + self.label_embs_concat = nn.Parameter( + torch.FloatTensor(sum(self.num_classes), final_dim) + ) + nn.init.uniform_(self.label_embs_concat) + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: AVHubertConfig, task: AVHubertPretrainingTask): + """Build a new model instance.""" + + kwargs = {} + model = AVHubertModel(cfg, task.cfg, task.dictionaries, **kwargs) + return model + + def apply_input_mask(self, x, padding_mask, target_list): + B, C, T = x.shape[:3] + is_audio = True if len(x.shape) == 3 else False + if is_audio: + mask_prob, mask_length = self.mask_prob_audio, self.mask_length_audio + else: + mask_prob, mask_length = self.mask_prob_image, self.mask_length_image + if mask_prob > 0: + + mask_indices, starts, ends, batch_indexes = compute_mask_indices( + (B, T), + padding_mask, + mask_prob, + mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices_np = mask_indices + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = x.transpose(1, 2).contiguous() # [B, T, C, H, W] + if B == 1: + x[mask_indices] = 0 + elif is_audio: + x[mask_indices] = self.mask_emb + elif self.selection_type == 'same_other_seq': + perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B + x_perm = x[perm] + x[mask_indices] = x_perm[mask_indices] + elif self.selection_type == 'same_seq': + batch_indexes_, other_indexes = [], [] + for batch_index, start, end in zip(batch_indexes, starts, ends): + length = end-start + other_start = np.setdiff1d(np.arange(T), np.arange(max(0, start-length), end)) + if len(other_start) > 0: + other_start = np.random.choice(other_start, size=1) + else: + other_start = 0 + other_end = other_start + length + other_indexes.append(np.arange(other_start, other_end).clip(max=T-1)) + batch_indexes_.append(np.zeros([length], dtype=np.int64)+batch_index) + batch_indexes, other_indexes = np.concatenate(batch_indexes_), np.concatenate(other_indexes) + x[mask_indices] = x[batch_indexes, other_indexes] + + x = x.transpose(1, 2).contiguous() + else: + mask_indices = None + + if self.mask_channel_prob > 0: + logger.info(f"No mask channel prob for input masking") + return x, mask_indices + + def apply_feature_mask(self, x, padding_mask, target_list): + B, T, C = x.shape + assert self.mask_prob_audio == self.mask_prob_image and self.mask_length_audio == self.mask_length_image, f"masking prob/length for image/audio be same for feature masking" + mask_prob, mask_length = self.mask_prob_audio, self.mask_length_image + if mask_prob > 0: + mask_indices, _, _, _ = compute_mask_indices( + (B, T), + padding_mask, + mask_prob, + mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices, _, _, _ = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_features(self, source: torch.Tensor, modality: str) -> torch.Tensor: + extractor = eval(f"self.feature_extractor_{modality}") + if self.feature_grad_mult > 0: + features = extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = extractor(source) + return features + + def forward_targets( + self, features: torch.Tensor, mask_indices: torch.Tensor, target_list: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Trim features to ensure labels exist and then get aligned labels + feat_tsz = features.size(2) + targ_tsz = min([t.size(1) for t in target_list]) + if self.feat2tar_ratio * feat_tsz > targ_tsz: + feat_tsz = int(targ_tsz / self.feat2tar_ratio) + features = features[..., :feat_tsz] + if mask_indices is not None: + mask_indices = mask_indices[..., :feat_tsz] + target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio + target_list = [t[:, target_inds.long()] for t in target_list] + return features, mask_indices, target_list + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def compute_logits(self, feats, emb_mat): + # feats: [B, T, F], emb_mat: [V, F] + if self.sim_type == 'dot': + logits = torch.matmul(feats, emb_mat.transpose(0, 1)) + elif self.sim_type == 'cosine': + batch_size, timesteps, emb_dim = feats.size() + feats_ = feats.view(-1, emb_dim) + nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1) # [B*T, V] + denom = (feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1) * (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0) # [B*T, V] + logits = (nom/denom.clamp(min=1e-6)).view(batch_size, timesteps, -1) + else: + raise NotImplementedError + logits = logits / self.logit_temp + return logits + + def forward( + self, + source: torch.Tensor, + target_list: Optional[List[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = True, + features_only: bool = False, + output_layer: Optional[int] = None + ) -> Dict[str, torch.Tensor]: + """output layer is 1-based""" + src_audio, src_video = source['audio'], source['video'] + if mask and self.masking_type == 'input': + src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list) + src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list) + mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video) + else: + src_audio, src_video, mask_indices = src_audio, src_video, None + + features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T] + features_video = self.forward_features(src_video, modality='video') + modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random() + if self.training: + if modality_drop_prob < self.modality_dropout: + if audio_drop_prob < self.audio_dropout: + features_audio = 0 * features_audio + else: + features_video = 0 * features_video + if self.modality_fuse == 'concat': + features = torch.cat([features_audio, features_video], dim=1) + elif self.modality_fuse == 'add': + features = features_audio + features_video + if target_list is not None: + features, mask_indices, target_list = self.forward_targets(features, mask_indices, target_list) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + if self.masking_type == 'feature' and mask: + x, mask_indices = self.apply_feature_mask(features, padding_mask, target_list) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, _ = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + if features_only: + return {"x": x, "padding_mask": padding_mask, "features": features} + + label_embs_list = self.label_embs_concat.split(self.num_classes, 0) + proj_x = self.final_proj(x) + if self.untie_final_proj: + proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1) + else: + proj_x_list = [proj_x for _ in self.num_classes] + logit_list = [self.compute_logits(proj, emb).view(-1, num_class) for proj, emb, num_class in zip(proj_x_list, label_embs_list, self.num_classes)] # [[B*T, V]] + mask, unmask = torch.logical_and(mask_indices, ~padding_mask).view(-1), torch.logical_and(~mask_indices, ~padding_mask).view(-1) # [B*T] + logit_m_list, logit_u_list = [logit[mask] for logit in logit_list], [logit[unmask] for logit in logit_list] + target_m_list, target_u_list = [target.view(-1)[mask].long() for target in target_list], [target.view(-1)[unmask].long() for target in target_list] + result = { + "logit_m_list": logit_m_list, + "logit_u_list": logit_u_list, + "target_m_list": target_m_list, + "target_u_list": target_u_list, + "padding_mask": padding_mask, + "features_pen": features_pen, + } + return result + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=output_layer, + ) + feature = res["features"] if ret_conv else res["x"] + return feature, res["padding_mask"] + + def extract_finetune(self, source, padding_mask=None, mask=False, ret_conv=False, output_layer=None): + src_audio, src_video = source['audio'], source['video'] #torch.Size([1, 1, 106, 112, 112]) + if mask and self.masking_type == 'input': + src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list=None) + src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list=None) + mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video) # mask_indices not used in fine-tuning + else: # + src_audio, src_video, mask_indices = src_audio, src_video, None + + if src_audio is not None and src_video is None: + features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T] + features_video = features_audio.new_zeros(features_audio.size(0), self.encoder_embed_dim, features_audio.size(-1)) + elif src_audio is None and src_video is not None: + features_video = self.forward_features(src_video, modality='video') + features_audio = features_video.new_zeros(features_video.size(0), self.encoder_embed_dim, features_video.size(-1)) #全0! + elif src_audio is not None and src_video is not None: + features_video = self.forward_features(src_video, modality='video') #torch.Size([1, 1024, 106]) #scr torch.Size([12, 1, 314, 88, 88]) + features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T] #torch.Size([12, 26, 314]) + + if self.modality_fuse == 'concat': # + features = torch.cat([features_audio, features_video], dim=1) #torch.Size([1, 2048, 106]) + elif self.modality_fuse == 'add': + features = features_audio + features_video + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: #features:torch.Size([1, 106, 2048]) + padding_mask = self.forward_padding_mask(features, padding_mask) #torch.Size([4, 154]) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) #torch.Size([1, 106, 1024]) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + x = features + mask_indices = None + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, _ = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + return x, padding_mask #torch.Size([1, 106, 1024]), None + + + def get_extra_losses(self, net_output): + extra_losses = [] + names = [] + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + + return extra_losses, names + + def remove_pretraining_modules(self): + self.target_glu = None + self.final_proj = None + + def get_logits(self, net_output, is_masked=True): + raise NotImplementedError + + def get_targets(self, net_output, is_masked=True): + raise NotImplementedError + + def compute_nce(self, x, pos, negs): + neg_is_pos = (pos == negs).all(-1) + pos = pos.unsqueeze(0) + targets = torch.cat([pos, negs], dim=0) + + logits = torch.cosine_similarity( + x.float(), targets.float(), dim=-1 + ).type_as(x) + logits /= self.logit_temp + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + logits = logits.transpose(0, 1) # (num_x, num_cls+1) + return logits diff --git a/slam_llm/models/avhubert/hubert_asr.py b/slam_llm/models/avhubert/hubert_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..0df02900d8dfea733e9f43db1a5861098f60779f --- /dev/null +++ b/slam_llm/models/avhubert/hubert_asr.py @@ -0,0 +1,523 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import sys,logging +import contextlib +import tempfile +from argparse import Namespace +from typing import Any, Optional + +import torch +import torch.nn as nn +from dataclasses import dataclass, field +from fairseq import checkpoint_utils, tasks, utils +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.models import BaseFairseqModel, FairseqEncoder, FairseqEncoderDecoderModel, register_model +from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES +from fairseq.tasks import FairseqTask +from omegaconf import II, MISSING + +DBG=True if len(sys.argv) == 1 else False + +if DBG: + from hubert import AVHubertModel + from decoder import TransformerDecoder +else: + from .hubert import AVHubertModel + from .decoder import TransformerDecoder + +logger = logging.getLogger(__name__) + + +@dataclass +class AVHubertAsrConfig(FairseqDataclass): + w2v_path: str = field( + default=MISSING, metadata={"help": "path to hubert model"} + ) + no_pretrained_weights: bool = field( + default=False, + metadata={"help": "if true, does not load pretrained weights"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + final_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout after transformer and before final projection" + }, + ) + dropout: float = field( + default=0.0, + metadata={"help": "dropout probability inside hubert model"}, + ) + attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights " + "inside hubert model" + }, + ) + activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN " + "inside hubert model" + }, + ) + + # masking + apply_mask: bool = field( + default=False, metadata={"help": "apply masking during fine-tuning"} + ) + mask_length: int = field( + default=10, metadata={"help": "repeat the mask indices multiple times"} + ) + mask_prob: float = field( + default=0.5, + metadata={ + "help": "probability of replacing a token with mask " + "(normalized by length)" + }, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose masks"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + freeze_finetune_updates: int = field( + default=0, + metadata={"help": "dont finetune hubert for this many updates"}, + ) + feature_grad_mult: float = field( + default=0.0, + metadata={"help": "reset feature grad mult in hubert to this"}, + ) + layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a layer in hubert"}, + ) + normalize: bool = II("task.normalize") + data: str = II("task.data") + + # this holds the loaded hubert args + w2v_args: Any = None + + +@dataclass +class AVHubertCtcConfig(AVHubertAsrConfig): + pass + + +@register_model("av_hubert_ctc", dataclass=AVHubertCtcConfig) +class AVHubertCtc(BaseFairseqModel): + def __init__(self, cfg: AVHubertCtcConfig, w2v_encoder: BaseFairseqModel): + super().__init__() + self.cfg = cfg + self.w2v_encoder = w2v_encoder + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: AVHubertCtcConfig, task: FairseqTask): + """Build a new model instance.""" + w2v_encoder = HubertEncoder(cfg, task.target_dictionary) + return cls(cfg, w2v_encoder) + + def get_normalized_probs(self, net_output, log_probs): + """Get normalized probabilities (or log probs) from a net's output.""" + + logits = net_output["encoder_out"] + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def get_logits(self, net_output): + logits = net_output["encoder_out"] + padding = net_output["encoder_padding_mask"] + if padding is not None and padding.any(): + padding = padding.T + logits[padding][..., 0] = 0 + logits[padding][..., 1:] = float("-inf") + + return logits + + def forward(self, **kwargs): + x = self.w2v_encoder(**kwargs) + return x + + +@dataclass +class AVHubertSeq2SeqConfig(AVHubertAsrConfig): + decoder_embed_dim: int = field( + default=768, metadata={"help": "decoder embedding dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field( + default=6, metadata={"help": "num of decoder layers"} + ) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "decoder layerdrop chance"} + ) + decoder_attention_heads: int = field( + default=4, metadata={"help": "num decoder attention heads"} + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_normalize_before: bool = field( + default=False, + metadata={"help": "apply layernorm before each decoder block"}, + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings " + "(outside self attention)" + }, + ) + decoder_dropout: float = field( + default=0.0, metadata={"help": "dropout probability in the decoder"} + ) + decoder_attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights " + "inside the decoder" + }, + ) + decoder_activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN " + "inside the decoder" + }, + ) + max_target_positions: int = field( + default=2048, metadata={"help": "max target positions"} + ) + share_decoder_input_output_embed: bool = field( + default=False, + metadata={"help": "share decoder input and output embeddings"}, + ) + no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'}) + +class HubertEncoder(FairseqEncoder): + def __init__(self, cfg: AVHubertAsrConfig, tgt_dict=None): + self.apply_mask = cfg.apply_mask + + arg_overrides = { + "dropout": cfg.dropout, + "activation_dropout": cfg.activation_dropout, + "dropout_input": cfg.dropout_input, + "attention_dropout": cfg.attention_dropout, + "mask_length": cfg.mask_length, + "mask_prob": cfg.mask_prob, + "mask_selection": cfg.mask_selection, + "mask_other": cfg.mask_other, + "no_mask_overlap": cfg.no_mask_overlap, + "mask_channel_length": cfg.mask_channel_length, + "mask_channel_prob": cfg.mask_channel_prob, + "mask_channel_selection": cfg.mask_channel_selection, + "mask_channel_other": cfg.mask_channel_other, + "no_mask_channel_overlap": cfg.no_mask_channel_overlap, + "encoder_layerdrop": cfg.layerdrop, + "feature_grad_mult": cfg.feature_grad_mult, + } + + if cfg.w2v_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu( + cfg.w2v_path, arg_overrides + ) + w2v_args = state.get("cfg", None) + if w2v_args is None: + w2v_args = convert_namespace_to_omegaconf(state["args"]) + cfg.w2v_args = w2v_args + else: + state = None + w2v_args = cfg.w2v_args + if isinstance(w2v_args, Namespace): + cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( + w2v_args + ) + + assert cfg.normalize == w2v_args.task.normalize, ( + "Fine-tuning works best when data normalization is the same. " + "Please check that --normalize is set or unset for " + "both pre-training and here" + ) + + w2v_args.task.data = cfg.data + + task = tasks.setup_task(w2v_args.task) + model = task.build_model(w2v_args.model) + + if state is not None and not cfg.no_pretrained_weights: + # set strict=False because we omit some modules + model.load_state_dict(state["model"], strict=False) + + model.remove_pretraining_modules() + + super().__init__(task.source_dictionary) + + d = model.encoder.embedding_dim + + self.w2v_model = model + + self.final_dropout = nn.Dropout(cfg.final_dropout) + self.freeze_finetune_updates = cfg.freeze_finetune_updates + self.num_updates = 0 + + if tgt_dict is not None: + self.proj = Linear(d, len(tgt_dict)) + elif getattr(cfg, "decoder_embed_dim", d) != d: + self.proj = Linear(d, cfg.decoder_embed_dim) + else: + self.proj = None + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def forward(self, source, padding_mask, tbc=True, **kwargs): + + w2v_args = { + "source": source, + "padding_mask": padding_mask, + "mask": self.apply_mask and self.training, + } + ft = self.freeze_finetune_updates <= self.num_updates + + with torch.no_grad() if not ft else contextlib.ExitStack(): + x, padding_mask = self.w2v_model.extract_finetune(**w2v_args) + + if tbc: + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + x = self.final_dropout(x) + + if self.proj: + x = self.proj(x) + + return { + "encoder_out": x, # T x B x C + "encoder_padding_mask": padding_mask, # B x T + "padding_mask": padding_mask, + } + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = encoder_out[ + "encoder_out" + ].index_select(1, new_order) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return None + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +class HubertEncoderWrapper(FairseqEncoder): + def __init__(self, w2v_model): + super().__init__(None) + self.w2v_model = w2v_model + + def forward(self, source, padding_mask, **kwargs): + w2v_args = { + "source": source, + "padding_mask": padding_mask, + } + + x, padding_mask = self.w2v_model.extract_finetune(**w2v_args) + # B x T x C -> T x B x C + x = x.transpose(0, 1) #torch.Size([106, 1, 1024]) + + return { + "encoder_out": x, # T x B x C + "encoder_padding_mask": padding_mask, # B x T + "padding_mask": padding_mask + } + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = encoder_out[ + "encoder_out" + ].index_select(1, new_order) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + if encoder_out["padding_mask"] is not None: + encoder_out["padding_mask"] = encoder_out[ + "padding_mask" + ].index_select(0, new_order) + return encoder_out + +@register_model("av_hubert_seq2seq", dataclass=AVHubertSeq2SeqConfig) +class AVHubertSeq2Seq(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder, tgt_dict, cfg): + super().__init__(encoder, decoder) + self.cfg = cfg + self.freeze_finetune_updates = cfg.freeze_finetune_updates + + @classmethod + def build_model(cls, cfg, task): + """Build a new model instance.""" + + arg_overrides = { + "dropout": cfg.dropout, + "activation_dropout": cfg.activation_dropout, + "dropout_input": cfg.dropout_input, + "attention_dropout": cfg.attention_dropout, + "mask_length": cfg.mask_length, + "mask_prob": cfg.mask_prob, + "mask_selection": cfg.mask_selection, + "mask_other": cfg.mask_other, + "no_mask_overlap": cfg.no_mask_overlap, + "mask_channel_length": cfg.mask_channel_length, + "mask_channel_prob": cfg.mask_channel_prob, + "mask_channel_selection": cfg.mask_channel_selection, + "mask_channel_other": cfg.mask_channel_other, + "no_mask_channel_overlap": cfg.no_mask_channel_overlap, + "encoder_layerdrop": cfg.layerdrop, + "feature_grad_mult": cfg.feature_grad_mult, + } + + if cfg.w2v_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu( + cfg.w2v_path, arg_overrides + ) + w2v_args = state.get("cfg", None) + if w2v_args is None: + w2v_args = convert_namespace_to_omegaconf(state["args"]) + cfg.w2v_args = w2v_args + else: + state = None + w2v_args = cfg.w2v_args + if isinstance(w2v_args, Namespace): + cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( + w2v_args + ) + + assert cfg.normalize == w2v_args.task.normalize, ( + "Fine-tuning works best when data normalization is the same. " + "Please check that --normalize is set or unset for " + "both pre-training and here" + ) + + w2v_args.task.data = cfg.data + + task_pretrain = tasks.setup_task(w2v_args.task) + if state is not None: + task_pretrain.load_state_dict(state['task_state']) + + encoder_ = task_pretrain.build_model(w2v_args.model) + + encoder = HubertEncoderWrapper(encoder_) + if state is not None and not cfg.no_pretrained_weights: + # set strict=False because we omit some modules + del state['model']['mask_emb'] + encoder.w2v_model.load_state_dict(state["model"], strict=False) + + encoder.w2v_model.remove_pretraining_modules() + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx=padding_idx) + return emb + + decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim) + decoder = TransformerDecoder(cfg, tgt_dict, decoder_embed_tokens) + + return AVHubertSeq2Seq(encoder, decoder, tgt_dict, cfg) + + + def forward(self, **kwargs): + # ft = self.freeze_finetune_updates <= self.num_updates + # with torch.no_grad() if not ft else contextlib.ExitStack(): + # output = self.encoder(**kwargs) + with torch.no_grad(): + output = self.encoder(**kwargs) #encoder_out,encoder_padding_mask,padding_mask + # decoder_out = self.decoder(prev_output_tokens=kwargs['prev_output_tokens'], encoder_out=output) + return output + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m diff --git a/slam_llm/models/avhubert/hubert_criterion.py b/slam_llm/models/avhubert/hubert_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..bc7262288958455de4189df82020f28f80d58e34 --- /dev/null +++ b/slam_llm/models/avhubert/hubert_criterion.py @@ -0,0 +1,169 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import re +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class AVHubertCriterionConfig(FairseqDataclass): + pred_masked_weight: float = field( + default=1.0, + metadata={"help": "weight for predictive loss for masked frames"}, + ) + pred_nomask_weight: float = field( + default=0.0, + metadata={"help": "weight for predictive loss for unmasked frames"}, + ) + loss_weights: Optional[List[float]] = field( + default=None, + metadata={"help": "weights for additional loss terms (not first one)"}, + ) + log_keys: List[str] = field( + default_factory=lambda: [], + metadata={"help": "output keys to log"}, + ) + + +@register_criterion("av_hubert", dataclass=AVHubertCriterionConfig) +class AVHubertCriterion(FairseqCriterion): + def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None): + super().__init__(task) + self.pred_masked_weight = pred_masked_weight + self.pred_nomask_weight = pred_nomask_weight + self.loss_weights = loss_weights + self.log_keys = [] if log_keys is None else log_keys + + def forward(self, model, sample, reduce=True, log_pred=False): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(target_list=sample["target_list"], **sample["net_input"]) + loss = 0. + sample_size = 0 + logging_output = {} + reduction = "sum" if reduce else "none" + + loss_m_list = [] + logp_m_list, targ_m_list = net_output['logit_m_list'], net_output['target_m_list'] + for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)): + loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction) + loss_m_list.append(loss_m) + logging_output[f"loss_m_{i}"] = loss_m.detach().item() + if self.pred_masked_weight > 0: + loss += self.pred_masked_weight * sum(loss_m_list) + sample_size += targ_m_list[0].numel() + + loss_u_list = [] + logp_u_list, targ_u_list = net_output['logit_u_list'], net_output['target_u_list'] + for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)): + loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction) + loss_u_list.append(loss_u) + logging_output[f"loss_u_{i}"] = loss_u.detach().item() + if self.pred_nomask_weight > 0: + loss += self.pred_nomask_weight * sum(loss_u_list) + sample_size += targ_u_list[0].numel() + + if self.loss_weights is not None: + assert hasattr(model, "get_extra_losses") + extra_losses, names = model.get_extra_losses(net_output) + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + names = [names] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, n, coef in zip(extra_losses, names, self.loss_weights): + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + logging_output[f"loss_{n}"] = p.item() + + logging_output = { + "loss": loss.item() if reduce else loss, + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + **logging_output, + } + + for lk in self.log_keys: + if lk in net_output: + logging_output[lk] = float((net_output[lk])) + + with torch.no_grad(): + for i, logp_m in enumerate(logp_m_list): + # corr_m, count_m = compute_correct(logp_m) + if logp_m.numel() == 0: + corr_m, count_m = 0, 0 + else: + corr_m, count_m = (logp_m.argmax(dim=-1)==targ_m_list[i]).sum().item(), len(targ_m_list[i]) + logging_output[f"correct_m_{i}"] = corr_m + logging_output[f"count_m_{i}"] = count_m + + for i, logp_u in enumerate(logp_u_list): + if logp_u.numel() == 0: + corr_u, count_u = 0, 0 + else: + corr_u, count_u = (logp_u.argmax(dim=-1)==targ_u_list[i]).sum().item(), len(targ_u_list[i]) + logging_output[f"correct_u_{i}"] = corr_u + logging_output[f"count_u_{i}"] = count_u + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training (copied from normal cross entropy).""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3) + if sample_size != ntokens: + metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3) + metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)) + else: + metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)) + + counts = {} + for lk in logging_outputs[0].keys(): + if lk.startswith("count_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val) + counts[lk] = val + + for lk in logging_outputs[0].keys(): + if lk.startswith("loss_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / sample_size / math.log(2), round=3) + elif lk.startswith("correct_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)]) + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + raise NotImplementedError() + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False diff --git a/slam_llm/models/avhubert/hubert_dataset.py b/slam_llm/models/avhubert/hubert_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..292ebdec9e409ec29c2fa800727e6b935c85a5be --- /dev/null +++ b/slam_llm/models/avhubert/hubert_dataset.py @@ -0,0 +1,529 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import os +import sys +import time +from typing import Any, List, Optional, Union + +import numpy as np + +import torch +import torch.nn.functional as F +from fairseq.data import data_utils +from fairseq.data.fairseq_dataset import FairseqDataset +from python_speech_features import logfbank +from scipy.io import wavfile + +DBG=True if len(sys.argv) == 1 else False + +if DBG: + import utils as custom_utils + logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "DEBUG").upper(), + stream=sys.stdout, + ) +else: + from . import utils as custom_utils + +logger = logging.getLogger(__name__) + + +def load_audio_visual(manifest_path, max_keep, min_keep, frame_rate, label_paths, label_rates, tol=0.1): + def is_audio_label_aligned(audio_dur, label_durs): + return all([abs(audio_dur - label_dur)<tol for label_dur in label_durs]) + + n_long, n_short, n_unaligned = 0, 0, 0 + names, inds, sizes = [], [], [] + dur_from_label_list = [] + is_seq_label = any([x==-1 for x in label_rates]) + for label_path, label_rate in zip(label_paths, label_rates): + label_lengths = [len(line.rstrip().split())/label_rate for line in open(label_path).readlines()] + dur_from_label_list.append(label_lengths) + dur_from_label_list = list(zip(*dur_from_label_list)) + + with open(manifest_path) as f: + root = f.readline().strip() + for ind, line in enumerate(f): + items = line.strip().split("\t") + sz = int(items[-2]) # + if min_keep is not None and sz < min_keep: + n_short += 1 + elif max_keep is not None and sz > max_keep: + n_long += 1 + elif (not is_seq_label) and (not is_audio_label_aligned(sz/frame_rate, dur_from_label_list[ind])): + n_unaligned += 1 + else: + video_path = items[1] + audio_path = items[2] + audio_id = items[0] + names.append((video_path, audio_path+':'+audio_id)) + inds.append(ind) + sizes.append(sz) + tot = ind + 1 + logger.info( + ( + f"max_keep={max_keep}, min_keep={min_keep}, " + f"loaded {len(names)}, skipped {n_short} short and {n_long} long and {n_unaligned} unaligned, " + f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" + ) + ) + return root, names, inds, tot, sizes + +def load_label(label_path, inds, tot): + with open(label_path) as f: + labels = [line.rstrip() for line in f] + assert ( + len(labels) == tot + ), f"number of labels does not match ({len(labels)} != {tot})" + labels = [labels[i] for i in inds] + return labels + + +def load_label_offset(label_path, inds, tot): + with open(label_path) as f: + code_lengths = [len(line.encode("utf-8")) for line in f] + assert ( + len(code_lengths) == tot + ), f"number of labels does not match ({len(code_lengths)} != {tot})" + offsets = list(itertools.accumulate([0] + code_lengths)) + offsets = [(offsets[i], offsets[i + 1]) for i in inds] + return offsets + + +def verify_label_lengths( + audio_sizes, + audio_rate, + label_path, + label_rate, + inds, + tot, + tol=0.1, # tolerance in seconds +): + if label_rate < 0: + logger.info(f"{label_path} is sequence label. skipped") + return + + with open(label_path) as f: + lengths = [len(line.rstrip().split()) for line in f] + assert len(lengths) == tot + lengths = [lengths[i] for i in inds] + num_invalid = 0 + for i, ind in enumerate(inds): + dur_from_audio = audio_sizes[i] / audio_rate + dur_from_label = lengths[i] / label_rate + if abs(dur_from_audio - dur_from_label) > tol: + logger.warning( + ( + f"audio and label duration differ too much " + f"(|{dur_from_audio} - {dur_from_label}| > {tol}) " + f"in line {ind+1} of {label_path}. Check if `label_rate` " + f"is correctly set (currently {label_rate}). " + f"num. of samples = {audio_sizes[i]}; " + f"label length = {lengths[i]}" + ) + ) + num_invalid += 1 + if num_invalid > 0: + logger.warning( + f"total {num_invalid} (audio, label) pairs with mismatched lengths" + ) + + +class AVHubertDataset(FairseqDataset): + def __init__( + self, + manifest_path: str, + sample_rate: float, + label_paths: List[str], + label_rates: Union[List[float], float], # -1 for sequence labels + pad_list: List[str], + eos_list: List[str], + label_processors: Optional[List[Any]] = None, + max_keep_sample_size: Optional[int] = None, + min_keep_sample_size: Optional[int] = None, + max_sample_size: Optional[int] = None, + shuffle: bool = True, + pad_audio: bool = False, + normalize: bool = False, + store_labels: bool = True, + random_crop: bool = False, + single_target: bool = False, + stack_order_audio: int=1, + skip_verify: bool=False, + image_mean: float=0, + image_std: float=1, + image_crop_size: int=88, + image_aug: bool=False, + modalities: Optional[List[str]]=None, + is_s2s=False, + noise_fn=None, + noise_prob=0, + noise_snr=0, + noise_num=1 + ): + self.label_rates = ( + [label_rates for _ in range(len(label_paths))] + if isinstance(label_rates, int) + else label_rates + ) + self.modalities = set(modalities) + self.audio_root, self.names, inds, tot, self.sizes = load_audio_visual(manifest_path, max_keep_sample_size, min_keep_sample_size, frame_rate=sample_rate, label_paths=label_paths, label_rates=self.label_rates) + self.sample_rate = sample_rate + self.stack_order_audio = stack_order_audio + self.shuffle = shuffle + self.random_crop = random_crop + + self.num_labels = len(label_paths) + self.pad_list = pad_list + self.eos_list = eos_list + self.label_processors = label_processors + self.single_target = single_target + self.store_labels = store_labels + self.is_s2s = is_s2s + self.noise_wav, self.noise_prob, self.noise_snr, self.noise_num = [ln.strip() for ln in open(noise_fn).readlines()] if noise_fn is not None else [], noise_prob, noise_snr, noise_num + + assert self.single_target == (self.label_rates[0] == -1), f"single target should be equivalent to sequence label (label_rate==-1)" + if store_labels: + self.label_list = [load_label(p, inds, tot) for p in label_paths] + else: + self.label_paths = label_paths + self.label_offsets_list = [ + load_label_offset(p, inds, tot) for p in label_paths + ] + assert ( + label_processors is None + or len(label_processors) == self.num_labels + ) + if not skip_verify: + for label_path, label_rate in zip(label_paths, self.label_rates): + verify_label_lengths(self.sizes, self.sample_rate, label_path, label_rate, inds, tot) + else: + logger.info(f"Skip label alignment verifying") + + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) + self.pad_audio = pad_audio + self.normalize = normalize + if image_aug: + self.transform = custom_utils.Compose([ + custom_utils.Normalize( 0.0,255.0 ), + custom_utils.RandomCrop((image_crop_size, image_crop_size)), + custom_utils.HorizontalFlip(0.5), + custom_utils.Normalize(image_mean, image_std) ]) + else: + self.transform = custom_utils.Compose([ + custom_utils.Normalize( 0.0,255.0 ), + custom_utils.CenterCrop((image_crop_size, image_crop_size)), + custom_utils.Normalize(image_mean, image_std) ]) + logger.info(f"image transform: {self.transform}") + + logger.info( + f"pad_audio={pad_audio}, random_crop={random_crop}, " + f"normalize={normalize}, max_sample_size={self.max_sample_size}, " + f"seqs2seq data={self.is_s2s},") + logger.info( + f"Noise wav: {noise_fn}->{len(self.noise_wav)} wav, Prob: {self.noise_prob}, SNR: {self.noise_snr}, Number of mixture: {self.noise_num}" + ) + + def get_label(self, index, label_idx): + if self.store_labels: + label = self.label_list[label_idx][index] + else: + with open(self.label_paths[label_idx]) as f: + offset_s, offset_e = self.label_offsets_list[label_idx][index] + f.seek(offset_s) + label = f.read(offset_e - offset_s) + + if self.label_processors is not None: + label = self.label_processors[label_idx](label) + return label + + def get_labels(self, index): + return [self.get_label(index, i) for i in range(self.num_labels)] + + def load_feature(self, mix_name): + """ + Load image and audio feature + Returns: + video_feats: numpy.ndarray of shape [T, H, W, 1], audio_feats: numpy.ndarray of shape [T, F] + """ + def stacker(feats, stack_order): + """ + Concatenating consecutive audio frames + Args: + feats - numpy.ndarray of shape [T, F] + stack_order - int (number of neighboring frames to concatenate + Returns: + feats - numpy.ndarray of shape [T', F'] + """ + feat_dim = feats.shape[1] + if len(feats) % stack_order != 0: + res = stack_order - len(feats) % stack_order + res = np.zeros([res, feat_dim]).astype(feats.dtype) + feats = np.concatenate([feats, res], axis=0) + feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order*feat_dim) + return feats + video_fn, audio_fn = mix_name + if 'video' in self.modalities: + video_feats = self.load_video(video_fn) # [T, H, W, 1] + else: + video_feats = None + if 'audio' in self.modalities: + audio_fn = audio_fn.split(':')[0] + sample_rate, wav_data = wavfile.read(audio_fn) + assert sample_rate == 16_000 and len(wav_data.shape) == 1 + if np.random.rand() < self.noise_prob: + wav_data = self.add_noise(wav_data) + audio_feats = logfbank(wav_data, samplerate=sample_rate).astype(np.float32) # [T, F] + audio_feats = stacker(audio_feats, self.stack_order_audio) # [T/stack_order_audio, F*stack_order_audio] + else: + audio_feats = None + if audio_feats is not None and video_feats is not None: + diff = len(audio_feats) - len(video_feats) + if diff < 0: + audio_feats = np.concatenate([audio_feats, np.zeros([-diff, audio_feats.shape[-1]], dtype=audio_feats.dtype)]) + elif diff > 0: + audio_feats = audio_feats[:-diff] + return video_feats, audio_feats + + def load_video(self, audio_name): + feats = custom_utils.load_video(os.path.join(self.audio_root, audio_name)) + feats = self.transform(feats) + feats = np.expand_dims(feats, axis=-1) + return feats + + def select_noise(self): + rand_indexes = np.random.randint(0, len(self.noise_wav), size=self.noise_num) + noise_wav = [] + for x in rand_indexes: + noise_wav.append(wavfile.read(self.noise_wav[x])[1].astype(np.float32)) + if self.noise_num == 1: + return noise_wav[0] + else: + min_len = min([len(x) for x in noise_wav]) + noise_wav = [x[:min_len] for x in noise_wav] + noise_wav = np.floor(np.stack(noise_wav).mean(axis=0)) + return noise_wav + + def add_noise(self, clean_wav): + clean_wav = clean_wav.astype(np.float32) + noise_wav = self.select_noise() + if type(self.noise_snr) == int or type(self.noise_snr) == float: + snr = self.noise_snr + elif type(self.noise_snr) == tuple: + snr = np.random.randint(self.noise_snr[0], self.noise_snr[1]+1) + clean_rms = np.sqrt(np.mean(np.square(clean_wav), axis=-1)) + if len(clean_wav) > len(noise_wav): + ratio = int(np.ceil(len(clean_wav)/len(noise_wav))) + noise_wav = np.concatenate([noise_wav for _ in range(ratio)]) + if len(clean_wav) < len(noise_wav): + start = 0 + noise_wav = noise_wav[start: start + len(clean_wav)] + noise_rms = np.sqrt(np.mean(np.square(noise_wav), axis=-1)) + adjusted_noise_rms = clean_rms / (10**(snr/20)) + adjusted_noise_wav = noise_wav * (adjusted_noise_rms / noise_rms) + mixed = clean_wav + adjusted_noise_wav + + #Avoid clipping noise + max_int16 = np.iinfo(np.int16).max + min_int16 = np.iinfo(np.int16).min + if mixed.max(axis=0) > max_int16 or mixed.min(axis=0) < min_int16: + if mixed.max(axis=0) >= abs(mixed.min(axis=0)): + reduction_rate = max_int16 / mixed.max(axis=0) + else : + reduction_rate = min_int16 / mixed.min(axis=0) + mixed = mixed * (reduction_rate) + mixed = mixed.astype(np.int16) + return mixed + + def __getitem__(self, index): + video_feats, audio_feats = self.load_feature(self.names[index]) + audio_feats, video_feats = torch.from_numpy(audio_feats.astype(np.float32)) if audio_feats is not None else None, torch.from_numpy(video_feats.astype(np.float32)) if video_feats is not None else None + if self.normalize and 'audio' in self.modalities: + with torch.no_grad(): + audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:]) + labels = self.get_labels(index) + fid = self.names[index][1].split(':')[1] + return {"id": index, 'fid': fid, "video_source": video_feats, 'audio_source': audio_feats, "label_list": labels} + + def __len__(self): + return len(self.sizes) + + def crop_to_max_size(self, wav, target_size, start=None): + size = len(wav) + diff = size - target_size + if diff <= 0: + return wav, 0 + # longer utterances + if start is None: + start, end = 0, target_size + if self.random_crop: + start = np.random.randint(0, diff + 1) + end = size - diff + start + else: + end = start + target_size + return wav[start:end], start + + def collater(self, samples): + samples = [s for s in samples if s["id"] is not None] + if len(samples) == 0: + return {} + + audio_source, video_source = [s["audio_source"] for s in samples], [s["video_source"] for s in samples] + if audio_source[0] is None: + audio_source = None + if video_source[0] is None: + video_source = None + if audio_source is not None: + audio_sizes = [len(s) for s in audio_source] + else: + audio_sizes = [len(s) for s in video_source] + if self.pad_audio: + audio_size = min(max(audio_sizes), self.max_sample_size) + else: + audio_size = min(min(audio_sizes), self.max_sample_size) + if audio_source is not None: + collated_audios, padding_mask, audio_starts = self.collater_audio(audio_source, audio_size) + else: + collated_audios, audio_starts = None, None + if video_source is not None: + collated_videos, padding_mask, audio_starts = self.collater_audio(video_source, audio_size, audio_starts) + else: + collated_videos = None + targets_by_label = [ + [s["label_list"][i] for s in samples] + for i in range(self.num_labels) + ] + targets_list, lengths_list, ntokens_list = self.collater_label( + targets_by_label, audio_size, audio_starts + ) + source = {"audio": collated_audios, "video": collated_videos} + net_input = {"source": source, "padding_mask": padding_mask} + batch = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": net_input, + "utt_id": [s['fid'] for s in samples] + } + + if self.single_target: + batch["target_lengths"] = lengths_list[0] + batch["ntokens"] = ntokens_list[0] + if self.is_s2s: + batch['target'], net_input['prev_output_tokens'] = targets_list[0][0], targets_list[0][1] + else: + batch["target"] = targets_list[0] + else: + batch["target_lengths_list"] = lengths_list + batch["ntokens_list"] = ntokens_list + batch["target_list"] = targets_list + return batch + + def collater_audio(self, audios, audio_size, audio_starts=None): + audio_feat_shape = list(audios[0].shape[1:]) + collated_audios = audios[0].new_zeros([len(audios), audio_size]+audio_feat_shape) + padding_mask = ( + torch.BoolTensor(len(audios), audio_size).fill_(False) # + ) + start_known = audio_starts is not None + audio_starts = [0 for _ in audios] if not start_known else audio_starts + for i, audio in enumerate(audios): + diff = len(audio) - audio_size + if diff == 0: + collated_audios[i] = audio + elif diff < 0: + assert self.pad_audio + collated_audios[i] = torch.cat( + [audio, audio.new_full([-diff]+audio_feat_shape, 0.0)] + ) + padding_mask[i, diff:] = True + else: + collated_audios[i], audio_starts[i] = self.crop_to_max_size( + audio, audio_size, audio_starts[i] if start_known else None + ) + if len(audios[0].shape) == 2: + collated_audios = collated_audios.transpose(1, 2) # [B, T, F] -> [B, F, T] + else: + collated_audios = collated_audios.permute((0, 4, 1, 2, 3)).contiguous() # [B, T, H, W, C] -> [B, C, T, H, W] + return collated_audios, padding_mask, audio_starts + + def collater_frm_label( + self, targets, audio_size, audio_starts, label_rate, pad + ): + assert label_rate > 0 + s2f = label_rate / self.sample_rate # num label per sample + frm_starts = [int(round(s * s2f)) for s in audio_starts] + frm_size = int(round(audio_size * s2f)) + if not self.pad_audio: + rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] + frm_size = min(frm_size, *rem_size) + targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)] + logger.debug(f"audio_starts={audio_starts}") + logger.debug(f"frame_starts={frm_starts}") + logger.debug(f"frame_size={frm_size}") + + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens( + targets, pad_idx=pad, left_pad=False + ) + return targets, lengths, ntokens + + def collater_seq_label(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens( + targets, pad_idx=pad, left_pad=False + ) + return targets, lengths, ntokens + + def collater_seq_label_s2s(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + pad, eos = self.label_processors[0].dictionary.pad(), self.label_processors[0].dictionary.eos() + targets_ = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False) + prev_output_tokens = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False, move_eos_to_beginning=True) + return (targets_, prev_output_tokens), lengths, ntokens + + def collater_label(self, targets_by_label, audio_size, audio_starts): + targets_list, lengths_list, ntokens_list = [], [], [] + itr = zip(targets_by_label, self.label_rates, self.pad_list) + for targets, label_rate, pad in itr: + if label_rate == -1: + if self.is_s2s: + targets, lengths, ntokens = self.collater_seq_label_s2s(targets, pad) + else: + targets, lengths, ntokens = self.collater_seq_label(targets, pad) + else: + targets, lengths, ntokens = self.collater_frm_label( + targets, audio_size, audio_starts, label_rate, pad + ) + targets_list.append(targets) + lengths_list.append(lengths) + ntokens_list.append(ntokens) + return targets_list, lengths_list, ntokens_list + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + if self.pad_audio: + return self.sizes[index] + return min(self.sizes[index], self.max_sample_size) + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + order.append(self.sizes) + return np.lexsort(order)[::-1] diff --git a/slam_llm/models/avhubert/hubert_pretraining.py b/slam_llm/models/avhubert/hubert_pretraining.py new file mode 100644 index 0000000000000000000000000000000000000000..5829e6eadbd1da1ae66242bf4478c93e2161c32e --- /dev/null +++ b/slam_llm/models/avhubert/hubert_pretraining.py @@ -0,0 +1,401 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os, glob +import sys +from typing import Dict, List, Optional, Tuple + +import numpy as np + +from dataclasses import dataclass, field +from fairseq import metrics, search +from fairseq.data import Dictionary, encoders +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.tasks import register_task +from fairseq.tasks.fairseq_task import FairseqTask +from omegaconf import MISSING, II +import numpy as np +from argparse import Namespace + +DBG=True if len(sys.argv) == 1 else False + +if DBG: + from hubert_dataset import AVHubertDataset + from sequence_generator import SequenceGenerator +else: + from .hubert_dataset import AVHubertDataset + from .sequence_generator import SequenceGenerator + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary: Dictionary) -> None: + self.dictionary = dictionary + + def __call__(self, label: str) -> List[str]: + return self.dictionary.encode_line( + label, append_eos=False, add_if_not_exist=False, + ) + +class LabelEncoderS2SToken(object): + def __init__(self, dictionary: Dictionary, bpe_tokenizer) -> None: + self.bpe_tokenizer = bpe_tokenizer + self.dictionary = dictionary + + def __call__(self, label: str) -> List[str]: + label = self.bpe_tokenizer.encode(label.lower()) + return self.dictionary.encode_line( + label, append_eos=True, add_if_not_exist=False, + ).long() + + def decode(self, tok, symbols_ignore=None): + tok = self.dictionary.string(tok, extra_symbols_to_ignore=symbols_ignore) + if self.bpe_tokenizer: + tok = self.bpe_tokenizer.decode(tok) + return tok + +@dataclass +class AVHubertPretrainingConfig(FairseqDataclass): + input_modality: str = II("task.input_modality") #?? + data: str = field( + default=MISSING, metadata={"help": "path to data directory"} + ) + labels: List[str] = field( + default_factory=lambda: ["ltr"], + metadata={ + "help": ( + "extension of the label files to load, frame-level labels for" + " pre-training, and sequence-level label for fine-tuning" + ) + }, + ) + label_dir: Optional[str] = field( + default=None, + metadata={ + "help": "if set, looks for labels in this directory instead", + }, + ) + label_rate: int = field( + default=-1, + metadata={"help": "label frame rate. -1 for sequence label"}, + ) + + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down " + "sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={ + "help": "if set, normalizes input to have 0 mean and unit variance" + }, + ) + enable_padding: bool = field( + default=False, + metadata={"help": "pad shorter samples instead of cropping"}, + ) + max_sample_size: Optional[int] = field( + default=None, + metadata={"help": "max sample size to keep in training"}, + ) + min_sample_size: Optional[int] = field( + default=None, + metadata={"help": "min sample size to keep in training"}, + ) + max_trim_sample_size: Optional[int] = field( + default=II("task.max_sample_size"), + metadata={"help": "max sample size to trim to for batching"}, + ) + single_target: Optional[bool] = field( + default=False, + metadata={ + "help": "if set, AddTargetDatasets outputs same keys " + "as AddTargetDataset" + }, + ) + random_crop: Optional[bool] = field( + default=True, + metadata={"help": "always crop from the beginning if false"}, + ) + pad_audio: Optional[bool] = field( + default=False, + metadata={"help": "pad audio to the longest one in the batch if true"}, + ) + pdb: Optional[bool] = field( + default=False, + metadata={"help": "pdb"}, + ) + stack_order_audio: int = field( + default=1, + metadata={"help": "concatenate n consecutive audio frames for one step"}, + ) + skip_verify: Optional[bool] = field( + default=False, + metadata={"help": "skip verifying label-audio alignment"}, + ) + image_aug: bool = field(default=False, metadata={'help': 'image data augmentation'}) + image_crop_size: int = field( + default=88, metadata={"help": "image ROI size"}) + image_mean: float = field( + default=0.421, metadata={"help": "image mean"}) + image_std: float = field( + default=0.165, metadata={"help": "image std"}) + modalities: Optional[List[str]] = field(default_factory=lambda: ["audio", "video"], metadata={'help': 'modalities to load'}) + is_s2s: bool=field(default=False, metadata={'help': 'seq2seq fine-tuning only'}) + tokenizer_bpe_name: Optional[str] = field(default=None, metadata={'help': 'tokenizer model name'}) + tokenizer_bpe_model: Optional[str] = field(default=None, metadata={'help': 'tokenizer model path'}) + noise_wav: Optional[str] = field(default=None, metadata={'help': 'manifest of noise wav files (one wav file path per line)'}) + noise_prob: float = field(default=0, metadata={'help': 'noise probability'}) + noise_snr: Optional[str] = field(default='0', metadata={'help': 'noise SNR in audio'}) + noise_num: int = field(default=1, metadata={'help': 'number of noise wav files to mix'}) + fine_tuning: bool = field(default=False, metadata={"help": "set to true if fine-tuning AV-Hubert"}) + +@register_task("av_hubert_pretraining", dataclass=AVHubertPretrainingConfig) +class AVHubertPretrainingTask(FairseqTask): + + cfg: AVHubertPretrainingConfig + + def __init__( + self, + cfg: AVHubertPretrainingConfig, + ) -> None: + super().__init__(cfg) + + logger.info(f"current directory is {os.getcwd()}") + logger.info(f"AVHubertPretrainingTask Config {cfg}") + + self.fine_tuning = cfg.fine_tuning + if cfg.fine_tuning: + self.state.add_factory("target_dictionary", self.load_dictionaries) + if cfg.is_s2s: + self.state.add_factory("s2s_tokenizer", self.load_tokenizer) + else: + self.state.add_factory("dictionaries", self.load_dictionaries) + + self.blank_symbol = "<s>" + + @property + def source_dictionary(self) -> Optional[Dictionary]: + return None # self._source_dictionary + + @property + def target_dictionary(self) -> Optional[Dictionary]: + return self.state.target_dictionary # self._target_dictionary + + @property + def dictionaries(self) -> List[Dictionary]: + return self.state.dictionaries + + def load_dictionaries(self): + label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir + dictionaries = [ + Dictionary.load(f"{label_dir}/dict.{label}.txt") + for label in self.cfg.labels + ] + return dictionaries[0] if self.cfg.fine_tuning else dictionaries + + def load_tokenizer(self): + bpe_args = Namespace(**{'bpe': self.cfg.tokenizer_bpe_name, f"{self.cfg.tokenizer_bpe_name}_model": self.cfg.tokenizer_bpe_model}) + bpe_tokenizer = encoders.build_bpe(bpe_args) + return bpe_tokenizer + + @property + def s2s_tokenizer(self): + return self.state.s2s_tokenizer + + @classmethod + def setup_task( + cls, cfg: AVHubertPretrainingConfig, **kwargs + ) -> "AVHubertPretrainingTask": + if cfg.pdb: + import pdb + pdb.set_trace() + return cls(cfg) + + def get_label_dir(self) -> str: + if self.cfg.label_dir is None: + return self.cfg.data + return self.cfg.label_dir + + def load_dataset(self, split: str, **kwargs) -> None: + manifest = f"{self.cfg.data}/{split}.tsv" + dictionaries = [self.target_dictionary] if self.fine_tuning else self.dictionaries + pad_list = [dictionary.pad() for dictionary in dictionaries] + eos_list = [dictionary.eos() for dictionary in dictionaries] + if not self.cfg.is_s2s: + procs = [LabelEncoder(dictionary) for dictionary in dictionaries] + else: + logger.info(f"Using tokenizer") + bpe_tokenizer = self.s2s_tokenizer + procs = [LabelEncoderS2SToken(dictionary, bpe_tokenizer) for dictionary in dictionaries] + paths = [ + f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels + ] + image_aug = self.cfg.image_aug if split == 'train' else False + noise_fn, noise_snr = f"{self.cfg.noise_wav}/{split}.tsv" if self.cfg.noise_wav is not None else None, eval(self.cfg.noise_snr) + noise_num = self.cfg.noise_num # + self.datasets[split] = AVHubertDataset( + manifest, + sample_rate=self.cfg.sample_rate, + label_paths=paths, + label_rates=self.cfg.label_rate, + pad_list=pad_list, + eos_list=eos_list, + label_processors=procs, + max_keep_sample_size=self.cfg.max_sample_size, + min_keep_sample_size=self.cfg.min_sample_size, + max_sample_size=self.cfg.max_trim_sample_size, + pad_audio=self.cfg.pad_audio, + normalize=self.cfg.normalize, + store_labels=False, + random_crop=self.cfg.random_crop, + single_target=self.cfg.single_target, + stack_order_audio=self.cfg.stack_order_audio, + skip_verify=self.cfg.skip_verify, + image_mean=self.cfg.image_mean, + image_std=self.cfg.image_std, + image_crop_size=self.cfg.image_crop_size, + image_aug=image_aug, + modalities=self.cfg.modalities, + is_s2s=self.cfg.is_s2s, + noise_fn=noise_fn, + noise_prob=self.cfg.noise_prob, + noise_snr=noise_snr, + noise_num=noise_num + ) + + def max_positions(self) -> Tuple[int, int]: + return (sys.maxsize, sys.maxsize) + + def filter_indices_by_size( + self, indices: np.array, *args, **kwargs + ) -> np.array: + return indices + + def build_generator( + self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None, + ): + """ + Build a :class:`~fairseq.SequenceGenerator` instance for this + task. + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + args (fairseq.dataclass.configs.GenerationConfig): + configuration object (dataclass) for generation + extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass + through to SequenceGenerator + prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): + If provided, this function constrains the beam search to + allowed tokens only at each step. The provided function + should take 2 arguments: the batch ID (`batch_id: int`) + and a unidimensional tensor of token ids (`inputs_ids: + torch.Tensor`). It has to return a `List[int]` with the + allowed tokens for the next generation step conditioned + on the previously generated tokens (`inputs_ids`) and + the batch ID (`batch_id`). This argument is useful for + constrained generation conditioned on the prefix, as + described in "Autoregressive Entity Retrieval" + (https://arxiv.org/abs/2010.00904) and + https://github.com/facebookresearch/GENRE. + """ + if getattr(args, "score_reference", False): + from fairseq.sequence_scorer import SequenceScorer + + return SequenceScorer( + self.target_dictionary, + compute_alignment=getattr(args, "print_alignment", False), + ) + + # Choose search strategy. Defaults to Beam Search. + sampling = getattr(args, "sampling", False) + sampling_topk = getattr(args, "sampling_topk", -1) + sampling_topp = getattr(args, "sampling_topp", -1.0) + diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) + diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) + match_source_len = getattr(args, "match_source_len", False) + diversity_rate = getattr(args, "diversity_rate", -1) + constrained = getattr(args, "constraints", False) + if prefix_allowed_tokens_fn is None: + prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) + if ( + sum( + int(cond) + for cond in [ + sampling, + diverse_beam_groups > 0, + match_source_len, + diversity_rate > 0, + ] + ) + > 1 + ): + raise ValueError("Provided Search parameters are mutually exclusive.") + assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" + assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" + + if sampling: + search_strategy = search.Sampling( + self.target_dictionary, sampling_topk, sampling_topp + ) + elif diverse_beam_groups > 0: + search_strategy = search.DiverseBeamSearch( + self.target_dictionary, diverse_beam_groups, diverse_beam_strength + ) + elif match_source_len: + # this is useful for tagging applications where the output + # length should match the input length, so we hardcode the + # length constraints for simplicity + search_strategy = search.LengthConstrainedBeamSearch( + self.target_dictionary, + min_len_a=1, + min_len_b=0, + max_len_a=1, + max_len_b=0, + ) + elif diversity_rate > -1: + search_strategy = search.DiverseSiblingsSearch( + self.target_dictionary, diversity_rate + ) + elif constrained: + search_strategy = search.LexicallyConstrainedBeamSearch( + self.target_dictionary, args.constraints + ) + elif prefix_allowed_tokens_fn: + search_strategy = search.PrefixConstrainedBeamSearch( + self.target_dictionary, prefix_allowed_tokens_fn + ) + else: + search_strategy = search.BeamSearch(self.target_dictionary) + + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + if seq_gen_cls is None: + if getattr(args, "print_alignment", False): + seq_gen_cls = SequenceGeneratorWithAlignment + extra_gen_cls_kwargs["print_alignment"] = args.print_alignment + else: + seq_gen_cls = SequenceGenerator + + return seq_gen_cls( + models, + self.target_dictionary, + beam_size=getattr(args, "beam", 5), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + search_strategy=search_strategy, + **extra_gen_cls_kwargs, + ) diff --git a/slam_llm/models/avhubert/infer_s2s.py b/slam_llm/models/avhubert/infer_s2s.py new file mode 100644 index 0000000000000000000000000000000000000000..805b66fdc976073ba61364072bbd41df7547a4b9 --- /dev/null +++ b/slam_llm/models/avhubert/infer_s2s.py @@ -0,0 +1,318 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import ast +from itertools import chain +import logging +import math +import os +import sys +import json +import hashlib +import editdistance +from argparse import Namespace + +import numpy as np +import torch +from fairseq import checkpoint_utils, options, tasks, utils, distributed_utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.logging import progress_bar +from fairseq.logging.meters import StopwatchMeter, TimeMeter +from fairseq.models import FairseqLanguageModel +from omegaconf import DictConfig + +from pathlib import Path +import hydra +from hydra.core.config_store import ConfigStore +from fairseq.dataclass.configs import ( + CheckpointConfig, + CommonConfig, + CommonEvalConfig, + DatasetConfig, + DistributedTrainingConfig, + GenerationConfig, + FairseqDataclass, +) +from dataclasses import dataclass, field, is_dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +from omegaconf import OmegaConf + +logging.root.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +config_path = Path(__file__).resolve().parent / "conf" + +@dataclass +class OverrideConfig(FairseqDataclass): + noise_wav: Optional[str] = field(default=None, metadata={'help': 'noise wav file'}) + noise_prob: float = field(default=0, metadata={'help': 'noise probability'}) + noise_snr: float = field(default=0, metadata={'help': 'noise SNR in audio'}) + modalities: List[str] = field(default_factory=lambda: [""], metadata={'help': 'which modality to use'}) + data: Optional[str] = field(default=None, metadata={'help': 'path to test data directory'}) + label_dir: Optional[str] = field(default=None, metadata={'help': 'path to test label directory'}) + +@dataclass +class InferConfig(FairseqDataclass): + task: Any = None + generation: GenerationConfig = GenerationConfig() + common: CommonConfig = CommonConfig() + common_eval: CommonEvalConfig = CommonEvalConfig() + checkpoint: CheckpointConfig = CheckpointConfig() + distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() + dataset: DatasetConfig = DatasetConfig() + override: OverrideConfig = OverrideConfig() + is_ax: bool = field( + default=False, + metadata={ + "help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume" + }, + ) + + +def main(cfg: DictConfig): + + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + assert cfg.common_eval.path is not None, "--path required for recognition!" + assert ( + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam + ), "--sampling requires --nbest to be equal to --beam" + + if cfg.common_eval.results_path is not None: + os.makedirs(cfg.common_eval.results_path, exist_ok=True) + output_path = os.path.join(cfg.common_eval.results_path, "decode.log") + with open(output_path, "w", buffering=1, encoding="utf-8") as h: + return _main(cfg, h) + return _main(cfg, sys.stdout) + + +def get_symbols_to_strip_from_output(generator): + if hasattr(generator, "symbols_to_strip_from_output"): + return generator.symbols_to_strip_from_output + else: + return {generator.eos, generator.pad} + +def _main(cfg, output_file): + logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=output_file, + ) + logger = logging.getLogger("hybrid.speech_recognize") + if output_file is not sys.stdout: # also print to stdout + logger.addHandler(logging.StreamHandler(sys.stdout)) + + utils.import_user_module(cfg.common) + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([cfg.common_eval.path]) + models = [model.eval().cuda() for model in models] #!! + saved_cfg.task.modalities = cfg.override.modalities + task = tasks.setup_task(saved_cfg.task) + + task.build_tokenizer(saved_cfg.tokenizer) + task.build_bpe(saved_cfg.bpe) + + logger.info(cfg) + + # Fix seed for stochastic decoding + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) + + use_cuda = torch.cuda.is_available() + + # Set dictionary + dictionary = task.target_dictionary + + # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config + task.cfg.noise_prob = cfg.override.noise_prob + task.cfg.noise_snr = cfg.override.noise_snr + task.cfg.noise_wav = cfg.override.noise_wav + if cfg.override.data is not None: + task.cfg.data = cfg.override.data + if cfg.override.label_dir is not None: + task.cfg.label_dir = cfg.override.label_dir + task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) + + lms = [None] + + # Optimize ensemble for generation + for model in chain(models, lms): + if model is None: + continue + if cfg.common.fp16: + model.half() + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: + model.cuda() + model.prepare_for_inference_(cfg) + + # Load dataset (possibly sharded) + itr = task.get_batch_iterator( + dataset=task.dataset(cfg.dataset.gen_subset), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + task.max_positions(), *[m.max_positions() for m in models] + ), + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + + # Initialize generator + if cfg.generation.match_source_len: + logger.warning( + "The option match_source_len is not applicable to speech recognition. Ignoring it." + ) + gen_timer = StopwatchMeter() + extra_gen_cls_kwargs = { + "lm_model": lms[0], + "lm_weight": cfg.generation.lm_weight, + } + cfg.generation.score_reference = False # + save_attention_plot = cfg.generation.print_alignment is not None + cfg.generation.print_alignment = None # + generator = task.build_generator( + models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs + ) + + def decode_fn(x): + symbols_ignore = get_symbols_to_strip_from_output(generator) + symbols_ignore.add(dictionary.pad()) + if hasattr(task.datasets[cfg.dataset.gen_subset].label_processors[0], 'decode'): + return task.datasets[cfg.dataset.gen_subset].label_processors[0].decode(x, symbols_ignore) + chars = dictionary.string(x, extra_symbols_to_ignore=symbols_ignore) + words = " ".join("".join(chars.split()).replace('|', ' ').split()) + return words + + num_sentences = 0 + has_target = True + wps_meter = TimeMeter() + result_dict = {'utt_id': [], 'ref': [], 'hypo': []} + for sample in progress: + sample = utils.move_to_cuda(sample) if use_cuda else sample + if "net_input" not in sample: + continue + + prefix_tokens = None + if cfg.generation.prefix_size > 0: + prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] + + constraints = None + if "constraints" in sample: + constraints = sample["constraints"] + + gen_timer.start() + hypos = task.inference_step( + generator, + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + ) + num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) + gen_timer.stop(num_generated_tokens) + + for i in range(len(sample["id"])): + result_dict['utt_id'].append(sample['utt_id'][i]) + ref_sent = decode_fn(sample['target'][i].int().cpu()) + result_dict['ref'].append(ref_sent) + best_hypo = hypos[i][0]['tokens'].int().cpu() + hypo_str = decode_fn(best_hypo) + result_dict['hypo'].append(hypo_str) + logger.info(f"\nREF:{ref_sent}\nHYP:{hypo_str}\n") + wps_meter.update(num_generated_tokens) + progress.log({"wps": round(wps_meter.avg)}) + num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel() + + logger.info("NOTE: hypothesis and token scores are output in base 2") + logger.info("Recognized {:,} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( + num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) + + yaml_str = OmegaConf.to_yaml(cfg.generation) + fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16) + fid = fid % 1000000 + result_fn = f"{cfg.common_eval.results_path}/hypo-{fid}.json" + json.dump(result_dict, open(result_fn, 'w'), indent=4) + n_err, n_total = 0, 0 + assert len(result_dict['hypo']) == len(result_dict['ref']) + for hypo, ref in zip(result_dict['hypo'], result_dict['ref']): + hypo, ref = hypo.strip().split(), ref.strip().split() + n_err += editdistance.eval(hypo, ref) + n_total += len(ref) + wer = 100 * n_err / n_total + wer_fn = f"{cfg.common_eval.results_path}/wer.{fid}" + with open(wer_fn, "w") as fo: + fo.write(f"WER: {wer}\n") + fo.write(f"err / num_ref_words = {n_err} / {n_total}\n\n") + fo.write(f"{yaml_str}") + logger.info(f"WER: {wer}%") + return + + +@hydra.main(config_path=config_path, config_name="infer") +def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]: + container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) + cfg = OmegaConf.create(container) + OmegaConf.set_struct(cfg, True) + + if cfg.common.reset_logging: + reset_logging() + + wer = float("inf") + + try: + if cfg.common.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + distributed_utils.call_main(cfg, main) + else: + distributed_utils.call_main(cfg, main) + + except BaseException as e: # pylint: disable=broad-except + if not cfg.common.suppress_crashes: + raise + else: + logger.error("Crashed! %s", str(e)) + return + + +def cli_main() -> None: + try: + from hydra._internal.utils import ( + get_args, + ) # pylint: disable=import-outside-toplevel + + cfg_name = get_args().config_name or "infer" + except ImportError: + logger.warning("Failed to get config name from hydra args") + cfg_name = "infer" + + cs = ConfigStore.instance() + cs.store(name=cfg_name, node=InferConfig) + + for k in InferConfig.__dataclass_fields__: + if is_dataclass(InferConfig.__dataclass_fields__[k].type): + v = InferConfig.__dataclass_fields__[k].default + cs.store(name=k, node=v) + + hydra_main() # pylint: disable=no-value-for-parameter + + +if __name__ == "__main__": + cli_main() diff --git a/slam_llm/models/avhubert/resnet.py b/slam_llm/models/avhubert/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..0134e111f8892a8c314628e7ffc3543fd9d35c03 --- /dev/null +++ b/slam_llm/models/avhubert/resnet.py @@ -0,0 +1,169 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import torch.nn as nn +import pdb + + +logger = logging.getLogger(__name__) + +def conv3x3(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +def downsample_basic_block( inplanes, outplanes, stride ): + return nn.Sequential( + nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(outplanes), + ) + +def downsample_basic_block_v2( inplanes, outplanes, stride ): + return nn.Sequential( + nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False), + nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(outplanes), + ) + + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type = 'relu' ): + super(BasicBlock, self).__init__() + + assert relu_type in ['relu','prelu'] + + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + + if relu_type == 'relu': + self.relu1 = nn.ReLU(inplace=True) + self.relu2 = nn.ReLU(inplace=True) + elif relu_type == 'prelu': + self.relu1 = nn.PReLU(num_parameters=planes) + self.relu2 = nn.PReLU(num_parameters=planes) + else: + raise Exception('relu type not implemented') + + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu2(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, relu_type = 'relu', gamma_zero = False, avg_pool_downsample = False): + self.inplanes = 64 + self.relu_type = relu_type + self.gamma_zero = gamma_zero + self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block + + super(ResNet, self).__init__() + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d(1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + if self.gamma_zero: + for m in self.modules(): + if isinstance(m, BasicBlock ): + m.bn2.weight.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + + + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = self.downsample_block( inplanes = self.inplanes, + outplanes = planes * block.expansion, + stride = stride ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, relu_type = self.relu_type)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, relu_type = self.relu_type)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + return x + +class ResEncoder(nn.Module): + def __init__(self, relu_type, weights): + super(ResEncoder, self).__init__() + self.frontend_nout = 64 + self.backend_out = 512 + frontend_relu = nn.PReLU(num_parameters=self.frontend_nout) if relu_type == 'prelu' else nn.ReLU() + self.frontend3D = nn.Sequential( + nn.Conv3d(1, self.frontend_nout, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False), + nn.BatchNorm3d(self.frontend_nout), + frontend_relu, + nn.MaxPool3d( kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))) + self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type) + if weights is not None: + logger.info(f"Load {weights} for resnet") + std = torch.load(weights, map_location=torch.device('cpu'))['model_state_dict'] + frontend_std, trunk_std = OrderedDict(), OrderedDict() + for key, val in std.items(): + new_key = '.'.join(key.split('.')[1:]) + if 'frontend3D' in key: + frontend_std[new_key] = val + if 'trunk' in key: + trunk_std[new_key] = val + self.frontend3D.load_state_dict(frontend_std) + self.trunk.load_state_dict(trunk_std) + + def forward(self, x): + B, C, T, H, W = x.size() + x = self.frontend3D(x) + Tnew = x.shape[2] + x = self.threeD_to_2D_tensor(x) + x = self.trunk(x) + x = x.view(B, Tnew, x.size(1)) + x = x.transpose(1, 2).contiguous() + return x + + def threeD_to_2D_tensor(self, x): + n_batch, n_channels, s_time, sx, sy = x.shape + x = x.transpose(1, 2).contiguous() + return x.reshape(n_batch*s_time, n_channels, sx, sy) diff --git a/slam_llm/models/avhubert/sequence_generator.py b/slam_llm/models/avhubert/sequence_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..428a5765e5cb2f3d56771eac5df2e31f0484fdf6 --- /dev/null +++ b/slam_llm/models/avhubert/sequence_generator.py @@ -0,0 +1,985 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, List, Optional +import sys + +import torch +import torch.nn as nn +from fairseq import search, utils +from fairseq.data import data_utils +from fairseq.models import FairseqIncrementalDecoder +from torch import Tensor +from fairseq.ngram_repeat_block import NGramRepeatBlock + + +class SequenceGenerator(nn.Module): + def __init__( + self, + models, + tgt_dict, + beam_size=1, + max_len_a=0, + max_len_b=200, + max_len=0, + min_len=1, + normalize_scores=True, + len_penalty=1.0, + unk_penalty=0.0, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + search_strategy=None, + eos=None, + symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0, + ): + """Generates translations of a given source sentence. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + max_len (int, optional): the maximum length of the generated output + (not including end-of-sentence) + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + len_penalty (float, optional): length penalty, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + """ + super().__init__() + if isinstance(models, EnsembleModel): + self.model = models + else: + self.model = EnsembleModel(models) + self.tgt_dict = tgt_dict + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() if eos is None else eos + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) + self.vocab_size = len(tgt_dict) + self.beam_size = beam_size + # the max beam size is the dictionary size - 1, since we never select pad + self.beam_size = min(beam_size, self.vocab_size - 1) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.min_len = min_len + self.max_len = max_len or self.model.max_decoder_positions() + + self.normalize_scores = normalize_scores + self.len_penalty = len_penalty + self.unk_penalty = unk_penalty + self.temperature = temperature + self.match_source_len = match_source_len + + if no_repeat_ngram_size > 0: + self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) + else: + self.repeat_ngram_blocker = None + + assert temperature > 0, "--temperature must be greater than 0" + + self.search = ( + search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy + ) + # We only need to set src_lengths in LengthConstrainedBeamSearch. + # As a module attribute, setting it would break in multithread + # settings when the model is shared. + self.should_set_src_lengths = ( + hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths + ) + + self.model.eval() + + self.lm_model = lm_model + self.lm_weight = lm_weight + if self.lm_model is not None: + self.lm_model.eval() + + def cuda(self): + self.model.cuda() + return self + + @torch.no_grad() + def forward( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + """Generate a batch of translations. + + Args: + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, prefix_tokens, bos_token=bos_token) + + # TODO(myleott): unused, deprecate after pytorch-translate migration + def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): + """Iterate over a batched dataset and yield individual translations. + Args: + cuda (bool, optional): use GPU for generation + timer (StopwatchMeter, optional): time generations + """ + for sample in data_itr: + s = utils.move_to_cuda(sample) if cuda else sample + if "net_input" not in s: + continue + input = s["net_input"] + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in input.items() if k != "prev_output_tokens" + } + if timer is not None: + timer.start() + with torch.no_grad(): + hypos = self.generate(encoder_input) + if timer is not None: + timer.stop(sum(len(h[0]["tokens"]) for h in hypos)) + for i, id in enumerate(s["id"].data): + # remove padding + src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad) + ref = ( + utils.strip_pad(s["target"].data[i, :], self.pad) + if s["target"] is not None + else None + ) + yield id, src, ref, hypos[i] + + @torch.no_grad() + def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]: + """Generate translations. Match the api of other fairseq generators. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + constraints (torch.LongTensor, optional): force decoder to include + the list of constraints + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, **kwargs) + + def _generate( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(self.model.models_size) + ], + ) + net_input = sample["net_input"] + + if "src_tokens" in net_input: + src_tokens = net_input["src_tokens"] + # length of the source text being the character length except EndOfSentence and pad + src_lengths = ( + (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) + ) + elif "source" in net_input: + src_tokens = net_input["source"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + elif "features" in net_input: + src_tokens = net_input["features"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + else: + raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys())) + + # bsz: total number of sentences in beam + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) + if src_tokens['audio'] is not None: + bsz, src_len = src_tokens['audio'].size()[:2] + src_device = src_tokens['audio'].device + else: + bsz, src_len = net_input['padding_mask'].size() + src_device = src_tokens['video'].device + beam_size = self.beam_size + if constraints is not None and not self.search.supports_constraints: + raise NotImplementedError( + "Target-side constraints were provided, but search method doesn't support them" + ) + + # Initialize constraints, when active + self.search.init_constraints(constraints, beam_size) + + max_len: int = -1 + if self.match_source_len: + max_len = src_lengths.max().item() + else: + max_len = min( + int(self.max_len_a * src_len + self.max_len_b), + self.max_len - 1, + ) + assert ( + self.min_len <= max_len + ), "min_len cannot be larger than max_len, please adjust these!" + # compute the encoder output for each beam + encoder_outs = self.model.forward_encoder(net_input) + + # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores + new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) + new_order = new_order.to(src_device).long() + encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) + # ensure encoder_outs is a List. + assert encoder_outs is not None + + # initialize buffers + scores = ( + torch.zeros(bsz * beam_size, max_len + 1).to(src_device).float() + ) # +1 for eos; pad is never chosen for scoring + tokens = ( + torch.zeros(bsz * beam_size, max_len + 2) + .to(src_device) + .long() + .fill_(self.pad) + ) # +2 for eos and pad + tokens[:, 0] = self.eos if bos_token is None else bos_token + attn: Optional[Tensor] = None + + # A list that indicates candidates that should be ignored. + # For example, suppose we're sampling and have already finalized 2/5 + # samples. Then cands_to_ignore would mark 2 positions as being ignored, + # so that we only finalize the remaining 3 samples. + cands_to_ignore = ( + torch.zeros(bsz, beam_size).to(src_device).eq(-1) + ) # forward and backward-compatible False mask + + # list of completed sentences + finalized = torch.jit.annotate( + List[List[Dict[str, Tensor]]], + [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], + ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step + + # a boolean array indicating if the sentence at the index is finished or not + finished = [False for i in range(bsz)] + num_remaining_sent = bsz # number of sentences remaining + + # number of candidate hypos per step + cand_size = 2 * beam_size # 2 x beam size in case half are EOS + + # offset arrays for converting between different indexing schemes + bbsz_offsets = ( + (torch.arange(0, bsz) * beam_size) + .unsqueeze(1) + .type_as(tokens) + .to(src_device) + ) + cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_device) + + reorder_state: Optional[Tensor] = None + batch_idxs: Optional[Tensor] = None + + original_batch_idxs: Optional[Tensor] = None + if "id" in sample and isinstance(sample["id"], Tensor): + original_batch_idxs = sample["id"] + else: + original_batch_idxs = torch.arange(0, bsz).type_as(tokens) + + for step in range(max_len + 1): # one extra step for EOS marker + # reorder decoder internal states based on the prev choice of beams + if reorder_state is not None: + if batch_idxs is not None: + # update beam indices to take into account removed sentences + corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( + batch_idxs + ) + reorder_state.view(-1, beam_size).add_( + corr.unsqueeze(-1) * beam_size + ) + original_batch_idxs = original_batch_idxs[batch_idxs] + self.model.reorder_incremental_state(incremental_states, reorder_state) + encoder_outs = self.model.reorder_encoder_out( + encoder_outs, reorder_state + ) + + lprobs, avg_attn_scores = self.model.forward_decoder( + tokens[:, : step + 1], + encoder_outs, + incremental_states, + self.temperature, + ) + + if self.lm_model is not None: + lm_out = self.lm_model(tokens[:, : step + 1]) + probs = self.lm_model.get_normalized_probs( + lm_out, log_probs=True, sample=None + ) + probs = probs[:, -1, :] * self.lm_weight + lprobs += probs + + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) + + lprobs[:, self.pad] = -math.inf # never select pad + lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + + # handle max length constraint + if step >= max_len: + lprobs[:, : self.eos] = -math.inf + lprobs[:, self.eos + 1 :] = -math.inf + + # handle prefix tokens (possibly with different lengths) + if ( + prefix_tokens is not None + and step < prefix_tokens.size(1) + and step < max_len + ): + lprobs, tokens, scores = self._prefix_tokens( + step, lprobs, scores, tokens, prefix_tokens, beam_size + ) + elif step < self.min_len: + # minimum length constraint (does not apply if using prefix_tokens) + lprobs[:, self.eos] = -math.inf + + # Record attention scores, only support avg_attn_scores is a Tensor + if avg_attn_scores is not None: + if attn is None: + attn = torch.empty( + bsz * beam_size, avg_attn_scores.size(1), max_len + 2 + ).to(scores) + attn[:, :, step + 1].copy_(avg_attn_scores) + + scores = scores.type_as(lprobs) + eos_bbsz_idx = torch.empty(0).to( + tokens + ) # indices of hypothesis ending with eos (finished sentences) + eos_scores = torch.empty(0).to( + scores + ) # scores of hypothesis ending with eos (finished sentences) + + if self.should_set_src_lengths: + self.search.set_src_lengths(src_lengths) + + if self.repeat_ngram_blocker is not None: + lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) + + # Shape: (batch, cand_size) + cand_scores, cand_indices, cand_beams = self.search.step( + step, + lprobs.view(bsz, -1, self.vocab_size), + scores.view(bsz, beam_size, -1)[:, :, :step], + tokens[:, : step + 1], + original_batch_idxs, + ) + + # cand_bbsz_idx contains beam indices for the top candidate + # hypotheses, with a range of values: [0, bsz*beam_size), + # and dimensions: [bsz, cand_size] + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + # finalize hypotheses that end in eos + # Shape of eos_mask: (batch size, beam size) + eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) + eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) + + # only consider eos when it's among the top beam_size indices + # Now we know what beam item(s) to finish + # Shape: 1d list of absolute-numbered + eos_bbsz_idx = torch.masked_select( + cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents: List[int] = [] + if eos_bbsz_idx.numel() > 0: + eos_scores = torch.masked_select( + cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents = self.finalize_hypos( + step, + eos_bbsz_idx, + eos_scores, + tokens, + scores, + finalized, + finished, + beam_size, + attn, + src_lengths, + max_len, + ) + num_remaining_sent -= len(finalized_sents) + + assert num_remaining_sent >= 0 + if num_remaining_sent == 0: + break + if self.search.stop_on_max_len and step >= max_len: + break + assert step < max_len, f"{step} < {max_len}" + + # Remove finalized sentences (ones for which {beam_size} + # finished hypotheses have been generated) from the batch. + if len(finalized_sents) > 0: + new_bsz = bsz - len(finalized_sents) + + # construct batch_idxs which holds indices of batches to keep for the next pass + batch_mask = torch.ones( + bsz, dtype=torch.bool, device=cand_indices.device + ) + batch_mask[finalized_sents] = False + # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it + batch_idxs = torch.arange( + bsz, device=cand_indices.device + ).masked_select(batch_mask) + + # Choose the subset of the hypothesized constraints that will continue + self.search.prune_sentences(batch_idxs) + + eos_mask = eos_mask[batch_idxs] + cand_beams = cand_beams[batch_idxs] + bbsz_offsets.resize_(new_bsz, 1) + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + cand_scores = cand_scores[batch_idxs] + cand_indices = cand_indices[batch_idxs] + + if prefix_tokens is not None: + prefix_tokens = prefix_tokens[batch_idxs] + src_lengths = src_lengths[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] + + scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + if attn is not None: + attn = attn.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, attn.size(1), -1 + ) + bsz = new_bsz + else: + batch_idxs = None + + # Set active_mask so that values > cand_size indicate eos hypos + # and values < cand_size indicate candidate active hypos. + # After, the min values per row are the top candidate active hypos + + # Rewrite the operator since the element wise or is not supported in torchscript. + + eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) + active_mask = torch.add( + eos_mask.type_as(cand_offsets) * cand_size, + cand_offsets[: eos_mask.size(1)], + ) + + # get the top beam_size active hypotheses, which are just + # the hypos with the smallest values in active_mask. + # {active_hypos} indicates which {beam_size} hypotheses + # from the list of {2 * beam_size} candidates were + # selected. Shapes: (batch size, beam size) + new_cands_to_ignore, active_hypos = torch.topk( + active_mask, k=beam_size, dim=1, largest=False + ) + + # update cands_to_ignore to ignore any finalized hypos. + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + # Make sure there is at least one active item for each sentence in the batch. + assert (~cands_to_ignore).any(dim=1).all() + + # update cands_to_ignore to ignore any finalized hypos + + # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam + # can be selected more than once). + active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) + active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) + + active_bbsz_idx = active_bbsz_idx.view(-1) + active_scores = active_scores.view(-1) + + # copy tokens and scores for active hypotheses + + # Set the tokens for each beam (can select the same row more than once) + tokens[:, : step + 1] = torch.index_select( + tokens[:, : step + 1], dim=0, index=active_bbsz_idx + ) + # Select the next token for each of them + tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( + cand_indices, dim=1, index=active_hypos + ) + if step > 0: + scores[:, :step] = torch.index_select( + scores[:, :step], dim=0, index=active_bbsz_idx + ) + scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( + cand_scores, dim=1, index=active_hypos + ) + + # Update constraints based on which candidates were selected for the next beam + self.search.update_constraints(active_hypos) + + # copy attention for active hypotheses + if attn is not None: + attn[:, :, : step + 2] = torch.index_select( + attn[:, :, : step + 2], dim=0, index=active_bbsz_idx + ) + + # reorder incremental state in decoder + reorder_state = active_bbsz_idx + + # sort by score descending + for sent in range(len(finalized)): + scores = torch.tensor( + [float(elem["score"].item()) for elem in finalized[sent]] + ) + _, sorted_scores_indices = torch.sort(scores, descending=True) + finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] + finalized[sent] = torch.jit.annotate( + List[Dict[str, Tensor]], finalized[sent] + ) + return finalized + + def _prefix_tokens( + self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int + ): + """Handle prefix tokens""" + prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) + prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + prefix_mask = prefix_toks.ne(self.pad) + lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) + lprobs[prefix_mask] = lprobs[prefix_mask].scatter( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] + ) + # if prefix includes eos, then we should make sure tokens and + # scores are the same across all beams + eos_mask = prefix_toks.eq(self.eos) + if eos_mask.any(): + # validate that the first beam matches the prefix + first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ + :, 0, 1 : step + 1 + ] + eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] + target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] + assert (first_beam == target_prefix).all() + + # copy tokens, scores and lprobs from the first beam to all beams + tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) + scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) + lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) + return lprobs, tokens, scores + + def replicate_first_beam(self, tensor, mask, beam_size: int): + tensor = tensor.view(-1, beam_size, tensor.size(-1)) + tensor[mask] = tensor[mask][:, :1, :] + return tensor.view(-1, tensor.size(-1)) + + def finalize_hypos( + self, + step: int, + bbsz_idx, + eos_scores, + tokens, + scores, + finalized: List[List[Dict[str, Tensor]]], + finished: List[bool], + beam_size: int, + attn: Optional[Tensor], + src_lengths, + max_len: int, + ): + """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. + A sentence is finalized when {beam_size} finished items have been collected for it. + + Returns number of sentences (not beam items) being finalized. + These will be removed from the batch and not processed further. + Args: + bbsz_idx (Tensor): + """ + assert bbsz_idx.numel() == eos_scores.numel() + + # clone relevant token and attention tensors. + # tokens is (batch * beam, max_len). So the index_select + # gets the newly EOS rows, then selects cols 1..{step + 2} + tokens_clone = tokens.index_select(0, bbsz_idx)[ + :, 1 : step + 2 + ] # skip the first index, which is EOS + + tokens_clone[:, step] = self.eos + attn_clone = ( + attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] + if attn is not None + else None + ) + + # compute scores per token position + pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] + pos_scores[:, step] = eos_scores + # convert from cumulative to per-position scores + pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] + + # normalize sentence-level scores + if self.normalize_scores: + eos_scores /= (step + 1) ** self.len_penalty + + # cum_unfin records which sentences in the batch are finished. + # It helps match indexing between (a) the original sentences + # in the batch and (b) the current, possibly-reduced set of + # sentences. + cum_unfin: List[int] = [] + prev = 0 + for f in finished: + if f: + prev += 1 + else: + cum_unfin.append(prev) + + # The keys here are of the form "{sent}_{unfin_idx}", where + # "unfin_idx" is the index in the current (possibly reduced) + # list of sentences, and "sent" is the index in the original, + # unreduced batch + # set() is not supported in script export + sents_seen: Dict[str, Optional[Tensor]] = {} + + # For every finished beam item + for i in range(bbsz_idx.size()[0]): + idx = bbsz_idx[i] + score = eos_scores[i] + # sentence index in the current (possibly reduced) batch + unfin_idx = idx // beam_size + # sentence index in the original (unreduced) batch + sent = unfin_idx + cum_unfin[unfin_idx] + # Cannot create dict for key type '(int, int)' in torchscript. + # The workaround is to cast int to string + seen = str(sent.item()) + "_" + str(unfin_idx.item()) + if seen not in sents_seen: + sents_seen[seen] = None + + if self.match_source_len and step > src_lengths[unfin_idx]: + score = torch.tensor(-math.inf).to(score) + + # An input sentence (among those in a batch) is finished when + # beam_size hypotheses have been collected for it + if len(finalized[sent]) < beam_size: + if attn_clone is not None: + # remove padding tokens from attn scores + hypo_attn = attn_clone[i] + else: + hypo_attn = torch.empty(0) + + finalized[sent].append( + { + "tokens": tokens_clone[i], + "score": score, + "attention": hypo_attn, # src_len x tgt_len + "alignment": torch.empty(0), + "positional_scores": pos_scores[i], + } + ) + + newly_finished: List[int] = [] + + for seen in sents_seen.keys(): + # check termination conditions for this sentence + sent: int = int(float(seen.split("_")[0])) + unfin_idx: int = int(float(seen.split("_")[1])) + + if not finished[sent] and self.is_finished( + step, unfin_idx, max_len, len(finalized[sent]), beam_size + ): + finished[sent] = True + newly_finished.append(unfin_idx) + + return newly_finished + + def is_finished( + self, + step: int, + unfin_idx: int, + max_len: int, + finalized_sent_len: int, + beam_size: int, + ): + """ + Check whether decoding for a sentence is finished, which + occurs when the list of finalized sentences has reached the + beam size, or when we reach the maximum length. + """ + assert finalized_sent_len <= beam_size + if finalized_sent_len == beam_size or step == max_len: + return True + return False + + +class EnsembleModel(nn.Module): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__() + self.models_size = len(models) + # method '__len__' is not supported in ModuleList for torch script + self.single_model = models[0] + self.models = nn.ModuleList(models) + + self.has_incremental: bool = False + if all( + hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) + for m in models + ): + self.has_incremental = True + + def forward(self): + pass + + def has_encoder(self): + return hasattr(self.single_model, "encoder") + + def has_incremental_states(self): + return self.has_incremental + + def max_decoder_positions(self): + return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize]) + + @torch.jit.export + def forward_encoder(self, net_input: Dict[str, Tensor]): + if not self.has_encoder(): + return None + return [model.encoder.forward_torchscript(net_input) for model in self.models] + + @torch.jit.export + def forward_decoder( + self, + tokens, + encoder_outs: List[Dict[str, List[Tensor]]], + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + temperature: float = 1.0, + ): + log_probs = [] + avg_attn: Optional[Tensor] = None + encoder_out: Optional[Dict[str, List[Tensor]]] = None + for i, model in enumerate(self.models): + if self.has_encoder(): + encoder_out = encoder_outs[i] + # decode each model + if self.has_incremental_states(): + decoder_out = model.decoder.forward( + tokens, + encoder_out=encoder_out, + incremental_state=incremental_states[i], + ) + else: + if hasattr(model, "decoder"): + decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) + else: + decoder_out = model.forward(tokens) + + attn: Optional[Tensor] = None + decoder_len = len(decoder_out) + if decoder_len > 1 and decoder_out[1] is not None: + if isinstance(decoder_out[1], Tensor): + attn = decoder_out[1] + else: + attn_holder = decoder_out[1]["attn"] + if isinstance(attn_holder, Tensor): + attn = attn_holder + elif attn_holder is not None: + attn = attn_holder[0] + if attn is not None: + attn = attn[:, -1, :] + + decoder_out_tuple = ( + decoder_out[0][:, -1:, :].div_(temperature), + None if decoder_len <= 1 else decoder_out[1], + ) + probs = model.get_normalized_probs( + decoder_out_tuple, log_probs=True, sample=None + ) + probs = probs[:, -1, :] + if self.models_size == 1: + return probs, attn + + log_probs.append(probs) + if attn is not None: + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + + avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log( + self.models_size + ) + + if avg_attn is not None: + avg_attn.div_(self.models_size) + return avg_probs, avg_attn + + @torch.jit.export + def reorder_encoder_out( + self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order + ): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_outs: List[Dict[str, List[Tensor]]] = [] + if not self.has_encoder(): + return new_outs + for i, model in enumerate(self.models): + assert encoder_outs is not None + new_outs.append( + model.encoder.reorder_encoder_out(encoder_outs[i], new_order) + ) + return new_outs + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + new_order, + ): + if not self.has_incremental_states(): + return + for i, model in enumerate(self.models): + model.decoder.reorder_incremental_state_scripting( + incremental_states[i], new_order + ) + + +class SequenceGeneratorWithAlignment(SequenceGenerator): + def __init__( + self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs + ): + """Generates translations of a given source sentence. + + Produces alignments following "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + left_pad_target (bool, optional): Whether or not the + hypothesis should be left padded or not when they are + teacher forced for generating alignments. + """ + super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) + self.left_pad_target = left_pad_target + + if print_alignment == "hard": + self.extract_alignment = utils.extract_hard_alignment + elif print_alignment == "soft": + self.extract_alignment = utils.extract_soft_alignment + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + finalized = super()._generate(sample, **kwargs) + + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.shape[0] + beam_size = self.beam_size + ( + src_tokens, + src_lengths, + prev_output_tokens, + tgt_tokens, + ) = self._prepare_batch_for_alignment(sample, finalized) + if any(getattr(m, "full_context_alignment", False) for m in self.model.models): + attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) + else: + attn = [ + finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0) + for i in range(bsz * beam_size) + ] + + if src_tokens.device != "cpu": + src_tokens = src_tokens.to("cpu") + tgt_tokens = tgt_tokens.to("cpu") + attn = [i.to("cpu") for i in attn] + + # Process the attn matrix to extract hard alignments. + for i in range(bsz * beam_size): + alignment = self.extract_alignment( + attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos + ) + finalized[i // beam_size][i % beam_size]["alignment"] = alignment + return finalized + + def _prepare_batch_for_alignment(self, sample, hypothesis): + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.shape[0] + src_tokens = ( + src_tokens[:, None, :] + .expand(-1, self.beam_size, -1) + .contiguous() + .view(bsz * self.beam_size, -1) + ) + src_lengths = sample["net_input"]["src_lengths"] + src_lengths = ( + src_lengths[:, None] + .expand(-1, self.beam_size) + .contiguous() + .view(bsz * self.beam_size) + ) + prev_output_tokens = data_utils.collate_tokens( + [beam["tokens"] for example in hypothesis for beam in example], + self.pad, + self.eos, + self.left_pad_target, + move_eos_to_beginning=True, + ) + tgt_tokens = data_utils.collate_tokens( + [beam["tokens"] for example in hypothesis for beam in example], + self.pad, + self.eos, + self.left_pad_target, + move_eos_to_beginning=False, + ) + return src_tokens, src_lengths, prev_output_tokens, tgt_tokens + + +class EnsembleModelWithAlignment(EnsembleModel): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__(models) + + def forward_align(self, src_tokens, src_lengths, prev_output_tokens): + avg_attn = None + for model in self.models: + decoder_out = model(src_tokens, src_lengths, prev_output_tokens) + attn = decoder_out[1]["attn"][0] + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + if len(self.models) > 1: + avg_attn.div_(len(self.models)) + return avg_attn diff --git a/slam_llm/models/avhubert/utils.py b/slam_llm/models/avhubert/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c9ffbc397118ba786af646238dac1550f7ec5ca --- /dev/null +++ b/slam_llm/models/avhubert/utils.py @@ -0,0 +1,298 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 +import torch +import random +import numpy as np +from typing import Dict, List, Optional, Tuple + +def load_video(path): + for i in range(3): + try: + cap = cv2.VideoCapture(path) + frames = [] + while True: + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + frames.append(frame) + else: + break + frames = np.stack(frames) + return frames + except Exception: + print(f"failed loading {path} ({i} / 3)") + if i == 2: + raise ValueError(f"Unable to load {path}") + + +class Compose(object): + """Compose several preprocess together. + Args: + preprocess (list of ``Preprocess`` objects): list of preprocess to compose. + """ + + def __init__(self, preprocess): + self.preprocess = preprocess + + def __call__(self, sample): + for t in self.preprocess: + sample = t(sample) + return sample + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.preprocess: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class Normalize(object): + """Normalize a ndarray image with mean and standard deviation. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, frames): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor image. + """ + frames = (frames - self.mean) / self.std + return frames + + def __repr__(self): + return self.__class__.__name__+'(mean={0}, std={1})'.format(self.mean, self.std) + +class CenterCrop(object): + """Crop the given image at the center + """ + def __init__(self, size): + self.size = size + + def __call__(self, frames): + """ + Args: + img (numpy.ndarray): Images to be cropped. + Returns: + numpy.ndarray: Cropped image. + """ + t, h, w = frames.shape + th, tw = self.size + delta_w = int(round((w - tw))/2.) + delta_h = int(round((h - th))/2.) + frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw] + return frames + + +class RandomCrop(object): + """Crop the given image at the center + """ + + def __init__(self, size): + self.size = size + + def __call__(self, frames): + """ + Args: + img (numpy.ndarray): Images to be cropped. + Returns: + numpy.ndarray: Cropped image. + """ + t, h, w = frames.shape + th, tw = self.size + delta_w = random.randint(0, w-tw) + delta_h = random.randint(0, h-th) + frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw] + return frames + + def __repr__(self): + return self.__class__.__name__ + '(size={0})'.format(self.size) + +class HorizontalFlip(object): + """Flip image horizontally. + """ + + def __init__(self, flip_ratio): + self.flip_ratio = flip_ratio + + def __call__(self, frames): + """ + Args: + img (numpy.ndarray): Images to be flipped with a probability flip_ratio + Returns: + numpy.ndarray: Cropped image. + """ + t, h, w = frames.shape + if random.random() < self.flip_ratio: + for index in range(t): + frames[index] = cv2.flip(frames[index], 1) + return frames + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + batch_indexes, starts, ends = [], [], [] + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + vals, run_starts, run_lengths = find_runs(mask[i]) + start_indices, lengths = run_starts[vals == True], run_lengths[vals == True] + starts.append(start_indices) + ends.append(start_indices+lengths) + batch_indexes.append(np.zeros([len(start_indices)])+i) + return mask, np.concatenate(starts).astype(np.int64), np.concatenate(ends).astype(np.int64), np.concatenate(batch_indexes).astype(np.int64) + +def find_runs(x): + """Find runs of consecutive items in an array.""" + + # ensure array + x = np.asanyarray(x) + if x.ndim != 1: + raise ValueError('only 1D array supported') + n = x.shape[0] + + # handle empty array + if n == 0: + return np.array([]), np.array([]), np.array([]) + + else: + # find run starts + loc_run_start = np.empty(n, dtype=bool) + loc_run_start[0] = True + np.not_equal(x[:-1], x[1:], out=loc_run_start[1:]) + run_starts = np.nonzero(loc_run_start)[0] + + # find run values + run_values = x[loc_run_start] + + # find run lengths + run_lengths = np.diff(np.append(run_starts, n)) + + return run_values, run_starts, run_lengths diff --git a/slam_llm/models/encoder.py b/slam_llm/models/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6810b87446edb8d96a680d273f09e7afb1e66536 --- /dev/null +++ b/slam_llm/models/encoder.py @@ -0,0 +1,158 @@ +import types +import torch +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass + +class WhisperWrappedEncoder: + + @classmethod + def load(cls, model_config): + + def extract_variable_length_features(self, x: torch.Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" + # x = (x + self.positional_embedding).to(x.dtype) + x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + import whisper + encoder = whisper.load_model(name=model_config.encoder_path, device='cpu').encoder + encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder) + return encoder + + +class BEATsEncoder: + + @classmethod + def load(cls, model_config): + from .BEATs.BEATs import BEATs, BEATsConfig + checkpoint = torch.load(model_config.encoder_path) + cfg = BEATsConfig(checkpoint['cfg']) + BEATs_model = BEATs(cfg) + BEATs_model.load_state_dict(checkpoint['model']) + + return BEATs_model + + +@dataclass +class UserDirModule: + user_dir: str + +class EATEncoder: + + @classmethod + def load(cls, model_config): + import fairseq + model_path = UserDirModule(model_config.encoder_fairseq_dir) + fairseq.utils.import_user_module(model_path) + EATEncoder, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) + EATEncoder = EATEncoder[0] + + return EATEncoder + + def extract_features(self, source, padding_mask): + return self.model.extract_features(source, padding_mask = padding_mask, mask=False, remove_extra_tokens = False)['x'] + +class SpatialASTEncoder: + @classmethod + def load(cls, model_config): + from functools import partial + from .SpatialAST import SpatialAST + binaural_encoder = SpatialAST.BinauralEncoder( + num_classes=355, drop_path_rate=0.1, num_cls_tokens=3, + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) + ) + + checkpoint = torch.load(model_config.encoder_ckpt, map_location='cpu') + binaural_encoder.load_state_dict(checkpoint['model'], strict=False) + return binaural_encoder + +class WavLMEncoder(nn.Module): + def __init__(self, config, model): + super().__init__() + self.config = config + self.model = model + + @classmethod + def load(cls, model_config): + from .wavlm.WavLM import WavLM, WavLMConfig + checkpoint = torch.load(model_config.encoder_path) + cfg = WavLMConfig(checkpoint['cfg']) + WavLM_model = WavLM(cfg) + WavLM_model.load_state_dict(checkpoint['model']) + assert model_config.normalize == cfg.normalize, "normalize flag in config and model checkpoint do not match" + + return cls(cfg, WavLM_model) + + def extract_features(self, source, padding_mask): + return self.model.extract_features(source, padding_mask)[0] + +class AVHubertEncoder: + + @classmethod + def load(cls, model_config): + import fairseq + from .avhubert import hubert_pretraining, hubert, hubert_asr + models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) + model = models[0] + return model + +class HubertEncoder: + + @classmethod + def load(cls, model_config): + import fairseq + models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) + model = models[0] + if model_config.encoder_type == "pretrain": + pass + elif model_config.encoder_type == "finetune": + model.w2v_encoder.proj = None + model.w2v_encoder.apply_mask = False + else: + assert model_config.encoder_type in ["pretrain", "finetune"], "input_type must be one of [pretrain, finetune]" + return model + + +class HfTextEncoder: + + @classmethod + def load(cls, model_config): + from transformers import AutoModel + model = AutoModel.from_pretrained(model_config.encoder_path) + return model + +class MusicFMEncoder(nn.Module): + def __init__(self, config, model): + super().__init__() + self.config = config + self.model = model + + @classmethod + def load(cls, model_config): + from .musicfm.model.musicfm_25hz import MusicFM25Hz + model = MusicFM25Hz( + stat_path = model_config.encoder_stat_path, + model_path = model_config.encoder_path, + w2v2_config_path = model_config.get('encoder_config_path', "facebook/wav2vec2-conformer-rope-large-960h-ft") + ) + return cls(model_config, model) + + def extract_features(self, source, padding_mask=None): + _, hidden_states = self.model.get_predictions(source) + out = hidden_states[self.config.encoder_layer_idx] + return out diff --git a/slam_llm/models/musicfm/model/__init__.py b/slam_llm/models/musicfm/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..99a8091366e2ac7f301452706d5cfdff5e320a0d --- /dev/null +++ b/slam_llm/models/musicfm/model/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/slam_llm/models/musicfm/model/musicfm_25hz.py b/slam_llm/models/musicfm/model/musicfm_25hz.py new file mode 100644 index 0000000000000000000000000000000000000000..9acdca859ba3822f82664a74adea7dc0f471ac75 --- /dev/null +++ b/slam_llm/models/musicfm/model/musicfm_25hz.py @@ -0,0 +1,253 @@ +# MIT License +# +# Copyright 2023 ByteDance Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), +# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. + +import json +import random +import torch +from torch import nn +from einops import rearrange + +from ..modules.random_quantizer import RandomProjectionQuantizer +from ..modules.features import MelSTFT +from ..modules.conv import Conv2dSubsampling + + +class MusicFM25Hz(nn.Module): + """ + MusicFM + + Input: 128-band mel spectrogram + Frontend: 2-layer Residual convolution + Backend: 12-layer Conformer + Quantizer: a codebook for mel spectrogram + """ + + def __init__( + self, + num_codebooks=1, + codebook_dim=16, + codebook_size=4096, + features=["melspec_2048"], + hop_length=240, + n_mels=128, + conv_dim=512, + encoder_dim=1024, + encoder_depth=12, + mask_hop=0.4, + mask_prob=0.6, + is_flash=False, + stat_path="./data/fma_stats.json", + model_path="./data/pretrained_fma.pt", + w2v2_config_path="facebook/wav2vec2-conformer-rope-large-960h-ft", + ): + super(MusicFM25Hz, self).__init__() + + # global variables + self.hop_length = hop_length + self.mask_hop = mask_hop + self.mask_prob = mask_prob + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.features = features + + # load feature mean / std stats + with open(stat_path, "r") as f: + self.stat = json.load(f) + + # feature extractor + self.preprocessor_melspec_2048 = MelSTFT( + n_fft=2048, hop_length=hop_length, is_db=True + ) + + # random quantizer + seed = 142 + for feature in self.features: + for i in range(num_codebooks): + setattr( + self, + f"quantizer_{feature}_{i}", + RandomProjectionQuantizer( + n_mels * 4, codebook_dim, codebook_size, seed=seed + i + ), + ) + + # two residual convolution layers + one projection layer + self.conv = Conv2dSubsampling( + 1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels + ) + + # Conformer + if is_flash: + from modules.flash_conformer import ( + Wav2Vec2ConformerEncoder, + Wav2Vec2ConformerConfig, + ) + else: + from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( + Wav2Vec2ConformerEncoder, + Wav2Vec2ConformerConfig, + ) + config = Wav2Vec2ConformerConfig.from_pretrained( + w2v2_config_path + ) + config.num_hidden_layers = encoder_depth + config.hidden_size = encoder_dim + + self.conformer = Wav2Vec2ConformerEncoder(config) + + # projection + self.linear = nn.Linear(encoder_dim, codebook_size) + + # loss function + self.loss = nn.CrossEntropyLoss() + + # cls token (used for sequence classification) + random.seed(seed) + self.cls_token = nn.Parameter(torch.randn(encoder_dim)) + + # load model + if model_path: + S = torch.load(model_path)["state_dict"] + SS = {k[6:]: v for k, v in S.items()} + self.load_state_dict(SS, strict=True) + + def masking(self, x): + """random masking of 400ms with given probability""" + mx = x.clone() + b, t = mx.shape + len_masking_raw = int(24000 * self.mask_hop) + len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop) + + # get random mask indices + start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob + time_domain_masked_indices = torch.nonzero( + start_indices.repeat_interleave(len_masking_raw, dim=1) + ) + token_domain_masked_indices = torch.nonzero( + start_indices.repeat_interleave(len_masking_token, dim=1) + ) + + # mask with random values + masking_noise = ( + torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1 + ) # 0 mean 0.1 std + mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device) + + return mx, token_domain_masked_indices + + @torch.no_grad() + def preprocessing(self, x, features): + """extract classic audio features""" + # check precision + if x.dtype == torch.float16: + precision = 16 + else: + precision = 32 + + out = {} + for key in features: + layer = getattr(self, "preprocessor_%s" % key) + out[key] = layer.float()(x.float())[..., :-1] + if precision == 16: + out[key] = out[key].half() + return out + + def encoder(self, x): + """2-layer conv + w2v-conformer""" + x = self.conv(x) + out = self.conformer(x, output_hidden_states=True) + hidden_emb = out["hidden_states"] + last_emb = out["last_hidden_state"] + logits = self.linear(last_emb) + logits = { + key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size] + for i, key in enumerate(self.features) + } + return logits, hidden_emb + + @torch.no_grad() + def normalize(self, x): + """normalize the input audio to have zero mean unit variance""" + for key in x.keys(): + x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] + return x + + @torch.no_grad() + def rearrange(self, x): + """rearrange the batch to flatten every 4 steps""" + for key in x.keys(): + if key == "chromagram": + x[key] = rearrange(x[key], "b f t -> b t f") + else: + x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4) + return x + + @torch.no_grad() + def tokenize(self, x): + out = {} + for key in x.keys(): + layer = getattr(self, "quantizer_%s" % key) + out[key] = layer(x[key]) + return out + + def get_targets(self, x): + x = self.preprocessing(x, features=self.features) + x = self.normalize(x) + x = self.rearrange(x) + target_tokens = self.tokenize(x) + return target_tokens + + def get_predictions(self, x): + # preprocessing + x = self.preprocessing(x, features=["melspec_2048"]) + x = self.normalize(x) + + # encoding + logits, hidden_emb = self.encoder(x["melspec_2048"]) + + return logits, hidden_emb + + def get_latent(self, x, layer_ix=12): + _, hidden_states = self.get_predictions(x) + emb = hidden_states[layer_ix] + return emb + + def get_loss(self, logits, target_tokens, masked_indices): + losses = {} + accuracies = {} + for key in logits.keys(): + masked_logits = logits[key][tuple(masked_indices.t())] + masked_tokens = target_tokens[key][tuple(masked_indices.t())] + losses[key] = self.loss(masked_logits, masked_tokens) + accuracies[key] = ( + torch.sum(masked_logits.argmax(-1) == masked_tokens) + / masked_tokens.numel() + ) + return losses, accuracies + + def forward(self, x): + # get target feature tokens + target_tokens = self.get_targets(x) + + # masking + x, masked_indices = self.masking(x) + + # forward + logits, hidden_emb = self.get_predictions(x) + + # get loss + losses, accuracies = self.get_loss(logits, target_tokens, masked_indices) + + return logits, hidden_emb, losses, accuracies diff --git a/slam_llm/models/musicfm/modules/__init__.py b/slam_llm/models/musicfm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..99a8091366e2ac7f301452706d5cfdff5e320a0d --- /dev/null +++ b/slam_llm/models/musicfm/modules/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/slam_llm/models/musicfm/modules/conv.py b/slam_llm/models/musicfm/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3a83cfe0c3dfb2aaba46483615d3f3c75d7565 --- /dev/null +++ b/slam_llm/models/musicfm/modules/conv.py @@ -0,0 +1,82 @@ +# MIT License +# +# Copyright 2023 ByteDance Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), +# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. + +from torch import nn +from einops import rearrange + + +class Res2dModule(nn.Module): + def __init__(self, idim, odim, stride=(2, 2)): + super(Res2dModule, self).__init__() + self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride) + self.bn1 = nn.BatchNorm2d(odim) + self.conv2 = nn.Conv2d(odim, odim, 3, padding=1) + self.bn2 = nn.BatchNorm2d(odim) + self.relu = nn.ReLU() + + # residual + self.diff = False + if (idim != odim) or (stride[0] > 1): + self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride) + self.bn3 = nn.BatchNorm2d(odim) + self.diff = True + + def forward(self, x): + out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))) + if self.diff: + x = self.bn3(self.conv3(x)) + out = x + out + out = self.relu(out) + return out + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + hdim (int): Hidden dimension. + odim (int): Output dimension. + strides (list): Sizes of strides. + n_bands (int): Number of frequency bands. + """ + + def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling, self).__init__() + + self.conv = nn.Sequential( + Res2dModule(idim, hdim, (2, strides[0])), + Res2dModule(hdim, hdim, (2, strides[1])), + ) + self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim) + + def forward(self, x): + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, idim, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + """ + + if x.dim() == 3: + x = x.unsqueeze(1) # (b, c, f, t) + x = self.conv(x) + x = rearrange(x, "b c f t -> b t (c f)") + x = self.linear(x) + return x diff --git a/slam_llm/models/musicfm/modules/features.py b/slam_llm/models/musicfm/modules/features.py new file mode 100644 index 0000000000000000000000000000000000000000..d29f859d32d628497751d02419b3e5f749d2030a --- /dev/null +++ b/slam_llm/models/musicfm/modules/features.py @@ -0,0 +1,45 @@ +# MIT License +# +# Copyright 2023 ByteDance Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), +# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. + +import torchaudio +from torch import nn + + +class MelSTFT(nn.Module): + def __init__( + self, + sample_rate=24000, + n_fft=2048, + hop_length=240, + n_mels=128, + is_db=False, + ): + super(MelSTFT, self).__init__() + + # spectrogram + self.mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels + ) + + # amplitude to decibel + self.is_db = is_db + if is_db: + self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() + + def forward(self, waveform): + if self.is_db: + return self.amplitude_to_db(self.mel_stft(waveform)) + else: + return self.mel_stft(waveform) diff --git a/slam_llm/models/musicfm/modules/flash_conformer.py b/slam_llm/models/musicfm/modules/flash_conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..689e2c1c1c705a69ed987ed7507e02fc58e36341 --- /dev/null +++ b/slam_llm/models/musicfm/modules/flash_conformer.py @@ -0,0 +1,2114 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Wav2Vec2-Conformer model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.nn import functional as F + +from transformers.activations import ACT2FN +from transformers.deepspeed import is_deepspeed_zero3_enabled +from transformers.modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, + Wav2Vec2BaseModelOutput, + XVectorOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 2 + +# General docstring +_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 64.21 + + +WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/wav2vec2-conformer-rel-pos-large", + # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer +] + + +@dataclass +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput): + """ + Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions. + + Args: + loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official + paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss. + projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked + projected quantized states. + projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive + target vectors for contrastive loss. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`): + The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) . + diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`): + The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) . + """ + + loss: Optional[torch.FloatTensor] = None + projected_states: torch.FloatTensor = None + projected_quantized_states: torch.FloatTensor = None + codevector_perplexity: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + contrastive_loss: Optional[torch.FloatTensor] = None + diversity_loss: Optional[torch.FloatTensor] = None + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.sum(-1).detach().tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices +def _sample_negative_indices( + features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None +): + """ + Sample `num_negatives` vectors from feature vectors. + """ + batch_size, sequence_length = features_shape + + # generate indices of the positive vectors themselves, repeat them `num_negatives` times + sequence_length_range = np.arange(sequence_length) + + # get `num_negatives` random vector indices from the same utterance + sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32) + + mask_time_indices = ( + mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool) + ) + + for batch_idx in range(batch_size): + high = mask_time_indices[batch_idx].sum() - 1 + mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]] + + feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives)) + sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives)) + # avoid sampling the same positive vector, but keep the distribution uniform + sampled_indices[sampled_indices >= feature_indices] += 1 + + # remap to actual indices + sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices] + + # correct for batch size + sampled_negative_indices[batch_idx] += batch_idx * sequence_length + + return sampled_negative_indices + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + deepspeed.zero.register_external_parameter(self, self.conv.weight_v) + deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + else: + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + + self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module): + """Rotary positional embedding + Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf + """ + + def __init__(self, config): + super().__init__() + dim = config.hidden_size // config.num_attention_heads + base = config.rotary_embedding_base + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.cached_sequence_length = None + self.cached_rotary_positional_embedding = None + + def forward(self, hidden_states): + sequence_length = hidden_states.shape[1] + + if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: + return self.cached_rotary_positional_embedding + + self.cached_sequence_length = sequence_length + time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) + embeddings = torch.cat((freqs, freqs), dim=-1) + + cos_embeddings = embeddings.cos()[:, None, None, :] + sin_embeddings = embeddings.sin()[:, None, None, :] + self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]) + return self.cached_rotary_positional_embedding + + +class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module): + """Relative positional encoding module.""" + + def __init__(self, config): + super().__init__() + self.max_len = config.max_source_positions + self.d_model = config.hidden_size + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)) + + def extend_pe(self, x): + # Reset the positional encodings + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` is the position of query vector and `j` is the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i<j). + pe_positive = torch.zeros(x.size(1), self.d_model) + pe_negative = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model) + ) + pe_positive[:, 0::2] = torch.sin(position * div_term) + pe_positive[:, 1::2] = torch.cos(position * div_term) + pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) + pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) + + # Reverse the order of positive indices and concat both positive and + # negative indices. This is used to support the shifting trick + # as in https://arxiv.org/abs/1901.02860 + pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) + pe_negative = pe_negative[1:].unsqueeze(0) + pe = torch.cat([pe_positive, pe_negative], dim=1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, hidden_states: torch.Tensor): + self.extend_pe(hidden_states) + start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1 + end_idx = self.pe.size(1) // 2 + hidden_states.size(1) + relative_position_embeddings = self.pe[:, start_idx:end_idx] + + return relative_position_embeddings + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerFeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [ + Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1) + for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [ + Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) + ] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + if self._requires_grad and self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(conv_layer), + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states, norm_hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class Wav2Vec2ConformerConvolutionModule(nn.Module): + """Convolution block used in the conformer block""" + + def __init__(self, config): + super().__init__() + if (config.conv_depthwise_kernel_size - 1) % 2 == 1: + raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.pointwise_conv1 = torch.nn.Conv1d( + config.hidden_size, + 2 * config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.glu = torch.nn.GLU(dim=1) + self.depthwise_conv = torch.nn.Conv1d( + config.hidden_size, + config.hidden_size, + config.conv_depthwise_kernel_size, + stride=1, + padding=(config.conv_depthwise_kernel_size - 1) // 2, + groups=config.hidden_size, + bias=False, + ) + self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size) + self.activation = ACT2FN[config.hidden_act] + self.pointwise_conv2 = torch.nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.dropout = torch.nn.Dropout(config.conformer_conv_dropout) + + def forward(self, hidden_states): + hidden_states = self.layer_norm(hidden_states) + # exchange the temporal dimension and the feature dimension + hidden_states = hidden_states.transpose(1, 2) + + # GLU mechanism + # => (batch, 2*channel, dim) + hidden_states = self.pointwise_conv1(hidden_states) + # => (batch, channel, dim) + hidden_states = self.glu(hidden_states) + + # 1D Depthwise Conv + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.batch_norm(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class Wav2Vec2ConformerSelfAttention(nn.Module): + """Construct an Wav2Vec2ConformerSelfAttention object. + Can be enhanced with rotary or relative position embeddings. + """ + + def __init__(self, config): + super().__init__() + + self.head_size = config.hidden_size // config.num_attention_heads + self.num_heads = config.num_attention_heads + self.position_embeddings_type = config.position_embeddings_type + + self.linear_q = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_k = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_v = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_out = nn.Linear(config.hidden_size, config.hidden_size) + + self.dropout = nn.Dropout(p=config.attention_dropout) + self.dropout_p = config.attention_dropout + + self.is_causal = config.is_causal + + if self.position_embeddings_type == "relative": + # linear transformation for positional encoding + self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self-attention mechanism + batch_size, sequence_length, hidden_size = hidden_states.size() + + # make sure query/key states can be != value states + query_key_states = hidden_states + value_states = hidden_states + + if self.position_embeddings_type == "rotary": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'" + ) + query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings) + + # project query_key_states and value_states + query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size) + + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): + hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal) + probs = None + + # # apply attention_mask if necessary + # if attention_mask is not None: + # scores = scores + attention_mask + + # # => (batch, head, time1, time2) + # probs = torch.softmax(scores, dim=-1) + # probs = self.dropout(probs) + + # # => (batch, head, time1, d_k) + # hidden_states = torch.matmul(probs, value) + + # => (batch, time1, hidden_size) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size) + hidden_states = self.linear_out(hidden_states) + + return hidden_states, probs + + def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings): + batch_size, sequence_length, hidden_size = hidden_states.size() + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size) + + cos = relative_position_embeddings[0, :sequence_length, ...] + sin = relative_position_embeddings[1, :sequence_length, ...] + + # rotate hidden_states with rotary embeddings + hidden_states = hidden_states.transpose(0, 1) + rotated_states_begin = hidden_states[..., : self.head_size // 2] + rotated_states_end = hidden_states[..., self.head_size // 2 :] + rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1) + hidden_states = (hidden_states * cos) + (rotated_states * sin) + hidden_states = hidden_states.transpose(0, 1) + + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size) + + return hidden_states + + def _apply_relative_embeddings(self, query, key, relative_position_embeddings): + # 1. project positional embeddings + # => (batch, head, 2*time1-1, d_k) + proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings) + proj_relative_position_embeddings = proj_relative_position_embeddings.view( + relative_position_embeddings.size(0), -1, self.num_heads, self.head_size + ) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3) + + # 2. Add bias to query + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) + + # 3. attention score: first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # => (batch, head, time1, time2) + scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) + + # 4. then compute matrix b and matrix d + # => (batch, head, time1, 2*time1-1) + scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings) + + # 5. shift matrix b and matrix d + zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype) + scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1) + scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2]) + scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape) + scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd) + scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1] + + # 6. sum matrices + # => (batch, head, time1, time2) + scores = (scores_ac + scores_bd) / math.sqrt(self.head_size) + + return scores + + +class Wav2Vec2ConformerEncoderLayer(nn.Module): + """Conformer block based on https://arxiv.org/abs/2005.08100.""" + + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.attention_dropout + + # Feed-forward 1 + self.ffn1_layer_norm = nn.LayerNorm(embed_dim) + self.ffn1 = Wav2Vec2ConformerFeedForward(config) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_dropout = torch.nn.Dropout(dropout) + self.self_attn = Wav2Vec2ConformerSelfAttention(config) + + # Conformer Convolution + self.conv_module = Wav2Vec2ConformerConvolutionModule(config) + + # Feed-forward 2 + self.ffn2_layer_norm = nn.LayerNorm(embed_dim) + self.ffn2 = Wav2Vec2ConformerFeedForward(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + hidden_states = hidden_states + + # 1. Feed-Forward 1 layer + residual = hidden_states + hidden_states = self.ffn1_layer_norm(hidden_states) + hidden_states = self.ffn1(hidden_states) + hidden_states = hidden_states * 0.5 + residual + residual = hidden_states + + # 2. Self-Attention layer + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weigts = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + # 3. Convolutional Layer + residual = hidden_states + hidden_states = self.conv_module(hidden_states) + hidden_states = residual + hidden_states + + # 4. Feed-Forward 2 Layer + residual = hidden_states + hidden_states = self.ffn2_layer_norm(hidden_states) + hidden_states = self.ffn2(hidden_states) + hidden_states = hidden_states * 0.5 + residual + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states, attn_weigts + + +class Wav2Vec2ConformerEncoder(nn.Module): + def __init__(self, config, is_causal=False): + super().__init__() + config.is_causal = is_causal + self.config = config + + if config.position_embeddings_type == "relative": + self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config) + elif config.position_embeddings_type == "rotary": + self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config) + else: + self.embed_positions = None + + self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states[~attention_mask] = 0.0 + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + hidden_states = self.dropout(hidden_states) + + if self.embed_positions is not None: + relative_position_embeddings = self.embed_positions(hidden_states) + else: + relative_position_embeddings = None + + deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = np.random.uniform(0, 1) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or deepspeed_zero3_is_enabled: + # under deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + # create gradient checkpointing function + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + attention_mask, + relative_position_embeddings, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module): + """ + Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH + GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. + """ + + def __init__(self, config): + super().__init__() + self.num_groups = config.num_codevector_groups + self.num_vars = config.num_codevectors_per_group + + if config.codevector_dim % self.num_groups != 0: + raise ValueError( + f"`config.codevector_dim {config.codevector_dim} must be divisible " + f"by `config.num_codevector_groups` {self.num_groups} for concatenation" + ) + + # storage for codebook variables (codewords) + self.codevectors = nn.Parameter( + torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups) + ) + self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) + + # can be decayed for training + self.temperature = 2 + + @staticmethod + def _compute_perplexity(probs, mask=None): + if mask is not None: + mask_extended = mask.flatten()[:, None, None].expand(probs.shape) + probs = torch.where(mask_extended, probs, torch.zeros_like(probs)) + marginal_probs = probs.sum(dim=0) / mask.sum() + else: + marginal_probs = probs.mean(dim=0) + + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def forward(self, hidden_states, mask_time_indices=None): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiateable way + codevector_probs = nn.functional.gumbel_softmax( + hidden_states.float(), tau=self.temperature, hard=True + ).type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerAdapter(nn.Module): + def __init__(self, config): + super().__init__() + + # feature dim might need to be down-projected + if config.output_hidden_size != config.hidden_size: + self.proj = nn.Linear(config.hidden_size, config.output_hidden_size) + self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size) + else: + self.proj = self.proj_layer_norm = None + + self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers)) + self.layerdrop = config.layerdrop + + def forward(self, hidden_states): + # down project hidden_states if necessary + if self.proj is not None and self.proj_layer_norm is not None: + hidden_states = self.proj(hidden_states) + hidden_states = self.proj_layer_norm(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + + for layer in self.layers: + layerdrop_prob = np.random.random() + if not self.training or (layerdrop_prob > self.layerdrop): + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerAdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.output_hidden_size, + 2 * config.output_hidden_size, + config.adapter_kernel_size, + stride=config.adapter_stride, + padding=1, + ) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=1) + + return hidden_states + + +class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Wav2Vec2ConformerConfig + base_model_prefix = "wav2vec2_conformer" + main_input_name = "input_values" + _keys_to_ignore_on_load_missing = [r"position_ids"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init. + if isinstance(module, Wav2Vec2ConformerForPreTraining): + module.project_hid.reset_parameters() + module.project_q.reset_parameters() + module.project_hid._is_hf_initialized = True + module.project_q._is_hf_initialized = True + # gumbel softmax requires special init + elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, Wav2Vec2ConformerSelfAttention): + if hasattr(module, "pos_bias_u"): + nn.init.xavier_uniform_(module.pos_bias_u) + if hasattr(module, "pos_bias_v"): + nn.init.xavier_uniform_(module.pos_bias_v) + elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, Wav2Vec2ConformerFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None + ): + """ + Computes the output length of the convolutional layers + """ + + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + + return input_lengths + + def _get_feature_vector_attention_mask( + self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None + ): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) + output_lengths = output_lengths.to(torch.long) + + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)): + module.gradient_checkpointing = value + + +WAV2VEC2_CONFORMER_START_DOCSTRING = r""" + Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech + Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael + Auli. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a + regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. + + Parameters: + config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + <Tip warning={true}> + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, such as + [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large), + `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For + such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware + that these models also yield slightly different results depending on whether `input_values` is padded or + not. + + </Tip> + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.", + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): + def __init__(self, config: Wav2Vec2ConformerConfig): + super().__init__(config) + self.config = config + self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config) + self.feature_projection = Wav2Vec2ConformerFeatureProjection(config) + + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) + + self.encoder = Wav2Vec2ConformerEncoder(config) + + self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.feature_extractor._freeze_parameters() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Wav2Vec2BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING +) +class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer + def __init__(self, config: Wav2Vec2ConformerConfig): + super().__init__(config) + self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) + self.dropout_features = nn.Dropout(config.feat_quantizer_dropout) + + self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config) + + self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) + self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature + def set_gumbel_temperature(self, temperature: int): + """ + Set the Gumbel softmax temperature to a given value. Only necessary for training + """ + self.quantizer.temperature = temperature + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2_conformer.feature_extractor._freeze_parameters() + + @staticmethod + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits + def compute_contrastive_logits( + target_features: torch.FloatTensor, + negative_features: torch.FloatTensor, + predicted_features: torch.FloatTensor, + temperature: int = 0.1, + ): + """ + Compute logits for contrastive loss based using cosine similarity as the distance measure between + `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied. + """ + target_features = torch.cat([target_features, negative_features], dim=0) + + logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as( + target_features + ) + + # apply temperature + logits = logits / temperature + return logits + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.BoolTensor] = None, + sampled_negative_indices: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]: + r""" + mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in *config.proj_codevector_dim* space. + sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*): + Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss. + Required input for pre-training. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining + >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( + ... _compute_mask_indices, + ... _sample_negative_indices, + ... ) + >>> from datasets import load_dataset + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") + >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1 + + >>> # compute masked indices + >>> batch_size, raw_sequence_length = input_values.shape + >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item() + >>> mask_time_indices = _compute_mask_indices( + ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2 + ... ) + >>> sampled_negative_indices = _sample_negative_indices( + ... features_shape=(batch_size, sequence_length), + ... num_negatives=model.config.num_negatives, + ... mask_time_indices=mask_time_indices, + ... ) + >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long) + >>> sampled_negative_indices = torch.tensor( + ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long + ... ) + + >>> with torch.no_grad(): + ... outputs = model(input_values, mask_time_indices=mask_time_indices) + + >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states) + >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1) + + >>> # show that cosine similarity is much higher than random + >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5 + tensor(True) + + >>> # for contrastive loss training model should be put into train mode + >>> model = model.train() + >>> loss = model( + ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices + ... ).loss + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if mask_time_indices is not None: + mask_time_indices = mask_time_indices.to(torch.bool) + + outputs = self.wav2vec2_conformer( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + mask_time_indices=mask_time_indices, + return_dict=return_dict, + ) + + # 1. project all transformed features (including masked) to final vq dim + transformer_features = self.project_hid(outputs[0]) + + # 2. quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(outputs[1]) + + if attention_mask is not None: + # compute reduced attention_mask correponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + quantized_features, codevector_perplexity = self.quantizer( + extract_features, mask_time_indices=mask_time_indices + ) + quantized_features = self.project_q(quantized_features) + + loss = contrastive_loss = diversity_loss = None + if sampled_negative_indices is not None: + batch_size, sequence_length, hidden_size = quantized_features.shape + + # for training, we sample negatives + # 3. sample K negatives (distractors) quantized states for contrastive loss + # if attention_mask is passed, make sure that padded feature vectors cannot be sampled + # sample negative quantized vectors BTC => (BxT)C + negative_quantized_features = quantized_features.view(-1, hidden_size)[ + sampled_negative_indices.long().view(-1) + ] + negative_quantized_features = negative_quantized_features.view( + batch_size, sequence_length, -1, hidden_size + ).permute(2, 0, 1, 3) + + # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa` + # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf + logits = self.compute_contrastive_logits( + quantized_features[None, :], + negative_quantized_features, + transformer_features, + self.config.contrastive_logits_temperature, + ) + + # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low), + # its cosine similarity will be masked + neg_is_pos = (quantized_features == negative_quantized_features).all(-1) + + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + + # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) = + # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa)) + logits = logits.transpose(0, 2).reshape(-1, logits.size(0)) + target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten() + + contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum") + # 7. compute diversity loss: \mathbf{L}_d + num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups + diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum() + + # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d + loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss + + if not return_dict: + if loss is not None: + return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + + return Wav2Vec2ConformerForPreTrainingOutput( + loss=loss, + projected_states=transformer_features, + projected_quantized_states=quantized_features, + codevector_perplexity=codevector_perplexity, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + contrastive_loss=contrastive_loss, + diversity_loss=diversity_loss, + ) + + +@add_start_docstrings( + """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer + def __init__(self, config): + super().__init__(config) + + self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + output_hidden_size = ( + config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size + ) + self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2_conformer.feature_extractor._freeze_parameters() + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.wav2vec2_conformer( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for + tasks like SUPERB Keyword Spotting. + """, + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel): + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)" + ) + self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2_conformer.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2_conformer.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2_conformer( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + hidden_states[~padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization. + """, + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel): + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)" + ) + self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.num_labels = config.num_labels + + self.init_weights() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2_conformer.feature_extractor._freeze_parameters() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2_conformer.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2_conformer( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss +class AMSoftmaxLoss(nn.Module): + def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): + super(AMSoftmaxLoss, self).__init__() + self.scale = scale + self.margin = margin + self.num_labels = num_labels + self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True) + self.loss = nn.CrossEntropyLoss() + + def forward(self, hidden_states, labels): + labels = labels.flatten() + weight = nn.functional.normalize(self.weight, dim=0) + hidden_states = nn.functional.normalize(hidden_states, dim=1) + cos_theta = torch.mm(hidden_states, weight) + psi = cos_theta - self.margin + + onehot = nn.functional.one_hot(labels, self.num_labels) + logits = self.scale * torch.where(onehot.bool(), psi, cos_theta) + loss = self.loss(logits, labels) + + return loss + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer +class TDNNLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id] + self.out_conv_dim = config.tdnn_dim[layer_id] + self.kernel_size = config.tdnn_kernel[layer_id] + self.dilation = config.tdnn_dilation[layer_id] + + self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim) + self.activation = nn.ReLU() + + def forward(self, hidden_states): + hidden_states = hidden_states.unsqueeze(1) + hidden_states = nn.functional.unfold( + hidden_states, + (self.kernel_size, self.in_conv_dim), + stride=(1, self.in_conv_dim), + dilation=(self.dilation, 1), + ) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.kernel(hidden_states) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +@add_start_docstrings( + """ + Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0]) + + tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))] + self.tdnn = nn.ModuleList(tdnn_layers) + + self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim) + self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim) + + self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) + + self.init_weights() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2_conformer.feature_extractor._freeze_parameters() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2_conformer.parameters(): + param.requires_grad = False + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer + def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the TDNN layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size in self.config.tdnn_kernel: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1) + + return input_lengths + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, XVectorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2_conformer( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + + for tdnn_layer in self.tdnn: + hidden_states = tdnn_layer(hidden_states) + + # Statistic Pooling + if attention_mask is None: + mean_features = hidden_states.mean(dim=1) + std_features = hidden_states.std(dim=1) + else: + feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1)) + tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths) + mean_features = [] + std_features = [] + for i, length in enumerate(tdnn_output_lengths): + mean_features.append(hidden_states[i, :length].mean(dim=0)) + std_features.append(hidden_states[i, :length].std(dim=0)) + mean_features = torch.stack(mean_features) + std_features = torch.stack(std_features) + statistic_pooling = torch.cat([mean_features, std_features], dim=-1) + + output_embeddings = self.feature_extractor(statistic_pooling) + logits = self.classifier(output_embeddings) + + loss = None + if labels is not None: + loss = self.objective(logits, labels) + + if not return_dict: + output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return XVectorOutput( + loss=loss, + logits=logits, + embeddings=output_embeddings, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/slam_llm/models/musicfm/modules/random_quantizer.py b/slam_llm/models/musicfm/modules/random_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..257164a951b2137e377369605abb74a2557bd965 --- /dev/null +++ b/slam_llm/models/musicfm/modules/random_quantizer.py @@ -0,0 +1,83 @@ +# MIT License +# +# Copyright 2023 ByteDance Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), +# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. + +import torch +from torch import nn, einsum +from einops import rearrange + + +class RandomProjectionQuantizer(nn.Module): + """ + Random projection and codebook lookup module + + Some code is borrowed from: + https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py + But I did normalization using pre-computed global mean & variance instead of using layer norm. + """ + + def __init__( + self, + input_dim, + codebook_dim, + codebook_size, + seed=142, + ): + super().__init__() + + # random seed + torch.manual_seed(seed) + + # randomly initialized projection + random_projection = torch.empty(input_dim, codebook_dim) + nn.init.xavier_normal_(random_projection) + self.register_buffer("random_projection", random_projection) + + # randomly initialized codebook + codebook = torch.empty(codebook_size, codebook_dim) + nn.init.normal_(codebook) + self.register_buffer("codebook", codebook) + + def codebook_lookup(self, x): + # reshape + b = x.shape[0] + x = rearrange(x, "b n e -> (b n) e") + + # L2 normalization + normalized_x = nn.functional.normalize(x, dim=1, p=2) + normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2) + + # compute distances + distances = torch.cdist(normalized_codebook, normalized_x) + + # get nearest + nearest_indices = torch.argmin(distances, dim=0) + + # reshape + xq = rearrange(nearest_indices, "(b n) -> b n", b=b) + + return xq + + @torch.no_grad() + def forward(self, x): + # always eval + self.eval() + + # random projection [batch, length, input_dim] -> [batch, length, codebook_dim] + x = einsum("b n d, d e -> b n e", x, self.random_projection) + + # codebook lookup + xq = self.codebook_lookup(x) + + return xq diff --git a/slam_llm/models/projector.py b/slam_llm/models/projector.py new file mode 100644 index 0000000000000000000000000000000000000000..69825c58c037de5b116dfafa9d86fe437b3c1cc5 --- /dev/null +++ b/slam_llm/models/projector.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn + + +class EncoderProjectorConcat(nn.Module): + def __init__(self, config): + super().__init__() + self.k = config.encoder_projector_ds_rate + self.encoder_dim = config.encoder_dim + self.llm_dim = config.llm_dim + self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048) + self.relu = nn.ReLU() + self.linear2 = nn.Linear(2048, config.llm_dim) + + def forward(self, x): + batch_size, seq_len, dim = x.size() + num_frames_to_discard = seq_len % self.k + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.contiguous() + x = x.view(batch_size, seq_len // self.k, dim * self.k) + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + +class EncoderProjectorCov1d(nn.Module): + def __init__(self, config): + super().__init__() + self.k = config.encoder_projector_ds_rate + self.encoder_dim = config.encoder_dim + self.llm_dim = config.llm_dim + self.conv1d = nn.Conv1d(in_channels=self.encoder_dim, out_channels=self.encoder_dim, kernel_size=self.k, stride=self.k, padding=0) + self.linear1 = nn.Linear(self.encoder_dim, 2048) + self.relu1 = nn.ReLU() + self.linear2 = nn.Linear(2048, self.llm_dim) + self.relu2 = nn.ReLU() + + def forward(self, x): + x = x.transpose(1, 2) + x = self.conv1d(x) + x = x.transpose(1, 2) + x = self.relu1(x) + x = self.linear1(x) + x = self.relu2(x) + x = self.linear2(x) + return x + +class EncoderProjectorQFormer(nn.Module): + def __init__(self, config): + super().__init__() + self.encoder_dim = config.encoder_dim + self.llm_dim = config.llm_dim + from transformers import Blip2QFormerConfig, Blip2QFormerModel + configuration = Blip2QFormerConfig() + configuration.encoder_hidden_size = self.encoder_dim + configuration.num_hidden_layers = 8 + + self.query_len = 64 + self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size)) + self.query.data.normal_(mean=0.0, std=1.0) + self.qformer = Blip2QFormerModel(configuration) + + self.linear = nn.Linear(configuration.hidden_size, self.llm_dim) + self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5) + + def forward(self, x, atts): + query = self.query.expand(x.shape[0], -1, -1) + + query_output = self.qformer( + query_embeds=query, + encoder_hidden_states=x, + encoder_attention_mask=atts, + return_dict=True, + ) + + query_proj = self.norm(self.linear(query_output.last_hidden_state)) + + return query_proj \ No newline at end of file diff --git a/slam_llm/models/slam_model.py b/slam_llm/models/slam_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d7e032e53f7c2d3f3e741142b360f3cb8710ba --- /dev/null +++ b/slam_llm/models/slam_model.py @@ -0,0 +1,443 @@ +import os +import types +import torch +import soundfile as sf +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from typing import List, Optional, Tuple, Union +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, AutoModelForSeq2SeqLM, T5ForConditionalGeneration +from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training + +from slam_llm.utils.config_utils import generate_peft_config +from slam_llm.utils.train_utils import print_module_size, print_model_size +from peft import PeftModel, PeftConfig +from torch.nn import CrossEntropyLoss +from slam_llm.utils.metric import compute_accuracy + +import logging +logger = logging.getLogger(__name__) + +def model_factory(train_config, model_config, **kwargs): + # return necessary components for training + tokenizer = setup_tokenizer(train_config, model_config, **kwargs) + + encoder = setup_encoder(train_config, model_config, **kwargs) + + # llm + llm = setup_llm(train_config, model_config, **kwargs) + + # projector + encoder_projector = setup_encoder_projector( + train_config, model_config, **kwargs + ) + model = slam_model( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + ckpt_path = kwargs.get("ckpt_path", None) #FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) + if ckpt_path is not None: + logger.info("loading other parts from: {}".format(ckpt_path)) + ckpt_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt_dict, strict=False) + + print_model_size(model, train_config, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) + return model, tokenizer + + +def setup_tokenizer(train_config, model_config, **kwargs): + # Load the tokenizer and add special tokens + if "vallex" in model_config.llm_name.lower(): + return None + elif "mupt" in model_config.llm_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path, + trust_remote_code=True, + use_fast=False) + else: + tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path) + tokenizer.pad_token_id = tokenizer.eos_token_id + return tokenizer + + +def setup_encoder(train_config, model_config, **kwargs): + encoder_list = model_config.encoder_name.split(",") if model_config.encoder_name else [] + if len(encoder_list) == 0: + return None + if len(encoder_list) == 1: + encoder_name = encoder_list[0] + if encoder_name == "whisper" or encoder_name == "qwen-audio": + from slam_llm.models.encoder import WhisperWrappedEncoder + encoder = WhisperWrappedEncoder.load(model_config) + if encoder_name == "beats": + from slam_llm.models.encoder import BEATsEncoder + encoder = BEATsEncoder.load(model_config) + if encoder_name == "eat": + from slam_llm.models.encoder import EATEncoder + encoder = EATEncoder.load(model_config) + if encoder_name == "SpatialAST": + from slam_llm.models.encoder import SpatialASTEncoder + encoder = SpatialASTEncoder.load(model_config) + if encoder_name == "wavlm": + from slam_llm.models.encoder import WavLMEncoder + encoder = WavLMEncoder.load(model_config) + if encoder_name == "av_hubert": + from slam_llm.models.encoder import AVHubertEncoder + encoder = AVHubertEncoder.load(model_config) + if encoder_name == "hubert": + from slam_llm.models.encoder import HubertEncoder + encoder = HubertEncoder.load(model_config) + if encoder_name == "musicfm": + from slam_llm.models.encoder import MusicFMEncoder + encoder = MusicFMEncoder.load(model_config) + + if "llama" in encoder_name.lower(): + from slam_llm.models.encoder import HfTextEncoder + encoder = HfTextEncoder.load(model_config) + print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) + + if train_config.freeze_encoder: + for name, param in encoder.named_parameters(): + param.requires_grad = False + encoder.eval() + print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) + + return encoder + +def setup_llm(train_config, model_config, **kwargs): + from pkg_resources import packaging + use_cache = False if train_config.enable_fsdp or train_config.enable_ddp else None + if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.low_cpu_fsdp: + """ + for FSDP, we can save cpu memory by loading pretrained model on rank0 only. + this avoids cpu oom when loading large models like llama 70B, in which case + model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms + overhead and currently requires latest nightly. + """ + # v = packaging.version.parse(torch.__version__) + # verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 + # if not verify_latest_nightly: + # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " + # "please install latest nightly.") + rank = int(os.environ["RANK"]) + if rank == 0: + if "vallex" in model_config.llm_name.lower(): + from src.slam_llm.models.vallex.vallex_config import VallexConfig + from src.slam_llm.models.vallex.vallex_model import VALLE + vallex_config = VallexConfig( + **model_config + ) + model = VALLE(vallex_config) + elif "aya" in model_config.llm_name.lower(): + model = AutoModelForSeq2SeqLM.from_pretrained( + model_config.llm_path, + load_in_8bit=True if train_config.quantization else None, + device_map="auto" if train_config.quantization else None, + use_cache=use_cache, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_config.llm_path, + load_in_8bit=True if train_config.quantization else None, + device_map="auto" if train_config.quantization else None, + use_cache=use_cache, + ) + else: + llama_config = AutoConfig.from_pretrained(model_config.llm_path) + llama_config.use_cache = use_cache + # with torch.device("meta"): + if "aya" in model_config.llm_name.lower(): + model = AutoModelForSeq2SeqLM(llama_config) + else: + model = AutoModelForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta` + + else: + if "vallex" in model_config.llm_name.lower(): + from src.slam_llm.models.vallex.vallex_config import VallexConfig + from src.slam_llm.models.vallex.vallex_model import VALLE + vallex_config = VallexConfig( + **model_config + ) + model = VALLE(vallex_config) + elif "aya" in model_config.llm_name.lower(): + model = AutoModelForSeq2SeqLM.from_pretrained( + model_config.llm_path, + load_in_8bit=True if train_config.quantization else None, + device_map="auto" if train_config.quantization else None, + use_cache=use_cache, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_config.llm_path, + load_in_8bit=True if train_config.quantization else None, + device_map="auto" if train_config.quantization else None, + use_cache=use_cache, + ) + if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.use_fast_kernels: + """ + For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable + using of Flash Attention or Xformer memory-efficient kernels + based on the hardware being used. This would speed up fine-tuning. + """ + try: + from optimum.bettertransformer import BetterTransformer + model = BetterTransformer.transform(model) + except ImportError: + logger.warning("Module 'optimum' not found. Please install 'optimum' it before proceeding.") + + print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) + + # Prepare the model for int8 training if quantization is enabled + if train_config.quantization: + model = prepare_model_for_kbit_training(model) + + if train_config.freeze_llm: # TODO:to test offical `freeze_layers` and `num_freeze_layers` + for name, param in model.named_parameters(): + param.requires_grad = False + model.eval() + + if kwargs.get("peft_ckpt", None): # (FIX:MZY):reload will get wrong results when decoding + logger.info("loading peft_ckpt from: {}".format(kwargs.get("peft_ckpt"))) + model = PeftModel.from_pretrained(model=model, model_id=kwargs.get("peft_ckpt"), is_trainable=True) + model.print_trainable_parameters() + elif train_config.use_peft: + logger.info("setup peft...") + peft_config = generate_peft_config(train_config) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) + return model + +def setup_encoder_projector(train_config, model_config, **kwargs): + if model_config.encoder_projector == "linear": + from slam_llm.models.projector import EncoderProjectorConcat + encoder_projector = EncoderProjectorConcat(model_config) + elif model_config.encoder_projector == "cov1d-linear": + from slam_llm.models.projector import EncoderProjectorCov1d + encoder_projector = EncoderProjectorCov1d(model_config) + elif model_config.encoder_projector == "q-former": + from slam_llm.models.projector import EncoderProjectorQFormer + encoder_projector = EncoderProjectorQFormer(model_config) + else: + return None + print_module_size(encoder_projector, model_config.encoder_projector, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) + return encoder_projector + + +class slam_model(nn.Module): + def __init__( + self, + encoder: nn.Module, + llm: nn.Module, + encoder_projector: nn.Module, + tokenizer, + train_config, + model_config, + **kwargs + ): + super().__init__() + # modality encoder + self.encoder = encoder + + # llm + self.llm = llm + + # projector + self.encoder_projector = encoder_projector + + # tokenizer + self.tokenizer = tokenizer + self.metric = kwargs.get("metric", "acc") + + self.train_config = train_config + self.model_config = model_config + + if train_config.get("enable_deepspeed", False): + def new_forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + for item in self.modules(): + if isinstance(item, nn.LayerNorm): + item.forward = types.MethodType(new_forward, item) + + + + def forward(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ): + audio_mel = kwargs.get("audio_mel", None) + audio_mel_mask = kwargs.get("audio_mel_mask", None) + audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper + + audio = kwargs.get("audio", None) + audio_mask = kwargs.get("audio_mask", None) + visual = kwargs.get("visual", None) + visual_mask = kwargs.get("visual_mask", None) + + + # for text encoder + instruct_ids = kwargs.get("instruct_ids", None) + instruct_mask = kwargs.get("instruct_mask", None) + + modality_mask = kwargs.get("modality_mask", None) + + zh_data = kwargs.get("zh", None) + en_data = kwargs.get("en", None) + + encoder_outs = None + if audio_mel is not None or audio is not None or visual is not None: + if self.train_config.freeze_encoder: # freeze encoder + self.encoder.eval() + + if self.model_config.encoder_name == "whisper": + encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim + if self.model_config.encoder_name == "beats": + encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim + if self.model_config.encoder_name == "eat": + encoder_outs = self.encoder.model.extract_features(audio_mel.unsqueeze(dim=1), padding_mask = None, mask=False, remove_extra_tokens = False)['x'] + if self.model_config.encoder_name == "SpatialAST": + encoder_outs = self.encoder(audio) # output: [bs, seq_len=3+512, dim=768] + if self.model_config.encoder_name == "wavlm": + encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask + if self.model_config.encoder_name == "hubert": + results = self.encoder(source = audio, padding_mask = 1-audio_mask) + if self.model_config.encoder_type == "pretrain": + encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"] + if self.model_config.encoder_type == "finetune": + encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] + encoder_outs = encoder_outs.transpose(0, 1) + if self.model_config.encoder_name == "av_hubert": + results = self.encoder(source={'video':visual, 'audio':audio}, padding_mask=visual_mask) # bs*seq*dim + encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] + encoder_outs = encoder_outs.transpose(0, 1) + audio_mel_post_mask = (~audio_mel_post_mask).float() + if self.model_config.encoder_name == 'musicfm': + encoder_outs = self.encoder.extract_features(audio, padding_mask = None) # MusicFM doesn't support padding mask + if self.encoder is None: + encoder_outs = audio_mel if audio_mel is not None else audio + + if self.model_config.encoder_projector == "q-former": + encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) + if self.model_config.encoder_projector == "cov1d-linear": + encoder_outs = self.encoder_projector(encoder_outs) + + if instruct_ids is not None: + if self.encoder is not None: + encoder_outs = self.encoder(input_ids=instruct_ids, attention_mask=instruct_mask).last_hidden_state + + if self.model_config.encoder_projector == "q-former": + encoder_outs = self.encoder_projector(encoder_outs, instruct_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) + + if input_ids is not None: + input_ids[input_ids == -1] = 0 + if isinstance(self.llm, T5ForConditionalGeneration): + inputs_embeds = self.llm.shared(input_ids) + else: + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(input_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(input_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) + + if modality_mask is not None: + modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1) + modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist() + + encoder_outs_pad = torch.zeros_like(inputs_embeds) + for i in range(encoder_outs.shape[0]): + encoder_outs_pad[ + i, modality_mask_start_indices[i]:modality_mask_start_indices[i]+modality_lengths[i] + ] = encoder_outs[i][:modality_lengths[i]] + + inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None]) + + if kwargs.get("inference_mode", False): + return inputs_embeds, attention_mask + + if zh_data is not None and en_data is not None: + model_outputs, acc = self.llm(zh=zh_data, en=en_data) + else: + model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) + acc = -1 + if self.metric: + with torch.no_grad(): + preds = torch.argmax(model_outputs.logits, -1) + acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100) + + return model_outputs, acc + + @torch.no_grad() + def generate(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ): + kwargs["inference_mode"] = True + + inputs_embeds, attention_mask = self.forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + model_outputs = self.llm.generate( + inputs_embeds=inputs_embeds, + # max_length=kwargs.get("max_length", 200), + max_new_tokens=kwargs.get("max_new_tokens", 200), + num_beams=kwargs.get("num_beams", 4), + do_sample=kwargs.get("do_sample", False), + min_length=kwargs.get("min_length", 1), + top_p=kwargs.get("top_p", 1.0), + repetition_penalty=kwargs.get("repetition_penalty", 1.0), + length_penalty=kwargs.get("length_penalty", 1.0), + temperature=kwargs.get("temperature", 1.0), + attention_mask=attention_mask, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id + ) + + return model_outputs diff --git a/slam_llm/models/vallex/__init__.py b/slam_llm/models/vallex/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/slam_llm/models/vallex/activation.py b/slam_llm/models/vallex/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f3e55e87b32b1045a5323c720a93a9efefe13a --- /dev/null +++ b/slam_llm/models/vallex/activation.py @@ -0,0 +1,179 @@ +from typing import Optional, Tuple, List +import math + +import torch +from torch import Tensor +from torch.nn import Linear, Module +from torch.nn import functional as F +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear + + +class MultiheadAttention(Module): + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + linear1_cls=Linear, + linear2_cls=Linear, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = False + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + self.num_heads = num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.k_proj = Linear(self.kdim, embed_dim) + self.v_proj = Linear(self.kdim, embed_dim) + self.q_proj = Linear(self.kdim, embed_dim) + + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + self.add_zero_attn = add_zero_attn + self.scaling = self.head_dim**-0.5 + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + ) -> Tuple[Tensor, Optional[Tensor]]: + + # T,B,C + B, T, C = query.size() + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) + + attn_weights = q @ k.transpose(-2, -1) # B, nh, T, T + + if attn_mask is not None: + # attn_mask is inf + # attn_mask = attn_mask.unsqueeze(0) + # attn_weights += attn_mask + if torch.is_floating_point(attn_mask): + # print(attn_weights.size(), attn_mask.size()) + attn_weights += attn_mask.unsqueeze(0).unsqueeze(1) + else: + attn_weights = attn_weights.masked_fill(attn_mask, float('-inf')) + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(B, self.num_heads, T, T) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1) + .unsqueeze(2) + .to(torch.bool), + float("-inf"), + ) + attn_weights_float = F.softmax(attn_weights, dim=-1) + attn = attn_weights_float @ v + + y = attn.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = self.out_proj(y) + return y, attn_weights + + def infer(self, + x: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + past_kv = None, + use_cache = False): + + # print("debug:"+str(x.size())) + + B, T, C = x.size() + + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + q *= self.scaling + + # k = k.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs) + # q = q.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs) + # v = v.view(T, B*self.num_heads, self.head_dim).transpose(0, 1) # (B, nh, T, hs) + + k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) + + if past_kv is not None: + past_key = past_kv[0] + past_value = past_kv[1] + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + FULL_T = k.shape[-2] + + if use_cache is True: + present = (k, v) + else: + present = None + + # print(q.size(), k.size()) + attn_weights = q @ k.transpose(-2, -1) + # print(attn_mask.size()) + attn_weights = attn_weights.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf')) + + # if key_padding_mask is not None: + # # don't attend to padding symbols + # attn_weights = attn_weights.view(B, self.num_heads, T, T) + # attn_weights = attn_weights.view(B, -1, self.num_heads, T, T) + # attn_weights = attn_weights.masked_fill( + # key_padding_mask.unsqueeze(1) + # .unsqueeze(2) + # .unsqueeze(3) + # .to(torch.bool), + # float("-inf"), + # ) + attn_weights_float = F.softmax(attn_weights, dim=-1, ) + # attn_weights = attn_weights_float.type_as(attn_weights) + # attn = torch.bmm(attn_weights, v) + attn = attn_weights_float @ v + + y = attn.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = self.out_proj(y) + return (y, present) \ No newline at end of file diff --git a/slam_llm/models/vallex/scaling.py b/slam_llm/models/vallex/scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..a34de303ce471279cdb78003f5409927c6c941a6 --- /dev/null +++ b/slam_llm/models/vallex/scaling.py @@ -0,0 +1,1404 @@ +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import collections +import logging +import random +import math +from functools import reduce +from itertools import repeat +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Embedding as ScaledEmbedding + +class Transpose(nn.Identity): + """(N, T, D) -> (N, D, T)""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input.transpose(1, 2) + +class ActivationBalancerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = x > 0 + if sign_factor is None: + ctx.save_for_backward(xgt0, scale_factor) + else: + ctx.save_for_backward(xgt0, scale_factor, sign_factor) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + if len(ctx.saved_tensors) == 3: + xgt0, scale_factor, sign_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + sign_factor = sign_factor.unsqueeze(-1) + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + else: + xgt0, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + scale_factor = scale_factor.unsqueeze(-1) + factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) + + +def _compute_scale_factor( + x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float, +) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) + + if min_abs == 0.0: + below_threshold = 0.0 + else: + # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if + # x_abs)_mean , min_abs. + below_threshold = ( + (min_abs - x_abs_mean) * (gain_factor / min_abs) + ).clamp(min=0, max=max_factor) + + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( + min=0, max=max_factor + ) + + return below_threshold - above_threshold + + +def _compute_sign_factor( + x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float, +) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + sum_dims = [d for d in range(x.ndim) if d != channel_dim] + proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) + if min_positive == 0.0: + factor1 = 0.0 + else: + # 0 if proportion_positive >= min_positive, else can be + # as large as max_factor. + factor1 = ( + (min_positive - proportion_positive) * (gain_factor / min_positive) + ).clamp_(min=0, max=max_factor) + + if max_positive == 1.0: + factor2 = 0.0 + else: + # 0 if self.proportion_positive <= max_positive, else can be + # as large as -max_factor. + factor2 = ( + (proportion_positive - max_positive) + * (gain_factor / (1.0 - max_positive)) + ).clamp_(min=0, max=max_factor) + sign_factor = factor1 - factor2 + # require min_positive != 0 or max_positive != 1: + assert not isinstance(sign_factor, float) + return sign_factor + + +class ActivationScaleBalancerFunction(torch.autograd.Function): + """ + This object is used in class ActivationBalancer when the user specified + min_positive=0, max_positive=1, so there are no constraints on the signs + of the activations and only the absolute value has a constraint. + """ + + @staticmethod + def forward( + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, + ) -> Tensor: + if channel_dim < 0: + channel_dim += x.ndim + ctx.channel_dim = channel_dim + xgt0 = x > 0 + ctx.save_for_backward(xgt0, sign_factor, scale_factor) + return x + + @staticmethod + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + xgt0, sign_factor, scale_factor = ctx.saved_tensors + for _ in range(ctx.channel_dim, x_grad.ndim - 1): + sign_factor = sign_factor.unsqueeze(-1) + scale_factor = scale_factor.unsqueeze(-1) + + factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) + neg_delta_grad = x_grad.abs() * factor + return ( + x_grad - neg_delta_grad, + None, + None, + None, + ) + + +class RandomClampFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float, + ) -> Tensor: + x_clamped = torch.clamp(x, min=min, max=max) + mask = torch.rand_like(x) < prob + ans = torch.where(mask, x_clamped, x) + if x.requires_grad: + ctx.save_for_backward(ans == x) + ctx.reflect = reflect + if reflect != 0.0: + ans = ans * (1.0 + reflect) - (x * reflect) + return ans + + @staticmethod + def backward( + ctx, ans_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None]: + (is_same,) = ctx.saved_tensors + x_grad = ans_grad * is_same.to(ans_grad.dtype) + reflect = ctx.reflect + if reflect != 0.0: + x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) + return x_grad, None, None, None, None + + +def random_clamp( + x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0, +): + return RandomClampFunction.apply(x, min, max, prob, reflect) + + +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: + """ + A randomized way of casting a floating point value to half precision. + """ + if x.dtype == torch.float16: + return x + x_abs = x.abs() + is_too_small = x_abs < min_abs + # for elements where is_too_small is true, random_val will contain +-min_abs with + # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, + # for those elements]. + random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) + return torch.where(is_too_small, random_val, x).to(torch.float16) + + +class RandomGradFunction(torch.autograd.Function): + """ + Does nothing in forward pass; in backward pass, gets rid of very small grads using + randomized approach that preserves expectations (intended to reduce roundoff). + """ + + @staticmethod + def forward(ctx, x: Tensor, min_abs: float) -> Tensor: + ctx.min_abs = min_abs + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: + if ans_grad.dtype == torch.float16: + return ( + random_cast_to_half( + ans_grad.to(torch.float32), min_abs=ctx.min_abs + ), + None, + ) + else: + return ans_grad, None + + +class RandomGrad(torch.nn.Module): + """ + Gets rid of very small gradients using an expectation-preserving method, intended to increase + accuracy of training when using amp (automatic mixed precision) + """ + + def __init__(self, min_abs: float = 5.0e-06): + super(RandomGrad, self).__init__() + self.min_abs = min_abs + + def forward(self, x: Tensor): + if ( + torch.jit.is_scripting() + or not self.training + or torch.jit.is_tracing() + ): + return x + else: + return RandomGradFunction.apply(x, self.min_abs) + + +class SoftmaxFunction(torch.autograd.Function): + """ + Tries to handle half-precision derivatives in a randomized way that should + be more accurate for training than the default behavior. + """ + + @staticmethod + def forward(ctx, x: Tensor, dim: int): + ans = x.softmax(dim=dim) + # if x dtype is float16, x.softmax() returns a float32 because + # (presumably) that op does not support float16, and autocast + # is enabled. + if torch.is_autocast_enabled(): + ans = ans.to(torch.float16) + ctx.save_for_backward(ans) + ctx.x_dtype = x.dtype + ctx.dim = dim + return ans + + @staticmethod + def backward(ctx, ans_grad: Tensor): + (ans,) = ctx.saved_tensors + with torch.cuda.amp.autocast(enabled=False): + ans_grad = ans_grad.to(torch.float32) + ans = ans.to(torch.float32) + x_grad = ans_grad * ans + x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) + return x_grad, None + + +def softmax(x: Tensor, dim: int): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x.softmax(dim) + + return SoftmaxFunction.apply(x, dim) + + +class MaxEigLimiterFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: + ctx.channel_dim = channel_dim + ctx.grad_scale = grad_scale + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + return x + + @staticmethod + def backward(ctx, x_grad, *args): + with torch.enable_grad(): + (x_orig, coeffs, new_direction) = ctx.saved_tensors + x_orig.requires_grad = True + num_channels = x_orig.shape[ctx.channel_dim] + x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) + new_direction.requires_grad = False + x = x - x.mean(dim=0) + x_var = (x ** 2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual ** 2).mean() + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. This is to be minimized. + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) + variance_proportion.backward() + x_orig_grad = x_orig.grad + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) + return x_grad + x_extra_grad.detach(), None, None, None, None + + +class BasicNorm(torch.nn.Module): + """ + This is intended to be a simpler, and hopefully cheaper, replacement for + LayerNorm. The observation this is based on, is that Transformer-type + networks, especially with pre-norm, sometimes seem to set one of the + feature dimensions to a large constant value (e.g. 50), which "defeats" + the LayerNorm because the output magnitude is then not strongly dependent + on the other (useful) features. Presumably the weight and bias of the + LayerNorm are required to allow it to do this. + + So the idea is to introduce this large constant value as an explicit + parameter, that takes the role of the "eps" in LayerNorm, so the network + doesn't have to do this trick. We make the "eps" learnable. + + Args: + num_channels: the number of channels, e.g. 512. + channel_dim: the axis/dimension corresponding to the channel, + interprted as an offset from the input's ndim if negative. + shis is NOT the num_channels; it should typically be one of + {-2, -1, 0, 1, 2, 3}. + eps: the initial "epsilon" that we add as ballast in: + scale = ((input_vec**2).mean() + epsilon)**-0.5 + Note: our epsilon is actually large, but we keep the name + to indicate the connection with conventional LayerNorm. + learn_eps: if true, we learn epsilon; if false, we keep it + at the initial value. + eps_min: float + eps_max: float + """ + + def __init__( + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True, + eps_min: float = -3.0, + eps_max: float = 3.0, + ) -> None: + super(BasicNorm, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + if learn_eps: + self.eps = nn.Parameter(torch.tensor(eps).log().detach()) + else: + self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.eps_min = eps_min + self.eps_max = eps_max + + def forward(self, x: Tensor) -> Tensor: + assert x.shape[self.channel_dim] == self.num_channels + eps = self.eps + if self.training and random.random() < 0.25: + # with probability 0.25, in training mode, clamp eps between the min + # and max; this will encourage it to learn parameters within the + # allowed range by making parameters that are outside the allowed + # range noisy. + + # gradients to allow the parameter to get back into the allowed + # region if it happens to exit it. + eps = eps.clamp(min=self.eps_min, max=self.eps_max) + scales = ( + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + ) ** -0.5 + return x * scales + + +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + """ + Behaves like a constructor of a modified version of nn.Linear + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Linear(*args, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_( + ans.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + return ans + + +def ScaledConv1d( + *args, + initial_scale: float = 1.0, + kernel_size: int = 3, + padding: str = "same", + **kwargs, +) -> nn.Conv1d: + """ + Behaves like a constructor of a modified version of nn.Conv1d + that gives an easy way to set the default initial parameter scale. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs) + with torch.no_grad(): + ans.weight[:] *= initial_scale + if ans.bias is not None: + torch.nn.init.uniform_( + ans.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) + return ans + + +def TransposeScaledConv1d( + *args, + initial_scale: float = 1.0, + kernel_size: int = 3, + padding: str = "same", + **kwargs, +) -> nn.Sequential: + """ + Transpose -> ScaledConv1d + """ + return nn.Sequential( + Transpose(), + ScaledConv1d( + *args, + initial_scale=initial_scale, + kernel_size=kernel_size, + padding=padding, + **kwargs, + ), + ) + + +def ScaledConv1dTranspose( + *args, + initial_scale: float = 1.0, + kernel_size: int = 3, + padding: str = "same", + **kwargs, +) -> nn.Sequential: + """ + Transpose -> ScaledConv1d + """ + return nn.Sequential( + ScaledConv1d( + *args, + initial_scale=initial_scale, + kernel_size=kernel_size, + padding=padding, + **kwargs, + ), + Transpose(), + ) + + +def TransposeConv1d( + *args, kernel_size: int = 3, padding: str = "same", **kwargs +) -> nn.Sequential: + """ + Transpose -> Conv1d + """ + return nn.Sequential( + Transpose(), + nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs), + ) + + +def Conv1dTranspose( + *args, kernel_size: int = 3, padding: str = "same", **kwargs +) -> nn.Sequential: + """ + ScaledConv1d -> Transpose + """ + return nn.Sequential( + nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs), + Transpose(), + ) + + +class SRLinear(nn.Linear): + """https://arxiv.org/abs/2303.06296 + Stabilizing Transformer Training by Preventing Attention Entropy Collapse + """ + + def __init__(self, in_features, out_features, bias=True, **kwargs): + super().__init__(in_features, out_features, bias=bias, **kwargs) + self.register_buffer( + "u", nn.functional.normalize(torch.randn(in_features), dim=0) + ) + with torch.no_grad(): + sigma = self.get_sigma() + self.register_buffer("spectral_norm", sigma) + self.sigma = nn.Parameter(torch.ones(1)) + + def get_sigma(self): + with torch.no_grad(): + u = self.u + v = self.weight.mv(u) + v = nn.functional.normalize(v, dim=0) + u = self.weight.T.mv(v) + u = nn.functional.normalize(u, dim=0) + self.u.data.copy_(u) + return torch.einsum("c,cd,d->", v, self.weight, u) + + def get_weight(self): + sigma = self.get_sigma() + if self.training: + self.spectral_norm.data.copy_(sigma) + weight = (self.sigma / sigma) * self.weight + return weight + + def forward(self, x): + return nn.functional.linear(x, self.get_weight(), self.bias) + + +class SRConv1d(SRLinear): + def __init__( + self, + in_features, + out_features, + kernel_size, + stride: int = 1, + padding: str = "same", + bias: bool = True, + **kwargs, + ): + in_features = in_features * kernel_size + super().__init__(in_features, out_features, bias=bias, **kwargs) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + def forward(self, x): + in_features = self.in_features // self.kernel_size + weight = self.get_weight().view( + self.out_features, in_features, self.kernel_size + ) + return nn.functional.conv1d( + x, weight, bias=self.bias, stride=self.stride, padding=self.padding + ) + + +def TransposeSRConv1d( + *args, kernel_size: int = 3, padding: str = "same", **kwargs +) -> nn.Sequential: + """ + Transpose -> SRConv1d + """ + return nn.Sequential( + Transpose(), + SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs), + ) + + +def SRConv1dTranspose( + *args, kernel_size: int = 3, padding: str = "same", **kwargs +) -> nn.Sequential: + """ + SRConv1d -> Transpose + """ + return nn.Sequential( + SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs), + Transpose(), + ) + + +class ActivationBalancer(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to encourage, for + each channel, that it is positive at least a proportion `threshold` of the + time. It does this by multiplying negative derivative values by up to + (1+max_factor), and positive derivative values by up to (1-max_factor), + interpolated from 1 at the threshold to those extremal values when none + of the inputs are positive. + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + min_positive: the minimum, per channel, of the proportion of the time + that (x > 0), below which we start to modify the derivatives. + max_positive: the maximum, per channel, of the proportion of the time + that (x > 0), above which we start to modify the derivatives. + max_factor: the maximum factor by which we modify the derivatives for + either the sign constraint or the magnitude constraint; + e.g. with max_factor=0.02, the the derivatives would be multiplied by + values in the range [0.98..1.02]. + sign_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_positive and max_positive + are violated. + scale_gain_factor: determines the 'gain' with which we increase the + change in gradient once the constraints on min_abs and max_abs + are violated. + min_abs: the minimum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + max_abs: the maximum average-absolute-value difference from the mean + value per channel, which we allow, before we start to modify + the derivatives to prevent this. + min_prob: determines the minimum probability with which we modify the + gradients for the {min,max}_positive and {min,max}_abs constraints, + on each forward(). This is done randomly to prevent all layers + from doing it at the same time. Early in training we may use + higher probabilities than this; it will decay to this value. + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, + ): + super(ActivationBalancer, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.min_positive = min_positive + self.max_positive = max_positive + self.max_factor = max_factor + self.min_abs = min_abs + self.max_abs = max_abs + self.min_prob = min_prob + self.sign_gain_factor = sign_gain_factor + self.scale_gain_factor = scale_gain_factor + + # count measures how many times the forward() function has been called. + # We occasionally sync this to a tensor called `count`, that exists to + # make sure it is synced to disk when we load and save the model. + self.cpu_count = 0 + self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or not x.requires_grad + or torch.jit.is_tracing() + ): + return _no_op(x) + + count = self.cpu_count + self.cpu_count += 1 + + if random.random() < 0.01: + # Occasionally sync self.cpu_count with self.count. + # count affects the decay of 'prob'. don't do this on every iter, + # because syncing with the GPU is slow. + self.cpu_count = max(self.cpu_count, self.count.item()) + self.count.fill_(self.cpu_count) + + # the prob of doing some work exponentially decreases from 0.5 till it hits + # a floor at min_prob (==0.1, by default) + prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) + + if random.random() < prob: + sign_gain_factor = 0.5 + if self.min_positive != 0.0 or self.max_positive != 1.0: + sign_factor = _compute_sign_factor( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor, + ) + else: + sign_factor = None + + scale_factor = _compute_scale_factor( + x.detach(), + self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor, + ) + return ActivationBalancerFunction.apply( + x, + scale_factor, + sign_factor, + self.channel_dim, + ) + else: + return _no_op(x) + + +def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: + """ + Returns x unmodified, but in backprop will put a penalty for the excess of + the absolute values of elements of x over the limit "limit". E.g. if + limit == 10.0, then if x has any values over 10 it will get a penalty. + + Caution: the value of this penalty will be affected by grad scaling used + in automatic mixed precision training. For this reasons we use this, + it shouldn't really matter, or may even be helpful; we just use this + to disallow really implausible values of scores to be given to softmax. + """ + x_sign = x.sign() + over_limit = (x.abs() - limit) > 0 + # The following is a memory efficient way to penalize the absolute values of + # x that's over the limit. (The memory efficiency comes when you think + # about which items torch needs to cache for the autograd, and which ones it + # can throw away). The numerical value of aux_loss as computed here will + # actually be larger than it should be, by limit * over_limit.sum(), but it + # has the same derivative as the real aux_loss which is penalty * (x.abs() - + # limit).relu(). + aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) + # note: we don't do sum() here on aux)_loss, but it's as if we had done + # sum() due to how with_loss() works. + x = with_loss(x, aux_loss) + # you must use x for something, or this will be ineffective. + return x + + +def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. + if x.ndim == 2: + return x.diag() + else: + (batch, dim, dim) = x.shape + x = x.reshape(batch, dim * dim) + x = x[:, :: dim + 1] + assert x.shape == (batch, dim) + return x + + +def _whitening_metric(x: Tensor, num_groups: int): + """ + Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of + of the centered feature covariance are the same within each group's covariance matrix + and also between groups. + Args: + x: a Tensor of shape (*, num_channels) + num_groups: the number of groups of channels, a number >=1 that divides num_channels + Returns: + Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and + greater than 1.0 otherwise. + """ + assert x.dtype != torch.float16 + x = x.reshape(-1, x.shape[-1]) + (num_frames, num_channels) = x.shape + assert num_channels % num_groups == 0 + channels_per_group = num_channels // num_groups + x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) + # x now has shape (num_groups, num_frames, channels_per_group) + # subtract the mean so we use the centered, not uncentered, covariance. + # My experience has been that when we "mess with the gradients" like this, + # it's better not do anything that tries to move the mean around, because + # that can easily cause instability. + x = x - x.mean(dim=1, keepdim=True) + # x_covar: (num_groups, channels_per_group, channels_per_group) + x_covar = torch.matmul(x.transpose(1, 2), x) + x_covar_mean_diag = _diag(x_covar).mean() + # the following expression is what we'd get if we took the matrix product + # of each covariance and measured the mean of its trace, i.e. + # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). + x_covarsq_mean_diag = (x_covar ** 2).sum() / ( + num_groups * channels_per_group + ) + # this metric will be >= 1.0; the larger it is, the less 'white' the data was. + metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + return metric + + +class WhiteningPenaltyFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + num_groups: int, + whitening_limit: float, + grad_scale: float, + ) -> Tensor: + ctx.save_for_backward(x) + ctx.num_groups = num_groups + ctx.whitening_limit = whitening_limit + ctx.grad_scale = grad_scale + return x + + @staticmethod + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors + with torch.enable_grad(): + with torch.cuda.amp.autocast(enabled=False): + x_detached = x_orig.to(torch.float32).detach() + x_detached.requires_grad = True + + metric = _whitening_metric(x_detached, ctx.num_groups) + + if random.random() < 0.005 or __name__ == "__main__": + logging.info( + f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" + ) + + (metric - ctx.whitening_limit).relu().backward() + penalty_grad = x_detached.grad + scale = ctx.grad_scale * ( + x_grad.to(torch.float32).norm() + / (penalty_grad.norm() + 1.0e-20) + ) + penalty_grad = penalty_grad * scale + return x_grad + penalty_grad.to(x_grad.dtype), None, None, None + + +class Whiten(nn.Module): + def __init__( + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float, float]], + grad_scale: float, + ): + """ + Args: + num_groups: the number of groups to divide the channel dim into before + whitening. We will attempt to make the feature covariance + within each group, after mean subtraction, as "white" as possible, + while having the same trace across all groups. + whitening_limit: a value greater than 1.0, that dictates how much + freedom we have to violate the constraints. 1.0 would mean perfectly + white, with exactly the same trace across groups; larger values + give more freedom. E.g. 2.0. + prob: the probability with which we apply the gradient modification + (also affects the grad scale). May be supplied as a float, + or as a pair (min_prob, max_prob) + + grad_scale: determines the scale on the gradient term from this object, + relative to the rest of the gradient on the attention weights. + E.g. 0.02 (you may want to use smaller values than this if prob is large) + """ + super(Whiten, self).__init__() + assert num_groups >= 1 + assert whitening_limit >= 1 + assert grad_scale >= 0 + self.num_groups = num_groups + self.whitening_limit = whitening_limit + if isinstance(prob, float): + assert 0 < prob <= 1 + self.prob = prob + else: + (self.min_prob, self.max_prob) = prob + assert 0 < self.min_prob < self.max_prob <= 1 + self.prob = self.max_prob + + self.grad_scale = grad_scale + + def forward(self, x: Tensor) -> Tensor: + """ + In the forward pass, this function just returns the input unmodified. + In the backward pass, it will modify the gradients to ensure that the + distribution in each group has close to (lambda times I) as the covariance + after mean subtraction, with the same lambda across groups. + For whitening_limit > 1, there will be more freedom to violate this + constraint. + + Args: + x: the input of shape (*, num_channels) + + Returns: + x, unmodified. You should make sure + you use the returned value, or the graph will be freed + and nothing will happen in backprop. + """ + if ( + not x.requires_grad + or random.random() > self.prob + or self.grad_scale == 0 + ): + return _no_op(x) + else: + if hasattr(self, "min_prob") and random.random() < 0.25: + # occasionally switch between min_prob and max_prob, based on whether + # we are above or below the threshold. + if ( + _whitening_metric(x.to(torch.float32), self.num_groups) + > self.whitening_limit + ): + # there would be a change to the grad. + self.prob = self.max_prob + else: + self.prob = self.min_prob + + return WhiteningPenaltyFunction.apply( + x, self.num_groups, self.whitening_limit, self.grad_scale + ) + + +class WithLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor, y: Tensor): + ctx.y_shape = y.shape + return x + + @staticmethod + def backward(ctx, ans_grad: Tensor): + return ans_grad, torch.ones( + ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device + ) + + +def with_loss(x, y): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + # returns x but adds y.sum() to the loss function. + return WithLoss.apply(x, y) + + +def _no_op(x: Tensor) -> Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x + else: + # a no-op function that will have a node in the autograd graph, + # to avoid certain bugs relating to backward hooks + return x.chunk(1, dim=-1)[0] + + +class Identity(torch.nn.Module): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return _no_op(x) + + +class MaxEig(torch.nn.Module): + """ + Modifies the backpropped derivatives of a function to try to discourage + that any given direction in activation space accounts for more than + a specified proportion of the covariance (e.g. 0.2). + + + Args: + num_channels: the number of channels + channel_dim: the dimension/axis corresponding to the channel, e.g. + -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. + max_var_per_eig: the maximum proportion of the variance of the + features/channels, after mean subtraction, that can come from + any given eigenvalue. + min_prob: the minimum probability with which we apply this during any invocation + of forward(), assuming last time we applied the constraint it was + not active; supplied for speed. + scale: determines the scale with which we modify the gradients, relative + to the existing / unmodified gradients + """ + + def __init__( + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, + ): + super(MaxEig, self).__init__() + self.num_channels = num_channels + self.channel_dim = channel_dim + self.scale = scale + assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels + self.max_var_per_eig = max_var_per_eig + + # we figure out the dominant direction using the power method: starting with + # a random vector, keep multiplying by the covariance and renormalizing. + with torch.no_grad(): + # arbitrary.. would use randn() but want to leave the rest of the model's + # random parameters unchanged for comparison + direction = torch.arange(num_channels).to(torch.float) + direction = direction / direction.norm() + self.register_buffer("max_eig_direction", direction) + + self.min_prob = min_prob + # cur_prob is the current probability we'll use to apply the ActivationBalancer. + # We'll regress this towards prob, each time we try to apply it and it is not + # active. + self.cur_prob = 1.0 + + def forward(self, x: Tensor) -> Tensor: + if ( + torch.jit.is_scripting() + or self.max_var_per_eig <= 0 + or random.random() > self.cur_prob + or torch.jit.is_tracing() + ): + return _no_op(x) + + with torch.cuda.amp.autocast(enabled=False): + eps = 1.0e-20 + orig_x = x + x = x.to(torch.float32) + with torch.no_grad(): + x = x.transpose(self.channel_dim, -1).reshape( + -1, self.num_channels + ) + x = x - x.mean(dim=0) + new_direction, coeffs = self._find_direction_coeffs( + x, self.max_eig_direction + ) + x_var = (x ** 2).mean() + x_residual = x - coeffs * new_direction + x_residual_var = (x_residual ** 2).mean() + + # `variance_proportion` is the proportion of the variance accounted for + # by the top eigen-direction. + variance_proportion = (x_var - x_residual_var) / ( + x_var + 1.0e-20 + ) + + # ensure new direction is nonzero even if x == 0, by including `direction`. + self._set_direction( + 0.1 * self.max_eig_direction + new_direction + ) + + if random.random() < 0.01 or __name__ == "__main__": + logging.info( + f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" + ) + + if variance_proportion >= self.max_var_per_eig: + # The constraint is active. Note, we should quite rarely + # reach here, only near the beginning of training if we are + # starting to diverge, should this constraint be active. + cur_prob = self.cur_prob + self.cur_prob = ( + 1.0 # next time, do the update with probability 1.0. + ) + return MaxEigLimiterFunction.apply( + orig_x, coeffs, new_direction, self.channel_dim, self.scale + ) + else: + # let self.cur_prob exponentially approach self.min_prob, as + # long as the constraint is inactive. + self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob + return orig_x + + def _set_direction(self, direction: Tensor): + """ + Sets self.max_eig_direction to a normalized version of `direction` + """ + direction = direction.detach() + direction = direction / direction.norm() + direction_sum = direction.sum().item() + if direction_sum - direction_sum == 0: # no inf/nan + self.max_eig_direction[:] = direction + else: + logging.info( + f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}" + ) + + def _find_direction_coeffs( + self, x: Tensor, prev_direction: Tensor + ) -> Tuple[Tensor, Tensor, Tensor]: + """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. + + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ + (num_frames, num_channels) = x.shape + assert num_channels > 1 and num_frames > 1 + assert prev_direction.shape == (num_channels,) + # `coeffs` are the coefficients of `prev_direction` in x. + # actually represent the coeffs up to a constant positive factor. + coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 + cur_direction = (x * coeffs).sum(dim=0) / ( + (coeffs ** 2).sum() + 1.0e-20 + ) + return cur_direction, coeffs + + +class DoubleSwishFunction(torch.autograd.Function): + """ + double_swish(x) = x * torch.sigmoid(x-1) + This is a definition, originally motivated by its close numerical + similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). + + Memory-efficient derivative computation: + double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) + double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). + Now, s'(x) = s(x) * (1-s(x)). + double_swish'(x) = x * s'(x) + s(x). + = x * s(x) * (1-s(x)) + s(x). + = double_swish(x) * (1-s(x)) + s(x) + ... so we just need to remember s(x) but not x itself. + """ + + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + requires_grad = x.requires_grad + x_dtype = x.dtype + if x.dtype == torch.float16: + x = x.to(torch.float32) + + s = torch.sigmoid(x - 1.0) + y = x * s + + if requires_grad: + deriv = y * (1 - s) + s + # notes on derivative of x * sigmoid(x - 1): + # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 + # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund + # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. + # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which + # floors), should be expectation-preserving. + floor = -0.043637 + ceil = 1.2 + d_scaled = (deriv - floor) * ( + 255.0 / (ceil - floor) + ) + torch.rand_like(deriv) + if __name__ == "__main__": + # for self-testing only. + assert d_scaled.min() >= 0.0 + assert d_scaled.max() < 256.0 + d_int = d_scaled.to(torch.uint8) + ctx.save_for_backward(d_int) + if x.dtype == torch.float16 or torch.is_autocast_enabled(): + y = y.to(torch.float16) + return y + + @staticmethod + def backward(ctx, y_grad: Tensor) -> Tensor: + (d,) = ctx.saved_tensors + # the same constants as used in forward pass. + floor = -0.043637 + ceil = 1.2 + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d + + +class DoubleSwish(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return double-swish activation function which is an approximation to Swish(Swish(x)), + that we approximate closely with x * sigmoid(x-1). + """ + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return x * torch.sigmoid(x - 1.0) + return DoubleSwishFunction.apply(x) + + +def BalancedDoubleSwish( + d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25 +) -> nn.Sequential: + """ + ActivationBalancer -> DoubleSwish + """ + balancer = ActivationBalancer( + d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob + ) + return nn.Sequential( + balancer, + DoubleSwish(), + ) + + +def _test_max_eig(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = MaxEig( + num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad, atol=1.0e-02) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_whiten(): + for proportion in [0.1, 0.5, 10.0]: + logging.info(f"_test_whiten(): proportion = {proportion}") + x = torch.randn(100, 128) + direction = torch.randn(128) + coeffs = torch.randn(100, 1) + x += proportion * direction * coeffs + + x.requires_grad = True + + num_channels = 128 + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale + + for _ in range(4): + y = m(x) + + y_grad = torch.randn_like(x) + y.backward(gradient=y_grad) + + if proportion < 0.2: + assert torch.allclose(x.grad, y_grad) + elif proportion > 1.0: + assert not torch.allclose(x.grad, y_grad) + + +def _test_activation_balancer_sign(): + probs = torch.arange(0, 1, 0.01) + N = 1000 + x = 1.0 * ( + (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0 + ) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + probs.numel(), + channel_dim=0, + min_positive=0.05, + max_positive=0.95, + max_factor=0.2, + min_abs=0.0, + ) + + y_grad = torch.sign(torch.randn(probs.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_sign: x = ", x) + print("_test_activation_balancer_sign: y grad = ", y_grad) + print("_test_activation_balancer_sign: x grad = ", x.grad) + + +def _test_activation_balancer_magnitude(): + magnitudes = torch.arange(0, 1, 0.01) + N = 1000 + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) + x = x.detach() + x.requires_grad = True + m = ActivationBalancer( + magnitudes.numel(), + channel_dim=0, + min_positive=0.0, + max_positive=1.0, + max_factor=0.2, + min_abs=0.2, + max_abs=0.8, + min_prob=1.0, + ) + + y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) + + y = m(x) + y.backward(gradient=y_grad) + print("_test_activation_balancer_magnitude: x = ", x) + print("_test_activation_balancer_magnitude: y grad = ", y_grad) + print("_test_activation_balancer_magnitude: x grad = ", x.grad) + + +def _test_basic_norm(): + num_channels = 128 + m = BasicNorm(num_channels=num_channels, channel_dim=1) + + x = torch.randn(500, num_channels) + + y = m(x) + + assert y.shape == x.shape + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() + print("x rms = ", x_rms) + print("y rms = ", y_rms) + assert y_rms < x_rms + assert y_rms > 0.5 * x_rms + + +def _test_double_swish_deriv(): + x = torch.randn(10, 12, dtype=torch.double) * 3.0 + x.requires_grad = True + m = DoubleSwish() + + tol = (1.2 - (-0.043637)) / 255.0 + torch.autograd.gradcheck(m, x, atol=tol) + + # for self-test. + x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 + x.requires_grad = True + y = m(x) + + +def _test_softmax(): + a = torch.randn(2, 10, dtype=torch.float64) + b = a.clone() + a.requires_grad = True + b.requires_grad = True + a.softmax(dim=1)[:, 0].sum().backward() + print("a grad = ", a.grad) + softmax(b, dim=1)[:, 0].sum().backward() + print("b grad = ", b.grad) + assert torch.allclose(a.grad, b.grad) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_softmax() + _test_whiten() + _test_max_eig() + _test_activation_balancer_sign() + _test_activation_balancer_magnitude() + _test_basic_norm() + _test_double_swish_deriv() diff --git a/slam_llm/models/vallex/transformers.py b/slam_llm/models/vallex/transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..183b65ca3190146fb195a8c590a76db91ad49ea2 --- /dev/null +++ b/slam_llm/models/vallex/transformers.py @@ -0,0 +1,613 @@ +import copy +import numbers +from functools import partial +from typing import Any, Callable, List, Optional, Tuple, Union + +import torch +from torch import Tensor, nn +from torch.nn import functional as F + +from .activation import MultiheadAttention +from .scaling import BasicNorm as _BasicNorm + +_shape_t = Union[int, List[int], torch.Size] + + +class LayerNorm(nn.Module): + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + self.bias = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + F.layer_norm( + input, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ), + embedding, + ) + + assert embedding is None + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) + + def extra_repr(self) -> str: + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return (weight * self.norm(input) + bias, embedding) + + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +class BasicNorm(_BasicNorm): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ): + super(BasicNorm, self).__init__(d_model, eps=eps) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + super(BasicNorm, self).forward(input), + embedding, + ) + + assert embedding is None + return super(BasicNorm, self).forward(input) + + +class BalancedBasicNorm(nn.Module): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ): + super(BalancedBasicNorm, self).__init__() + self.balancer = ActivationBalancer( + d_model, + channel_dim=-1, + min_positive=0.45, + max_positive=0.55, + max_abs=6.0, + ) + self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return self.norm((self.balancer(input), embedding)) + + assert embedding is None + return self.norm(self.balancer(input)) + + +class IdentityNorm(nn.Module): + def __init__( + self, + d_model: int, + eps: float = 1e-5, + device=None, + dtype=None, + ) -> None: + super(IdentityNorm, self).__init__() + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + return input + + assert embedding is None + return input + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + + # Implementation of Feedforward model + + self.dropout = nn.Dropout(dropout) + + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + activation = activation(d_model) + + self.activation = activation + + norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + if layer_norm_cls == IdentityNorm: + norm2 = BalancedBasicNorm( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + else: + norm2 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + + self.norm1 = norm1 + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) + self.norm2 = norm2 + + + def __setstate__(self, state): + super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + x, stage_embedding = src, None + is_src_tuple = False + if isinstance(src, tuple): + x, stage_embedding = src + is_src_tuple = True + + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point( + src_key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + + if self.norm_first: + x = x + self._sa_block( + self.norm1(x, stage_embedding), + src_mask, + src_key_padding_mask, + ) + + x = x + self._ff_block(self.norm2(x, stage_embedding)) + else: + x = self.norm1( + x + self._sa_block(x, src_mask, src_key_padding_mask), + stage_embedding, + ) + x = self.norm2(x + self._ff_block(x), stage_embedding) + + if is_src_tuple: + return (x, stage_embedding) + return x + + def infer( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + past_kv: Optional[Tensor] = None, + use_cache: bool = False, + ): + x, stage_embedding = src, None + is_src_tuple = False + if isinstance(src, tuple): + x, stage_embedding = src + is_src_tuple = True + + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point( + src_key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + + if self.norm_first: + x_attn_out, kv = self.self_attn.infer( + self.norm1(x, stage_embedding), + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + need_weights=False, + past_kv=past_kv, + use_cache=use_cache, + ) + x = x + x_attn_out + x = x + self._ff_block(self.norm2(x, stage_embedding)) + + if is_src_tuple: + return (x, stage_embedding) + return (x, kv) + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class TransformerEncoder(nn.Module): + __constants__ = ["norm"] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + ) -> Tensor: + output = src + for i, mod in enumerate(self.layers): + output = mod( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + # print(i, output.mean()) + if self.norm is not None: + output = self.norm(output) + + return output + + def infer( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + past_kv: Optional[Tensor] = None, + use_cache: bool = False, + ): + if past_kv is None: + past_length = 0 + past_kv = tuple([None] * self.num_layers) + else: + past_length = past_kv[0][0].size(-2) + new_kv = () if use_cache else None + output = src + for i, (mod, past_layer_kv) in enumerate(zip(self.layers, past_kv)): + output, kv = mod.infer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache + ) + # print(i, output.mean()) + if use_cache: + new_kv = new_kv + (kv,) + + if self.norm is not None: + output = self.norm(output) + + return output, new_kv + + +class TransformerDecoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerDecoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + self.multihead_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + # Implementation of Feedforward model + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) + + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + self.activation = activation(d_model) + else: + self.activation = activation + + if adaptive_layer_norm: + norm1 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + norm2 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + norm3 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + self.norm3 = AdaptiveLayerNorm(d_model, norm3) + else: + self.norm1 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + self.norm2 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + if layer_norm_cls == IdentityNorm: + self.norm3 = BalancedBasicNorm( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + else: + self.norm3 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + tgt_is_tuple = False + if isinstance(tgt, tuple): + x, stage_embedding = tgt + tgt_is_tuple = True + else: + x, stage_embedding = tgt, None + + if self.norm_first: + x = x + self._sa_block( + self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask + ) + x = x + self._mha_block( + self.norm2(x, stage_embedding), + memory, + memory_mask, + memory_key_padding_mask, + ) + x = x + self._ff_block(self.norm3(x, stage_embedding)) + else: + x = self.norm1( + x + self._sa_block(x, tgt_mask, tgt_key_padding_mask), + stage_embedding, + ) + x = self.norm2( + x + + self._mha_block( + x, memory, memory_mask, memory_key_padding_mask + ), + stage_embedding, + ) + x = self.norm3(x + self._ff_block(x), stage_embedding) + + if tgt_is_tuple: + return (x, stage_embedding) + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # multihead attention block + def _mha_block( + self, + x: Tensor, + mem: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + ) -> Tensor: + x = self.multihead_attn( + x, + mem, + mem, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout2(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout3(x) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) diff --git a/slam_llm/models/vallex/vallex_config.py b/slam_llm/models/vallex/vallex_config.py new file mode 100644 index 0000000000000000000000000000000000000000..242218f2f4559b74ba90a76c31bfce6bbac1cdc7 --- /dev/null +++ b/slam_llm/models/vallex/vallex_config.py @@ -0,0 +1,56 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from fairseq.data import Dictionary +from transformers import AutoConfig, AutoModel, AutoModelForImageClassification + +logger = logging.get_logger(__name__) + + + +class VallexConfig(PretrainedConfig): + + model_type = "vallex" + + def __init__(self, + n_layer=24, + n_head=16, + n_dim=1024, + prefix_mode=1, + num_quantizers=8, + sample_rate=24000, + ar_at_dict="", + ar_st_dict="", + nar_at_dict="", + nar_st_dict="", + nar_scale_factor=1.0, + prepend_bos=True, + norm_first=True, + eps=0.0, + only_ar=False, + only_nar=False, + **kwargs + ): + self.n_layer = n_layer + self.n_head = n_head + self.n_dim = n_dim + self.prefix_mode = prefix_mode + self.num_quantizers = num_quantizers + self.sample_rate = sample_rate + self.nar_scale_factor = nar_scale_factor + self.prepend_bos = prepend_bos + self.norm_first = norm_first + + self.ar_at_dict = ar_at_dict + self.ar_st_dict = ar_st_dict + self.nar_at_dict = nar_at_dict + self.nar_st_dict = nar_st_dict + self.eps = eps + self.only_ar = only_ar + self.only_nar = only_nar + + super().__init__( + **kwargs + ) + + +AutoConfig.register("vallex", VallexConfig) \ No newline at end of file diff --git a/slam_llm/models/vallex/vallex_model.py b/slam_llm/models/vallex/vallex_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a57426e0f0ffb40a89b1b39b323284c7638b15 --- /dev/null +++ b/slam_llm/models/vallex/vallex_model.py @@ -0,0 +1,772 @@ +import random +from typing import Dict, Iterator, List, Tuple, Union +from fairseq import utils +import numpy as np +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +from fairseq.data import Dictionary +from src.slam_llm.models.vallex.transformers import ( + LayerNorm, + TransformerEncoder, + TransformerEncoderLayer, +) +from src.slam_llm.models.vallex.vallex_config import VallexConfig +from transformers.modeling_utils import PreTrainedModel +from transformers import AutoConfig, AutoModel, AutoModelForImageClassification +from dataclasses import dataclass + +@dataclass +class ModelOutput: + logits: torch.Tensor + loss: torch.Tensor + acc: torch.Tensor + +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True, scale=1, prob_mask=None): + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + if prob_mask is not None: + lprobs = lprobs.masked_fill(prob_mask, 0.0) + n_class = (1-prob_mask.float()).sum() + else: + n_class = lprobs.size(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + # nll_loss = nll_loss * scale + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) * scale + if ignore_index is not None: + pad_mask = target.eq(ignore_index) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + pad_mask_float = (1 - pad_mask.to(torch.float)).sum() + else: + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + if reduce: + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + eps_i = epsilon / (n_class - 1) + loss = (1.0 - epsilon - eps_i) * nll_loss + \ + eps_i * smooth_loss + return loss / pad_mask_float, nll_loss / pad_mask_float + + +class SinusoidalPositionalEmbedding(nn.Module): + def __init__(self, embedding_dim, padding_idx, init_size=1024): + super().__init__() + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx if padding_idx is not None else 0 + self.weights = SinusoidalPositionalEmbedding.get_embedding( + init_size, embedding_dim, padding_idx + ) + self.onnx_trace = False + self.register_buffer("_float_tensor", torch.FloatTensor(1)) + self.max_positions = int(1e5) + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + @staticmethod + def get_embedding( + num_embeddings: int, embedding_dim: int, padding_idx = None + ): + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( + 1 + ) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( + num_embeddings, -1 + ) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + def forward( + self, + input, + incremental_state = None, + timestep = None, + positions = None, + ): + bspair = torch.onnx.operators.shape_as_tensor(input) + bsz, seq_len = bspair[0], bspair[1] + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.size(0): + # recompute/expand embeddings if needed + self.weights = SinusoidalPositionalEmbedding.get_embedding( + max_pos, self.embedding_dim, self.padding_idx + ) + self.weights = self.weights.to(self._float_tensor) + + if incremental_state is not None: + # positions is the same for every token when decoding a single step + pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len + if self.onnx_trace: + return ( + self.weights.index_select(index=self.padding_idx + pos, dim=0) + .unsqueeze(1) + .repeat(bsz, 1, 1) + ) + return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) + + positions = utils.make_positions( + input, self.padding_idx, onnx_trace=self.onnx_trace + ) + if self.onnx_trace: + flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) + embedding_shape = torch.cat( + (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) + ) + embeddings = torch.onnx.operators.reshape_from_tensor_shape( + flat_embeddings, embedding_shape + ) + return embeddings + return ( + self.weights.index_select(0, positions.view(-1)) + .view(bsz, seq_len, -1) + .detach() + ) + + +class Transpose(nn.Identity): + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input.transpose(1, 2) + + +class VALLF(PreTrainedModel): + config_class = VallexConfig + + def __init__( + self, + config: VallexConfig + ): + super().__init__(config) + + self.ar_at_dict = Dictionary.load(self.config.ar_at_dict) + self.ar_st_dict = Dictionary.load(self.config.ar_st_dict) + self.nar_at_dict = Dictionary.load(self.config.nar_at_dict) + self.nar_st_dict = Dictionary.load(self.config.nar_st_dict) + + self.ar_at_dict.tts_flag = self.ar_at_dict.add_symbol("<TTS>") + self.ar_st_dict.asr_flag = self.ar_st_dict.add_symbol("<ASR>") + self.ar_st_dict.mt_flag = self.ar_st_dict.add_symbol("<MT>") + + self.padding_idx = self.ar_at_dict.pad() + self.config = config + d_model = self.config.n_dim + nar_scale_factor = self.config.nar_scale_factor + prepend_bos = self.config.prepend_bos + + norm_first = self.config.norm_first + num_layers = self.config.n_layer + self.NUM_AUDIO_TOKENS = self.ar_at_dict.eos() + + nar_d_model = int(d_model * nar_scale_factor) + + self.ar_text_embedding = nn.Embedding(len(self.ar_st_dict), d_model, self.ar_st_dict.pad()) # W_x + if config.only_ar: + pass + else: + self.nar_text_embedding = nn.Embedding(len(self.nar_st_dict), d_model, self.nar_st_dict.pad()) + + # ID self.NUM_AUDIO_TOKENS -> PAD + # ID self.NUM_AUDIO_TOKENS + 1 -> BOS + self.ar_audio_prepend_bos = prepend_bos + self.ar_audio_embedding = EncodecDecoderLstm( + dictionary=self.ar_at_dict, emb_dim=d_model + ) + + self.ar_text_prenet = nn.Identity() + self.ar_audio_prenet = nn.Identity() + + self.ar_text_position = SinusoidalPositionalEmbedding( + d_model, + padding_idx=self.ar_at_dict.pad(), + init_size=1024+self.ar_at_dict.pad()+1 + ) + self.ar_audio_position = SinusoidalPositionalEmbedding( + d_model, + padding_idx=self.ar_at_dict.pad(), + init_size=1024+self.ar_at_dict.pad()+1 + ) + + self.ar_decoder = TransformerEncoder( + TransformerEncoderLayer( + d_model, + self.config.n_head, + dim_feedforward=d_model * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + ), + num_layers=num_layers, + norm=LayerNorm(d_model) if norm_first else None, + ) + self.ar_predict_layer = nn.Linear( + d_model, len(self.ar_at_dict), bias=False + ) + + self.rng = random.Random(0) + self.num_heads = self.config.n_head + self.prefix_mode = self.config.prefix_mode + self.num_quantizers = self.config.num_quantizers + + assert self.num_quantizers >= 1 + if config.only_ar: + pass + else: + if self.num_quantizers > 1: + self.nar_audio_embeddings = NATEncodecDecoderLstm( + codecs=[0, 1, 2, 3, 4, 5, 6, 7], dictionary=self.nar_at_dict, emb_dim=d_model + ) # W_a + + self.nar_text_prenet = nn.Identity() + self.nar_audio_prenet = nn.Identity() + + self.nar_text_position = SinusoidalPositionalEmbedding( + d_model, + padding_idx=self.nar_at_dict.pad(), + init_size=1024+self.nar_at_dict.pad()+1 + ) + self.nar_audio_position = SinusoidalPositionalEmbedding( + d_model, + padding_idx=self.nar_at_dict.pad(), + init_size=1024+self.nar_at_dict.pad()+1 + ) + + self.nar_decoder = TransformerEncoder( + TransformerEncoderLayer( + nar_d_model, + int(self.num_heads * nar_scale_factor), + dim_feedforward=nar_d_model * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + adaptive_layer_norm=True, + ), + num_layers=int(num_layers * nar_scale_factor), + norm=nn.LayerNorm(nar_d_model) + if norm_first + else None, + ) + self.nar_predict_layers = nn.ModuleList( + [ + nn.Linear(nar_d_model, len(self.nar_at_dict), bias=False) + for i in range(self.num_quantizers) + ] + ) + self.nar_stage_embeddings = None + + def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: + assert stage > 0 + if stage == 1: + for name, param in self.named_parameters(): + if name.startswith("ar_"): + print(f" AR parameter: {name}") + yield param + + if stage == 2: + for name, param in self.named_parameters(): + if name.startswith("nar_"): + print(f"NAR parameter: {name}") + yield param + + def stage_named_parameters( + self, stage: int = 1 + ) -> Iterator[Tuple[str, nn.Parameter]]: + assert stage > 0 + if stage == 1: + for pair in self.named_parameters(): + if pair[0].startswith("ar_"): + yield pair + + if stage == 2: + for pair in self.named_parameters(): + if pair[0].startswith("nar_"): + yield pair + + def pad_y_eos(self, y, y_mask_int, eos_id): + targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( + y_mask_int, (0, 1), value=1 + ) + # inputs, targets + if self.ar_audio_prepend_bos: + return ( + F.pad(targets[:, :-1], (1, 0), value=self.NUM_AUDIO_TOKENS + 1), + targets, + ) + + return targets[:, :-1], targets[:, 1:] + +class VALLE(VALLF): + config_class = VallexConfig + + def __init__( + self, + config: VallexConfig, + **kwargs, + ): + super(VALLE, self).__init__( + config, + **kwargs, + ) + print(config) + self.config = config + d_model = self.config.n_dim + self.eps = config.eps + + self.language_ID = { + 'en': 0, + 'zh': 1, + } + self.ar_language_embedding = nn.Embedding(3, d_model, padding_idx=2) + self.nar_language_embedding = nn.Embedding(3, d_model, padding_idx=2) + self.embed_scale = 32.0 + + def forward( + self, + zh, + en + ): + """ + "zh": { + "st_tokens": zh_st, + "at_tokens_wbos": zh_prev_at, + "at_tokens_tgt": zh_tgt_at, + "self_atten_mask": zh_self_atten_mask, + "padding_mask": zh_padding_mask, + "langid": zh_id.long() + }, + "en": { + "st_tokens": en_st, + "at_tokens_wbos": en_prev_at, + "at_tokens_tgt": en_tgt_at, + "self_atten_mask": en_self_atten_mask, + "padding_mask": en_padding_mask, + "langid": en_id.long() + } + """ + flag = (np.random.randint(low=0, high=1000) % 2 == 0) # zh or en + if flag: + data = zh + else: + data = en + + st_tokens = data["st_tokens"] + at_tokens_wbos = data["at_tokens_wbos"] + at_tokens_tgt = data["at_tokens_tgt"] + self_atten_mask = data["self_atten_mask"] + padding_mask = data["padding_mask"] + langid = data["langid"] + + st_len = st_tokens.size(1) + st_emb = self.embed_scale * self.ar_text_embedding(st_tokens) + src_lang_emb = self.embed_scale * self.ar_language_embedding(langid) + st_emb += src_lang_emb + st_pos = self.ar_text_position(st_tokens) + st_emb += st_pos + + at_emb, _ = self.ar_audio_embedding(at_tokens_wbos, None) + at_emb = self.embed_scale * at_emb + tgt_lang_emb = self.embed_scale * self.ar_language_embedding(langid) + at_emb += tgt_lang_emb + at_pos = self.ar_audio_position(at_tokens_wbos) + at_emb += at_pos + + x = torch.concat([st_emb, at_emb], dim=1) + + x = self.ar_decoder( + x, + mask=self_atten_mask, + src_key_padding_mask=padding_mask + ) + x = self.ar_predict_layer(x) + x = x[:, st_len:, :] + loss, nll_loss, lprob, right_rate = self.calculate_loss( + x, at_tokens_tgt + ) + return ModelOutput(logits=lprob, loss=loss, acc=right_rate), right_rate + + def calculate_loss(self, encoder_out, target, reduce=True, scale=1.0, prob_mask=None, acc=True): + lprob = self.get_normalized_probs(encoder_out, log_probs=True) + with torch.no_grad(): + mask = target.ne(self.padding_idx) + n_correct = torch.sum( + lprob.argmax(-1).masked_select(mask).eq(target.masked_select(mask)) + ) + total = torch.sum(mask) + right_rate = n_correct * 100.0 / total + + lprob, target = lprob.view(-1, lprob.size(-1)), target.view(-1) + loss, nll_loss = label_smoothed_nll_loss( + lprob, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + scale=scale, + prob_mask=prob_mask + ) + + return loss, nll_loss, lprob, right_rate + + def get_normalized_probs(self, encoder_out, log_probs, sample=None): + if torch.is_tensor(encoder_out): + logits = encoder_out.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + + + def inference_24L( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + enroll_x_lens: torch.Tensor, + top_k: int = -100, + temperature: float = 1.0, + prompt_language: str = None, + text_language: str = None, + best_of: int = 1, + length_penalty: float = 1.0, + return_worst: bool = False, + at_eos: int = -1 + ) -> torch.Tensor: + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3, y.shape + assert y.shape[0] == 1, y.shape + + assert torch.all(x_lens > 0) + self.NUM_AUDIO_TOKENS = at_eos + text = x + x = self.embed_scale * self.ar_text_embedding(text) + prompt_language_id = prompt_language.to(x.device) + text_language_id = text_language.to(x.device) + src_lang_emb = self.embed_scale * self.ar_language_embedding(prompt_language_id) + tgt_lang_emb = self.embed_scale * self.ar_language_embedding(text_language_id) + x[:, :enroll_x_lens, :] += src_lang_emb + x[:, enroll_x_lens:, :] += tgt_lang_emb + x = self.ar_text_prenet(x) + x_pos = self.ar_text_position(text) + + text_len = x_lens.max() + prompts = y + prefix_len = y.shape[1] + + # AR Decoder + # TODO: Managing decoder steps avoid repetitive computation + y = prompts[..., 0] + if self.ar_audio_prepend_bos: + y = F.pad(y, (1, 0), value=self.ar_at_dict.tts_flag) + + x_len = x_lens.max() + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + + kv_cache = None + use_kv_caching = True + + sum_logprobs = torch.zeros(best_of, device=y.device) # implement batch decoding here + x = x.repeat(best_of, 1, 1) + y = y.repeat(best_of, 1) + lstm_h = None + first_ar = True + while True: + if first_ar: + y_emb, lstm_h = self.ar_audio_embedding(y, lstm_h) + y_emb = y_emb * self.embed_scale + y_emb = self.ar_audio_prenet(y_emb) + y_pos = self.ar_audio_position(y) + y_emb[:, :prefix_len] = y_emb[:, :prefix_len] + src_lang_emb + y_emb[:, prefix_len:] = y_emb[:, prefix_len:] + tgt_lang_emb + xy_pos_token = torch.concat([x_pos+x, y_pos+y_emb], dim=1) + first_ar = False + else: + y_emb_cur, lstm_h = self.ar_audio_embedding(y[:, -1:], lstm_h) + y_emb_cur = y_emb_cur * self.embed_scale + y_emb_cur = self.ar_audio_prenet(y_emb_cur) + y_pos_cur = self.ar_audio_position(y)[:, -1:] + y_emb_cur = y_emb_cur + src_lang_emb + y_emb_cur = y_emb_cur + tgt_lang_emb + xy_pos_token = torch.concat([xy_pos_token, y_pos_cur+y_emb_cur], dim=1) + # print(xy_pos_token.size()) + + y_len = y.shape[1] + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), + value=True, + ) + y_attn_mask = F.pad( + torch.triu( + torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1 + ), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat( + [x_attn_mask_pad, y_attn_mask], dim=0 + ).to(y.device) + + + if use_kv_caching and kv_cache is not None: + xy_pos = xy_pos_token[:, [-1]] + xy_attn_mask = xy_attn_mask[:, [-1]] + else: + xy_pos = xy_pos_token + + xy_dec, kv_cache = self.ar_decoder.infer( + xy_pos, + mask=xy_attn_mask, + past_kv=kv_cache, + use_cache=use_kv_caching, + ) + + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples, current_logprobs = topk_sampling( + logits, top_k=top_k, top_p=1, temperature=temperature + ) + sum_logprobs += current_logprobs * (y[:, -1] != self.NUM_AUDIO_TOKENS) + samples[y[:, -1] == self.NUM_AUDIO_TOKENS] = self.NUM_AUDIO_TOKENS + completed = (samples[:, -1] == self.NUM_AUDIO_TOKENS).all() + if ( + completed + or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 32 + ): + if prompts.shape[1] == y.shape[1]: + raise SyntaxError( + "well trained model shouldn't reach here." + ) + lengths = torch.sum(y != self.NUM_AUDIO_TOKENS, dim=1) + avg_logprobs = sum_logprobs / lengths ** length_penalty + # choose the best beam according to sum_logprobs + best_beam = y[torch.argmax(avg_logprobs), :] + worst_beam = y[torch.argmin(avg_logprobs), :] + # strip all eos tokens + best_beam = best_beam[best_beam != self.NUM_AUDIO_TOKENS] + worst_beam = worst_beam[worst_beam != self.NUM_AUDIO_TOKENS] + if return_worst: + y = worst_beam.unsqueeze(0) + else: + y = best_beam.unsqueeze(0) + print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]") + break + + y = torch.concat([y, samples], dim=1) + + codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] + if self.num_quantizers == 1: + return torch.stack(codes, dim=-1) + + if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes + enrolled_len = enroll_x_lens.max().item() + # SOS + Synthesis Text + EOS + text = torch.concat( + [ + text[:, :1], + text[:, enrolled_len - 1 :], + ], + dim=1, + ) + text_len = text_len - (enrolled_len - 2) + assert text.shape[0] == 1 + + x = self.embed_scale * self.nar_text_embedding(text) + # Add language embedding + prompt_language_id = prompt_language.to(x.device) + text_language_id = text_language.to(x.device) + src_lang_emb = self.embed_scale * self.nar_language_embedding(prompt_language_id) + tgt_lang_emb = self.embed_scale * self.nar_language_embedding(text_language_id) + x[:, :enroll_x_lens, :] += src_lang_emb + x[:, enroll_x_lens:, :] += tgt_lang_emb + x = self.nar_text_prenet(x) + x_pos = self.nar_text_position(text) + + if self.prefix_mode == 0: + for i, predict_layer in enumerate( + self.nar_predict_layers + ): + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < self.num_quantizers - 2: + y_emb[:, :prefix_len] += self.embed_scale * self.nar_audio_embeddings( + prompts[..., i + 1] + )[0] + y_emb[:, prefix_len:] += self.embed_scale * self.nar_audio_embeddings(samples)[0] + else: + y_pos = self.nar_audio_position(y[:, int(self.ar_audio_prepend_bos):]) + + ref_at_emb = self.embed_scale * self.nar_audio_embeddings(prompts)[0] + src_lang_emb + est_at = y[:, prefix_len+int(self.ar_audio_prepend_bos):].unsqueeze(-1) + # + for i in range(1, 8): + y_emb, _ = self.nar_audio_embeddings(est_at) + y_emb = self.embed_scale * y_emb + tgt_lang_emb + + y_emb = torch.concat([ref_at_emb, y_emb], dim=1) + xy_pos = torch.concat([x+x_pos, y_emb+y_pos], dim=1) + + xy_dec = self.nar_decoder( + xy_pos + ) + logits = self.nar_predict_layers[i-1](xy_dec[:, text_len + prefix_len :]) + # print(logits.size(), xy_pos.size(), xy_dec.size()) + samples = torch.argmax(logits, dim=-1) + est_at = torch.concat([est_at, samples.unsqueeze(-1)], dim=-1) + codes.append(samples) + + assert len(codes) == self.num_quantizers + return torch.stack(codes, dim=-1) + +def top_k_top_p_filtering( + logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 +): + if top_k > 0: + top_k = min( + max(top_k, min_tokens_to_keep), logits.size(-1) + ) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1 + ) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + + +def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): + if temperature != 1.0: + logits = logits / temperature + # Top-p/top-k filtering + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + # Sample + token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + logprobs = F.log_softmax(logits.float(), dim=-1) + current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)] + return token, current_logprobs + +class SLSTM(nn.Module): + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True, bidirectional=False): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=bidirectional) + if bidirectional: + self.out_fc = nn.Linear(dimension*2, dimension) + else: + self.out_fc = None + + def forward(self, x, hidden=None): + x = x.permute(2, 0, 1) + y, hidden = self.lstm(x, hidden) + if self.out_fc is not None: + y = self.out_fc(y) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y, hidden + +class EncodecDecoderLstm(nn.Module): + def __init__(self, dictionary, emb_dim, + out_dim=None, + num_layers=3, lstm_skip=True, lstm_bidire=False, + activation_param={'alpha': 1.0}, **kwargs): + super().__init__() + + # Identity() + if out_dim is None: + out_dim = emb_dim + self.slstm = SLSTM(dimension=out_dim, num_layers=num_layers, skip=lstm_skip, bidirectional=lstm_bidire) + self.elu = nn.ELU(**activation_param) + self.embedding_dim = emb_dim + self.padding_idx = dictionary.pad() + self.emb = nn.Embedding(len(dictionary), emb_dim, dictionary.pad_index) + + def forward(self, x, hidden=None): + """ + Args: + x (_type_): B,T,D + """ + # print(x.size()) + quantized_out = self.emb(x) + out, hidden = self.slstm(quantized_out.permute(0,2,1), hidden) + out = self.elu(out) + return out.permute(0,2,1), hidden + +class NATEncodecDecoderLstm(nn.Module): + def __init__(self, codecs, dictionary, emb_dim, + out_dim=None, + num_layers=3, lstm_skip=True, lstm_bidire=False, + activation_param={'alpha': 1.0}, **kwargs): + super().__init__() + + # Identity() + if out_dim is None: + out_dim = emb_dim + self.slstm = SLSTM(dimension=out_dim, num_layers=num_layers, skip=lstm_skip, bidirectional=lstm_bidire) + self.elu = nn.ELU(**activation_param) + self.codecs = codecs + self.embedding_dim = emb_dim + self.padding_idx = dictionary.pad() + self.emb_list = nn.ModuleList( + [nn.Embedding(len(dictionary), emb_dim, dictionary.pad_index) for i in range(len(self.codecs))] + ) + + def forward(self, x, hidden=None): + """ + Args: + x (_type_): B,T,D + """ + if len(x.size()) == 2: + x = x.unsqueeze(-1) + + if x.size(2) != len(self.codecs) and x.size(1) == len(self.codecs): + x = x.permute(0, 2, 1) + + quantized_out = 0 + for i in range(x.size(2)): + quantized = self.emb_list[i](x[: , :, i]) + quantized_out = quantized_out + quantized + # quantized_out = quantized_out / len(self.codecs) + + out, hidden = self.slstm(quantized_out.permute(0,2,1), hidden) + out = self.elu(out) + return out.permute(0,2,1), hidden + +AutoModel.register(VallexConfig, VALLE) \ No newline at end of file diff --git a/slam_llm/models/wavlm/WavLM.py b/slam_llm/models/wavlm/WavLM.py new file mode 100644 index 0000000000000000000000000000000000000000..f0134c660c0bc13bd35ee16867ef6bd71ce1d362 --- /dev/null +++ b/slam_llm/models/wavlm/WavLM.py @@ -0,0 +1,743 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import logging +from typing import List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm +from .modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + MultiheadAttention, + SamePad, + init_bert_params, + get_activation_fn, + TransposeLast, + GLU_Linear, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default" + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride, padding=1) + ) + self.conv_layers.append( + torch.nn.LayerNorm([dim, idim]) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append( + torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + ) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, + self_attn_mask=streaming_mask, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + diff --git a/slam_llm/models/wavlm/modules.py b/slam_llm/models/wavlm/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..8f19ab8a48665d5fa5db87371ba2b8e524fcfbf7 --- /dev/null +++ b/slam_llm/models/wavlm/modules.py @@ -0,0 +1,827 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +from torch.nn import Parameter +import torch.nn.functional as F + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function + """ + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights diff --git a/slam_llm/policies/__init__.py b/slam_llm/policies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..309b87277c583123820458295626de24cd1d10fb --- /dev/null +++ b/slam_llm/policies/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from slam_llm.policies.mixed_precision import * +from slam_llm.policies.wrapping import * +from slam_llm.policies.activation_checkpointing_functions import apply_fsdp_checkpointing +from slam_llm.policies.anyprecision_optimizer import AnyPrecisionAdamW diff --git a/slam_llm/policies/activation_checkpointing_functions.py b/slam_llm/policies/activation_checkpointing_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa98fe68c3a0581d134f4038d2cdde9d0985003 --- /dev/null +++ b/slam_llm/policies/activation_checkpointing_functions.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from functools import partial + +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + CheckpointImpl, + apply_activation_checkpointing, +) +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +non_reentrant_wrapper = partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, +) + +check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) + + +def apply_fsdp_checkpointing(model): + """apply activation checkpointing to model + returns None as model is updated directly + """ + print(f"--> applying fsdp activation checkpointing...") + + apply_activation_checkpointing( + model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn + ) diff --git a/slam_llm/policies/anyprecision_optimizer.py b/slam_llm/policies/anyprecision_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0172229656e611616ff8c05d07ecfabdac05ce3c --- /dev/null +++ b/slam_llm/policies/anyprecision_optimizer.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# AnyPrecisionAdamW: a flexible precision AdamW optimizer +# with optional Kahan summation for high precision weight updates. +# Allows direct control over momentum, variance and auxiliary compensation +# buffer dtypes. +# Optional Kahan summation is used to offset precision reduction for +# the weight updates. This allows full training in BFloat16 (equal or +# better than FP32 results in many cases) due to high precision weight upates. + +import torch +from torch.optim.optimizer import Optimizer + + +class AnyPrecisionAdamW(Optimizer): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0, + use_kahan_summation=False, + momentum_dtype=torch.bfloat16, + variance_dtype=torch.bfloat16, + compensation_buffer_dtype=torch.bfloat16, + ): + """ + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + + # Any Precision specific + use_kahan_summation = creates auxiliary buffer to ensure high precision + model param updates (default: False) + momentum_dtype = dtype for momentum (default: BFloat32) + variance_dtype = dtype for uncentered variance (default: BFloat16) + compensation_buffer_dtype = dtype for Kahan summation + buffer (default: BFloat16) + + # Usage + This optimizer implements optimizer states, and Kahan summation + for high precision updates, all in user controlled dtypes. + Defaults are variance in BF16, Momentum in FP32. + This can be run in FSDP mixed precision, amp, or full precision, + depending on what training pipeline you wish to work with. + + Setting to use_kahan_summation = False, and changing momentum and + variance dtypes to FP32, reverts this to a standard AdamW optimizer. + + """ + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + use_kahan_summation=use_kahan_summation, + momentum_dtype=momentum_dtype, + variance_dtype=variance_dtype, + compensation_buffer_dtype=compensation_buffer_dtype, + ) + + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + if closure is not None: + with torch.enable_grad(): + # to fix linter, we do not keep the returned loss for use atm. + closure() + + for group in self.param_groups: + + beta1, beta2 = group["betas"] + lr = group["lr"] + weight_decay = group["weight_decay"] + eps = group["eps"] + use_kahan_summation = group["use_kahan_summation"] + + momentum_dtype = group["momentum_dtype"] + variance_dtype = group["variance_dtype"] + compensation_buffer_dtype = group["compensation_buffer_dtype"] + + for p in group["params"]: + if p.grad is None: + continue + + if p.grad.is_sparse: + raise RuntimeError( + "AnyPrecisionAdamW does not support sparse gradients" + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + + state["step"] = torch.tensor(0.0) + + # momentum - EMA of gradient values + state["exp_avg"] = torch.zeros_like( + p, + dtype=momentum_dtype, + ) + + # variance uncentered - EMA of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, + dtype=variance_dtype, + ) + + # optional Kahan summation - accumulated error tracker + if use_kahan_summation: + state["compensation"] = torch.zeros_like( + p, + dtype=compensation_buffer_dtype, + ) + + # main processing ------------------------- + + # update the steps for each param group update + state["step"] += 1 + step = state["step"] + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + + grad = p.grad + + # weight decay, AdamW style + if weight_decay: + p.data.mul_(1 - lr * weight_decay) + + # update momentum + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + # update uncentered variance + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # adjust using bias1 + bias_correction1 = 1 - beta1**step + + step_size = lr / bias_correction1 + + # adjust using bias2 + denom_correction = (1 - beta2**step) ** 0.5 # avoids math import + + centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_( + eps, alpha=1 + ) + + # lr update to compensation + if use_kahan_summation: + compensation = state["compensation"] + + compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) + + # update weights with compensation (Kahan summation) + # save error back to compensation for next iteration + temp_buffer = p.detach().clone() + p.data.add_(compensation) + compensation.add_(temp_buffer.sub_(p.data)) + + else: + # usual AdamW updates + p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) \ No newline at end of file diff --git a/slam_llm/policies/mixed_precision.py b/slam_llm/policies/mixed_precision.py new file mode 100644 index 0000000000000000000000000000000000000000..5175b2ad9690f67571c1ac7ac27d3c2495d3f914 --- /dev/null +++ b/slam_llm/policies/mixed_precision.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import torch + +from torch.distributed.fsdp import ( + MixedPrecision, +) + +# requires grad scaler in main loop +fpSixteen = MixedPrecision( + param_dtype=torch.float16, + # Gradient communication precision. + reduce_dtype=torch.float16, + # Buffer precision. + buffer_dtype=torch.float16, +) + +bfSixteen = MixedPrecision( + param_dtype=torch.bfloat16, + # Gradient communication precision. + reduce_dtype=torch.bfloat16, + # Buffer precision. + buffer_dtype=torch.bfloat16, + cast_forward_inputs=True, +) + +bfSixteen_mixed = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, +) + +fp32_policy = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, +) diff --git a/slam_llm/policies/wrapping.py b/slam_llm/policies/wrapping.py new file mode 100644 index 0000000000000000000000000000000000000000..2703ebc684936ac79f87b01fee44af2cec7e4c1b --- /dev/null +++ b/slam_llm/policies/wrapping.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import functools + +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from torch.distributed.fsdp.wrap import ( + transformer_auto_wrap_policy, + size_based_auto_wrap_policy, +) + + +def get_size_policy(min_params=1e8): + num_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=min_params + ) + return num_wrap_policy + + +def get_llama_wrapper(): + """we register our main layer class and use the fsdp transformer wrapping policy + ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers + """ + # ==== use new transformer wrapper + + llama_auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + LlamaDecoderLayer, + }, + ) + + return llama_auto_wrap_policy diff --git a/slam_llm/utils/__init__.py b/slam_llm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a46f9a8ca7dc578bbe8ac822a8f98e5880351a70 --- /dev/null +++ b/slam_llm/utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from slam_llm.utils.memory_utils import MemoryTrace +from slam_llm.utils.dataset_utils import * +from slam_llm.utils.fsdp_utils import fsdp_auto_wrap_policy +from slam_llm.utils.train_utils import * \ No newline at end of file diff --git a/slam_llm/utils/checkpoint_handler.py b/slam_llm/utils/checkpoint_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..50449e232a7481519125feb9354b27b0b6933538 --- /dev/null +++ b/slam_llm/utils/checkpoint_handler.py @@ -0,0 +1,333 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import os +from pathlib import Path +from datetime import datetime +import torch +import time +from collections import OrderedDict + +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + StateDictType, + FullStateDictConfig, # general model non-sharded, non-flattened params + LocalStateDictConfig, # flattened params, usable only by FSDP + # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. +) + +from torch.distributed.checkpoint import ( + FileSystemReader, + FileSystemWriter, + save_state_dict, + load_state_dict, +) +from torch.distributed.checkpoint.default_planner import ( + DefaultSavePlanner, + DefaultLoadPlanner, +) + + +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +import torch.distributed.checkpoint as dist_cp +import torch.distributed as dist + + +import logging +logger = logging.getLogger(__name__) + + +def get_date_of_run(): + """create date and time for file save uniqueness + example: 2022-05-07-08:31:12_PM' + """ + date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") + logger.info(f"--> current date and time of run = {date_of_run}") + return date_of_run + + +# create singleton saving policies to avoid making over and over +fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + + +def load_model_sharded(model, rank, cfg): + # torch.manual_seed(103) + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + + load_dir = Path.cwd() / folder_name + + if not load_dir.exists(): + if rank == 0: + logger.info(f"No sharded_state_dict checkpoint directory found...skipping") + return + if rank == 0: + logger.info(f"loading model from model path: {load_dir} ") + reader = FileSystemReader(load_dir) + + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + checkpoint = {"model": model.state_dict()} + if rank == 0: + ck = checkpoint.keys() + logger.info(f" checkpoint key len = {len(ck)} and \n keys = {ck}") + + dist_cp.load_state_dict( + state_dict=checkpoint, + storage_reader=reader, + ) + if rank == 0: + logger.info(f"checkpoint after load_state_dict()") + ck = checkpoint.keys() + logger.info(f" checkpoint key len = {len(ck)} and \n keys = {ck}") + model.load_state_dict(checkpoint["model"]) + if rank == 0: + logger.info(f"Sharded state checkpoint loaded from {load_dir}") + + +def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): + """save model and optimizer via sharded_state_dict to save_dir""" + + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + + save_dir = Path.cwd() / folder_name + if rank == 0: + logger.info(f"Saving model to {save_dir}") + + distributed_writer = dist_cp.FileSystemWriter( + save_dir, + ) + t0 = time.perf_counter() + + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + + state_dict = {"model": model.state_dict()} + if optim is not None: + state_dict["optim"] = FSDP.optim_state_dict(model, optim) + + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=distributed_writer, + planner=DefaultSavePlanner(), + + ) + dist.barrier() + t1 = time.perf_counter() + if rank == 0: + logger.info(f"Sharded state checkpoint saved to {save_dir}") + logger.info( + f"Checkpoint Time = {t1-t0:.4f}\n" + ) +def save_model_checkpoint( + model, + optimizer, + rank, + cfg, + epoch=1, +): + """saving model via rank0 cpu streaming and full_state_dict""" + + with FSDP.state_dict_type( + model, StateDictType.FULL_STATE_DICT, fullstate_save_policy + ): + cpu_state = model.state_dict() + + logger.info(f"saving process: rank {rank} done w model state_dict\n") + + + if rank == 0: + logger.info(f"--> saving model ...") + # create save path + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + save_dir = Path.cwd() / folder_name + save_dir.mkdir(parents=True, exist_ok=True) + save_name = cfg.model_name + "-" + str(epoch) + ".pt" + save_full_path = str(save_dir) + "/" + save_name + + # save model + torch.save(cpu_state, save_full_path) + + + logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") + +def save_model_checkpoint_deepspeed(model, cfg, checkpoint_name="checkpoint"): + logger.info(f"--> saving model ...") + save_dir = os.path.join(cfg.output_dir, checkpoint_name) + os.makedirs(save_dir, exist_ok=True) + # save_full_path = os.path.join(save_dir, "model.pt") + save_full_path = save_dir + model.save_checkpoint(save_dir=save_full_path, exclude_frozen_parameters=True) + logger.info(f"encoder saved at {save_full_path}") + +def save_model_checkpoint_peft(model, optimizer, rank, cfg, checkpoint_name="checkpoint", save_trainable_only=True): + logger.info(f"--> saving model ...") + save_dir = os.path.join(cfg.output_dir, checkpoint_name) + os.makedirs(save_dir, exist_ok=True) + save_full_path = os.path.join(save_dir, "model.pt") + if cfg.enable_ddp: + model = model.module + cpu_state = model.state_dict() + if save_trainable_only: + state_dict = OrderedDict() + for name, para in model.named_parameters(): + if para.requires_grad: + state_dict[name] = cpu_state[name] + else: + state_dict = cpu_state + torch.save(state_dict, save_full_path) + logger.info(f"encoder saved at {save_full_path}") + +def save_model_checkpoint_peft_full_shard(model, optimizer, rank, cfg, epoch=0): + with FSDP.state_dict_type( + model, StateDictType.FULL_STATE_DICT, fullstate_save_policy + ): + cpu_state = model.state_dict() + logger.info(f"saving process: rank {rank} done w model state_dict\n") + + if rank == 0: + logger.info(f"--> saving model ...") + save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1)) + os.makedirs(save_dir, exist_ok=True) + + if not cfg.freeze_llm: + llm_dict = {} + for key in cpu_state.keys(): + if key.startswith("llm."): + llm_dict[key] = cpu_state[key] + model.llm.save_pretrained(save_directory=save_dir, state_dict=llm_dict) + logger.info(f"llm saved at {save_dir}") + + save_full_path = os.path.join(save_dir, "model.pt") + encoder_dict = {} + if not cfg.freeze_encoder: + for key in cpu_state.keys(): + if key.startswith("encoder."): + encoder_dict[key] = cpu_state[key] + for key in cpu_state.keys(): + if key.startswith("encoder_projector."): + encoder_dict[key] = cpu_state[key] + torch.save(encoder_dict, save_full_path) + logger.info(f"encoder saved at {save_full_path}") + + logger.info(f"model checkpoint saved for epoch {epoch+1}\n") + + dist.barrier() + +def load_model_checkpoint(model, rank, cfg): + """load local checkpoint to rank0 cpu + must be called * before * passing to FSDP""" + + if rank != 0: + return + + # where is the checkpoint at... + full_state_dict_model_path = ( + Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename + ) + # is it present... + if not full_state_dict_model_path.is_file(): + logger.info( + f"model checkpoint {full_state_dict_model_path} not present. Returning..." + ) + return + + + model_checkpoint = torch.load(full_state_dict_model_path) + # integrate into loaded model + model.load_state_dict(model_checkpoint) + + + logger.info(f"model checkpoint loaded to rank0 cpu") + + +def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): + """save optimizer state via full state dict""" + + + logger.info(f"--> optim state call on rank {rank}\n") + + # pull all sharded optimizer states to rank0 cpu... + + optim_state = FSDP.full_optim_state_dict(model, optimizer) + + + logger.info(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") + + if rank == 0: + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + cfg.model_name + ) + save_dir = Path.cwd() / folder_name + save_dir.mkdir(parents=True, exist_ok=True) + + opt_save_name = ( + "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt" + ) + opt_save_full_path = save_dir / opt_save_name + + logger.info(f"--> saving optimizer state...") + + torch.save(optim_state, opt_save_full_path) + + logger.info(f"--> saved {opt_save_full_path} to disk") + + +def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): + """load an fsdp optimizer full_state checkpoint using scatter method + this ensures only rank 0 loads the optimizer state dict and scatters to other ranks + """ + + + if not optimizer_checkpoint_path.is_file(): + logger.info( + f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. " + ) + return + + full_osd = None + + if rank == 0: + full_osd = torch.load(optimizer_checkpoint_path) + + # called from all ranks, though only rank0 has a valid param for full_osd + sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) + + logger.info(f"optimizer shard loaded on rank {rank}") + +def load_sharded_model_single_gpu(model,model_path): + + reader = FileSystemReader(model_path) + + state_dict = { + "model": model.state_dict() + } + + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader= FileSystemReader(model_path), + no_dist=True, + ) + + model.load_state_dict(state_dict["model"]) + + logger.info(f"Sharded state checkpoint loaded from {model_path}") + return model diff --git a/slam_llm/utils/compute_aac_metrics.py b/slam_llm/utils/compute_aac_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..6390fd038cefe10c3d0af698d120d29037deb109 --- /dev/null +++ b/slam_llm/utils/compute_aac_metrics.py @@ -0,0 +1,38 @@ +from aac_metrics import evaluate +import sys + +def compute_wer(ref_file, + hyp_file): + pred_captions = [] + gt_captions = [] + + with open(hyp_file, 'r') as hyp_reader: + for line in hyp_reader: + key = line.strip().split()[0] + value = line.strip().split()[1:] + pred_captions.append(value) + with open(ref_file, 'r') as ref_reader: + for line in ref_reader: + key = line.strip().split()[0] + value = line.strip().split()[1:] + gt_captions.append(value) + + print('Used lines:', len(pred_captions)) + candidates: list[str] = pred_captions + mult_references: list[list[str]] = [[gt] for gt in gt_captions] + + corpus_scores, _ = evaluate(candidates, mult_references) + print(corpus_scores) + # dict containing the score of each metric: "bleu_1", "bleu_2", "bleu_3", "bleu_4", "rouge_l", "meteor", "cider_d", "spice", "spider" + # {"bleu_1": tensor(0.4278), "bleu_2": ..., ...} + + +if __name__ == '__main__': + if len(sys.argv) != 3: + print("usage : python compute_aac_metrics.py test.ref test.hyp") + sys.exit(0) + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + cer_detail_file = sys.argv[3] + compute_wer(ref_file, hyp_file, cer_detail_file) \ No newline at end of file diff --git a/slam_llm/utils/compute_ppl.py b/slam_llm/utils/compute_ppl.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd0f159548609f291842def106d64c59c50daff --- /dev/null +++ b/slam_llm/utils/compute_ppl.py @@ -0,0 +1,45 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +from tqdm import tqdm +import json + +# MODEL_PATH = "/nfs/maziyang.mzy/models/vicuna-7b-v1.5" +MODEL_PATH = "/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf" +# MODEL_PATH = "/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf" +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) +model = AutoModelForCausalLM.from_pretrained(MODEL_PATH) + +device = 'cuda:7' +model.to(device) +model.eval() + +corpus_path = "/nfs/maziyang.mzy/data/librispeech/librispeech_test_clean_filtered.jsonl" +corpus = [] +with open(corpus_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + corpus.append(data_dict.get("target", None)) + +cumulative_log_likelihood = 0 +total_tokens = 0 + +for sentence in tqdm(corpus): + inputs = tokenizer(sentence.strip().lower(), return_tensors="pt").to(device) + + input_ids = inputs["input_ids"] + # input_len = input_ids.size(1) + input_len = len(sentence.split(" ")) + total_tokens += input_len + + with torch.no_grad(): + outputs = model(**inputs, labels=input_ids) + log_likelihood = outputs.loss * input_len + cumulative_log_likelihood += log_likelihood.item() + + +average_log_likelihood = cumulative_log_likelihood / total_tokens +corpus_ppl = torch.exp(torch.tensor(average_log_likelihood)).item() + +print(f"Model: {MODEL_PATH}") +print(f"Corpus: {corpus_path}") +print(f"Corpus Perplexity: {corpus_ppl}") diff --git a/slam_llm/utils/compute_utils.py b/slam_llm/utils/compute_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4c85d63ae10708aa3f12a397522ec2aa8de0f3cf --- /dev/null +++ b/slam_llm/utils/compute_utils.py @@ -0,0 +1,3 @@ + +def calculate_output_length_1d(L_in, kernel_size, stride, padding=0): + return (L_in + 2 * padding - kernel_size) // stride + 1 \ No newline at end of file diff --git a/slam_llm/utils/compute_wer.py b/slam_llm/utils/compute_wer.py new file mode 100644 index 0000000000000000000000000000000000000000..4759b3f37e528444bce10123f6816672d22dd2d1 --- /dev/null +++ b/slam_llm/utils/compute_wer.py @@ -0,0 +1,197 @@ +import os +import numpy as np +import sys + +def build_diff(ref, hyp, path): + result = [] + ref = list(map(lambda x: x.lower(), ref)) + hyp = list(map(lambda x: x.lower(), hyp)) + r_record = -1 + h_record = -1 + # path = path+[(len(ref), len(hyp))] + + for rpointer, hpointer in path: + if rpointer!=r_record+1 or hpointer!=h_record+1: + r_buffer = ' '.join(ref[r_record+1:rpointer]) + r_buffer = r_buffer if len(r_buffer)>0 else "*" + h_buffer = ' '.join(hyp[h_record+1:hpointer]) + h_buffer = h_buffer if len(h_buffer)>0 else "*" + result.append(f"({r_buffer}->{h_buffer})") + + result.append(ref[rpointer]) + r_record = rpointer + h_record = hpointer + + if r_record<len(ref)-1 or h_record<len(hyp)-1: + r_buffer = ' '.join(ref[r_record+1:]) + r_buffer = r_buffer if len(r_buffer)>0 else "*" + h_buffer = ' '.join(hyp[h_record+1:]) + h_buffer = h_buffer if len(h_buffer)>0 else "*" + result.append(f"({r_buffer}->{h_buffer})") + return ' '.join(result) + + + + + + +def compute_wer(ref_file, + hyp_file, + cer_detail_file): + rst = { + 'Wrd': 0, + 'Corr': 0, + 'Ins': 0, + 'Del': 0, + 'Sub': 0, + 'Snt': 0, + 'Err': 0.0, + 'S.Err': 0.0, + 'wrong_words': 0, + 'wrong_sentences': 0 + } + + hyp_dict = {} + ref_dict = {} + with open(hyp_file, 'r') as hyp_reader: + for line in hyp_reader: + key = line.strip().split()[0] + value = line.strip().split()[1:] + hyp_dict[key] = value + with open(ref_file, 'r') as ref_reader: + for line in ref_reader: + key = line.strip().split()[0] + value = line.strip().split()[1:] + ref_dict[key] = value + + cer_detail_writer = open(cer_detail_file, 'w') + for hyp_key in hyp_dict: + if hyp_key in ref_dict: + out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key]) + # if out_item['ins'] > 10 or out_item['del'] > 10: + # print(hyp_key + print_cer_detail(out_item)) + # print("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key])))) + # print("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key])))) + rst['Wrd'] += out_item['nwords'] + rst['Corr'] += out_item['cor'] + rst['wrong_words'] += out_item['wrong'] + rst['Ins'] += out_item['ins'] + rst['Del'] += out_item['del'] + rst['Sub'] += out_item['sub'] + rst['Snt'] += 1 + if out_item['wrong'] > 0: + rst['wrong_sentences'] += 1 + cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n') + cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n') + cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n') + cer_detail_writer.write("diff:" + '\t' + build_diff(ref_dict[hyp_key], hyp_dict[hyp_key], out_item['path']) + '\n') + + if rst['Wrd'] > 0: + rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) + if rst['Snt'] > 0: + rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) + + cer_detail_writer.write('\n') + cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) + + ", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n') + cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n') + cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n') + + +def compute_wer_by_line(hyp, + ref): + hyp = list(map(lambda x: x.lower(), hyp)) + ref = list(map(lambda x: x.lower(), ref)) + + len_hyp = len(hyp) + len_ref = len(ref) + + cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) + + ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) + + for i in range(len_hyp + 1): + cost_matrix[i][0] = i + for j in range(len_ref + 1): + cost_matrix[0][j] = j + + for i in range(1, len_hyp + 1): + for j in range(1, len_ref + 1): + if hyp[i - 1] == ref[j - 1]: + cost_matrix[i][j] = cost_matrix[i - 1][j - 1] + else: + substitution = cost_matrix[i - 1][j - 1] + 1 + insertion = cost_matrix[i - 1][j] + 1 + deletion = cost_matrix[i][j - 1] + 1 + + compare_val = [substitution, insertion, deletion] + + min_val = min(compare_val) + operation_idx = compare_val.index(min_val) + 1 + cost_matrix[i][j] = min_val + ops_matrix[i][j] = operation_idx + + match_idx = [] + i = len_hyp + j = len_ref + rst = { + 'nwords': len_ref, + 'cor': 0, + 'wrong': 0, + 'ins': 0, + 'del': 0, + 'sub': 0, + 'path': [] + } + while i >= 0 or j >= 0: + i_idx = max(0, i) + j_idx = max(0, j) + + if ops_matrix[i_idx][j_idx] == 0: # correct + if i - 1 >= 0 and j - 1 >= 0: + match_idx.append((j - 1, i - 1)) + rst['cor'] += 1 + + i -= 1 + j -= 1 + + elif ops_matrix[i_idx][j_idx] == 2: # insert + i -= 1 + rst['ins'] += 1 + + elif ops_matrix[i_idx][j_idx] == 3: # delete + j -= 1 + rst['del'] += 1 + + elif ops_matrix[i_idx][j_idx] == 1: # substitute + i -= 1 + j -= 1 + rst['sub'] += 1 + + if i < 0 and j >= 0: + rst['del'] += 1 + elif j < 0 and i >= 0: + rst['ins'] += 1 + + match_idx.reverse() + wrong_cnt = cost_matrix[len_hyp][len_ref] + rst['wrong'] = wrong_cnt + rst['path'] = match_idx + + return rst + +def print_cer_detail(rst): + return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor']) + + ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub=" + + str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords']) + + ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords'])) + +if __name__ == '__main__': + if len(sys.argv) != 4: + print("usage : python compute-wer.py test.ref test.hyp test.wer") + sys.exit(0) + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + cer_detail_file = sys.argv[3] + compute_wer(ref_file, hyp_file, cer_detail_file) diff --git a/slam_llm/utils/config_utils.py b/slam_llm/utils/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c0057ce4d2d247ffdc8e123b21b60c4cbdd9d521 --- /dev/null +++ b/slam_llm/utils/config_utils.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import inspect +# from dataclasses import asdict + +import torch.distributed as dist +from torch.utils.data import DistributedSampler +from peft import ( + LoraConfig, + AdaptionPromptConfig, + PrefixTuningConfig, +) +from transformers import default_data_collator +from transformers.data import DataCollatorForSeq2Seq + +# from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config +from slam_llm.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler + +from omegaconf import OmegaConf + +import logging +logger = logging.getLogger(__name__) + +# def update_config(config, **kwargs): +# if isinstance(config, (tuple, list)): +# for c in config: +# update_config(c, **kwargs) +# else: +# for k, v in kwargs.items(): +# if hasattr(config, k): +# setattr(config, k, v) +# elif "." in k: +# # allow --some_config.some_param=True +# config_name, param_name = k.split(".") +# if type(config).__name__ == config_name: +# if hasattr(config, param_name): +# setattr(config, param_name, v) +# else: +# # In case of specialized config we can warm user +# logger.warning(f"Warning: {config_name} does not accept parameter: {k}") +# elif isinstance(config, train_config): +# logger.warning(f"Warning: unknown parameter {k}") + + +def generate_peft_config(train_config): + # configs = (lora_config, llama_adapter_config, prefix_config) + # peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) + peft_configs = {"lora": LoraConfig, + "llama_adapter": AdaptionPromptConfig, + "prefix": PrefixTuningConfig + } + # names = tuple(c.__name__.rstrip("_config") for c in configs) + # + # assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}" + # + # config = configs[names.index(train_config.peft_method)]() + config = train_config.peft_config + + params = OmegaConf.to_container(config, resolve=True) + # peft_config = peft_configs[names.index(train_config.peft_method)](**params) + params.pop("peft_method", None) #(FIX:MZY): remove peft_method from params to avoid error + peft_config = peft_configs[config.get("peft_method", "lora")](**params) + + return peft_config + + +def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): + kwargs = {} + batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size + if train_config.batching_strategy == "padding": + if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: + kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( + dataset, + batch_size=batch_size, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + shuffle=mode=="train", + ) + else: + kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train") + kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer) + elif train_config.batching_strategy == "packing": + if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: + kwargs["sampler"] = DistributedSampler( + dataset, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + shuffle=mode=="train", + ) + kwargs["batch_size"] = batch_size + kwargs["drop_last"] = True + kwargs["collate_fn"] = default_data_collator + else: + # raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}") + if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: + kwargs["sampler"] = DistributedSampler( + dataset, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + shuffle=mode=="train", + ) + kwargs["batch_size"] = batch_size + kwargs["drop_last"] = True + kwargs["collate_fn"] = dataset.collator + logger.info(f"Using batching strategy: {train_config.batching_strategy}") + + return kwargs diff --git a/slam_llm/utils/custom_utils.py b/slam_llm/utils/custom_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c9ffbc397118ba786af646238dac1550f7ec5ca --- /dev/null +++ b/slam_llm/utils/custom_utils.py @@ -0,0 +1,298 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 +import torch +import random +import numpy as np +from typing import Dict, List, Optional, Tuple + +def load_video(path): + for i in range(3): + try: + cap = cv2.VideoCapture(path) + frames = [] + while True: + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + frames.append(frame) + else: + break + frames = np.stack(frames) + return frames + except Exception: + print(f"failed loading {path} ({i} / 3)") + if i == 2: + raise ValueError(f"Unable to load {path}") + + +class Compose(object): + """Compose several preprocess together. + Args: + preprocess (list of ``Preprocess`` objects): list of preprocess to compose. + """ + + def __init__(self, preprocess): + self.preprocess = preprocess + + def __call__(self, sample): + for t in self.preprocess: + sample = t(sample) + return sample + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.preprocess: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class Normalize(object): + """Normalize a ndarray image with mean and standard deviation. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, frames): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor image. + """ + frames = (frames - self.mean) / self.std + return frames + + def __repr__(self): + return self.__class__.__name__+'(mean={0}, std={1})'.format(self.mean, self.std) + +class CenterCrop(object): + """Crop the given image at the center + """ + def __init__(self, size): + self.size = size + + def __call__(self, frames): + """ + Args: + img (numpy.ndarray): Images to be cropped. + Returns: + numpy.ndarray: Cropped image. + """ + t, h, w = frames.shape + th, tw = self.size + delta_w = int(round((w - tw))/2.) + delta_h = int(round((h - th))/2.) + frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw] + return frames + + +class RandomCrop(object): + """Crop the given image at the center + """ + + def __init__(self, size): + self.size = size + + def __call__(self, frames): + """ + Args: + img (numpy.ndarray): Images to be cropped. + Returns: + numpy.ndarray: Cropped image. + """ + t, h, w = frames.shape + th, tw = self.size + delta_w = random.randint(0, w-tw) + delta_h = random.randint(0, h-th) + frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw] + return frames + + def __repr__(self): + return self.__class__.__name__ + '(size={0})'.format(self.size) + +class HorizontalFlip(object): + """Flip image horizontally. + """ + + def __init__(self, flip_ratio): + self.flip_ratio = flip_ratio + + def __call__(self, frames): + """ + Args: + img (numpy.ndarray): Images to be flipped with a probability flip_ratio + Returns: + numpy.ndarray: Cropped image. + """ + t, h, w = frames.shape + if random.random() < self.flip_ratio: + for index in range(t): + frames[index] = cv2.flip(frames[index], 1) + return frames + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + batch_indexes, starts, ends = [], [], [] + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + vals, run_starts, run_lengths = find_runs(mask[i]) + start_indices, lengths = run_starts[vals == True], run_lengths[vals == True] + starts.append(start_indices) + ends.append(start_indices+lengths) + batch_indexes.append(np.zeros([len(start_indices)])+i) + return mask, np.concatenate(starts).astype(np.int64), np.concatenate(ends).astype(np.int64), np.concatenate(batch_indexes).astype(np.int64) + +def find_runs(x): + """Find runs of consecutive items in an array.""" + + # ensure array + x = np.asanyarray(x) + if x.ndim != 1: + raise ValueError('only 1D array supported') + n = x.shape[0] + + # handle empty array + if n == 0: + return np.array([]), np.array([]), np.array([]) + + else: + # find run starts + loc_run_start = np.empty(n, dtype=bool) + loc_run_start[0] = True + np.not_equal(x[:-1], x[1:], out=loc_run_start[1:]) + run_starts = np.nonzero(loc_run_start)[0] + + # find run values + run_values = x[loc_run_start] + + # find run lengths + run_lengths = np.diff(np.append(run_starts, n)) + + return run_values, run_starts, run_lengths diff --git a/slam_llm/utils/dataset_utils.py b/slam_llm/utils/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9726a7aad5e7b388381719b666b8993c6b62401d --- /dev/null +++ b/slam_llm/utils/dataset_utils.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import importlib +from pathlib import Path + +import torch + +import logging +logger = logging.getLogger(__name__) + + +def load_module_from_py_file(py_file: str) -> object: + """ + This method loads a module from a py file which is not in the Python path + """ + module_name = Path(py_file).name + loader = importlib.machinery.SourceFileLoader(module_name, py_file) + spec = importlib.util.spec_from_loader(module_name, loader) + module = importlib.util.module_from_spec(spec) + + loader.exec_module(module) + + return module + + +def get_custom_dataset(dataset_config, tokenizer, split: str): + if ":" in dataset_config.file: + module_path, func_name = dataset_config.file.split(":") + else: + module_path, func_name = dataset_config.file, "get_custom_dataset" + + if not module_path.endswith(".py"): + raise ValueError(f"Dataset file {module_path} is not a .py file.") + + module_path = Path(module_path) + if not module_path.is_file(): + raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.") + + module = load_module_from_py_file(module_path.as_posix()) + try: + return getattr(module, func_name)(dataset_config, tokenizer, split) + except AttributeError as e: + logger.info(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).") + raise e + + +def get_preprocessed_dataset( + tokenizer, dataset_config, split: str = "train" +) -> torch.utils.data.Dataset: + + def get_split(): + return ( + dataset_config.train_split + if split == "train" + else dataset_config.test_split + ) + + return get_custom_dataset( + dataset_config, + tokenizer, + get_split(), + ) diff --git a/slam_llm/utils/deepspeed_utils.py b/slam_llm/utils/deepspeed_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..66c49e593a65253ba8aca52b702ef4b4e54676d7 --- /dev/null +++ b/slam_llm/utils/deepspeed_utils.py @@ -0,0 +1,601 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +import time +import yaml +from contextlib import nullcontext +from pathlib import Path +from pkg_resources import packaging + + +import functools +import hydra +import torch +import torch.cuda.nccl as nccl +import torch.distributed as dist +from omegaconf import DictConfig +from tqdm import tqdm +from transformers import LlamaTokenizer +from typing import Any, Callable, List, Optional +from textwrap import dedent +from hydra import version +from hydra.main import _UNSPECIFIED_, _get_rerun_conf +from hydra._internal.deprecation_warning import deprecation_warning +from hydra._internal.utils import _run_hydra, get_args_parser +from hydra.types import TaskFunction +from hydra.core.utils import _flush_loggers, configure_log + + +from slam_llm.utils.checkpoint_handler import ( + save_model_checkpoint, + save_model_checkpoint_deepspeed, + save_model_and_optimizer_sharded, + save_optimizer_checkpoint, + save_model_checkpoint_peft, + save_model_checkpoint_peft_full_shard, +) +from slam_llm.policies import fpSixteen, bfSixteen_mixed, get_llama_wrapper +from slam_llm.utils.memory_utils import MemoryTrace +from slam_llm.utils.metric import compute_accuracy + +import wandb +import logging + +logger = logging.getLogger(__name__) + +# For deepspeed --local_rank argument +def deepspeed_main_wrapper( + config_path: Optional[str] = _UNSPECIFIED_, + config_name: Optional[str] = None, + version_base: Optional[str] = _UNSPECIFIED_, +) -> Callable[[TaskFunction], Any]: + """ + :param config_path: The config path, a directory where Hydra will search for + config files. This path is added to Hydra's searchpath. + Relative paths are interpreted relative to the declaring python + file. Alternatively, you can use the prefix `pkg://` to specify + a python package to add to the searchpath. + If config_path is None no directory is added to the Config search path. + :param config_name: The name of the config (usually the file name without the .yaml extension) + """ + + version.setbase(version_base) + + if config_path is _UNSPECIFIED_: + if version.base_at_least("1.2"): + config_path = None + elif version_base is _UNSPECIFIED_: + url = "https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path" + deprecation_warning( + message=dedent( + f""" + config_path is not specified in @hydra.main(). + See {url} for more information.""" + ), + stacklevel=2, + ) + config_path = "." + else: + config_path = "." + + def main_decorator(task_function: TaskFunction) -> Callable[[], None]: + @functools.wraps(task_function) + def decorated_main(cfg_passthrough: Optional[DictConfig] = None) -> Any: + if cfg_passthrough is not None: + return task_function(cfg_passthrough) + else: + args_parser = get_args_parser() + args_parser.add_argument("--local_rank", type=int, default=-1) + args = args_parser.parse_args() + if args.experimental_rerun is not None: + cfg = _get_rerun_conf(args.experimental_rerun, args.overrides) + task_function(cfg) + _flush_loggers() + else: + # no return value from run_hydra() as it may sometime actually run the task_function + # multiple times (--multirun) + _run_hydra( + args=args, + args_parser=args_parser, + task_function=task_function, + config_path=config_path, + config_name=config_name, + ) + + return decorated_main + + return main_decorator + + + +def set_tokenizer_params(tokenizer: LlamaTokenizer): + tokenizer.pad_token_id = 0 + tokenizer.padding_side = "left" + + +# Converting Bytes to Megabytes +def byte2mb(x): + return int(x / 2**20) + + +def train( + model, + train_dataloader, + eval_dataloader, + tokenizer, + gradient_accumulation_steps, + train_config, + log_config, + local_rank=None, + rank=None, +): + """ + Trains the model on the given dataloader + + Args: + model: The model to be trained + train_dataloader: The dataloader containing the training data + optimizer: The optimizer used for training + lr_scheduler: The learning rate scheduler + gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation + num_epochs: The number of epochs to train for + local_rank: The rank of the current node in a distributed setting + train_config: The training configuration + log_config: The logging configuration + eval_dataloader: The dataloader containing the eval data + tokenizer: tokenizer used in the eval for decoding the predicitons + + Returns: results dictionary containing average training and validation perplexity and loss + """ + # Create a gradient scaler for fp16 + # if train_config.use_fp16 and train_config.enable_fsdp: + # scaler = ShardedGradScaler() + # elif train_config.use_fp16 and not train_config.enable_fsdp: + # scaler = torch.cuda.amp.GradScaler() + if train_config.enable_ddp: + world_size = int(os.environ["WORLD_SIZE"]) + autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext + + train_prep = [] + train_loss = [] + train_acc = [] + val_prep = [] + val_loss = [] + val_acc = [] + epoch_times = [] + checkpoint_times = [] + results = {} + best_val_loss = float("inf") + best_val_acc = 0.0 + for epoch in range(train_config.num_epochs): + epoch_start_time = time.perf_counter() + with MemoryTrace() as memtrace: # track the memory usage + model.train() + total_loss = 0.0 + total_acc = 0.0 + total_length = len(train_dataloader) // gradient_accumulation_steps + pbar = tqdm( + colour="blue", + desc=f"Training Epoch: {epoch+1}", + total=total_length, + dynamic_ncols=True, + ) + for step, batch in enumerate(train_dataloader): + for key in batch.keys(): + batch[key] = ( + batch[key].to(local_rank).half() + if isinstance(batch[key], torch.Tensor) + and batch[key].dtype == torch.float32 + else ( + batch[key].to(local_rank) + if isinstance(batch[key], torch.Tensor) + else batch[key] + ) + ) + # with autocast(): + outputs, *rest = model(**batch) + acc = rest[0] if rest else -1 + loss = outputs.loss + + loss = loss / gradient_accumulation_steps + acc = acc / gradient_accumulation_steps + + if log_config.use_wandb and step % log_config.log_interval == 0: + if train_config.enable_fsdp or train_config.enable_ddp: + if rank == 0: + wandb.log( + { + "train_inner/train_inner_loss": loss, + "train_inner/train_inner_accuracy": acc, + }, + step=(epoch * total_length + step), + ) + else: + wandb.log( + { + "train_inner/train_inner_loss": loss, + "train_inner/train_inner_accuracy": acc, + }, + step=(epoch * total_length + step), + ) + + total_loss += loss.detach().float() + total_acc += acc + + # deepspeed should handle gradient accumulate + model.backward(loss) + model.step() + + if (step + 1) % gradient_accumulation_steps == 0 or step == len( + train_dataloader + ) - 1: + pbar.update(1) + + pbar.set_description( + f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()}, acc: {acc})" + ) + + if ( + (epoch * total_length + step + 1) % train_config.validation_interval + == 0 + and train_config.run_validation + ): + eval_ppl, eval_epoch_loss, *rest = evaluation( + model, train_config, eval_dataloader, local_rank, tokenizer + ) + eval_epoch_acc = rest[0] if rest else -1 + checkpoint_start_time = time.perf_counter() + + if train_config.save_model and (eval_epoch_loss < best_val_loss): + checkpoint_name = f"{train_config.model_name}_epoch_{str(epoch+1)}_step_{step+1}" + save_model_checkpoint_deepspeed( + model, train_config, checkpoint_name + ) + + checkpoint_end_time = time.perf_counter() - checkpoint_start_time + checkpoint_times.append(checkpoint_end_time) + if eval_epoch_loss < best_val_loss: + best_val_loss = eval_epoch_loss + if rank == 0: + logger.info( + f"best eval loss on epoch {epoch+1} is {best_val_loss}" + ) + val_loss.append(eval_epoch_loss) + val_prep.append(eval_ppl) + if rest: + if eval_epoch_acc > best_val_acc: + best_val_acc = eval_epoch_acc + if rank == 0: + logger.info( + f"best eval acc on epoch {epoch+1} is {best_val_acc}" + ) + val_acc.append(rest[0]) + else: + val_acc.append(-1) + + if log_config.use_wandb: + if rank == 0: + wandb.log( + { + "valid/val_epoch_loss": eval_epoch_loss, + "valid/val_perplexity": eval_ppl, + "valid/best_val_loss": best_val_loss, + "valid/val_accuracy": val_acc[-1], + "valid/val_best_accuracy": best_val_acc, + } + ) + + if train_config.run_test_during_validation: + if rank == 0: + logger.info("=====================================") + logger.info( + f"Test the file {train_config.run_test_during_validation_file} during validation:" + ) + with autocast(): + logger.info( + model.inference( + train_config.run_test_during_validation_file, + train_config.run_test_during_validation_prompt, + ) + ) + logger.info("=====================================") + dist.barrier() + pbar.close() + + epoch_end_time = time.perf_counter() - epoch_start_time + epoch_times.append(epoch_end_time) + # Reducing total_loss across all devices if there's more than one CUDA device + if torch.cuda.device_count() > 1 and ( + train_config.enable_fsdp or train_config.enable_ddp + ): + dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(total_acc, op=dist.ReduceOp.SUM) + train_epoch_loss = total_loss / len(train_dataloader) + train_epoch_acc = total_acc / len(train_dataloader) + if train_config.enable_fsdp or train_config.enable_ddp: + train_epoch_loss = train_epoch_loss / world_size + train_epoch_acc = train_epoch_acc / world_size + train_perplexity = torch.exp(train_epoch_loss) + + train_prep.append(train_perplexity) + train_loss.append(train_epoch_loss) + train_acc.append(train_epoch_acc) + + if log_config.use_wandb: + if train_config.enable_fsdp or train_config.enable_ddp: + if rank == 0: + wandb.log( + { + "train/train_perplexity": train_perplexity, + "train/train_epoch_loss": train_epoch_loss, + "train/train_epoch_acc": train_epoch_acc, + } + ) + else: + wandb.log( + { + "train/train_perplexity": train_perplexity, + "train/train_epoch_loss": train_epoch_loss, + "train/train_epoch_acc": train_epoch_acc, + } + ) + + if rank == 0: + logger.info( + f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s" + ) + + if rank == 0: + logger.info(f"Max CUDA memory allocated was {memtrace.peak} GB") + logger.info(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") + logger.info(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") + logger.info(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") + logger.info( + f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB" + ) + + # Update the learning rate as needed + # lr_scheduler.step() + + avg_epoch_time = sum(epoch_times) / len(epoch_times) + avg_checkpoint_time = ( + sum(checkpoint_times) / len(checkpoint_times) + if len(checkpoint_times) > 0 + else 0 + ) + avg_train_prep = sum(train_prep) / len(train_prep) + avg_train_loss = sum(train_loss) / len(train_loss) + avg_train_acc = sum(train_acc) / len(train_acc) + if train_config.run_validation: + avg_eval_prep = sum(val_prep) / len(val_prep) + avg_eval_loss = sum(val_loss) / len(val_loss) + avg_eval_acc = sum(val_acc) / len(val_acc) + + results["avg_train_prep"] = avg_train_prep + results["avg_train_loss"] = avg_train_loss + results["avg_train_acc"] = avg_train_acc + if train_config.run_validation: + results["avg_eval_prep"] = avg_eval_prep + results["avg_eval_loss"] = avg_eval_loss + results["avg_eval_acc"] = avg_eval_acc + results["avg_epoch_time"] = avg_epoch_time + results["avg_checkpoint_time"] = avg_checkpoint_time + + # saving the training params including fsdp setting for reference. + # if (train_config.enable_fsdp or train_config.enable_ddp)and not train_config.use_peft: + # save_train_params(train_config, fsdp_config, rank) + + return results + + +def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer): + """ + Evaluates the model on the given dataloader + + Args: + model: The model to evaluate + eval_dataloader: The dataloader containing the evaluation data + local_rank: The rank of the current node in a distributed setting + tokenizer: The tokenizer used to decode predictions + + Returns: eval_ppl, eval_epoch_loss + """ + world_size = int(os.environ["WORLD_SIZE"]) + model.eval() + eval_preds = [] + eval_loss = 0.0 # Initialize evaluation loss + eval_acc = 0.0 + autocast = ( + torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext + ) # (Fix:MZY): fix expected scalar type mismatch in norm + + with MemoryTrace() as memtrace: + total_length = len(eval_dataloader) + pbar = tqdm( + colour="green", + desc=f"Evaluating Epoch", + total=total_length, + dynamic_ncols=True, + ) + for step, batch in enumerate(eval_dataloader): + for key in batch.keys(): + batch[key] = ( + batch[key].to(local_rank).half() + if isinstance(batch[key], torch.Tensor) and batch[key].dtype==torch.float32 + else ( + batch[key].to(local_rank) if isinstance(batch[key], torch.Tensor) else batch[key] + ) + ) + # Ensure no gradients are computed for this scope to save memory + with torch.no_grad(): + # Forward pass and compute loss + with autocast(): # (Fix:MZY): fix expected scalar type mismatch in norm + outputs, *rest = model(**batch) + acc = rest[0] if rest else -1 + loss = outputs.loss + + eval_loss += loss.detach().float() + eval_acc += acc + # Decode predictions and add to evaluation predictions list + preds = torch.argmax(outputs.logits, -1) + eval_preds.extend( + tokenizer.batch_decode( + preds.detach().cpu().numpy(), skip_special_tokens=True + ) + ) + pbar.update(1) + pbar.set_description( + f"step: {step+1}/{total_length}, eval_loss: {eval_loss/(step+1):.4f}, eval_acc: {eval_acc/(step+1):.4f}" + ) + + # If there's more than one CUDA device, reduce evaluation loss across all devices + if ( + torch.cuda.device_count() > 1 + ): + dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(eval_acc, op=dist.ReduceOp.SUM) + + # Compute average loss and perplexity + eval_epoch_loss = eval_loss / len(eval_dataloader) + eval_epoch_acc = eval_acc / len(eval_dataloader) + eval_epoch_loss = eval_epoch_loss / world_size + eval_epoch_acc = eval_epoch_acc / world_size + eval_ppl = torch.exp(eval_epoch_loss) + + # Print evaluation metrics + if local_rank == 0: + logger.info(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}") + + model.train() + return eval_ppl, eval_epoch_loss, eval_epoch_acc + + +def freeze_transformer_layers(model, num_layer): + for i, layer in enumerate(model.model.layers): + if i < num_layer: + for param in layer.parameters(): + param.requires_grad = False + + +def check_frozen_layers_peft_model(model): + for i, layer in enumerate(model.base_model.model.model.layers): + for name, param in layer.named_parameters(): + logger.info( + f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}" + ) + + +def setup(): + """Initialize the process group for distributed training""" + dist.init_process_group("nccl") + + +def setup_environ_flags(rank): + """Set environment flags for debugging purposes""" + os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) + # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases. + # Note this is only availble in PyTorch Nighlies (as of July 30 2023) + # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' + if rank == 0: + logger.info(f"--> Running with torch dist debug set to detail") + + +def cleanup(): + """Clean up the process group after training""" + dist.destroy_process_group() + + +def clear_gpu_cache(rank=None): + """Clear the GPU cache for all ranks""" + if rank == 0: + logger.info(f"Clearing GPU cache for all ranks") + torch.cuda.empty_cache() + + +def get_parameter_dtypes(model): + """Get the data types of model parameters""" + parameter_dtypes = {} + for name, parameter in model.named_parameters(): + parameter_dtypes[name] = parameter.dtype + return parameter_dtypes + + +def print_model_size(model, config, rank: int = 0) -> None: + """ + log model name, the number of trainable parameters and initialization time. + + Args: + model: The PyTorch model. + model_name (str): Name of the model. + init_time_start (float): Initialization start time. + init_time_end (float): Initialization end time. + rank (int, optional): Current process's rank. Defaults to 0. + """ + if rank == 0: + logger.info(f"--> Model {config.model_name}") + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info( + f"--> {config.model_name} has {total_params / 1e6} Million params\n" + ) + + +def print_module_size(module, module_name, rank: int = 0) -> None: + """ + Print module name, the number of trainable parameters and initialization time. + + Args: + module: The PyTorch module. + module_name (str): Name of the model. + rank (int, optional): Current process's rank. Defaults to 0. + """ + if rank == 0: + logger.info(f"--> Module {module_name}") + total_params = sum(p.numel() for p in module.parameters() if p.requires_grad) + logger.info(f"--> {module_name} has {total_params / 1e6} Million params\n") + + +def save_train_params(train_config, fsdp_config, rank): + """ + This function saves the train_config and FSDP config into a train_params.yaml. + This will be used by converter script in the inference folder to fetch the HF model name or path. + It also would be hepful as a log for future references. + """ + # Convert the train_config and fsdp_config objects to dictionaries, + # converting all values to strings to ensure they can be serialized into a YAML file + train_config_dict = { + k: str(v) for k, v in vars(train_config).items() if not k.startswith("__") + } + fsdp_config_dict = { + k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith("__") + } + # Merge the two dictionaries into one + train_params_dict = {**train_config_dict, **fsdp_config_dict} + # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object + folder_name = ( + train_config.dist_checkpoint_root_folder + + "/" + + train_config.dist_checkpoint_folder + + "-" + + train_config.model_name + ) + + save_dir = Path.cwd() / folder_name + # If the directory does not exist, create it + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # Convert the dictionary to a YAML string + config_yaml = yaml.dump(train_params_dict, indent=4) + file_name = os.path.join(save_dir, "train_params.yaml") + + # Check if there's a directory with the same name as the file + if os.path.isdir(file_name): + logger.info(f"Error: {file_name} is a directory, not a file.") + else: + # Write the YAML string to the file + with open(file_name, "w") as f: + f.write(config_yaml) + if rank == 0: + logger.info(f"training params are saved in {file_name}") diff --git a/slam_llm/utils/fsdp_utils.py b/slam_llm/utils/fsdp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0e1f7d2ec544915f0288e674dd39752bc3bd0b55 --- /dev/null +++ b/slam_llm/utils/fsdp_utils.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +def fsdp_auto_wrap_policy(model, transformer_layer_name): + import functools + + from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy + + from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder + + def lambda_policy_fn(module): + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): + return True + return False + + lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=( + PrefixEncoder, + PromptEncoder, + PromptEmbedding, + transformer_layer_name, + # FullyShardedDataParallelPlugin.get_module_class_from_name( + # model, transformer_layer_name + # ), + ), + ) + + auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) + return auto_wrap_policy \ No newline at end of file diff --git a/slam_llm/utils/llm_tn.py b/slam_llm/utils/llm_tn.py new file mode 100644 index 0000000000000000000000000000000000000000..7967ab4be997036565934d4214df8f65166d7dd4 --- /dev/null +++ b/slam_llm/utils/llm_tn.py @@ -0,0 +1,34 @@ +import sys +import os +import re +import string +from whisper_normalizer.english import EnglishTextNormalizer + +english_normalizer = EnglishTextNormalizer() + +def reduce_repeated_words(text): + pattern ="." + for i in range(1, 50): + p = pattern * i + text = re.sub(f'({p})' + r'\1{4,200}', r'\1', text) + for i in range (50, 100): + p = pattern * i + text = re.sub(f'({p})' + r'\1{3,200}', r'\1', text) + return text + +def normalize_text(srcfn, dstfn): + with open(srcfn, "r") as f_read, open(dstfn, "w") as f_write: + all_lines = f_read.readlines() + for line in all_lines: + line = line.strip() + line_arr = line.split() + key = line_arr[0] + conts = " ".join(line_arr[1:]) + normalized_conts = english_normalizer(conts) + reduced_conts = reduce_repeated_words(normalized_conts) + f_write.write("{0}\t{1}\n".format(key, reduced_conts)) + +if __name__ == "__main__": + srcfn = sys.argv[1] + dstfn = sys.argv[2] + normalize_text(srcfn, dstfn) \ No newline at end of file diff --git a/slam_llm/utils/memory_utils.py b/slam_llm/utils/memory_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..19e3e7b16c59d5bc77ebe710e39bba8c708d0956 --- /dev/null +++ b/slam_llm/utils/memory_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import gc +import psutil +import threading + +import torch + +def byte2gb(x): + return int(x / 2**30) +# This context manager is used to track the peak memory usage of the process +class MemoryTrace: + def __enter__(self): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = byte2gb(torch.cuda.memory_allocated()) + self.process = psutil.Process() + self.cpu_begin = byte2gb(self.cpu_mem_used()) + self.peak_monitoring = True + peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) + peak_monitor_thread.daemon = True + peak_monitor_thread.start() + return self + + def cpu_mem_used(self): + """get resident set size memory for the current process""" + return self.process.memory_info().rss + + def peak_monitor_func(self): + self.cpu_peak = -1 + + while True: + self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) + + # can't sleep or will not catch the peak right (this comment is here on purpose) + # time.sleep(0.001) # 1msec + + if not self.peak_monitoring: + break + + def __exit__(self, *exc): + self.peak_monitoring = False + + gc.collect() + torch.cuda.empty_cache() + self.end = byte2gb(torch.cuda.memory_allocated()) + self.peak = byte2gb(torch.cuda.max_memory_allocated()) + cuda_info = torch.cuda.memory_stats() + self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) + self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0) + self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) + self.m_cuda_ooms = cuda_info.get("num_ooms", 0) + self.used = byte2gb(self.end - self.begin) + self.peaked = byte2gb(self.peak - self.begin) + self.max_reserved = byte2gb(torch.cuda.max_memory_reserved()) + + self.cpu_end = self.cpu_mem_used() + self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin) + self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin) + # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") \ No newline at end of file diff --git a/slam_llm/utils/metric.py b/slam_llm/utils/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..44fd563e3aaabdf7236997fbaae2594ec9d20187 --- /dev/null +++ b/slam_llm/utils/metric.py @@ -0,0 +1,20 @@ +import torch + +def compute_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + + Args: + pad_outputs (LongTensor): Prediction tensors (B, Lmax). + pad_targets (LongTensor): Target label tensors (B, Lmax). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) + ) + denominator = torch.sum(mask) + return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type \ No newline at end of file diff --git a/slam_llm/utils/model_utils.py b/slam_llm/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1f61a6ce081d30e61d46c2b8e2a1c26f1d1edfc --- /dev/null +++ b/slam_llm/utils/model_utils.py @@ -0,0 +1,29 @@ +from slam_llm.utils.dataset_utils import load_module_from_py_file +from pathlib import Path + +def get_custom_model_factory(model_config): + costom_model_path = model_config.get( + "file", None + ) + if costom_model_path is None: + from slam_llm.models.slam_model import model_factory + return model_factory + + if ":" in costom_model_path: + module_path, func_name = costom_model_path.split(":") + else: + module_path, func_name = costom_model_path, "model_factory" + + if not module_path.endswith(".py"): + raise ValueError(f"Dataset file {module_path} is not a .py file.") + + module_path = Path(module_path) + if not module_path.is_file(): + raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.") + + module = load_module_from_py_file(module_path.as_posix()) + try: + return getattr(module, func_name) + except AttributeError as e: + raise e + diff --git a/slam_llm/utils/num2word.py b/slam_llm/utils/num2word.py new file mode 100644 index 0000000000000000000000000000000000000000..ab6b87812a18d2760cc28a64afd4c5ab1f483a29 --- /dev/null +++ b/slam_llm/utils/num2word.py @@ -0,0 +1,19 @@ + +import sys +from num2words import num2words + +file = sys.argv[1] +out_file = sys.argv[2] + +with open(file) as f: + lines = f.readlines() + +with open(out_file, "w") as fw: + for line in lines: + key, content = line.strip().split(maxsplit=1) + new_content = "" + for ct in content.split(): + if ct.isdigit(): + ct = num2words(ct) + new_content += ct + " " + fw.write(key + " " + new_content + "\n") \ No newline at end of file diff --git a/slam_llm/utils/preprocess_text.py b/slam_llm/utils/preprocess_text.py new file mode 100644 index 0000000000000000000000000000000000000000..87a22f0b66ce11852faff02d280a7ff48ea3aff8 --- /dev/null +++ b/slam_llm/utils/preprocess_text.py @@ -0,0 +1,37 @@ + +import sys +import re +import string + +in_f = sys.argv[1] +out_f = sys.argv[2] + + +with open(in_f, "r", encoding="utf-8") as f: + lines = f.readlines() + +with open(out_f, "w", encoding="utf-8") as f: + for line in lines: + outs = line.strip().split("\t", 1) + if len(outs) == 2: + idx, text = outs + text = re.sub("<|", "", text) + text = re.sub("|>", "", text) + text = re.sub("—", "", text) + # text = re.sub("<s>", "", text) + # text = re.sub("@@", "", text) + # text = re.sub("@", "", text) + # text = re.sub("<unk>", "", text) + # text = re.sub(" ", "", text) + # text = text.lower() + translator = str.maketrans('', '', string.punctuation.replace("'", "")) + result = text.translate(translator) + text = result.upper() + else: + idx = outs[0] + text = " " + + # text = [x for x in text] + # text = " ".join(text) + out = "{} {}\n".format(idx, text) + f.write(out) diff --git a/slam_llm/utils/train_utils.py b/slam_llm/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf160cacec9bfa2fe78f9f705a41fd0dc9c563e --- /dev/null +++ b/slam_llm/utils/train_utils.py @@ -0,0 +1,628 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +import time +import yaml +from contextlib import nullcontext +from pathlib import Path +from pkg_resources import packaging + + +import torch +import torch.cuda.nccl as nccl +import torch.distributed as dist +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp import StateDictType +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from tqdm import tqdm +from transformers import LlamaTokenizer + + +from slam_llm.utils.checkpoint_handler import ( + save_model_checkpoint, + save_model_and_optimizer_sharded, + save_optimizer_checkpoint, + save_model_checkpoint_peft, + save_model_checkpoint_peft_full_shard +) +from slam_llm.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper +from slam_llm.utils.memory_utils import MemoryTrace +from slam_llm.utils.metric import compute_accuracy + +import wandb +import logging +logger = logging.getLogger(__name__) + + +def set_tokenizer_params(tokenizer: LlamaTokenizer): + tokenizer.pad_token_id = 0 + tokenizer.padding_side = "left" + +# Converting Bytes to Megabytes +def byte2mb(x): + return int(x / 2**20) + +def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, log_config, fsdp_config=None, local_rank=None, rank=None): + """ + Trains the model on the given dataloader + + Args: + model: The model to be trained + train_dataloader: The dataloader containing the training data + optimizer: The optimizer used for training + lr_scheduler: The learning rate scheduler + gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation + num_epochs: The number of epochs to train for + local_rank: The rank of the current node in a distributed setting + train_config: The training configuration + log_config: The logging configuration + eval_dataloader: The dataloader containing the eval data + tokenizer: tokenizer used in the eval for decoding the predicitons + + Returns: results dictionary containing average training and validation perplexity and loss + """ + # Create a gradient scaler for fp16 + # if train_config.use_fp16 and train_config.enable_fsdp: + # scaler = ShardedGradScaler() + # elif train_config.use_fp16 and not train_config.enable_fsdp: + # scaler = torch.cuda.amp.GradScaler() + if train_config.use_fp16: + scaler = torch.cuda.amp.GradScaler() + if train_config.enable_fsdp: + scaler = ShardedGradScaler() + if train_config.enable_fsdp or train_config.enable_ddp: + world_size = int(os.environ["WORLD_SIZE"]) + autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext + + train_prep = [] + train_loss = [] + train_acc = [] + val_prep = [] + val_loss =[] + val_acc = [] + epoch_times = [] + checkpoint_times = [] + results = {} + best_val_loss = float("inf") + best_val_acc = 0.0 + for epoch in range(train_config.num_epochs): + epoch_start_time = time.perf_counter() + with MemoryTrace() as memtrace: # track the memory usage + model.train() + total_loss = 0.0 + total_acc = 0.0 + total_length = len(train_dataloader)//gradient_accumulation_steps + pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) + for step, batch in enumerate(train_dataloader): + for key in batch.keys(): + if train_config.enable_fsdp or train_config.enable_ddp: + batch[key] = batch[key].to(local_rank) if isinstance(batch[key], torch.Tensor) else batch[key] + if isinstance(batch[key], dict): + for k2 in batch[key].keys(): + batch[key][k2] = batch[key][k2].to(local_rank) if isinstance(batch[key][k2], torch.Tensor) else batch[key][k2] + else: + batch[key] = batch[key].to('cuda:0') if isinstance(batch[key], torch.Tensor) else batch[key] + if isinstance(batch[key], dict): + for k2 in batch[key].keys(): + batch[key][k2] = batch[key][k2].to('cuda:0') if isinstance(batch[key][k2], torch.Tensor) else batch[key][k2] + with autocast(): + outputs, *rest = model(**batch) + acc = rest[0] if rest else -1 + audio_acc = rest[1] if rest else -1 # seven layers of audio acc + layer_loss = rest[2] if rest else -1 # eight layers of loss (seven audio and one text) + loss = outputs.loss + + loss = loss / gradient_accumulation_steps + layer_loss = [l / gradient_accumulation_steps for l in layer_loss] + acc = acc / gradient_accumulation_steps + audio_acc = [a / gradient_accumulation_steps for a in audio_acc] + + if log_config.use_wandb and step % log_config.log_interval == 0: + if train_config.enable_fsdp or train_config.enable_ddp: + if rank==0: + wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_text_accuracy":acc}, step=(epoch * total_length + step)) + for layer, acc in enumerate(audio_acc): + wandb.log({f"train_inner/train_inner_audio_accuracy_layer{layer}":acc}, step=(epoch * total_length + step)) + for layer, l in enumerate(layer_loss[:-1]): + wandb.log({f"train_inner/train_inner_audio_loss_layer{layer}":l}, step=(epoch * total_length + step)) + wandb.log({f"train_inner/train_inner_text_loss":layer_loss[-1]}, step=(epoch * total_length + step)) + else: + wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step)) + for layer, acc in enumerate(audio_acc): + wandb.log({f"train_inner/train_inner_audio_accuracy_layer{layer}":acc}, step=(epoch * total_length + step)) + for layer, l in enumerate(layer_loss[:-1]): + wandb.log({f"train_inner/train_inner_audio_loss_layer{layer}":l}, step=(epoch * total_length + step)) + wandb.log({f"train_inner/train_inner_text_loss":layer_loss[-1]}, step=(epoch * total_length + step)) + + total_loss += loss.detach().float() + total_acc += acc + if train_config.use_fp16: + # if fp16 is enabled, use gradient scaler to handle gradient update + scaler.scale(loss).backward() + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + scaler.step(optimizer) + scaler.update() + if lr_scheduler is not None: + lr_scheduler.step() + current_lr = lr_scheduler.get_last_lr()[0] + else: + current_lr = optimizer.param_groups[0]["lr"] + if current_lr == 0: + break + if log_config.use_wandb and step % log_config.log_interval == 0: + if train_config.enable_fsdp or train_config.enable_ddp: + if rank==0: + wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step)) + else: + wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step)) + optimizer.zero_grad() + pbar.update(1) + else: + # regular backpropagation when fp16 is not used + loss.backward() + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + if lr_scheduler is not None: + lr_scheduler.step() + current_lr = lr_scheduler.get_last_lr()[0] + else: + current_lr = optimizer.param_groups[0]["lr"] + if current_lr == 0: + break + if log_config.use_wandb and step % log_config.log_interval == 0: + if train_config.enable_fsdp or train_config.enable_ddp: + if rank==0: + wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step)) + else: + wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step)) + optimizer.zero_grad() + pbar.update(1) + + pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()}, acc: {acc})") + + if (epoch * total_length + step + 1) % train_config.validation_interval == 0 and train_config.run_validation: + eval_ppl, eval_epoch_loss, *rest = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) + eval_epoch_acc = rest[0] if rest else -1 + checkpoint_start_time = time.perf_counter() + if train_config.save_model and (eval_epoch_loss < best_val_loss): + checkpoint_name = f"{train_config.model_name}_epoch_{str(epoch+1)}_step_{step+1}" + if train_config.enable_fsdp or train_config.enable_ddp: + dist.barrier() + if train_config.use_peft: + if train_config.enable_fsdp or train_config.enable_ddp: + if rank==0: + logger.info(f"we are about to save the PEFT modules") + else: + logger.info(f"we are about to save the PEFT modules") + if train_config.enable_fsdp: + if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD: + save_model_checkpoint_peft_full_shard( + model, optimizer, rank, train_config, epoch=epoch + ) + elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD: + if rank==0: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, checkpoint_name=checkpoint_name + ) + dist.barrier() + elif train_config.enable_ddp: + if rank==0: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, checkpoint_name=checkpoint_name + ) + dist.barrier() + else: + # model.save_pretrained(train_config.output_dir) + save_model_checkpoint_peft( + model, optimizer, rank, train_config, checkpoint_name=checkpoint_name + ) + if train_config.enable_fsdp or train_config.enable_ddp: + if rank==0: + logger.info(f"PEFT modules are saved in {train_config.output_dir} directory") + else: + logger.info(f"PEFT modules are saved in {train_config.output_dir} directory") + + elif not train_config.use_peft and train_config.freeze_llm: + logger.info(f"llm is frozen, we are about to save other parts.") + if train_config.enable_fsdp: + if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD: + save_model_checkpoint_peft_full_shard( + model, optimizer, rank, train_config, epoch=epoch + ) + elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD: + if rank==0: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, checkpoint_name=checkpoint_name + ) + dist.barrier() + elif train_config.enable_ddp: + if rank==0: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, checkpoint_name=checkpoint_name + ) + dist.barrier() + else: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, checkpoint_name=checkpoint_name + ) + + else: + if train_config.enable_fsdp: + if getattr(StateDictType, fsdp_config.checkpoint_type) == StateDictType.FULL_STATE_DICT: + save_model_checkpoint( + model, optimizer, rank, train_config, epoch=epoch + ) + elif getattr(StateDictType, fsdp_config.checkpoint_type) == StateDictType.SHARDED_STATE_DICT: + logger.info(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") + logger.info("=====================================================") + + save_model_and_optimizer_sharded(model, rank, train_config) + if train_config.save_optimizer: + save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) + logger.info(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT") + logger.info("=====================================================") + + if train_config.save_optimizer: + save_optimizer_checkpoint( + model, optimizer, rank, train_config, epoch=epoch + ) + logger.info(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT") + logger.info("=====================================================") + + elif train_config.enable_ddp: + if rank==0: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, checkpoint_name=checkpoint_name + ) + dist.barrier() + + else: + save_model_checkpoint_peft( + model, optimizer, rank, train_config, checkpoint_name=checkpoint_name + ) + + if train_config.enable_fsdp or train_config.enable_ddp: + dist.barrier() + checkpoint_end_time = time.perf_counter() - checkpoint_start_time + checkpoint_times.append(checkpoint_end_time) + if eval_epoch_loss < best_val_loss: + best_val_loss = eval_epoch_loss + if train_config.enable_fsdp or train_config.enable_ddp: + if rank==0: + logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}") + else: + logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}") + val_loss.append(eval_epoch_loss) + val_prep.append(eval_ppl) + if rest: + if eval_epoch_acc > best_val_acc: + best_val_acc = eval_epoch_acc + if train_config.enable_fsdp or train_config.enable_ddp: + if rank==0: + logger.info(f"best eval acc on epoch {epoch+1} is {best_val_acc}") + else: + logger.info(f"best eval acc on epoch {epoch+1} is {best_val_acc}") + val_acc.append(rest[0]) + else: + val_acc.append(-1) + + if log_config.use_wandb: + if train_config.enable_fsdp or train_config.enable_ddp: + if rank==0: + wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1], "valid/val_best_accuracy":best_val_acc}) + else: + wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1], "valid/val_best_accuracy":best_val_acc}) + + if train_config.run_test_during_validation: + if train_config.enable_fsdp or train_config.enable_ddp: + if rank==0: + logger.info("=====================================") + logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:") + with autocast(): + logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) + logger.info("=====================================") + dist.barrier() + else: + logger.info("=====================================") + logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:") + with autocast(): + logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt)) + logger.info("=====================================") + pbar.close() + + epoch_end_time = time.perf_counter()-epoch_start_time + epoch_times.append(epoch_end_time) + # Reducing total_loss across all devices if there's more than one CUDA device + if torch.cuda.device_count() > 1 and (train_config.enable_fsdp or train_config.enable_ddp): + dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(total_acc, op=dist.ReduceOp.SUM) + train_epoch_loss = total_loss / len(train_dataloader) + train_epoch_acc = total_acc / len(train_dataloader) + if train_config.enable_fsdp or train_config.enable_ddp: + train_epoch_loss = train_epoch_loss/world_size + train_epoch_acc = train_epoch_acc/world_size + train_perplexity = torch.exp(train_epoch_loss) + + train_prep.append(train_perplexity) + train_loss.append(train_epoch_loss) + train_acc.append(train_epoch_acc) + + if log_config.use_wandb: + if train_config.enable_fsdp or train_config.enable_ddp: + if rank==0: + wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc}) + else: + wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc}) + + if train_config.enable_fsdp or train_config.enable_ddp: + if rank==0: + logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") + else: + logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") + + if train_config.enable_fsdp: + if rank==0: + logger.info(f"Max CUDA memory allocated was {memtrace.peak} GB") + logger.info(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") + logger.info(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") + logger.info(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") + logger.info(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") + else: + logger.info(f"Max CUDA memory allocated was {memtrace.peak} GB") + logger.info(f"Max CUDA memory reserved was {memtrace.max_reserved} GB") + logger.info(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB") + logger.info(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}") + logger.info(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB") + + # Update the learning rate as needed + # lr_scheduler.step() + + avg_epoch_time = sum(epoch_times)/ len(epoch_times) + avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0 + avg_train_prep = sum(train_prep)/len(train_prep) + avg_train_loss = sum(train_loss)/len(train_loss) + avg_train_acc = sum(train_acc)/len(train_acc) + if train_config.run_validation: + avg_eval_prep = sum(val_prep)/len(val_prep) + avg_eval_loss = sum(val_loss)/len(val_loss) + avg_eval_acc = sum(val_acc)/len(val_acc) + + results['avg_train_prep'] = avg_train_prep + results['avg_train_loss'] = avg_train_loss + results['avg_train_acc'] = avg_train_acc + if train_config.run_validation: + results['avg_eval_prep'] = avg_eval_prep + results['avg_eval_loss'] = avg_eval_loss + results['avg_eval_acc'] = avg_eval_acc + results["avg_epoch_time"] = avg_epoch_time + results["avg_checkpoint_time"] = avg_checkpoint_time + + #saving the training params including fsdp setting for reference. + # if (train_config.enable_fsdp or train_config.enable_ddp)and not train_config.use_peft: + # save_train_params(train_config, fsdp_config, rank) + + return results + +def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer): + """ + Evaluates the model on the given dataloader + + Args: + model: The model to evaluate + eval_dataloader: The dataloader containing the evaluation data + local_rank: The rank of the current node in a distributed setting + tokenizer: The tokenizer used to decode predictions + + Returns: eval_ppl, eval_epoch_loss + """ + if train_config.enable_fsdp or train_config.enable_ddp: + world_size = int(os.environ["WORLD_SIZE"]) + model.eval() + eval_preds = [] + eval_loss = 0.0 # Initialize evaluation loss + eval_acc = 0.0 + autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext # (Fix:MZY): fix expected scalar type mismatch in norm + + with MemoryTrace() as memtrace: + total_length = len(eval_dataloader) + pbar = tqdm(colour="green", desc=f"Evaluating Epoch", total=total_length, dynamic_ncols=True) + for step, batch in enumerate(eval_dataloader): + for key in batch.keys(): + if train_config.enable_fsdp or train_config.enable_ddp: + batch[key] = batch[key].to(local_rank) if isinstance(batch[key], torch.Tensor) else batch[key] + else: + batch[key] = batch[key].to('cuda:0') if isinstance(batch[key], torch.Tensor) else batch[key] + # Ensure no gradients are computed for this scope to save memory + with torch.no_grad(): + # Forward pass and compute loss + with autocast(): # (Fix:MZY): fix expected scalar type mismatch in norm + outputs, *rest = model(**batch) + acc = rest[0] if rest else -1 + loss = outputs.loss + + eval_loss += loss.detach().float() + eval_acc += acc + # Decode predictions and add to evaluation predictions list + try: + preds = torch.argmax(outputs.logits, -1) + eval_preds.extend( + tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True) + ) + except Exception: + pass # vallex does not need to show it's result (we can't view any thing from abstract acoustic token) + pbar.update(1) + pbar.set_description(f"step: {step+1}/{total_length}, eval_loss: {eval_loss/(step+1):.4f}, eval_acc: {eval_acc/(step+1):.4f}") + + # If there's more than one CUDA device, reduce evaluation loss across all devices + if torch.cuda.device_count() > 1 and train_config.enable_fsdp or train_config.enable_ddp: + dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(eval_acc, op=dist.ReduceOp.SUM) + + # Compute average loss and perplexity + eval_epoch_loss = eval_loss / len(eval_dataloader) + eval_epoch_acc = eval_acc / len(eval_dataloader) + if train_config.enable_fsdp or train_config.enable_ddp: + eval_epoch_loss = eval_epoch_loss/world_size + eval_epoch_acc = eval_epoch_acc/world_size + eval_ppl = torch.exp(eval_epoch_loss) + + # Print evaluation metrics + if train_config.enable_fsdp or train_config.enable_ddp: + if local_rank==0: + logger.info(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}") + else: + logger.info(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}") + + return eval_ppl, eval_epoch_loss, eval_epoch_acc + +def freeze_transformer_layers(model, num_layer): + for i, layer in enumerate(model.model.layers): + if i < num_layer: + for param in layer.parameters(): + param.requires_grad = False + + +def check_frozen_layers_peft_model(model): + for i, layer in enumerate(model.base_model.model.model.layers): + for name, param in layer.named_parameters(): + logger.info(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}") + + +def setup(): + """Initialize the process group for distributed training""" + dist.init_process_group("nccl") + + +def setup_environ_flags(rank): + """Set environment flags for debugging purposes""" + os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) + # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases. + # Note this is only availble in PyTorch Nighlies (as of July 30 2023) + # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' + if rank == 0: + logger.info(f"--> Running with torch dist debug set to detail") + + +def cleanup(): + """Clean up the process group after training""" + dist.destroy_process_group() + + +def clear_gpu_cache(rank=None): + """Clear the GPU cache for all ranks""" + if rank == 0: + logger.info(f"Clearing GPU cache for all ranks") + torch.cuda.empty_cache() + + +def get_parameter_dtypes(model): + """Get the data types of model parameters""" + parameter_dtypes = {} + for name, parameter in model.named_parameters(): + parameter_dtypes[name] = parameter.dtype + return parameter_dtypes + +def print_model_size(model, config, rank: int = 0) -> None: + """ + log model name, the number of trainable parameters and initialization time. + + Args: + model: The PyTorch model. + model_name (str): Name of the model. + init_time_start (float): Initialization start time. + init_time_end (float): Initialization end time. + rank (int, optional): Current process's rank. Defaults to 0. + """ + if rank == 0: + logger.info(f"--> Model {config.model_name}") + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"--> {config.model_name} has {total_params / 1e6} Million params\n") + +def print_module_size(module, module_name, rank: int = 0) -> None: + """ + Print module name, the number of trainable parameters and initialization time. + + Args: + module: The PyTorch module. + module_name (str): Name of the model. + rank (int, optional): Current process's rank. Defaults to 0. + """ + if rank == 0: + logger.info(f"--> Module {module_name}") + total_params = sum(p.numel() for p in module.parameters() if p.requires_grad) + logger.info(f"--> {module_name} has {total_params / 1e6} Million params\n") + + +def get_policies(cfg, rank): + """Get the policies for mixed precision and fsdp wrapping""" + + verify_bfloat_support = ( + torch.version.cuda + and torch.cuda.is_bf16_supported() + and packaging.version.parse(torch.version.cuda).release >= (11, 0) + and dist.is_nccl_available() + and nccl.version() >= (2, 10) + ) + + + mixed_precision_policy = None + wrapping_policy = None + + # Mixed precision + if cfg.mixed_precision: + bf16_ready = verify_bfloat_support + + if bf16_ready and not cfg.use_fp16: + mixed_precision_policy = bfSixteen_mixed + if rank == 0: + logger.info(f"bFloat16 enabled for mixed precision - using bfSixteen policy") + elif cfg.use_fp16: + mixed_precision_policy = fpSixteen + if rank == 0: + logger.info(f"FP16 enabled") + else: + logger.info(f"bFloat16 support not present. Using FP32, and not mixed precision") + wrapping_policy = get_llama_wrapper() + return mixed_precision_policy, wrapping_policy + +def save_train_params(train_config, fsdp_config, rank): + """ + This function saves the train_config and FSDP config into a train_params.yaml. + This will be used by converter script in the inference folder to fetch the HF model name or path. + It also would be hepful as a log for future references. + """ + # Convert the train_config and fsdp_config objects to dictionaries, + # converting all values to strings to ensure they can be serialized into a YAML file + train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')} + fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')} + # Merge the two dictionaries into one + train_params_dict = {**train_config_dict, **fsdp_config_dict} + # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object + folder_name = ( + train_config.dist_checkpoint_root_folder + + "/" + + train_config.dist_checkpoint_folder + + "-" + + train_config.model_name + ) + + save_dir = Path.cwd() / folder_name + # If the directory does not exist, create it + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # Convert the dictionary to a YAML string + config_yaml = yaml.dump(train_params_dict, indent=4) + file_name = os.path.join(save_dir,'train_params.yaml') + + # Check if there's a directory with the same name as the file + if os.path.isdir(file_name): + logger.info(f"Error: {file_name} is a directory, not a file.") + else: + # Write the YAML string to the file + with open(file_name, 'w') as f: + f.write(config_yaml) + if rank==0: + logger.info(f"training params are saved in {file_name}") diff --git a/slam_llm/utils/whisper_tn.py b/slam_llm/utils/whisper_tn.py new file mode 100644 index 0000000000000000000000000000000000000000..6748d6920f0fa2bbb42aee87848718507e488eb3 --- /dev/null +++ b/slam_llm/utils/whisper_tn.py @@ -0,0 +1,23 @@ +import sys +import os +import re +import string +from whisper_normalizer.english import EnglishTextNormalizer + +english_normalizer = EnglishTextNormalizer() + +def normalize_text(srcfn, dstfn): + with open(srcfn, "r") as f_read, open(dstfn, "w") as f_write: + all_lines = f_read.readlines() + for line in all_lines: + line = line.strip() + line_arr = line.split() + key = line_arr[0] + conts = " ".join(line_arr[1:]) + normalized_conts = english_normalizer(conts) + f_write.write("{0}\t{1}\n".format(key, normalized_conts)) + +if __name__ == "__main__": + srcfn = sys.argv[1] + dstfn = sys.argv[2] + normalize_text(srcfn, dstfn) \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/codec_utils.py b/utils/codec_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b576414a59db176a84d548e0eae83b2da601a503 --- /dev/null +++ b/utils/codec_utils.py @@ -0,0 +1,12 @@ +from snac import SNAC +from slam_llm.utils.train_utils import print_module_size +import os + +def setup_codec(train_config, model_config, **kwargs): + if model_config.codec_decoder_type == "SNAC": + codec_decoder = SNAC.from_pretrained(model_config.codec_decoder_path).eval() + else: + raise NotImplementedError + print_module_size(codec_decoder, model_config.codec_decoder_type, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) + + return codec_decoder \ No newline at end of file diff --git a/utils/snac_utils.py b/utils/snac_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05d74a0176deac04c77ecc522ece0f4c20bbd4be --- /dev/null +++ b/utils/snac_utils.py @@ -0,0 +1,168 @@ +import torch +import time +import numpy as np + + +class SnacConfig: + audio_vocab_size = 4096 + padded_vocab_size = 4160 + end_of_audio = 4096 + padding_token = 4097 + + +snac_config = SnacConfig() + + +def get_time_str(): + time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime()) + return time_str + + +def layershift(input_id, layer, stride=4160, shift=152000): + return input_id + shift + layer * stride + + +def generate_audio_data(snac_tokens, snacmodel, device=None): + audio = reconstruct_tensors(snac_tokens, device) + with torch.inference_mode(): + audio_hat = snacmodel.decode(audio) + audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0 + audio_data = audio_data.astype(np.int16) + audio_data = audio_data.tobytes() + return audio_data + + +def get_snac(list_output, index, nums_generate): + + snac = [] + start = index + for i in range(nums_generate): + snac.append("#") + for j in range(7): + snac.append(list_output[j][start - nums_generate - 5 + j + i]) + return snac + + +def reconscruct_snac(output_list): + if len(output_list) == 8: + output_list = output_list[:-1] + output = [] + for i in range(7): + output_list[i] = output_list[i][i + 1 :] + for i in range(len(output_list[-1])): + output.append("#") + for j in range(7): + output.append(output_list[j][i]) + return output + + +def get_snac_answer_token(snac_tokens_str): + snac_tokens = snac_tokens_str.split() + audio_length = len(snac_tokens) // 8 + 8 # here the additional 8 is due to parallel generation, 7 padding tokens and 1 end of audio token + snac_config = SnacConfig() + eoa = snac_config.end_of_audio + padding_token = snac_config.padding_token + result = [] + + for layer in range(1, 8): # 从第1层到第7层 + layer_tokens = [] + layer_tokens.extend([padding_token] * layer) + layer_tokens.extend([snac_tokens[i] for i in range(len(snac_tokens)) if i % 8 == layer]) + layer_tokens.append(eoa) + if layer < 7: + layer_tokens.extend([padding_token] * (7 - layer)) + result.append(torch.tensor([int(token) for token in layer_tokens])) + + result_tensor = torch.stack(result) + return result_tensor, audio_length + + +def reconstruct_tensors(flattened_output, device=None): + """Reconstructs the list of tensors from the flattened output.""" + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def count_elements_between_hashes(lst): + try: + # Find the index of the first '#' + first_index = lst.index("#") + # Find the index of the second '#' after the first + second_index = lst.index("#", first_index + 1) + # Count the elements between the two indices + return second_index - first_index - 1 + except ValueError: + # Handle the case where there aren't enough '#' symbols + return "List does not contain two '#' symbols" + + def remove_elements_before_hash(flattened_list): + try: + # Find the index of the first '#' + first_hash_index = flattened_list.index("#") + # Return the list starting from the first '#' + return flattened_list[first_hash_index:] + except ValueError: + # Handle the case where there is no '#' + return "List does not contain the symbol '#'" + + def list_to_torch_tensor(tensor1): + # Convert the list to a torch tensor + tensor = torch.tensor(tensor1) + # Reshape the tensor to have size (1, n) + tensor = tensor.unsqueeze(0) + return tensor + + flattened_output = remove_elements_before_hash(flattened_output) + codes = [] + tensor1 = [] + tensor2 = [] + tensor3 = [] + tensor4 = [] + + n_tensors = count_elements_between_hashes(flattened_output) + if n_tensors == 7: + for i in range(0, len(flattened_output), 8): + + tensor1.append(flattened_output[i + 1]) + tensor2.append(flattened_output[i + 2]) + tensor3.append(flattened_output[i + 3]) + tensor3.append(flattened_output[i + 4]) + + tensor2.append(flattened_output[i + 5]) + tensor3.append(flattened_output[i + 6]) + tensor3.append(flattened_output[i + 7]) + codes = [ + list_to_torch_tensor(tensor1).to(device), + list_to_torch_tensor(tensor2).to(device), + list_to_torch_tensor(tensor3).to(device), + ] + + if n_tensors == 15: + for i in range(0, len(flattened_output), 16): + + tensor1.append(flattened_output[i + 1]) + tensor2.append(flattened_output[i + 2]) + tensor3.append(flattened_output[i + 3]) + tensor4.append(flattened_output[i + 4]) + tensor4.append(flattened_output[i + 5]) + tensor3.append(flattened_output[i + 6]) + tensor4.append(flattened_output[i + 7]) + tensor4.append(flattened_output[i + 8]) + + tensor2.append(flattened_output[i + 9]) + tensor3.append(flattened_output[i + 10]) + tensor4.append(flattened_output[i + 11]) + tensor4.append(flattened_output[i + 12]) + tensor3.append(flattened_output[i + 13]) + tensor4.append(flattened_output[i + 14]) + tensor4.append(flattened_output[i + 15]) + + codes = [ + list_to_torch_tensor(tensor1).to(device), + list_to_torch_tensor(tensor2).to(device), + list_to_torch_tensor(tensor3).to(device), + list_to_torch_tensor(tensor4).to(device), + ] + + return codes + diff --git a/utils/trick_utils.py b/utils/trick_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ad178272b059301174a832e7d82f5e8ac9e537 --- /dev/null +++ b/utils/trick_utils.py @@ -0,0 +1,45 @@ +import torch +import os +import logging + +logger = logging.getLogger(__name__) + +def partial_freeze_weights(model, original_vocabsize, total_vocabsize): + if int(os.environ.get("RANK", "0")) == 0: + logger.info("Only training partial embedding layer") + + trainable_range = (original_vocabsize, total_vocabsize) + + # Define a hook to zero out the gradient for weights outside the trainable range during the backward pass + def zero_out_gradient(grad): + grad[:trainable_range[0], :] = 0 + grad[trainable_range[1] + 1:, :] = 0 + return grad + + # Freeze all layers first + for param in model.parameters(): + param.requires_grad = False + + # Assuming the output layer is `lm_head` + for param in model.llm.lm_head.parameters(): + # Compute the standard deviation for He initialization + std_dev = (2.0 / param.size(1)) ** 0.5 + + # Initialize the specific rows with He initialization + param[original_vocabsize:total_vocabsize] = ( + torch.randn((trainable_range[1] - trainable_range[0], param.size(1))) * std_dev + ) + param.requires_grad = True + + # Register the hook on the weight tensor + param.register_hook(zero_out_gradient) + +def train_embedding_layer_only(model): + if int(os.environ.get("RANK", "0")) == 0: + logger.info("Only training embedding layer") + + for param in model.parameters(): + param.requires_grad = False + + for param in model.llm.lm_head.parameters(): + param.requires_grad = True \ No newline at end of file diff --git a/utils/tts_adapter_utils.py b/utils/tts_adapter_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..866257130ecdea134cf723b1ce38246f77a2033b --- /dev/null +++ b/utils/tts_adapter_utils.py @@ -0,0 +1,323 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Full definition of a decoder-only transformer-based language model, all of it in this single file. + +Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and +https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. +""" + +import math +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn + +def setup_tts_adapter(adapter_config, model_config, **kwargs): + return nn.ModuleDict( + dict( + post_adapter=nn.ModuleList( + Block(adapter_config) for _ in range(adapter_config.n_layer) + ), + post_adapter_audio_ln=adapter_config.norm_class( + model_config.llm_dim, eps=adapter_config.norm_eps + ), + post_adapter_audio_lm_head=nn.Linear( + model_config.llm_dim, model_config.vocab_config.total_audio_vocabsize, bias=adapter_config.lm_head_bias + ), + ) + ) + +class Block(nn.Module): + + def __init__(self, config) -> None: + super().__init__() + if not config.parallel_residual and config.shared_attention_norm: + raise NotImplementedError( + "No checkpoint amongst the ones we support uses this configuration" + " (non-parallel residual and shared attention norm)." + ) + + if config.norm_class_name == "RMSNorm": + self.norm_class = RMSNorm + + self.norm_1 = self.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config) + self.norm_2 = ( + None + if config.shared_attention_norm + else self.norm_class(config.n_embd, eps=config.norm_eps) + ) + + if config.mlp_class_name == "GptNeoxMLP": + self.mlp_class = GptNeoxMLP + + self.mlp = self.mlp_class(config) + + self.config = config + + def forward( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Non-parallel residual Parallel residual + ┌─ x ┌─ x ────────────┐ Note: if `shared_attention_norm` is True, + │ ↓ │ ↓ ↓ the output from `norm_1` is reused + │ norm_1 │ norm_1 ───► norm_2 + │ ↓ │ ↓ ↓ + │ attn │ attn mlp + │ ↓ │ ↓ │ + ┌─ └► + └► + ◄───────────┘ + │ norm_2 + │ ↓ + │ mlp + │ ↓ + └───► + + """ + + x_normed = self.norm_1(x) + attention_output = self.attn(x_normed, cos, sin, mask, input_pos) + + if self.config.parallel_residual: + x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x) + x = self.mlp(x_normed) + attention_output + x + else: + x = attention_output + x + x = self.mlp(self.norm_2(x)) + x + return x + + +class CausalSelfAttention(nn.Module): + def __init__(self, config) -> None: + super().__init__() + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = nn.Linear(config.n_embd, shape, bias=config.add_qkv_bias) + # output projection + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` + self.proj = nn.Linear( + config.head_size * config.n_head, config.n_embd, bias=config.bias + ) + # disabled by default + self.kv_cache: Optional[KVCache] = None + + self.config = config + + def forward( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + B, T, C = ( + x.size() + ) # batch size, sequence length, embedding dimensionality (n_embd) + + qkv = self.attn(x) + + # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) + q_per_kv = self.config.n_head // self.config.n_query_groups + total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value + qkv = qkv.view( + B, T, self.config.n_query_groups, total_qkv, self.config.head_size + ) + qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) + + # split batched computation into three + q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) + + # maybe repeat k and v if for the non multi-head attention cases + # training: flash attention requires it + # inference: multi-query would require a full kv cache so avoid it to limit its memory usage + if self.config.n_query_groups != self.config.n_head and ( + input_pos is None or self.config.n_query_groups != 1 + ): + k = k.expand( + B, self.config.n_query_groups, q_per_kv, T, self.config.head_size + ) + v = v.expand( + B, self.config.n_query_groups, q_per_kv, T, self.config.head_size + ) + + q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) + k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) + v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) + + q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) + k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) + q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) + k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) + + if input_pos is not None: + if not isinstance(self.kv_cache, KVCache): + raise TypeError("You need to call `gpt.set_kv_cache()`") + k, v = self.kv_cache(input_pos, k, v) + + y = self.scaled_dot_product_attention(q, k, v, mask) + + y = y.reshape( + B, T, self.config.head_size * self.config.n_head + ) # re-assemble all head outputs side by side + + # output projection + return self.proj(y) + + def scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + scale = 1.0 / math.sqrt(self.config.head_size) + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None + ) + return y.transpose(1, 2) + + def build_kv_cache( + self, + batch_size: int, + max_seq_length: int, + rope_cache_length: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> "KVCache": + heads = 1 if self.config.n_query_groups == 1 else self.config.n_head + v_shape = (batch_size, heads, max_seq_length, self.config.head_size) + if rope_cache_length is None: + if self.config.rotary_percentage != 1.0: + raise TypeError( + "Please pass the `rope_cache_length=gpt.cos.size(-1)` value" + ) + k_shape = v_shape + else: + k_shape = ( + batch_size, + heads, + max_seq_length, + rope_cache_length + self.config.head_size - self.config.rope_n_elem, + ) + return KVCache(k_shape, v_shape, device=device, dtype=dtype) + + + + +def build_rope_cache( + seq_len: int, + n_elem: int, + device: Optional[torch.device] = None, + base: int = 10000, + condense_ratio: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=device) / condense_ratio + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) + + return torch.cos(idx_theta), torch.sin(idx_theta) + + +def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + head_size = x.size(-1) + x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) + x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) + rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) + roped = (x * cos) + (rotated * sin) + return roped.to(dtype=x.dtype) + + +class KVCache(nn.Module): + def __init__( + self, + k_shape: Tuple[int, int, int, int], + v_shape: Tuple[int, int, int, int], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + self.register_buffer( + "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False + ) + self.register_buffer( + "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False + ) + + def forward( + self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # move the buffer to the activation dtype for when AMP is used + self.k = self.k.to(k.dtype) + self.v = self.v.to(v.dtype) + # update the cache + k = self.k.index_copy_(2, input_pos, k) + v = self.v.index_copy_(2, input_pos, v) + return k, v + + def reset_parameters(self) -> None: + torch.nn.init.zeros_(self.k) + torch.nn.init.zeros_(self.v) + + + +class RMSNorm(torch.nn.Module): + """Root Mean Square Layer Normalization. + + Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: + https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. + """ + + def __init__( + self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False + ) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(size)) + self.eps = eps + self.dim = dim + self.add_unit_offset = add_unit_offset + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + x = x.float() + # NOTE: the original RMSNorm paper implementation is not equivalent + norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + x_normed = x_normed.to(dtype=dtype) + if self.add_unit_offset: + # Gemma model requires a unit offset + # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176 + return x_normed * (1 + self.weight) + return x_normed * self.weight + + def reset_parameters(self) -> None: + torch.nn.init.ones_(self.weight) + + +class GptNeoxMLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + + self.config = config + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate) + return self.proj(x) \ No newline at end of file