diff --git a/model/slam_model_s2s.py b/model/slam_model_s2s.py
new file mode 100644
index 0000000000000000000000000000000000000000..65ab83aa911f642aacb837d713fbe9f43801fcf2
--- /dev/null
+++ b/model/slam_model_s2s.py
@@ -0,0 +1,444 @@
+import torch
+import os
+import logging
+import torch.nn.functional as F
+from slam_llm.models.slam_model import (
+    slam_model,
+    setup_tokenizer,
+    setup_encoder,
+    setup_encoder_projector,
+    setup_llm,
+)
+from slam_llm.utils.train_utils import print_model_size
+from typing import List, Optional
+from slam_llm.utils.metric import compute_accuracy
+from transformers import T5ForConditionalGeneration
+from tqdm import tqdm
+from utils.tts_adapter_utils import setup_tts_adapter
+from utils.codec_utils import setup_codec
+from utils.trick_utils import partial_freeze_weights, train_embedding_layer_only
+from utils.snac_utils import layershift
+
+logger = logging.getLogger(__name__)
+
+
+def model_factory(train_config, model_config, ckpt_path, **kwargs):
+    # return necessary components for training
+    tokenizer = setup_tokenizer(train_config, model_config, **kwargs)
+
+    if train_config.task_type == "s2s" or train_config.task_type == "asr":
+        encoder = setup_encoder(train_config, model_config, **kwargs)
+    elif train_config.task_type == "tts":
+        encoder = None
+    else:
+        raise NotImplementedError
+
+    # llm
+    llm = setup_llm(train_config, model_config, **kwargs)
+
+    # projector
+    if encoder is not None:
+        encoder_projector = setup_encoder_projector(
+            train_config, model_config, **kwargs
+        )
+    else:
+        encoder_projector = None
+
+    codec_decoder = None
+    if model_config.codec_decode:
+        codec_decoder = setup_codec(train_config, model_config, **kwargs)
+
+    tts_adapter = None
+    if model_config.tts_adapter:
+        adapter_config = model_config.tts_adapter_config
+        tts_adapter = setup_tts_adapter(adapter_config, model_config, **kwargs)
+
+    model = slam_model_s2s(
+        encoder,
+        llm,
+        encoder_projector,
+        tokenizer,
+        tts_adapter,
+        codec_decoder,
+        train_config,
+        model_config,
+        **kwargs,
+    )
+
+    if ckpt_path is not None:
+        logger.info("loading other parts from: {}".format(ckpt_path))
+        ckpt_dict = torch.load(ckpt_path, map_location="cpu")
+        model.load_state_dict(ckpt_dict, strict=False)
+
+    if train_config.train_audio_embed_only:
+        partial_freeze_weights(model, model_config.vocab_config.padded_text_vocabsize, model_config.vocab_config.total_vocabsize)
+
+    if train_config.train_embed_only:
+        train_embedding_layer_only(model)
+
+    print_model_size(
+        model,
+        train_config,
+        (
+            int(os.environ["RANK"])
+            if train_config.enable_fsdp or train_config.enable_ddp
+            else 0
+        ),
+    )
+    return model, tokenizer
+
+
+class slam_model_s2s(slam_model):
+    def __init__(
+        self,
+        encoder,
+        llm,
+        encoder_projector,
+        tokenizer,
+        tts_adapter,
+        codec_decoder,
+        train_config,
+        model_config,
+        **kwargs,
+    ):
+        super().__init__(
+            encoder,
+            llm,
+            encoder_projector,
+            tokenizer,
+            train_config,
+            model_config,
+            **kwargs,
+        )
+
+        # resize llm embedding layer
+        self.original_vocabsize = self.llm.lm_head.weight.size(0)
+        if self.model_config.vocab_config.total_vocabsize != self.original_vocabsize:
+            self.llm.resize_token_embeddings(self.model_config.vocab_config.total_vocabsize)
+
+            if int(os.environ.get("RANK", "0")) == 0:
+                logger.info("Resize llm embedding layer's vocab size to {}".format(self.model_config.vocab_config.total_vocabsize))
+
+        self.codec_decoder = codec_decoder
+        self.tts_adapter = tts_adapter
+        self.code_layer = self.model_config.vocab_config.code_layer
+
+
+    def forward(self,
+                input_ids: torch.LongTensor = None,
+                attention_mask: Optional[torch.Tensor] = None,
+                position_ids: Optional[torch.LongTensor] = None,
+                past_key_values: Optional[List[torch.FloatTensor]] = None,
+                inputs_embeds: Optional[torch.FloatTensor] = None,
+                labels: Optional[torch.LongTensor] = None,
+                use_cache: Optional[bool] = None,
+                output_attentions: Optional[bool] = None,
+                output_hidden_states: Optional[bool] = None,
+                return_dict: Optional[bool] = None,
+                **kwargs,
+                ):
+        audio_mel = kwargs.get("audio_mel", None)
+        audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper
+
+        audio = kwargs.get("audio", None)
+        audio_mask = kwargs.get("audio_mask", None)
+
+        modality_mask = kwargs.get("modality_mask", None)
+
+        encoder_outs = None
+        if audio_mel is not None or audio is not None:
+            if self.train_config.freeze_encoder: # freeze encoder
+                self.encoder.eval()
+
+            if self.model_config.encoder_name == "whisper":
+                encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim
+            if self.model_config.encoder_name == "wavlm":
+                encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask
+            if self.model_config.encoder_name == "hubert":
+                results = self.encoder(source = audio, padding_mask = 1-audio_mask)
+                if self.model_config.encoder_type == "pretrain":
+                    encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"]
+                if self.model_config.encoder_type == "finetune":
+                    encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"]
+                    encoder_outs = encoder_outs.transpose(0, 1)
+            if self.encoder is None:
+                encoder_outs = audio_mel if audio_mel is not None else audio
+
+            if self.model_config.encoder_projector == "q-former":
+                encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
+            if self.model_config.encoder_projector == "linear":
+                encoder_outs = self.encoder_projector(encoder_outs)
+            if self.model_config.encoder_projector == "cov1d-linear": 
+                encoder_outs = self.encoder_projector(encoder_outs)
+
+        if input_ids is not None:
+            input_ids[input_ids == -1] = 0  # [btz, 8, seq_length]
+
+            if isinstance(self.llm, T5ForConditionalGeneration):
+                inputs_embeds = self.llm.shared(input_ids)
+            else:
+                if hasattr(self.llm.model, "embed_tokens"):
+                    inputs_embeds = self.llm.model.embed_tokens(input_ids)  # [btz, 8, seq_length, emb_dim]
+                elif hasattr(self.llm.model.model, "embed_tokens"):
+                    inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
+                else:
+                    inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
+
+        if modality_mask is not None and encoder_outs is not None:
+            modality_mask = modality_mask.unsqueeze(1).repeat(1, self.code_layer, 1)  # [btz, 8, seq_length]
+            modality_mask_start_indices = (modality_mask == True).float().argmax(dim=2)
+            modality_lengths = torch.clamp(modality_mask.sum(dim=2), max=encoder_outs.shape[1]).tolist()
+
+            encoder_outs_pad = torch.zeros_like(inputs_embeds)
+            for i in range(encoder_outs.shape[0]):
+                for j in range(self.code_layer):
+                    start_idx = modality_mask_start_indices[i, j].item()
+                    length = modality_lengths[i][j]
+                    encoder_outs_pad[i, j, start_idx:start_idx+length] = encoder_outs[i, :length]
+            
+            inputs_embeds[:, :self.code_layer, :, :] = encoder_outs_pad[:, :self.code_layer, :, :] + inputs_embeds[:, :self.code_layer, :, :] * (~modality_mask[:, :, :, None])
+        
+        inputs_embeds = torch.mean(inputs_embeds, dim=1)  # [btz, seq_length, emb_dim], average over the 8 layers
+
+        if kwargs.get("inference_mode", False):
+            return inputs_embeds, attention_mask
+
+        text_labels = labels[:,self.code_layer] if labels is not None else None
+        audio_labels = labels[:, :self.code_layer] if labels is not None else None
+        model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=text_labels)    # here we use the text token layer as the target label
+
+        # parrallel generation
+        # TODO: add tts adapter forward
+        x_ori = model_outputs.logits
+        text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize
+        audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize
+        xt = x_ori[..., :text_vocab_size]
+        xa = []
+        for i in range(self.code_layer):
+            xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)])
+
+        loss_recorder = []
+        total_loss, loss_recorder = self.compute_parallel_loss(xt, text_labels, xa, audio_labels)
+        model_outputs.loss = total_loss
+
+        text_acc = -1
+        audio_acc = [-1 for _ in range(self.code_layer)]
+        if self.metric:
+            with torch.no_grad():
+                preds = torch.argmax(xt, -1)
+                text_acc = compute_accuracy(preds.detach()[:, :-1], text_labels.detach()[:, 1:], ignore_label=-100)
+
+                preds_audio = [torch.argmax(xa[i], -1) for i in range(self.code_layer)]
+                audio_acc = [compute_accuracy(preds_audio[i].detach()[:, :-1], audio_labels[:, i, 1:], ignore_label=-100) for i in range(self.code_layer)]
+
+        # metrics = {"text_acc": text_acc, "audio_acc": audio_acc, "layer_loss": loss_recorder}
+        return model_outputs, text_acc, audio_acc, loss_recorder
+
+
+
+    def compute_parallel_loss(self, xt, text_labels, xa, audio_labels):
+        """
+        Compute the parallel loss for text and audio layers.
+        """
+        text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize
+        audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize
+        layer_loss = [0 for _ in range(self.code_layer+1) ]
+        
+        if text_labels is not None:
+            # text_loss = F.cross_entropy(xt.reshape(-1, text_vocab_size), text_labels.reshape(-1), ignore_index=-100)
+            text_loss = F.cross_entropy(xt[:, :-1, :].reshape(-1, text_vocab_size), text_labels[:, 1:].reshape(-1), ignore_index=-100)
+            layer_loss[self.code_layer] = text_loss
+        else:
+            text_loss = 0
+
+        total_audio_loss = 0
+        single_audio_loss = 0
+        for i in range(self.code_layer):
+            if audio_labels[:,i] is not None:
+                # audio_loss += F.cross_entropy(xa[i].reshape(-1, audio_vocab_size), audio_labels[:,i].reshape(-1), ignore_index=-100)
+                single_audio_loss = F.cross_entropy(xa[i][:, :-1, :].reshape(-1, audio_vocab_size), audio_labels[:, i, 1:].reshape(-1), ignore_index=-100)
+                layer_loss[i] = single_audio_loss
+                total_audio_loss += single_audio_loss
+
+        total_loss = (text_loss + total_audio_loss) / (self.code_layer+1)
+        return total_loss, layer_loss
+
+
+    @torch.no_grad()
+    def generate(self,
+                input_ids: torch.LongTensor = None,
+                attention_mask: Optional[torch.Tensor] = None,
+                position_ids: Optional[torch.LongTensor] = None,
+                past_key_values: Optional[List[torch.FloatTensor]] = None,
+                inputs_embeds: Optional[torch.FloatTensor] = None,
+                labels: Optional[torch.LongTensor] = None,
+                use_cache: Optional[bool] = None,
+                output_attentions: Optional[bool] = None,
+                output_hidden_states: Optional[bool] = None,
+                return_dict: Optional[bool] = None,
+                **kwargs,
+                ):
+        kwargs["inference_mode"] = True
+
+        inputs_embeds, attention_mask = self.forward(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            labels=labels,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            **kwargs,
+        )
+
+        generated_ids = [[] for _ in range((self.code_layer+1))]
+        current_input_text = None
+        current_audio_tokens = [None for _ in range(self.code_layer)]
+        # input_pos = torch.arange(input_ids.size(-1), device=input_ids.device).unsqueeze(0)
+        past_key_values = None
+
+        text_vocab_size = self.model_config.vocab_config.padded_text_vocabsize
+        audio_vocab_size = self.model_config.vocab_config.padded_audio_vocabsize
+
+        max_new_tokens = kwargs.get("max_new_tokens", 360)
+        repetition_penalty = kwargs.get("repetition_penalty", 1.0)
+        decode_text_only = kwargs.get("decode_text_only", False)
+
+        pad_t = self.model_config.vocab_config.pad_t
+        pad_a = self.model_config.vocab_config.pad_a
+        eot = self.model_config.vocab_config.eot
+        eoa = self.model_config.vocab_config.eoa
+
+        text_end = False     # Track whether text generation has ended
+        audio_end = False    # Track whether audio generation has ended
+
+        # NOTE: currently, we only support greedy decoding and sampling for parallel generation, no beam search
+        for step in tqdm(range(max_new_tokens), desc="Generating"):
+            if current_input_text is not None:
+                audio_tokens = torch.cat([layershift(current_audio_tokens[i], i).unsqueeze(1) for i in range(self.code_layer)], dim=1)
+                combined_input_ids = torch.cat([audio_tokens, current_input_text.unsqueeze(1)], dim=1)
+                inputs_embeds = self.llm.model.embed_tokens(combined_input_ids)
+                inputs_embeds = torch.mean(inputs_embeds, dim=1).unsqueeze(1)
+            
+            outputs = self.llm(
+                inputs_embeds=inputs_embeds,                  # [btz, seq_len / 1, emb_dim]
+                attention_mask=attention_mask,                # single sample, no need for attention mask
+                past_key_values=past_key_values,
+                # position_ids=input_pos,
+                use_cache=True,
+            )
+            
+            logits = outputs.logits
+            past_key_values = outputs.past_key_values       # Update past_key_values for the next step
+
+            # Split logits into text and audio layers based on vocab size
+            xt_logits = logits[..., :text_vocab_size]
+            xa_logits = [logits[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)] for i in range(self.code_layer)]
+
+            # Apply repetition penalty to the logits
+            if repetition_penalty != 1.0:
+                xt_logits = self.repetition_penalty(xt_logits, generated_ids[self.code_layer], repetition_penalty)
+                for i in range(self.code_layer):
+                    xa_logits[i] = self.repetition_penalty(xa_logits[i], generated_ids[i], repetition_penalty)
+
+            if not text_end:
+                next_token_text = self.sample_next_token(xt_logits[:, -1, :], **kwargs)
+            else:
+                next_token_text = torch.tensor([pad_t], device=input_ids.device)
+
+            next_tokens_audio = []
+            for i in range(self.code_layer):
+                if not audio_end and not decode_text_only:
+                    next_token_audio = self.sample_next_token(xa_logits[i][:, -1, :], **kwargs)
+                else:
+                    next_token_audio = torch.full((input_ids.size(0),), pad_a, device=input_ids.device)
+                next_tokens_audio.append(next_token_audio)
+
+            if next_tokens_audio[-1] == eoa or decode_text_only:
+                audio_end = True
+            if next_token_text == eot:
+                text_end = True
+            
+            # Update input_ids for the next step
+            current_input_text = next_token_text
+            for i in range(self.code_layer):
+                current_audio_tokens[i] = next_tokens_audio[i]
+
+            # if input_pos.size(-1) > 1:
+            #     input_pos = torch.tensor(input_pos.size(-1), device=input_ids.device).unsqueeze(0)
+            # else:
+            #     input_pos = input_pos.add_(1)
+            attention_mask = torch.cat([attention_mask, torch.ones((input_ids.size(0), 1), device=input_ids.device)], dim=1)
+
+            if audio_end and text_end:
+                break
+
+            # Append generated tokens to the list
+            for i in range(self.code_layer):
+                generated_ids[i].append(next_tokens_audio[i].clone().tolist()[0])  # Audio layers
+            generated_ids[self.code_layer].append(next_token_text.clone().tolist()[0])  # Text layer
+
+        # Concatenate the generated tokens to form the complete sequence
+        text_tokens = generated_ids[-1]
+        generated_ids[-1] = text_tokens[: text_tokens.index(eot)] if eot in text_tokens else text_tokens
+        generated_ids = [torch.tensor(layer) for layer in generated_ids] 
+        return generated_ids
+
+
+    @torch.no_grad()
+    def sample_next_token(self, logits, **kwargs):
+        """
+        Generate the next token based on the model output logits.
+        Supports both greedy decoding, top-k sampling, and top-p (nucleus) sampling.
+        """
+        do_sample = kwargs.get("do_sample", False)
+        temperature = kwargs.get("temperature", 1.0)
+        top_k = kwargs.get("top_k", 50)
+        top_p = kwargs.get("top_p", 1.0)
+        num_samples = kwargs.get("num_samples", 1)
+
+        # Adjust logits with temperature
+        logits = logits.squeeze(0)
+        logits = logits / temperature
+
+        # Top-k filtering
+        if top_k > 0:
+            top_k = min(top_k, logits.size(-1))  # Make sure top_k is within the vocab size
+            values, indices = torch.topk(logits, top_k)
+            logits[logits < values[..., [-1]]] = -float('Inf')  # Filter tokens not in top_k
+
+        # Top-p filtering (nucleus sampling)
+        if top_p < 1.0:
+            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+            # Remove tokens with cumulative probability above the threshold
+            sorted_indices_to_remove = cumulative_probs > top_p
+            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+            sorted_indices_to_remove[..., 0] = 0
+
+            indices_to_remove = sorted_indices[sorted_indices_to_remove]
+            logits[indices_to_remove] = -float('Inf')
+
+        if do_sample:
+            # Perform sampling
+            return torch.multinomial(F.softmax(logits, dim=-1), num_samples=num_samples)
+        else:
+            # Greedy decoding (argmax)
+            return torch.argmax(logits, dim=-1, keepdim=True)
+
+
+    def repetition_penalty(self, logits, generated_ids, repetition_penalty):
+        """
+        Apply repetition penalty to the logits.
+        """
+        for token_id in set(generated_ids):
+            if logits[0, -1, token_id] < 0:
+                logits[0, -1, token_id] *= repetition_penalty
+            else:
+                logits[0, -1, token_id] /= repetition_penalty
+
+        return logits
\ No newline at end of file
diff --git a/s2s.py b/s2s.py
new file mode 100644
index 0000000000000000000000000000000000000000..23018d5b50b0d0b5e42376b7485dd0d0b5c61c84
--- /dev/null
+++ b/s2s.py
@@ -0,0 +1,178 @@
+import random
+import torch
+from slam_llm.utils.model_utils import get_custom_model_factory
+from utils.snac_utils import reconscruct_snac, reconstruct_tensors, layershift
+import whisper
+import numpy as np
+from s2s_config import InferenceConfig, CKPT_PATH, CKPT_REPO, CKPT_LOCAL_DIR, CKPT_NAME
+import os
+from omegaconf import OmegaConf
+from huggingface_hub import hf_hub_download
+from typing import Callable
+
+
+def update_progress(progress_callback: Callable[[str], None] | None, message: str):
+    if progress_callback:
+        progress_callback(message)
+
+
+def pull_model_ckpt():
+    if not os.path.exists(CKPT_LOCAL_DIR):
+        os.makedirs(CKPT_LOCAL_DIR)
+    if os.path.exists(CKPT_PATH):
+        return
+    hf_hub_download(
+        repo_id=CKPT_REPO,
+        filename=CKPT_NAME,
+        local_dir=CKPT_LOCAL_DIR,
+        token=os.getenv("HF_TOKEN"),
+    )
+
+
+pull_model_ckpt()
+
+
+def extract_audio_feature(audio_path, mel_size):
+    print("Extracting audio features from", audio_path)
+    audio_raw = whisper.load_audio(audio_path)
+    audio_raw = whisper.pad_or_trim(audio_raw)
+    audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size).permute(1, 0)
+    audio_length = (audio_mel.shape[0] + 1) // 2
+    audio_length = audio_length // 5
+    audio_res = audio_mel
+
+    return audio_res, audio_length
+
+
+def get_input_ids(length, special_token_a, special_token_t, vocab_config):
+    input_ids = []
+    for i in range(vocab_config.code_layer):
+        input_ids_item = []
+        input_ids_item.append(layershift(vocab_config.input_a, i))
+        input_ids_item += [layershift(vocab_config.pad_a, i)] * length
+        input_ids_item += [
+            (layershift(vocab_config.eoa, i)),
+            layershift(special_token_a, i),
+        ]
+        input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
+    input_id_T = torch.tensor(
+        [vocab_config.input_t]
+        + [vocab_config.pad_t] * length
+        + [vocab_config.eot, special_token_t]
+    )
+    input_ids.append(input_id_T.unsqueeze(0))
+    return input_ids
+
+
+def generate_from_wav(
+    wav_path, model, codec_decoder, dataset_config, decode_config, device
+):
+    mel_size = dataset_config.mel_size
+    prompt = dataset_config.prompt
+    prompt_template = "USER: {}\n ASSISTANT: "
+    vocab_config = dataset_config.vocab_config
+    special_token_a = vocab_config.answer_a
+    special_token_t = vocab_config.answer_t
+    code_layer = vocab_config.code_layer
+    task_type = dataset_config.task_type
+
+    audio_mel, audio_length = extract_audio_feature(wav_path, mel_size)
+
+    prompt = prompt_template.format(prompt)
+    prompt_ids = model.tokenizer.encode(prompt)
+    prompt_length = len(prompt_ids)
+    prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
+
+    example_ids = get_input_ids(
+        audio_length + prompt_length, special_token_a, special_token_t, vocab_config
+    )
+    text_layer = example_ids[code_layer]
+    text_layer = torch.cat(
+        (
+            text_layer[:, : audio_length + 1],
+            prompt_ids.unsqueeze(0),
+            text_layer[:, -2:],
+        ),
+        dim=1,
+    )  # <bos> <audio> <prompt> <eos> <task>
+    example_ids[code_layer] = text_layer
+
+    input_length = audio_length
+    example_mask = example_ids[0][0].ge(-1)
+    example_ids = torch.stack(example_ids).squeeze()
+
+    input_ids = example_ids.unsqueeze(0).to(device)
+    attention_mask = example_mask.unsqueeze(0).to(device)
+    audio_mel = audio_mel.unsqueeze(0).to(device)
+    input_length = torch.tensor([input_length]).to(device)
+    audio_length = torch.tensor([audio_length]).to(device)
+    task_type = [task_type]
+
+    modality_mask = torch.zeros_like(attention_mask)
+    padding_left = 1  # +1 for <bos>
+    modality_mask[0, padding_left : padding_left + audio_length] = True
+
+    batch = {
+        "input_ids": input_ids,
+        "attention_mask": attention_mask,
+        "audio_mel": audio_mel,
+        "input_length": input_length,
+        "audio_length": audio_length,
+        "modality_mask": modality_mask,
+        "task_types": task_type,
+    }
+
+    model_outputs = model.generate(**batch, **decode_config)
+    text_outputs = model_outputs[7]
+    audio_outputs = model_outputs[:7]
+    output_text = model.tokenizer.decode(
+        text_outputs, add_special_tokens=False, skip_special_tokens=True
+    )
+
+    if decode_config.decode_text_only:
+        return None, output_text
+
+    audio_tokens = [audio_outputs[layer] for layer in range(7)]
+    audiolist = reconscruct_snac(audio_tokens)
+    audio = reconstruct_tensors(audiolist)
+    with torch.inference_mode():
+        audio_hat = codec_decoder.decode(audio)
+
+    return audio_hat, output_text
+
+
+def generate(
+    wav_path: str, progress_callback: Callable[[str], None] | None = None
+) -> tuple[np.ndarray, int | float]:
+    config = OmegaConf.structured(InferenceConfig())
+    train_config, model_config, dataset_config, decode_config = (
+        config.train_config,
+        config.model_config,
+        config.dataset_config,
+        config.decode_config,
+    )
+
+    torch.cuda.manual_seed(train_config.seed)
+    torch.manual_seed(train_config.seed)
+    random.seed(train_config.seed)
+
+    update_progress(progress_callback, "Loading model")
+
+    model_factory = get_custom_model_factory(model_config)
+    model, _ = model_factory(train_config, model_config, CKPT_PATH)
+    codec_decoder = model.codec_decoder
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model.to(device)
+    model.eval()
+
+    update_progress(progress_callback, "Generating")
+    output_wav, output_text = generate_from_wav(
+        wav_path, model, codec_decoder, dataset_config, decode_config, device
+    )
+
+    return output_wav.squeeze().cpu().numpy(), 24000
+
+
+if __name__ == "__main__":
+    wav_path = "sample.wav"
+    generate(wav_path)
diff --git a/s2s_config.py b/s2s_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea029518de41c02cdde5bf91f96099c529e224fd
--- /dev/null
+++ b/s2s_config.py
@@ -0,0 +1,272 @@
+from dataclasses import dataclass, field
+from typing import Optional, List
+import os
+
+CKPT_NAME = "model.pt"
+CKPT_LOCAL_DIR = "model_ckpts"
+CKPT_PATH = os.path.join(CKPT_LOCAL_DIR, CKPT_NAME)
+CKPT_REPO = "xcczach/mini-omni"
+
+
+@dataclass
+class VocabConfig:
+    text_vocabsize: int = 151936
+    text_specialtokens: int = 64
+    audio_vocabsize: int = 4096
+    audio_specialtokens: int = 64
+    total_vocabsize: int = 181120
+    code_layer: int = 7
+
+    padded_text_vocabsize: int = field(init=False)
+    padded_audio_vocabsize: int = field(init=False)
+    total_audio_vocabsize: int = field(init=False)
+
+    eot: int = field(init=False)  # end of text token
+    pad_t: int = field(init=False)  # padding text token
+    input_t: int = field(init=False)  # input text token
+    answer_t: int = field(init=False)  # answer text token
+    asr: int = field(init=False)  # ASR token
+
+    eoa: int = field(init=False)  # end of audio token
+    pad_a: int = field(init=False)  # padding audio token
+    input_a: int = field(init=False)  # input audio token
+    answer_a: int = field(init=False)  # answer audio token
+    split: int = field(init=False)  # split token
+
+    def __post_init__(self):
+        self.padded_text_vocabsize = self.text_vocabsize + self.text_specialtokens
+        self.padded_audio_vocabsize = self.audio_vocabsize + self.audio_specialtokens
+        self.total_audio_vocabsize = self.padded_audio_vocabsize * self.code_layer
+
+        self.eot = self.text_vocabsize
+        self.pad_t = self.text_vocabsize + 1
+        self.input_t = self.text_vocabsize + 2
+        self.answer_t = self.text_vocabsize + 3
+        self.asr = self.text_vocabsize + 4
+
+        self.eoa = self.audio_vocabsize
+        self.pad_a = self.audio_vocabsize + 1
+        self.input_a = self.audio_vocabsize + 2
+        self.answer_a = self.audio_vocabsize + 3
+        self.split = self.audio_vocabsize + 4
+
+
+@dataclass
+class TTSAdapterConfig:
+    add_qkv_bias: Optional[bool] = True
+    bias: bool = False
+    gelu_approximate: Optional[str] = None
+    head_size: Optional[int] = 64
+    intermediate_size: Optional[int] = 4864
+    lm_head_bias: bool = False
+    mlp_class_name: str = "GptNeoxMLP"
+    n_layer: int = 6
+    n_head: int = 14
+    n_embd: int = 896
+    n_query_groups: Optional[int] = 2
+    norm_class_name: str = "RMSNorm"
+    norm_eps: float = 1e-6
+    parallel_residual: bool = False
+    rotary_percentage: float = 1
+    shared_attention_norm: bool = False
+
+    def __post_init__(self):
+        self.rope_n_elem = int(self.rotary_percentage * self.head_size)
+
+
+@dataclass
+class ModelConfig:
+    file: str = "model/slam_model_s2s.py:model_factory"
+    llm_name: str = "qwen2-0.5b"
+    llm_path: str = "Qwen/Qwen2-0.5B"
+    llm_type: str = "decoder_only"
+    llm_dim: int = 896
+    encoder_name: Optional[str] = "whisper"
+    encoder_ds_rate: int = 2
+    encoder_path: Optional[str] = "small"
+    encoder_dim: int = 768
+    encoder_projector: str = "linear"
+    encoder_projector_ds_rate: int = 5
+    modal: str = "audio"
+    normalize: Optional[bool] = field(
+        default=False,
+        metadata={"help": "whether input is normalized, used for models such as wavlm"},
+    )
+    encoder_type: str = field(
+        default="finetune",
+        metadata={
+            "help": "whether model is only pretrained or finetuned, used for models such as hubert"
+        },
+    )
+    vocab_config: VocabConfig = field(default_factory=VocabConfig)
+    codec_decode: bool = True
+    codec_decoder_type: str = "SNAC"
+    codec_decoder_path: Optional[str] = "hubertsiuzdak/snac_24khz"
+    tts_adapter: bool = False
+    tts_adapter_config: TTSAdapterConfig = field(default_factory=TTSAdapterConfig)
+
+
+@dataclass
+class PeftConfig:
+    peft_method: str = "lora"  # None , llama_adapter, prefix
+    r: int = 8
+    lora_alpha: int = 32
+    target_modules: List = field(default_factory=lambda: ["q_proj", "v_proj"])
+    bias: str = "none"
+    task_type: str = "CAUSAL_LM"
+    lora_dropout: float = 0.05
+    inference_mode: bool = False
+
+
+@dataclass
+class TrainConfig:
+    model_name: str = "s2s"
+    enable_ddp: bool = False
+    enable_deepspeed: bool = False
+    enable_fsdp: bool = False
+    low_cpu_fsdp: bool = False
+    run_validation: bool = True
+    batch_size_training: int = 4
+    batching_strategy: str = field(
+        default="custom", metadata={"help": "alternative: padding"}
+    )  #
+    context_length: int = 4096
+    gradient_accumulation_steps: int = 1
+    num_epochs: int = 1
+    num_workers_dataloader: int = 2
+    warmup_steps: int = 1000
+    total_steps: int = 100000
+    validation_interval: int = 1000
+    lr: float = 1e-4
+    weight_decay: float = 0.0
+    gamma: float = 0.85
+    seed: int = 42
+    use_fp16: bool = False
+    mixed_precision: bool = True
+    val_batch_size: int = 1
+
+    use_peft: bool = False
+    peft_config: PeftConfig = field(default_factory=PeftConfig)
+    output_dir: str = "PATH/to/save/PEFT/model"
+    freeze_layers: bool = False
+    num_freeze_layers: int = 1
+    quantization: bool = False
+    one_gpu: bool = False
+    save_model: bool = True
+    dist_checkpoint_root_folder: str = (
+        "PATH/to/save/FSDP/model"  # will be used if using FSDP
+    )
+    dist_checkpoint_folder: str = "fine-tuned"  # will be used if using FSDP
+    save_optimizer: bool = False  # will be used if using FSDP
+    use_fast_kernels: bool = (
+        False  # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    )
+    run_test_during_validation: bool = False
+    run_test_during_validation_file: str = "test.wav"
+    run_test_during_validation_prompt: str = "<|S2S|>"
+    freeze_llm: bool = field(
+        default=True,
+        metadata={
+            "help": "whether to freeze llm when finetuning, should be true when use peft finetuning"
+        },
+    )
+    freeze_encoder: bool = True
+    train_embed_only: bool = False
+    train_audio_embed_only: bool = False
+    task_type: str = "s2s"
+
+
+@dataclass
+class DataConfig:
+    dataset: str = "speech_dataset_s2s"
+    file: str = "examples/s2s/speech_dataset_s2s.py:get_speech_dataset"
+    train_data_path: Optional[str] = None
+    val_data_path: Optional[str] = None
+    train_split: str = "train"
+    test_split: str = "validation"
+    prompt: Optional[str] = None
+    data_path: Optional[str] = None
+    max_words: Optional[int] = None
+    max_mel: Optional[float] = None
+    fix_length_audio: int = -1
+    inference_mode: bool = True
+    input_type: str = field(
+        default="mel",
+        metadata={"help": "Use raw when input is wav, mel when for whisper"},
+    )
+    mel_size: int = field(
+        default=80, metadata={"help": "80 for whisper large v1 and v2, 128 for v3"}
+    )
+    normalize: Optional[bool] = field(
+        default=False,
+        metadata={"help": "whether input is normalized, used for models such as wavlm"},
+    )
+    seed: int = 42
+    manifest_format: str = field(
+        default="datasets", metadata={"help": "alternative: jsonl"}
+    )
+    split_size: float = 0.1
+
+    vocab_config: VocabConfig = field(default_factory=VocabConfig)
+    load_from_cache_file: bool = False
+    task_type: str = "s2s"
+
+
+@dataclass
+class DecodeConfig:
+    do_sample: bool = False
+    max_new_tokens: int = 300
+    min_length: int = 10
+    temperature: float = 1.0
+    top_k: int = 50
+    top_p: float = 0.9
+    num_beams: int = 1
+    num_return_sequences: int = 1
+    num_samples: int = 1
+    max_time: float = 0.0
+    repetition_penalty: float = 1.0
+    length_penalty: float = 1.0
+    early_stopping: bool = False
+    no_repeat_ngram_size: int = 0
+    bad_words_ids: List = field(default_factory=list)
+    num_beam_groups: int = 1
+    diversity_penalty: float = 0.0
+    task_type: str = "s2s"
+    decode_text_only: bool = False
+
+
+@dataclass
+class FSDPConfig:
+    mixed_precision: bool = True
+    use_fp16: bool = False
+    # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD
+    sharding_strategy: str = (
+        "NO_SHARD"  # ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP
+    )
+    checkpoint_type: str = (
+        "SHARDED_STATE_DICT"  # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
+    )
+    fsdp_activation_checkpointing: bool = True
+    fsdp_cpu_offload: bool = False
+    pure_bf16: bool = False
+    optimizer: str = "AdamW"
+
+
+@dataclass
+class LogConfig:
+    use_wandb: bool = False
+    wandb_dir: str = "/valleblob/v-wenxichen/exp/wandb_log"
+    wandb_entity_name: str = "project_name"
+    wandb_project_name: str = "project_name"
+    wandb_exp_name: str = "exp_name"
+    log_file: str = "/valleblob/v-wenxichen/exp/log/test.log"
+    log_interval: int = 10
+    online_output_dir: Optional[str] = None
+
+
+@dataclass
+class InferenceConfig:
+    dataset_config: DataConfig = field(default_factory=DataConfig)
+    model_config: ModelConfig = field(default_factory=ModelConfig)
+    train_config: TrainConfig = field(default_factory=TrainConfig)
+    decode_config: DecodeConfig = field(default_factory=DecodeConfig)
diff --git a/slam_llm/__init__.py b/slam_llm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/slam_llm/data/__init__.py b/slam_llm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..47bc944e530156d72412524d05a2804983a80c8d
--- /dev/null
+++ b/slam_llm/data/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
\ No newline at end of file
diff --git a/slam_llm/data/concatenator.py b/slam_llm/data/concatenator.py
new file mode 100644
index 0000000000000000000000000000000000000000..29f49052ff84a960c367ebd16ea9b53db7fbbf2d
--- /dev/null
+++ b/slam_llm/data/concatenator.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from tqdm import tqdm
+from itertools import chain
+
+from torch.utils.data import Dataset
+
+
+class ConcatDataset(Dataset):
+    def __init__(self, dataset, chunk_size=4096):
+        self.dataset = dataset
+        self.chunk_size = chunk_size
+
+        self.samples = []
+
+        buffer = {
+            "input_ids": [],
+            "attention_mask": [],
+            "labels": [],
+            }
+
+        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
+            buffer = {k: v + sample[k] for k,v in buffer.items()}
+
+            while len(next(iter(buffer.values()))) > self.chunk_size:
+                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
+                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
+
+    def __getitem__(self, idx):
+        return self.samples[idx]
+
+    def __len__(self):
+        return len(self.samples)
diff --git a/slam_llm/data/sampler.py b/slam_llm/data/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c91af5b946b2216e2a3003fe37a1f32d072fa52
--- /dev/null
+++ b/slam_llm/data/sampler.py
@@ -0,0 +1,57 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+from itertools import islice
+
+import numpy as np
+import torch
+
+
+class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
+        if isinstance(next(iter(data_source)), dict):
+            first_key = next(iter(next(iter(data_source)).keys()))
+            self.lengths = [len(d[first_key]) for d in data_source]
+        else:
+            self.lengths = [len(d) for d in data_source]
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.shuffle = shuffle
+
+    def __iter__(self):
+        ids = np.argsort(self.lengths)
+        if self.drop_last:
+            ids = ids[:len(ids) // self.batch_size * self.batch_size]
+
+        batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
+
+        if self.shuffle:
+            random.shuffle(batches)
+
+        for b in batches:
+            yield b
+
+    def __len__(self):
+        if self.drop_last:
+            return len(self.lengths) // self.batch_size
+        else:
+            return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
+
+
+class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
+        random.seed(seed)
+        self.batch_sampler = LengthBasedBatchSampler(
+            data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
+            )
+        self.num_replicas = num_replicas
+        self.rank = rank
+        
+    def __iter__(self):
+        max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
+        return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
+         
+    def __len__(self):
+        return len(self.batch_sampler) // self.num_replicas
+            
\ No newline at end of file
diff --git a/slam_llm/models/BEATs/BEATs.py b/slam_llm/models/BEATs/BEATs.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1a2bbcb336bc89e2d8a0298d8165e41c87a1790
--- /dev/null
+++ b/slam_llm/models/BEATs/BEATs.py
@@ -0,0 +1,181 @@
+# --------------------------------------------------------
+# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
+# Github source: https://github.com/microsoft/unilm/tree/master/beats
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+
+import torch
+import torch.nn as nn
+from torch.nn import LayerNorm
+import torchaudio.compliance.kaldi as ta_kaldi
+
+from .backbone import (
+    TransformerEncoder,
+)
+
+import logging
+from typing import Optional
+
+logger = logging.getLogger(__name__)
+
+
+class BEATsConfig:
+    def __init__(self, cfg=None):
+        self.input_patch_size: int = -1  # path size of patch embedding
+        self.embed_dim: int = 512  # patch embedding dimension
+        self.conv_bias: bool = False  # include bias in conv encoder
+
+        self.encoder_layers: int = 12  # num encoder layers in the transformer
+        self.encoder_embed_dim: int = 768  # encoder embedding dimension
+        self.encoder_ffn_embed_dim: int = 3072  # encoder embedding dimension for FFN
+        self.encoder_attention_heads: int = 12  # num encoder attention heads
+        self.activation_fn: str = "gelu"  # activation function to use
+
+        self.layer_wise_gradient_decay_ratio: float = 1.0  # ratio for layer-wise gradient decay
+        self.layer_norm_first: bool = False  # apply layernorm first in the transformer
+        self.deep_norm: bool = False  # apply deep_norm first in the transformer
+
+        # dropouts
+        self.dropout: float = 0.1  # dropout probability for the transformer
+        self.attention_dropout: float = 0.1  # dropout probability for attention weights
+        self.activation_dropout: float = 0.0  # dropout probability after activation in FFN
+        self.encoder_layerdrop: float = 0.0  # probability of dropping a tarnsformer layer
+        self.dropout_input: float = 0.0  # dropout to apply to the input (after feat extr)
+
+        # positional embeddings
+        self.conv_pos: int = 128  # number of filters for convolutional positional embeddings
+        self.conv_pos_groups: int = 16  # number of groups for convolutional positional embedding
+
+        # relative position embedding
+        self.relative_position_embedding: bool = False  # apply relative position embedding
+        self.num_buckets: int = 320  # number of buckets for relative position embedding
+        self.max_distance: int = 1280  # maximum distance for relative position embedding
+        self.gru_rel_pos: bool = False  # apply gated relative position embedding
+
+        # label predictor
+        self.finetuned_model: bool = False  # whether the model is a fine-tuned model.
+        self.predictor_dropout: float = 0.1  # dropout probability for the predictor
+        self.predictor_class: int = 527  # target class number for the predictor
+
+        if cfg is not None:
+            self.update(cfg)
+
+    def update(self, cfg: dict):
+        self.__dict__.update(cfg)
+
+
+class BEATs(nn.Module):
+    def __init__(
+            self,
+            cfg: BEATsConfig,
+    ) -> None:
+        super().__init__()
+        logger.info(f"BEATs Config: {cfg.__dict__}")
+
+        self.cfg = cfg
+
+        self.embed = cfg.embed_dim
+        self.post_extract_proj = (
+            nn.Linear(self.embed, cfg.encoder_embed_dim)
+            if self.embed != cfg.encoder_embed_dim
+            else None
+        )
+
+        self.input_patch_size = cfg.input_patch_size
+        self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
+                                         bias=cfg.conv_bias)
+
+        self.dropout_input = nn.Dropout(cfg.dropout_input)
+
+        assert not cfg.deep_norm or not cfg.layer_norm_first
+        self.encoder = TransformerEncoder(cfg)
+        self.layer_norm = LayerNorm(self.embed)
+
+        if cfg.finetuned_model:
+            self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
+            self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
+        else:
+            self.predictor = None
+
+    def forward_padding_mask(
+            self,
+            features: torch.Tensor,
+            padding_mask: torch.Tensor,
+    ) -> torch.Tensor:
+        extra = padding_mask.size(1) % features.size(1)
+        if extra > 0:
+            padding_mask = padding_mask[:, :-extra]
+        padding_mask = padding_mask.view(
+            padding_mask.size(0), features.size(1), -1
+        )
+        padding_mask = padding_mask.all(-1)
+        return padding_mask
+
+    @classmethod
+    def preprocess(
+            cls,
+            source: torch.Tensor,
+            fbank_mean: float = 15.41663,
+            fbank_std: float = 6.55582,
+    ) -> torch.Tensor:
+        if len(source.shape) > 1: # batch
+            fbanks = []
+            for waveform in source:
+                waveform = waveform.unsqueeze(0) * 2 ** 15
+                fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
+                fbanks.append(fbank)
+            fbank = torch.stack(fbanks, dim=0)
+        else: # single
+            waveform = source.unsqueeze(0) * 2 ** 15
+            fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
+        
+        fbank = (fbank - fbank_mean) / (2 * fbank_std)
+        return fbank
+
+    def extract_features(
+            self,
+            fbank: torch.Tensor,
+            padding_mask: Optional[torch.Tensor] = None,
+    ):
+        if padding_mask is not None:
+            padding_mask = self.forward_padding_mask(fbank, padding_mask)
+
+        fbank = fbank.unsqueeze(1)
+        features = self.patch_embedding(fbank)
+        features = features.reshape(features.shape[0], features.shape[1], -1)
+        features = features.transpose(1, 2)
+        features = self.layer_norm(features)
+
+        if padding_mask is not None:
+            padding_mask = self.forward_padding_mask(features, padding_mask)
+
+        if self.post_extract_proj is not None:
+            features = self.post_extract_proj(features)
+
+        x = self.dropout_input(features)
+
+        x, layer_results = self.encoder(
+            x,
+            padding_mask=padding_mask,
+        )
+
+        if self.predictor is not None:
+            x = self.predictor_dropout(x)
+            logits = self.predictor(x)
+
+            if padding_mask is not None and padding_mask.any():
+                logits[padding_mask] = 0
+                logits = logits.sum(dim=1)
+                logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits)
+            else:
+                logits = logits.mean(dim=1)
+
+            lprobs = torch.sigmoid(logits)
+
+            return lprobs, padding_mask
+        else:
+            return x, padding_mask
diff --git a/slam_llm/models/BEATs/Tokenizers.py b/slam_llm/models/BEATs/Tokenizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..33e0f6ae711360bfa49777fd3c427eb73f2b6762
--- /dev/null
+++ b/slam_llm/models/BEATs/Tokenizers.py
@@ -0,0 +1,173 @@
+# --------------------------------------------------------
+# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
+# Github source: https://github.com/microsoft/unilm/tree/master/beats
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+
+import torch
+import torch.nn as nn
+from torch.nn import LayerNorm
+import torchaudio.compliance.kaldi as ta_kaldi
+
+from .backbone import (
+    TransformerEncoder,
+)
+from .quantizer import (
+    NormEMAVectorQuantizer,
+)
+
+import logging
+from typing import Optional
+
+logger = logging.getLogger(__name__)
+
+
+class TokenizersConfig:
+    def __init__(self, cfg=None):
+        self.input_patch_size: int = -1  # path size of patch embedding
+        self.embed_dim: int = 512  # patch embedding dimension
+        self.conv_bias: bool = False  # include bias in conv encoder
+
+        self.encoder_layers: int = 12  # num encoder layers in the transformer
+        self.encoder_embed_dim: int = 768  # encoder embedding dimension
+        self.encoder_ffn_embed_dim: int = 3072  # encoder embedding dimension for FFN
+        self.encoder_attention_heads: int = 12  # num encoder attention heads
+        self.activation_fn: str = "gelu"  # activation function to use
+
+        self.layer_norm_first: bool = False  # apply layernorm first in the transformer
+        self.deep_norm: bool = False  # apply deep_norm first in the transformer
+
+        # dropouts
+        self.dropout: float = 0.1  # dropout probability for the transformer
+        self.attention_dropout: float = 0.1  # dropout probability for attention weights
+        self.activation_dropout: float = 0.0  # dropout probability after activation in FFN
+        self.encoder_layerdrop: float = 0.0  # probability of dropping a tarnsformer layer
+        self.dropout_input: float = 0.0  # dropout to apply to the input (after feat extr)
+
+        # positional embeddings
+        self.conv_pos: int = 128  # number of filters for convolutional positional embeddings
+        self.conv_pos_groups: int = 16  # number of groups for convolutional positional embedding
+
+        # relative position embedding
+        self.relative_position_embedding: bool = False  # apply relative position embedding
+        self.num_buckets: int = 320  # number of buckets for relative position embedding
+        self.max_distance: int = 1280  # maximum distance for relative position embedding
+        self.gru_rel_pos: bool = False  # apply gated relative position embedding
+
+        # quantizer
+        self.quant_n: int = 1024 # codebook number in quantizer
+        self.quant_dim: int = 256    # codebook dimension in quantizer
+
+        if cfg is not None:
+            self.update(cfg)
+
+    def update(self, cfg: dict):
+        self.__dict__.update(cfg)
+
+
+class Tokenizers(nn.Module):
+    def __init__(
+            self,
+            cfg: TokenizersConfig,
+    ) -> None:
+        super().__init__()
+        logger.info(f"Tokenizers Config: {cfg.__dict__}")
+
+        self.cfg = cfg
+
+        self.embed = cfg.embed_dim
+        self.post_extract_proj = (
+            nn.Linear(self.embed, cfg.encoder_embed_dim)
+            if self.embed != cfg.encoder_embed_dim
+            else None
+        )
+
+        self.input_patch_size = cfg.input_patch_size
+        self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
+                                         bias=cfg.conv_bias)
+
+        self.dropout_input = nn.Dropout(cfg.dropout_input)
+
+        assert not cfg.deep_norm or not cfg.layer_norm_first
+        self.encoder = TransformerEncoder(cfg)
+        self.layer_norm = LayerNorm(self.embed)
+
+        self.quantize = NormEMAVectorQuantizer(
+            n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
+        )
+        self.quant_n = cfg.quant_n
+        self.quantize_layer = nn.Sequential(
+            nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
+            nn.Tanh(),
+            nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim)  # for quantize
+        )
+
+    def forward_padding_mask(
+            self,
+            features: torch.Tensor,
+            padding_mask: torch.Tensor,
+    ) -> torch.Tensor:
+        extra = padding_mask.size(1) % features.size(1)
+        if extra > 0:
+            padding_mask = padding_mask[:, :-extra]
+        padding_mask = padding_mask.view(
+            padding_mask.size(0), features.size(1), -1
+        )
+        padding_mask = padding_mask.all(-1)
+        return padding_mask
+
+    def preprocess(
+            self,
+            source: torch.Tensor,
+            fbank_mean: float = 15.41663,
+            fbank_std: float = 6.55582,
+    ) -> torch.Tensor:
+        fbanks = []
+        for waveform in source:
+            waveform = waveform.unsqueeze(0) * 2 ** 15
+            fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
+            fbanks.append(fbank)
+        fbank = torch.stack(fbanks, dim=0)
+        fbank = (fbank - fbank_mean) / (2 * fbank_std)
+        return fbank
+
+    def extract_labels(
+            self,
+            source: torch.Tensor,
+            padding_mask: Optional[torch.Tensor] = None,
+            fbank_mean: float = 15.41663,
+            fbank_std: float = 6.55582,
+    ):
+        fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
+
+        if padding_mask is not None:
+            padding_mask = self.forward_padding_mask(fbank, padding_mask)
+
+        fbank = fbank.unsqueeze(1)
+        features = self.patch_embedding(fbank)
+        features = features.reshape(features.shape[0], features.shape[1], -1)
+        features = features.transpose(1, 2)
+        features = self.layer_norm(features)
+
+        if padding_mask is not None:
+            padding_mask = self.forward_padding_mask(features, padding_mask)
+
+        if self.post_extract_proj is not None:
+            features = self.post_extract_proj(features)
+
+        x = self.dropout_input(features)
+
+        x, layer_results = self.encoder(
+            x,
+            padding_mask=padding_mask,
+        )
+
+        quantize_input = self.quantize_layer(x)
+        quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
+
+        return embed_ind
+
diff --git a/slam_llm/models/BEATs/backbone.py b/slam_llm/models/BEATs/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..97caa0673f677f10a8ea1bd756bca6e208c439b7
--- /dev/null
+++ b/slam_llm/models/BEATs/backbone.py
@@ -0,0 +1,783 @@
+# --------------------------------------------------------
+# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
+# Github source: https://github.com/microsoft/unilm/tree/master/beats
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+import numpy as np
+from typing import Dict, Optional, Tuple
+import torch
+from torch import Tensor, nn
+import torch.nn.functional as F
+from torch.nn import LayerNorm, Parameter
+from .modules import (
+    GradMultiply,
+    SamePad,
+    get_activation_fn,
+    GLU_Linear,
+    quant_noise,
+)
+
+
+class TransformerEncoder(nn.Module):
+    def __init__(self, args):
+        super().__init__()
+
+        self.dropout = args.dropout
+        self.embedding_dim = args.encoder_embed_dim
+
+        self.pos_conv = nn.Conv1d(
+            self.embedding_dim,
+            self.embedding_dim,
+            kernel_size=args.conv_pos,
+            padding=args.conv_pos // 2,
+            groups=args.conv_pos_groups,
+        )
+        dropout = 0
+        std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
+        nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
+        nn.init.constant_(self.pos_conv.bias, 0)
+
+        self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
+        self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
+
+        if hasattr(args, "relative_position_embedding"):
+            self.relative_position_embedding = args.relative_position_embedding
+            self.num_buckets = args.num_buckets
+            self.max_distance = args.max_distance
+        else:
+            self.relative_position_embedding = False
+            self.num_buckets = 0
+            self.max_distance = 0
+
+        self.layers = nn.ModuleList(
+            [
+                TransformerSentenceEncoderLayer(
+                    embedding_dim=self.embedding_dim,
+                    ffn_embedding_dim=args.encoder_ffn_embed_dim,
+                    num_attention_heads=args.encoder_attention_heads,
+                    dropout=self.dropout,
+                    attention_dropout=args.attention_dropout,
+                    activation_dropout=args.activation_dropout,
+                    activation_fn=args.activation_fn,
+                    layer_norm_first=args.layer_norm_first,
+                    deep_norm=args.deep_norm,
+                    has_relative_attention_bias=self.relative_position_embedding,
+                    num_buckets=self.num_buckets,
+                    max_distance=self.max_distance,
+                    gru_rel_pos=args.gru_rel_pos,
+                    encoder_layers=args.encoder_layers,
+                )
+                for i in range(args.encoder_layers)
+            ]
+        )
+        if self.relative_position_embedding:
+            for i in range(1, args.encoder_layers):
+                del self.layers[i].self_attn.relative_attention_bias
+                self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias
+
+        self.layer_norm_first = args.layer_norm_first
+        self.layer_norm = LayerNorm(self.embedding_dim)
+        self.layerdrop = args.encoder_layerdrop
+
+        self.apply(init_bert_params)
+
+        if args.deep_norm:
+            deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
+            for i in range(args.encoder_layers):
+                nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
+                nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta)
+                nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
+                nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta)
+                nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
+                nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
+
+        self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1)
+
+    def forward(self, x, padding_mask=None, layer=None):
+        x, layer_results = self.extract_features(x, padding_mask, layer)
+
+        if self.layer_norm_first and layer is None:
+            x = self.layer_norm(x)
+
+        return x, layer_results
+
+    def extract_features(self, x, padding_mask=None, tgt_layer=None):
+
+        if padding_mask is not None:
+            x[padding_mask] = 0
+
+        x_conv = self.pos_conv(x.transpose(1, 2))
+        x_conv = x_conv.transpose(1, 2)
+        x = x + x_conv
+
+        if not self.layer_norm_first:
+            x = self.layer_norm(x)
+
+        x = F.dropout(x, p=self.dropout, training=self.training)
+
+        # B x T x C -> T x B x C
+        x = x.transpose(0, 1)
+
+        layer_results = []
+        z = None
+        if tgt_layer is not None:
+            layer_results.append((x, z))
+        r = None
+        pos_bias = None
+        for i, layer in enumerate(self.layers):
+            if self.layer_wise_gradient_decay_ratio != 1.0:
+                x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
+            dropout_probability = np.random.random()
+            if not self.training or (dropout_probability > self.layerdrop):
+                x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias)
+            if tgt_layer is not None:
+                layer_results.append((x, z))
+            if i == tgt_layer:
+                r = x
+                break
+
+        if r is not None:
+            x = r
+
+        # T x B x C -> B x T x C
+        x = x.transpose(0, 1)
+
+        return x, layer_results
+
+
+class TransformerSentenceEncoderLayer(nn.Module):
+    def __init__(
+            self,
+            embedding_dim: float = 768,
+            ffn_embedding_dim: float = 3072,
+            num_attention_heads: float = 8,
+            dropout: float = 0.1,
+            attention_dropout: float = 0.1,
+            activation_dropout: float = 0.1,
+            activation_fn: str = "relu",
+            layer_norm_first: bool = False,
+            deep_norm: bool = False,
+            has_relative_attention_bias: bool = False,
+            num_buckets: int = 0,
+            max_distance: int = 0,
+            rescale_init: bool = False,
+            gru_rel_pos: bool = False,
+            encoder_layers: int = 0,
+    ) -> None:
+
+        super().__init__()
+        self.embedding_dim = embedding_dim
+        self.dropout = dropout
+        self.activation_dropout = activation_dropout
+
+        self.activation_name = activation_fn
+        self.activation_fn = get_activation_fn(activation_fn)
+        self.self_attn = MultiheadAttention(
+            self.embedding_dim,
+            num_attention_heads,
+            dropout=attention_dropout,
+            self_attention=True,
+            has_relative_attention_bias=has_relative_attention_bias,
+            num_buckets=num_buckets,
+            max_distance=max_distance,
+            rescale_init=rescale_init,
+            gru_rel_pos=gru_rel_pos,
+        )
+
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(self.activation_dropout)
+        self.dropout3 = nn.Dropout(dropout)
+
+        self.layer_norm_first = layer_norm_first
+
+        self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
+
+        if self.activation_name == "glu":
+            self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
+        else:
+            self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
+        self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
+
+        self.final_layer_norm = LayerNorm(self.embedding_dim)
+
+        self.deep_norm = deep_norm
+        if self.deep_norm:
+            self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
+        else:
+            self.deep_norm_alpha = 1
+
+    def forward(
+            self,
+            x: torch.Tensor,
+            self_attn_mask: torch.Tensor = None,
+            self_attn_padding_mask: torch.Tensor = None,
+            need_weights: bool = False,
+            pos_bias=None
+    ):
+        residual = x
+
+        if self.layer_norm_first:
+            x = self.self_attn_layer_norm(x)
+            x, attn, pos_bias = self.self_attn(
+                query=x,
+                key=x,
+                value=x,
+                key_padding_mask=self_attn_padding_mask,
+                need_weights=False,
+                attn_mask=self_attn_mask,
+                position_bias=pos_bias
+            )
+            x = self.dropout1(x)
+            x = residual + x
+
+            residual = x
+            x = self.final_layer_norm(x)
+            if self.activation_name == "glu":
+                x = self.fc1(x)
+            else:
+                x = self.activation_fn(self.fc1(x))
+            x = self.dropout2(x)
+            x = self.fc2(x)
+            x = self.dropout3(x)
+            x = residual + x
+        else:
+            x, attn, pos_bias = self.self_attn(
+                query=x,
+                key=x,
+                value=x,
+                key_padding_mask=self_attn_padding_mask,
+                need_weights=need_weights,
+                attn_mask=self_attn_mask,
+                position_bias=pos_bias
+            )
+
+            x = self.dropout1(x)
+            x = residual * self.deep_norm_alpha + x
+
+            x = self.self_attn_layer_norm(x)
+
+            residual = x
+            if self.activation_name == "glu":
+                x = self.fc1(x)
+            else:
+                x = self.activation_fn(self.fc1(x))
+            x = self.dropout2(x)
+            x = self.fc2(x)
+            x = self.dropout3(x)
+            x = residual * self.deep_norm_alpha + x
+            x = self.final_layer_norm(x)
+
+        return x, attn, pos_bias
+
+
+class MultiheadAttention(nn.Module):
+    """Multi-headed attention.
+
+    See "Attention Is All You Need" for more details.
+    """
+
+    def __init__(
+            self,
+            embed_dim,
+            num_heads,
+            kdim=None,
+            vdim=None,
+            dropout=0.0,
+            bias=True,
+            add_bias_kv=False,
+            add_zero_attn=False,
+            self_attention=False,
+            encoder_decoder_attention=False,
+            q_noise=0.0,
+            qn_block_size=8,
+            has_relative_attention_bias=False,
+            num_buckets=32,
+            max_distance=128,
+            gru_rel_pos=False,
+            rescale_init=False,
+    ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.kdim = kdim if kdim is not None else embed_dim
+        self.vdim = vdim if vdim is not None else embed_dim
+        self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+        self.num_heads = num_heads
+        self.dropout_module = nn.Dropout(dropout)
+
+        self.has_relative_attention_bias = has_relative_attention_bias
+        self.num_buckets = num_buckets
+        self.max_distance = max_distance
+        if self.has_relative_attention_bias:
+            self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
+
+        self.head_dim = embed_dim // num_heads
+        self.q_head_dim = self.head_dim
+        self.k_head_dim = self.head_dim
+        assert (
+                self.head_dim * num_heads == self.embed_dim
+        ), "embed_dim must be divisible by num_heads"
+        self.scaling = self.head_dim ** -0.5
+
+        self.self_attention = self_attention
+        self.encoder_decoder_attention = encoder_decoder_attention
+
+        assert not self.self_attention or self.qkv_same_dim, (
+            "Self-attention requires query, key and " "value to be of the same size"
+        )
+
+        k_bias = True
+        if rescale_init:
+            k_bias = False
+
+        k_embed_dim = embed_dim
+        q_embed_dim = embed_dim
+
+        self.k_proj = quant_noise(
+            nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
+        )
+        self.v_proj = quant_noise(
+            nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
+        )
+        self.q_proj = quant_noise(
+            nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
+        )
+
+        self.out_proj = quant_noise(
+            nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
+        )
+
+        if add_bias_kv:
+            self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+            self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+        else:
+            self.bias_k = self.bias_v = None
+
+        self.add_zero_attn = add_zero_attn
+
+        self.gru_rel_pos = gru_rel_pos
+        if self.gru_rel_pos:
+            self.grep_linear = nn.Linear(self.q_head_dim, 8)
+            self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        if self.qkv_same_dim:
+            # Empirically observed the convergence to be much better with
+            # the scaled initialization
+            nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
+            nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
+            nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
+        else:
+            nn.init.xavier_uniform_(self.k_proj.weight)
+            nn.init.xavier_uniform_(self.v_proj.weight)
+            nn.init.xavier_uniform_(self.q_proj.weight)
+
+        nn.init.xavier_uniform_(self.out_proj.weight)
+        if self.out_proj.bias is not None:
+            nn.init.constant_(self.out_proj.bias, 0.0)
+        if self.bias_k is not None:
+            nn.init.xavier_normal_(self.bias_k)
+        if self.bias_v is not None:
+            nn.init.xavier_normal_(self.bias_v)
+        if self.has_relative_attention_bias:
+            nn.init.xavier_normal_(self.relative_attention_bias.weight)
+
+    def _relative_positions_bucket(self, relative_positions, bidirectional=True):
+        num_buckets = self.num_buckets
+        max_distance = self.max_distance
+        relative_buckets = 0
+
+        if bidirectional:
+            num_buckets = num_buckets // 2
+            relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
+            relative_positions = torch.abs(relative_positions)
+        else:
+            relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
+
+        max_exact = num_buckets // 2
+        is_small = relative_positions < max_exact
+
+        relative_postion_if_large = max_exact + (
+                torch.log(relative_positions.float() / max_exact)
+                / math.log(max_distance / max_exact)
+                * (num_buckets - max_exact)
+        ).to(torch.long)
+        relative_postion_if_large = torch.min(
+            relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
+        )
+
+        relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
+        return relative_buckets
+
+    def compute_bias(self, query_length, key_length):
+        context_position = torch.arange(query_length, dtype=torch.long)[:, None]
+        memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
+        relative_position = memory_position - context_position
+        relative_position_bucket = self._relative_positions_bucket(
+            relative_position,
+            bidirectional=True
+        )
+        relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
+        values = self.relative_attention_bias(relative_position_bucket)
+        values = values.permute([2, 0, 1])
+        return values
+
+    def forward(
+            self,
+            query,
+            key: Optional[Tensor],
+            value: Optional[Tensor],
+            key_padding_mask: Optional[Tensor] = None,
+            incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+            need_weights: bool = True,
+            static_kv: bool = False,
+            attn_mask: Optional[Tensor] = None,
+            before_softmax: bool = False,
+            need_head_weights: bool = False,
+            position_bias: Optional[Tensor] = None
+    ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+        """Input shape: Time x Batch x Channel
+
+        Args:
+            key_padding_mask (ByteTensor, optional): mask to exclude
+                keys that are pads, of shape `(batch, src_len)`, where
+                padding elements are indicated by 1s.
+            need_weights (bool, optional): return the attention weights,
+                averaged over heads (default: False).
+            attn_mask (ByteTensor, optional): typically used to
+                implement causal attention, where the mask prevents the
+                attention from looking forward in time (default: None).
+            before_softmax (bool, optional): return the raw attention
+                weights and values before the attention softmax.
+            need_head_weights (bool, optional): return the attention
+                weights for each head. Implies *need_weights*. Default:
+                return the average attention weights over all heads.
+        """
+        if need_head_weights:
+            need_weights = True
+
+        is_tpu = query.device.type == "xla"
+
+        tgt_len, bsz, embed_dim = query.size()
+        src_len = tgt_len
+        assert embed_dim == self.embed_dim
+        assert list(query.size()) == [tgt_len, bsz, embed_dim]
+        if key is not None:
+            src_len, key_bsz, _ = key.size()
+            if not torch.jit.is_scripting():
+                assert key_bsz == bsz
+                assert value is not None
+                assert src_len, bsz == value.shape[:2]
+
+        if self.has_relative_attention_bias and position_bias is None:
+            position_bias = self.compute_bias(tgt_len, src_len)
+            position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
+
+        if incremental_state is not None:
+            saved_state = self._get_input_buffer(incremental_state)
+            if saved_state is not None and "prev_key" in saved_state:
+                # previous time steps are cached - no need to recompute
+                # key and value if they are static
+                if static_kv:
+                    assert self.encoder_decoder_attention and not self.self_attention
+                    key = value = None
+        else:
+            saved_state = None
+
+        if self.self_attention:
+            q = self.q_proj(query)
+            k = self.k_proj(query)
+            v = self.v_proj(query)
+        elif self.encoder_decoder_attention:
+            # encoder-decoder attention
+            q = self.q_proj(query)
+            if key is None:
+                assert value is None
+                k = v = None
+            else:
+                k = self.k_proj(key)
+                v = self.v_proj(key)
+
+        else:
+            assert key is not None and value is not None
+            q = self.q_proj(query)
+            k = self.k_proj(key)
+            v = self.v_proj(value)
+        q *= self.scaling
+        alpha = 32
+        q *= 1 / alpha
+
+        if self.bias_k is not None:
+            assert self.bias_v is not None
+            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+            if attn_mask is not None:
+                attn_mask = torch.cat(
+                    [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
+                )
+            if key_padding_mask is not None:
+                key_padding_mask = torch.cat(
+                    [
+                        key_padding_mask,
+                        key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
+                    ],
+                    dim=1,
+                )
+
+        q = (
+            q.contiguous()
+                .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
+                .transpose(0, 1)
+        )
+        if k is not None:
+            k = (
+                k.contiguous()
+                    .view(-1, bsz * self.num_heads, self.k_head_dim)
+                    .transpose(0, 1)
+            )
+        if v is not None:
+            v = (
+                v.contiguous()
+                    .view(-1, bsz * self.num_heads, self.head_dim)
+                    .transpose(0, 1)
+            )
+
+        if saved_state is not None:
+            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+            if "prev_key" in saved_state:
+                _prev_key = saved_state["prev_key"]
+                assert _prev_key is not None
+                prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
+                if static_kv:
+                    k = prev_key
+                else:
+                    assert k is not None
+                    k = torch.cat([prev_key, k], dim=1)
+                src_len = k.size(1)
+            if "prev_value" in saved_state:
+                _prev_value = saved_state["prev_value"]
+                assert _prev_value is not None
+                prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
+                if static_kv:
+                    v = prev_value
+                else:
+                    assert v is not None
+                    v = torch.cat([prev_value, v], dim=1)
+            prev_key_padding_mask: Optional[Tensor] = None
+            if "prev_key_padding_mask" in saved_state:
+                prev_key_padding_mask = saved_state["prev_key_padding_mask"]
+            assert k is not None and v is not None
+            key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
+                key_padding_mask=key_padding_mask,
+                prev_key_padding_mask=prev_key_padding_mask,
+                batch_size=bsz,
+                src_len=k.size(1),
+                static_kv=static_kv,
+            )
+
+            saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
+            saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
+            saved_state["prev_key_padding_mask"] = key_padding_mask
+            # In this branch incremental_state is never None
+            assert incremental_state is not None
+            incremental_state = self._set_input_buffer(incremental_state, saved_state)
+        assert k is not None
+        assert k.size(1) == src_len
+
+        # This is part of a workaround to get around fork/join parallelism
+        # not supporting Optional types.
+        if key_padding_mask is not None and key_padding_mask.dim() == 0:
+            key_padding_mask = None
+
+        if key_padding_mask is not None:
+            assert key_padding_mask.size(0) == bsz
+            assert key_padding_mask.size(1) == src_len
+
+        if self.add_zero_attn:
+            assert v is not None
+            src_len += 1
+            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+            if attn_mask is not None:
+                attn_mask = torch.cat(
+                    [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
+                )
+            if key_padding_mask is not None:
+                key_padding_mask = torch.cat(
+                    [
+                        key_padding_mask,
+                        torch.zeros(key_padding_mask.size(0), 1).type_as(
+                            key_padding_mask
+                        ),
+                    ],
+                    dim=1,
+                )
+
+        attn_weights = torch.bmm(q, k.transpose(1, 2))
+        attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha
+        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+        assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+        if attn_mask is not None:
+            attn_mask = attn_mask.unsqueeze(0)
+            attn_weights += attn_mask
+
+        if key_padding_mask is not None:
+            # don't attend to padding symbols
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            if not is_tpu:
+                attn_weights = attn_weights.masked_fill(
+                    key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+                    float("-inf"),
+                )
+            else:
+                attn_weights = attn_weights.transpose(0, 2)
+                attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
+                attn_weights = attn_weights.transpose(0, 2)
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if before_softmax:
+            return attn_weights, v, position_bias
+
+        if position_bias is not None:
+            attn_mask_rel_pos = position_bias
+            if self.gru_rel_pos == 1:
+                query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling
+                _B, _H, _L, __ = query_layer.size()
+                gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
+                    _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
+                gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
+                attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
+
+            attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
+
+            attn_weights = attn_weights + attn_mask_rel_pos
+
+        attn_weights_float = F.softmax(
+            attn_weights, dim=-1
+        )
+        attn_weights = attn_weights_float.type_as(attn_weights)
+        attn_probs = self.dropout_module(attn_weights)
+
+        assert v is not None
+        attn = torch.bmm(attn_probs, v)
+        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+        attn = self.out_proj(attn)
+        attn_weights: Optional[Tensor] = None
+        if need_weights:
+            attn_weights = attn_weights_float.view(
+                bsz, self.num_heads, tgt_len, src_len
+            ).transpose(1, 0)
+            if not need_head_weights:
+                # average attention weights over heads
+                attn_weights = attn_weights.mean(dim=0)
+
+        return attn, attn_weights, position_bias
+
+    @staticmethod
+    def _append_prev_key_padding_mask(
+            key_padding_mask: Optional[Tensor],
+            prev_key_padding_mask: Optional[Tensor],
+            batch_size: int,
+            src_len: int,
+            static_kv: bool,
+    ) -> Optional[Tensor]:
+        # saved key padding masks have shape (bsz, seq_len)
+        if prev_key_padding_mask is not None and static_kv:
+            new_key_padding_mask = prev_key_padding_mask
+        elif prev_key_padding_mask is not None and key_padding_mask is not None:
+            new_key_padding_mask = torch.cat(
+                [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
+            )
+        # During incremental decoding, as the padding token enters and
+        # leaves the frame, there will be a time when prev or current
+        # is None
+        elif prev_key_padding_mask is not None:
+            if src_len > prev_key_padding_mask.size(1):
+                filler = torch.zeros(
+                    (batch_size, src_len - prev_key_padding_mask.size(1)),
+                    device=prev_key_padding_mask.device,
+                )
+                new_key_padding_mask = torch.cat(
+                    [prev_key_padding_mask.float(), filler.float()], dim=1
+                )
+            else:
+                new_key_padding_mask = prev_key_padding_mask.float()
+        elif key_padding_mask is not None:
+            if src_len > key_padding_mask.size(1):
+                filler = torch.zeros(
+                    (batch_size, src_len - key_padding_mask.size(1)),
+                    device=key_padding_mask.device,
+                )
+                new_key_padding_mask = torch.cat(
+                    [filler.float(), key_padding_mask.float()], dim=1
+                )
+            else:
+                new_key_padding_mask = key_padding_mask.float()
+        else:
+            new_key_padding_mask = prev_key_padding_mask
+        return new_key_padding_mask
+
+    def _get_input_buffer(
+            self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
+    ) -> Dict[str, Optional[Tensor]]:
+        result = self.get_incremental_state(incremental_state, "attn_state")
+        if result is not None:
+            return result
+        else:
+            empty_result: Dict[str, Optional[Tensor]] = {}
+            return empty_result
+
+    def _set_input_buffer(
+            self,
+            incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
+            buffer: Dict[str, Optional[Tensor]],
+    ):
+        return self.set_incremental_state(incremental_state, "attn_state", buffer)
+
+    def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
+        return attn_weights
+
+
+def init_bert_params(module):
+    """
+    Initialize the weights specific to the BERT Model.
+    This overrides the default initializations depending on the specified arguments.
+        1. If normal_init_linear_weights is set then weights of linear
+           layer will be initialized using the normal distribution and
+           bais will be set to the specified value.
+        2. If normal_init_embed_weights is set then weights of embedding
+           layer will be initialized using the normal distribution.
+        3. If normal_init_proj_weights is set then weights of
+           in_project_weight for MultiHeadAttention initialized using
+           the normal distribution (to be validated).
+    """
+
+    def normal_(data):
+        # with FSDP, module params will be on CUDA, so we cast them back to CPU
+        # so that the RNG is consistent with and without FSDP
+        data.copy_(
+            data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
+        )
+
+    if isinstance(module, nn.Linear):
+        normal_(module.weight.data)
+        if module.bias is not None:
+            module.bias.data.zero_()
+    if isinstance(module, nn.Embedding):
+        normal_(module.weight.data)
+        if module.padding_idx is not None:
+            module.weight.data[module.padding_idx].zero_()
+    if isinstance(module, MultiheadAttention):
+        normal_(module.q_proj.weight.data)
+        normal_(module.k_proj.weight.data)
+        normal_(module.v_proj.weight.data)
diff --git a/slam_llm/models/BEATs/modules.py b/slam_llm/models/BEATs/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..f32ef58ad00c809f0b134cd556adea7ff5fc0503
--- /dev/null
+++ b/slam_llm/models/BEATs/modules.py
@@ -0,0 +1,219 @@
+# --------------------------------------------------------
+# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
+# Github source: https://github.com/microsoft/unilm/tree/master/beats
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+import warnings
+import torch
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class GradMultiply(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x, scale):
+        ctx.scale = scale
+        res = x.new(x)
+        return res
+
+    @staticmethod
+    def backward(ctx, grad):
+        return grad * ctx.scale, None
+
+
+class SamePad(nn.Module):
+    def __init__(self, kernel_size, causal=False):
+        super().__init__()
+        if causal:
+            self.remove = kernel_size - 1
+        else:
+            self.remove = 1 if kernel_size % 2 == 0 else 0
+
+    def forward(self, x):
+        if self.remove > 0:
+            x = x[:, :, : -self.remove]
+        return x
+
+
+class Swish(nn.Module):
+    def __init__(self):
+        super(Swish, self).__init__()
+        self.act = torch.nn.Sigmoid()
+
+    def forward(self, x):
+        return x * self.act(x)
+
+
+class GLU_Linear(nn.Module):
+    def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
+        super(GLU_Linear, self).__init__()
+
+        self.glu_type = glu_type
+        self.output_dim = output_dim
+
+        if glu_type == "sigmoid":
+            self.glu_act = torch.nn.Sigmoid()
+        elif glu_type == "swish":
+            self.glu_act = Swish()
+        elif glu_type == "relu":
+            self.glu_act = torch.nn.ReLU()
+        elif glu_type == "gelu":
+            self.glu_act = torch.nn.GELU()
+
+        if bias_in_glu:
+            self.linear = nn.Linear(input_dim, output_dim * 2, True)
+        else:
+            self.linear = nn.Linear(input_dim, output_dim * 2, False)
+
+    def forward(self, x):
+        # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
+        x = self.linear(x)
+
+        if self.glu_type == "bilinear":
+            x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
+        else:
+            x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
+
+        return x
+
+
+def gelu_accurate(x):
+    if not hasattr(gelu_accurate, "_a"):
+        gelu_accurate._a = math.sqrt(2 / math.pi)
+    return (
+        0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
+    )
+
+
+def gelu(x: torch.Tensor) -> torch.Tensor:
+    return torch.nn.functional.gelu(x.float()).type_as(x)
+
+
+def get_activation_fn(activation: str):
+    """Returns the activation function corresponding to `activation`"""
+
+    if activation == "relu":
+        return F.relu
+    elif activation == "gelu":
+        return gelu
+    elif activation == "gelu_fast":
+        warnings.warn(
+            "--activation-fn=gelu_fast has been renamed to gelu_accurate"
+        )
+        return gelu_accurate
+    elif activation == "gelu_accurate":
+        return gelu_accurate
+    elif activation == "tanh":
+        return torch.tanh
+    elif activation == "linear":
+        return lambda x: x
+    elif activation == "glu":
+        return lambda x: x
+    else:
+        raise RuntimeError("--activation-fn {} not supported".format(activation))
+
+
+def quant_noise(module, p, block_size):
+    """
+    Wraps modules and applies quantization noise to the weights for
+    subsequent quantization with Iterative Product Quantization as
+    described in "Training with Quantization Noise for Extreme Model Compression"
+
+    Args:
+        - module: nn.Module
+        - p: amount of Quantization Noise
+        - block_size: size of the blocks for subsequent quantization with iPQ
+
+    Remarks:
+        - Module weights must have the right sizes wrt the block size
+        - Only Linear, Embedding and Conv2d modules are supported for the moment
+        - For more detail on how to quantize by blocks with convolutional weights,
+          see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
+        - We implement the simplest form of noise here as stated in the paper
+          which consists in randomly dropping blocks
+    """
+
+    # if no quantization noise, don't register hook
+    if p <= 0:
+        return module
+
+    # supported modules
+    assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
+
+    # test whether module.weight has the right sizes wrt block_size
+    is_conv = module.weight.ndim == 4
+
+    # 2D matrix
+    if not is_conv:
+        assert (
+            module.weight.size(1) % block_size == 0
+        ), "Input features must be a multiple of block sizes"
+
+    # 4D matrix
+    else:
+        # 1x1 convolutions
+        if module.kernel_size == (1, 1):
+            assert (
+                module.in_channels % block_size == 0
+            ), "Input channels must be a multiple of block sizes"
+        # regular convolutions
+        else:
+            k = module.kernel_size[0] * module.kernel_size[1]
+            assert k % block_size == 0, "Kernel size must be a multiple of block size"
+
+    def _forward_pre_hook(mod, input):
+        # no noise for evaluation
+        if mod.training:
+            if not is_conv:
+                # gather weight and sizes
+                weight = mod.weight
+                in_features = weight.size(1)
+                out_features = weight.size(0)
+
+                # split weight matrix into blocks and randomly drop selected blocks
+                mask = torch.zeros(
+                    in_features // block_size * out_features, device=weight.device
+                )
+                mask.bernoulli_(p)
+                mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
+
+            else:
+                # gather weight and sizes
+                weight = mod.weight
+                in_channels = mod.in_channels
+                out_channels = mod.out_channels
+
+                # split weight matrix into blocks and randomly drop selected blocks
+                if mod.kernel_size == (1, 1):
+                    mask = torch.zeros(
+                        int(in_channels // block_size * out_channels),
+                        device=weight.device,
+                    )
+                    mask.bernoulli_(p)
+                    mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
+                else:
+                    mask = torch.zeros(
+                        weight.size(0), weight.size(1), device=weight.device
+                    )
+                    mask.bernoulli_(p)
+                    mask = (
+                        mask.unsqueeze(2)
+                        .unsqueeze(3)
+                        .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+                    )
+
+            # scale weights and apply mask
+            mask = mask.to(
+                torch.bool
+            )  # x.bool() is not currently supported in TorchScript
+            s = 1 / (1 - p)
+            mod.weight.data = s * weight.masked_fill(mask, 0)
+
+    module.register_forward_pre_hook(_forward_pre_hook)
+    return module
+
diff --git a/slam_llm/models/BEATs/quantizer.py b/slam_llm/models/BEATs/quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6686444ebb9f58f6bb7c8c12c4580d26cecdf2ec
--- /dev/null
+++ b/slam_llm/models/BEATs/quantizer.py
@@ -0,0 +1,215 @@
+# --------------------------------------------------------
+# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
+# Github source: https://github.com/microsoft/unilm/tree/master/beats
+# Copyright (c) 2022 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on VQGAN code bases
+# https://github.com/CompVis/taming-transformers
+# --------------------------------------------------------'
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as distributed
+
+try:
+    from einops import rearrange, repeat
+except ImportError:
+    pass
+
+
+def l2norm(t):
+    return F.normalize(t, p=2, dim=-1)
+
+
+def ema_inplace(moving_avg, new, decay):
+    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def sample_vectors(samples, num):
+    num_samples, device = samples.shape[0], samples.device
+
+    if num_samples >= num:
+        indices = torch.randperm(num_samples, device=device)[:num]
+    else:
+        indices = torch.randint(0, num_samples, (num,), device=device)
+
+    return samples[indices]
+
+
+def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
+    dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
+
+    means = sample_vectors(samples, num_clusters)
+
+    for _ in range(num_iters):
+        if use_cosine_sim:
+            dists = samples @ means.t()
+        else:
+            diffs = rearrange(samples, 'n d -> n () d') \
+                    - rearrange(means, 'c d -> () c d')
+            dists = -(diffs ** 2).sum(dim=-1)
+
+        buckets = dists.max(dim=-1).indices
+        bins = torch.bincount(buckets, minlength=num_clusters)
+        zero_mask = bins == 0
+        bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+        new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
+        new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
+        new_means = new_means / bins_min_clamped[..., None]
+
+        if use_cosine_sim:
+            new_means = l2norm(new_means)
+
+        means = torch.where(zero_mask[..., None], means, new_means)
+
+    return means, bins
+
+
+class EmbeddingEMA(nn.Module):
+    def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
+        super().__init__()
+        self.num_tokens = num_tokens
+        self.codebook_dim = codebook_dim
+        self.decay = decay
+        self.eps = eps
+        if codebook_init_path == '':
+            if not kmeans_init:
+                weight = torch.randn(num_tokens, codebook_dim)
+                weight = l2norm(weight)
+            else:
+                weight = torch.zeros(num_tokens, codebook_dim)
+            self.register_buffer('initted', torch.Tensor([not kmeans_init]))
+        else:
+            print(f"load init codebook weight from {codebook_init_path}")
+            codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
+            weight = codebook_ckpt_weight.clone()
+            self.register_buffer('initted', torch.Tensor([True]))
+
+        self.weight = nn.Parameter(weight, requires_grad=False)
+        self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
+        self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
+        # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
+        self.update = True
+
+    @torch.jit.ignore
+    def init_embed_(self, data):
+        if self.initted:
+            return
+        print("Performing Kemans init for codebook")
+        embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
+        self.weight.data.copy_(embed)
+        self.cluster_size.data.copy_(cluster_size)
+        self.initted.data.copy_(torch.Tensor([True]))
+
+    def forward(self, embed_id):
+        return F.embedding(embed_id, self.weight)
+
+    def cluster_size_ema_update(self, new_cluster_size):
+        self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
+
+    def embed_avg_ema_update(self, new_embed_avg):
+        self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
+
+    def weight_update(self, num_tokens):
+        n = self.cluster_size.sum()
+        smoothed_cluster_size = (
+                (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
+        )
+        # normalize embedding average with smoothed cluster size
+        embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
+        # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
+        self.weight.data.copy_(embed_normalized)
+
+
+def norm_ema_inplace(moving_avg, new, decay):
+    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+    moving_avg.data.copy_(l2norm(moving_avg.data))
+
+
+class NormEMAVectorQuantizer(nn.Module):
+    def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
+                 statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
+        super().__init__()
+        self.codebook_dim = embedding_dim
+        self.num_tokens = n_embed
+        self.beta = beta
+        self.decay = decay
+
+        # learnable = True if orthogonal_reg_weight > 0 else False
+        self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
+
+        self.statistic_code_usage = statistic_code_usage
+        if statistic_code_usage:
+            self.register_buffer('cluster_size', torch.zeros(n_embed))
+        if distributed.is_available() and distributed.is_initialized():
+            print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
+            self.all_reduce_fn = distributed.all_reduce
+        else:
+            self.all_reduce_fn = nn.Identity()
+
+    def reset_cluster_size(self, device):
+        if self.statistic_code_usage:
+            self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
+            self.cluster_size = self.cluster_size.to(device)
+
+    def forward(self, z):
+        # reshape z -> (batch, height, width, channel) and flatten
+        # z, 'b c h w -> b h w c'
+        # z = rearrange(z, 'b c h w -> b h w c')
+        # z = z.transpose(1, 2)
+        z = l2norm(z)
+        z_flattened = z.reshape(-1, self.codebook_dim)
+
+        self.embedding.init_embed_(z_flattened)
+
+        d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
+            self.embedding.weight.pow(2).sum(dim=1) - 2 * \
+            torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight)  # 'n d -> d n'
+
+        encoding_indices = torch.argmin(d, dim=1)
+
+        z_q = self.embedding(encoding_indices).view(z.shape)
+
+        encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
+
+        if not self.training:
+            with torch.no_grad():
+                cluster_size = encodings.sum(0)
+                self.all_reduce_fn(cluster_size)
+                ema_inplace(self.cluster_size, cluster_size, self.decay)
+
+        if self.training and self.embedding.update:
+            # EMA cluster size
+
+            bins = encodings.sum(0)
+            self.all_reduce_fn(bins)
+
+            # self.embedding.cluster_size_ema_update(bins)
+            ema_inplace(self.cluster_size, bins, self.decay)
+
+            zero_mask = (bins == 0)
+            bins = bins.masked_fill(zero_mask, 1.)
+
+            embed_sum = z_flattened.t() @ encodings
+            self.all_reduce_fn(embed_sum)
+
+            embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
+            embed_normalized = l2norm(embed_normalized)
+
+            embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
+                                           embed_normalized)
+            norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
+
+        # compute loss for embedding
+        loss = self.beta * F.mse_loss(z_q.detach(), z)
+
+        # preserve gradients
+        z_q = z + (z_q - z).detach()
+
+        # reshape back to match original input shape
+        # z_q, 'b h w c -> b c h w'
+        # z_q = rearrange(z_q, 'b h w c -> b c h w')
+        # z_q = z_q.transpose(1, 2)
+        return z_q, loss, encoding_indices
diff --git a/slam_llm/models/EAT/EAT.py b/slam_llm/models/EAT/EAT.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a42ba5ffa45d5fdd148174c8cb79824dd397434
--- /dev/null
+++ b/slam_llm/models/EAT/EAT.py
@@ -0,0 +1,32 @@
+import torch
+import torchaudio
+import random
+
+def EAT_preprocess(source, norm_mean = -4.268, norm_std = 4.569, target_length = 1024, fixed_length = False, random_crop = False):    
+    source = source - source.mean()
+    source = source.unsqueeze(dim=0)
+    
+    source = torchaudio.compliance.kaldi.fbank(source, htk_compat=True, sample_frequency=16000, use_energy=False,
+                                                window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10).unsqueeze(dim=0)
+    
+    n_frames = source.shape[1]
+    if not fixed_length:
+        target_length = n_frames
+        if target_length % 16 != 0:
+            target_length = n_frames + (16 - n_frames % 16)
+    diff = target_length - n_frames
+    if diff > 0:
+        m = torch.nn.ZeroPad2d((0, 0, 0, diff)) 
+        source = m(source)
+    elif diff < 0:
+        if random_crop: 
+            start_index = random.randint(0, n_frames - target_length)
+            source = source[:,start_index: start_index+target_length, :]
+        else: 
+            source = source[:,0:target_length, :]
+    
+    # Normalize the mel spectrogram
+    source = (source - norm_mean) / (norm_std * 2)
+    source = source.squeeze()
+    
+    return source
\ No newline at end of file
diff --git a/slam_llm/models/SpatialAST/SpatialAST.py b/slam_llm/models/SpatialAST/SpatialAST.py
new file mode 100644
index 0000000000000000000000000000000000000000..d08dceb147bcf8e487dee36b5fd08056175b5259
--- /dev/null
+++ b/slam_llm/models/SpatialAST/SpatialAST.py
@@ -0,0 +1,122 @@
+import torch
+import torch.nn as nn
+
+from torchlibrosa.stft import STFT, LogmelFilterBank
+from timm.models.layers import to_2tuple
+
+from .vision_transformer import VisionTransformer as _VisionTransformer
+
+def conv3x3(in_channels, out_channels, stride=1):
+    "3x3 convolution with padding"
+    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
+
+class PatchEmbed_new(nn.Module):
+    """ Flexible Image to Patch Embedding
+    """
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        stride = to_2tuple(stride)
+        
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.in_chans = in_chans
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
+        _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
+        self.patch_hw = (h, w)
+        self.num_patches = h*w
+
+    def get_output_shape(self, img_size):
+        return self.proj(torch.randn(1, self.in_chans, img_size[0], img_size[1])).shape 
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+
+        x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
+        x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
+        x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
+        return x
+
+class BinauralEncoder(_VisionTransformer):
+    """ Spatial Audio Spectrogram Transformer designed for Sound Event Localization and Detection
+        --------------------------------------------------------
+        References:
+        Spatial-AST from BAT: https://github.com/zszheng147/Spatial-AST and https://arxiv.org/abs/2402.01591
+        --------------------------------------------------------
+    """
+    def __init__(self, num_cls_tokens=3, **kwargs):
+        super(BinauralEncoder, self).__init__(**kwargs)
+        img_size = (1024, 128) # 1024, 128
+        in_chans = 1
+        emb_dim = 768
+
+        del self.cls_token
+        self.num_cls_tokens = num_cls_tokens
+        self.cls_tokens = nn.Parameter(torch.zeros(1, num_cls_tokens, emb_dim))
+
+        self.patch_embed = PatchEmbed_new(
+            img_size=img_size, patch_size=(16, 16), 
+            in_chans=in_chans, embed_dim=emb_dim, stride=16
+        ) # no overlap. stride=img_size=16
+
+        num_patches = self.patch_embed.num_patches
+        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False)  # fixed sin-cos embedding
+
+        self.spectrogram_extractor = STFT(
+            n_fft=1024, hop_length=320, win_length=1024, window='hann', 
+            center=True, pad_mode='reflect', freeze_parameters=True
+        )
+        self.logmel_extractor = LogmelFilterBank(
+            sr=32000, n_fft=1024, n_mels=128, fmin=50, 
+            fmax=14000, ref=1.0, amin=1e-10, top_db=None, freeze_parameters=True
+        )
+        
+        self.conv_downsample = nn.Sequential(
+            conv3x3(4, 1), 
+            nn.BatchNorm2d(1),
+            nn.GELU(),
+        )
+
+        self.bn = nn.BatchNorm2d(2, affine=False)
+        del self.norm  # remove the original norm
+
+        self.target_frame = 1024
+
+    def forward_features_mask(self, x):
+        B = x.shape[0] #bsz, 512, 768 (unmasked)
+
+        x = x + self.pos_embed[:, 1:, :]
+        
+        cls_tokens = self.cls_tokens
+        cls_tokens = cls_tokens.expand(B, -1, -1)
+        x = torch.cat([cls_tokens, x], dim=1)   # bsz, 512 + 2 + 10, 768 
+        x = self.pos_drop(x)
+        
+        for blk in self.blocks:
+            x = blk(x)
+
+        return x
+
+    @torch.no_grad()
+    def forward(self, waveforms):
+        B, C, T = waveforms.shape
+
+        waveforms = waveforms.reshape(B * C, T)
+        real, imag = self.spectrogram_extractor(waveforms) 
+
+        log_mel = self.logmel_extractor(torch.sqrt(real**2 + imag**2)).reshape(B, C, -1, 128)
+        log_mel = self.bn(log_mel)
+        
+        IPD = torch.atan2(imag[1::2], real[1::2]) - torch.atan2(imag[::2], real[::2])
+        x = torch.cat([log_mel, torch.matmul(torch.cat([torch.cos(IPD), torch.sin(IPD)], dim=1), self.logmel_extractor.melW)], dim=1)
+
+        if x.shape[2] < self.target_frame:
+            x = nn.functional.interpolate(x, (self.target_frame, x.shape[3]), mode="bicubic", align_corners=True)
+    
+        x = self.conv_downsample(x)
+        x = self.patch_embed(x)
+        x = self.forward_features_mask(x)
+
+        return x
\ No newline at end of file
diff --git a/slam_llm/models/SpatialAST/vision_transformer.py b/slam_llm/models/SpatialAST/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ade970ab00bd506ff4c5078d9a4361afeaf7294e
--- /dev/null
+++ b/slam_llm/models/SpatialAST/vision_transformer.py
@@ -0,0 +1,239 @@
+import torch
+from torch import nn
+
+from timm.models.layers import to_2tuple, DropPath, trunc_normal_
+
+
+class HybridEmbed(nn.Module):
+    """ CNN Feature Map Embedding
+    Extract feature map from CNN, flatten, project to embedding dim.
+    """
+    def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
+        super().__init__()
+        assert isinstance(backbone, nn.Module)
+        img_size = to_2tuple(img_size)
+        self.img_size = img_size
+        self.backbone = backbone
+        if feature_size is None:
+            with torch.no_grad():
+                # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
+                # map for all networks, the feature metadata has reliable channel and stride info, but using
+                # stride to calc feature dim requires info about padding of each stage that isn't captured.
+                training = backbone.training
+                if training:
+                    backbone.eval()
+                o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
+                feature_size = o.shape[-2:]
+                feature_dim = o.shape[1]
+                backbone.train(training)
+        else:
+            feature_size = to_2tuple(feature_size)
+            feature_dim = self.backbone.feature_info.channels()[-1]
+        self.num_patches = feature_size[0] * feature_size[1]
+        self.proj = nn.Linear(feature_dim, embed_dim)
+
+    def forward(self, x):
+        x = self.backbone(x)[-1]
+        x = x.flatten(2).transpose(1, 2)
+        x = self.proj(x)
+        return x
+    
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+class Attention(nn.Module):
+    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+
+    def forward(self, x):
+        B, N, C = x.shape
+        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
+
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+      
+class Block(nn.Module):
+
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+    def forward(self, x):
+        x = x + self.drop_path(self.attn(self.norm1(x)))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+    
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        # FIXME look at relaxing size constraints
+        assert H == self.img_size[0] and W == self.img_size[1], \
+            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x).flatten(2).transpose(1, 2)
+        return x
+
+
+class PatchEmbed_new(nn.Module):
+    """ Flexible Image to Patch Embedding
+    """
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        stride = to_2tuple(stride)
+        
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.in_chans = in_chans
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
+        _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
+        self.patch_hw = (h, w)
+        self.num_patches = h*w
+
+    def get_output_shape(self, img_size):
+        return self.proj(torch.randn(1, self.in_chans, img_size[0], img_size[1])).shape 
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+
+        x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
+        x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
+        x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
+        return x
+
+
+class VisionTransformer(nn.Module):
+    """ Vision Transformer with support for patch or hybrid CNN input stage
+    """
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+                 drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.num_classes = num_classes
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+
+        if hybrid_backbone is not None:
+            self.patch_embed = HybridEmbed(
+                hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
+        else:
+            self.patch_embed = PatchEmbed(
+                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+        num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
+            for i in range(depth)])
+        
+        self.norm = norm_layer(embed_dim)
+
+        # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
+        #self.repr = nn.Linear(embed_dim, representation_size)
+        #self.repr_act = nn.Tanh()
+
+        # Classifier head
+        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+        trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'pos_embed', 'cls_token'}
+
+    def get_classifier(self):
+        return self.head
+
+    def reset_classifier(self, num_classes, global_pool=''):
+        self.num_classes = num_classes
+        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+    def forward_features(self, x):
+        B = x.shape[0]
+        x = self.patch_embed(x)
+
+        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
+        x = torch.cat((cls_tokens, x), dim=1)
+        x = x + self.pos_embed
+        x = self.pos_drop(x)
+
+        for blk in self.blocks:
+            x = blk(x)
+
+        x = self.norm(x)
+        return x[:, 0]
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        x = self.head(x)
+        return x
\ No newline at end of file
diff --git a/slam_llm/models/avhubert/__init__.py b/slam_llm/models/avhubert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..994cc7f52f5b785ea61720b40f698dc48d77af34
--- /dev/null
+++ b/slam_llm/models/avhubert/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .hubert import *  # noqa
+from .hubert_asr import *  # noqa
+from .hubert_dataset import *
+from .hubert_pretraining import *
+from .hubert_criterion import *
diff --git a/slam_llm/models/avhubert/decoder.py b/slam_llm/models/avhubert/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e40868502d37cc4dddd9d7cfe784873930ef19dc
--- /dev/null
+++ b/slam_llm/models/avhubert/decoder.py
@@ -0,0 +1,243 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from argparse import Namespace
+import contextlib
+import copy
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from dataclasses import dataclass, field
+from omegaconf import MISSING, II, open_dict
+from typing import Any, Optional
+
+from fairseq import checkpoint_utils, tasks, utils
+from fairseq.dataclass import FairseqDataclass
+from fairseq.dataclass.utils import convert_namespace_to_omegaconf
+from fairseq.tasks import FairseqTask
+from fairseq.models import (
+    BaseFairseqModel,
+    FairseqEncoder,
+    FairseqEncoderDecoderModel,
+    FairseqIncrementalDecoder,
+    register_model,
+)
+# from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
+from fairseq.modules import (
+    LayerNorm,
+    PositionalEmbedding,
+    TransformerDecoderLayer,
+)
+
+
+class TransformerDecoder(FairseqIncrementalDecoder):
+    """
+    Transformer decoder consisting of *args.decoder_layers* layers. Each layer
+    is a :class:`TransformerDecoderLayer`.
+
+    Args:
+        args (argparse.Namespace): parsed command-line arguments
+        dictionary (~fairseq.data.Dictionary): decoding dictionary
+        embed_tokens (torch.nn.Embedding): output embedding
+        no_encoder_attn (bool, optional): whether to attend to encoder outputs
+            (default: False).
+    """
+
+    def __init__(
+        self,
+        cfg,
+        dictionary,
+        embed_tokens,
+        no_encoder_attn=False,
+    ):
+        super().__init__(dictionary)
+
+        self.dropout = cfg.decoder_dropout
+        self.share_input_output_embed = cfg.share_decoder_input_output_embed
+
+        input_embed_dim = embed_tokens.embedding_dim
+        embed_dim = cfg.decoder_embed_dim
+        self.output_embed_dim = cfg.decoder_embed_dim
+
+        self.layerdrop = cfg.decoder_layerdrop
+
+        padding_idx = embed_tokens.padding_idx
+        self.max_target_positions = cfg.max_target_positions
+
+        self.embed_tokens = embed_tokens
+        # self.embed_scale = math.sqrt(embed_dim)  # todo: try with input_embed_dim
+        self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
+
+        self.project_in_dim = (
+            Linear(input_embed_dim, embed_dim, bias=False)
+            if embed_dim != input_embed_dim
+            else None
+        )
+
+        self.embed_positions = (
+            PositionalEmbedding(
+                cfg.max_target_positions,
+                embed_dim,
+                padding_idx,
+                learned=cfg.decoder_learned_pos,
+            )
+            if not cfg.no_token_positional_embeddings
+            else None
+        )
+
+        # TODO: update this when transformer gets converted to dataclass configs
+        transformer_cfg = copy.deepcopy(cfg)
+        # with open_dict(transformer_cfg):
+        transformer_cfg.dropout = transformer_cfg.decoder_dropout
+        transformer_cfg.attention_dropout = (
+            transformer_cfg.decoder_attention_dropout
+        )
+        transformer_cfg.activation_dropout = (
+            transformer_cfg.decoder_activation_dropout
+        )
+
+        self.layers = nn.ModuleList([])
+        self.layers.extend(
+            [
+                TransformerDecoderLayer(transformer_cfg, no_encoder_attn)
+                for _ in range(transformer_cfg.decoder_layers)
+            ]
+        )
+
+        if not self.share_input_output_embed:
+            self.embed_out = nn.Parameter(
+                torch.Tensor(len(dictionary), self.output_embed_dim)
+            )
+            nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)
+
+        if transformer_cfg.decoder_normalize_before:
+            self.layer_norm = LayerNorm(embed_dim)
+        else:
+            self.layer_norm = None
+
+    def forward(
+        self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
+    ):
+        """
+        Args:
+            prev_output_tokens (LongTensor): previous decoder outputs of shape
+                `(batch, tgt_len)`, for teacher forcing
+            encoder_out (Tensor, optional): output from the encoder, used for
+                encoder-side attention
+            incremental_state (dict): dictionary used for storing state during
+                :ref:`Incremental decoding`
+
+        Returns:
+            tuple:
+                - the decoder's output of shape `(batch, tgt_len, vocab)`
+                - a dictionary with any model-specific outputs
+        """
+        prev_output_tokens = prev_output_tokens.long()
+        x, extra = self.extract_features(
+            prev_output_tokens, encoder_out, incremental_state
+        )
+        x = self.output_layer(x)
+        return x, extra
+
+    def extract_features(
+        self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused
+    ):
+        """
+        Similar to *forward* but only return features.
+
+        Returns:
+            tuple:
+                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
+                - a dictionary with any model-specific outputs
+        """
+
+        # embed positions
+        positions = (
+            self.embed_positions(
+                prev_output_tokens, incremental_state=incremental_state
+            )
+            if self.embed_positions is not None
+            else None
+        )
+
+        if incremental_state is not None:
+            prev_output_tokens = prev_output_tokens[:, -1:]
+            if positions is not None:
+                positions = positions[:, -1:]
+
+        # embed tokens and positions
+        x = self.embed_scale * self.embed_tokens(prev_output_tokens)
+
+        if self.project_in_dim is not None:
+            x = self.project_in_dim(x)
+
+        if positions is not None:
+            x += positions
+        x = F.dropout(x, p=self.dropout, training=self.training)
+
+        # B x T x C -> T x B x C
+        x = x.transpose(0, 1)
+        attn = None
+
+        inner_states = [x]
+
+        # decoder layers
+        for layer in self.layers:
+            dropout_probability = np.random.random()
+            if not self.training or (dropout_probability > self.layerdrop):
+                x, attn, _ = layer(
+                    x,
+                    encoder_out["encoder_out"] if encoder_out is not None else None,
+                    encoder_out["padding_mask"] if encoder_out is not None else None,
+                    incremental_state,
+                    self_attn_mask=self.buffered_future_mask(x)
+                    if incremental_state is None
+                    else None,
+                )
+                inner_states.append(x)
+
+        if self.layer_norm:
+            x = self.layer_norm(x)
+
+        # T x B x C -> B x T x C
+        x = x.transpose(0, 1)
+
+        return x, {"attn": attn, "inner_states": inner_states}
+
+    def output_layer(self, features, **kwargs):
+        """Project features to the vocabulary size."""
+        # project back to size of vocabulary
+        emb_mat = self.embed_tokens.weight if self.share_input_output_embed else self.embed_out
+        return torch.matmul(features, emb_mat.transpose(0, 1))
+        # if self.share_input_output_embed:
+        #     return F.linear(features, self.embed_tokens.weight)
+        # else:
+        #     return F.linear(features, self.embed_out)
+
+    def max_positions(self):
+        """Maximum output length supported by the decoder."""
+        if self.embed_positions is None:
+            return self.max_target_positions
+        return min(self.max_target_positions, self.embed_positions.max_positions)
+
+    def buffered_future_mask(self, tensor):
+        dim = tensor.size(0)
+        if (
+            not hasattr(self, "_future_mask")
+            or self._future_mask is None
+            or self._future_mask.device != tensor.device
+            or self._future_mask.size(0) < dim
+        ):
+            self._future_mask = torch.triu(
+                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
+            )
+        return self._future_mask[:dim, :dim]
+
+    def upgrade_state_dict_named(self, state_dict, name):
+        return state_dict
+
diff --git a/slam_llm/models/avhubert/hubert.py b/slam_llm/models/avhubert/hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..96273879528ad13600c928dbb2cce5cb9cc2b962
--- /dev/null
+++ b/slam_llm/models/avhubert/hubert.py
@@ -0,0 +1,792 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os,sys
+import logging
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+from dataclasses import dataclass, field
+from fairseq import utils
+from fairseq.data.data_utils import compute_mask_indices
+from fairseq.data.dictionary import Dictionary
+from fairseq.dataclass import ChoiceEnum, FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+from fairseq.models.wav2vec.wav2vec2 import (
+    ConvFeatureExtractionModel,
+    TransformerEncoder,
+)
+from fairseq.modules import GradMultiply, LayerNorm
+from copy import deepcopy
+
+DBG=True if len(sys.argv) == 1 else False
+
+if DBG:
+    from hubert_pretraining import (
+        AVHubertPretrainingConfig,
+        AVHubertPretrainingTask,
+    )
+    from resnet import ResEncoder
+    logging.basicConfig(
+        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+        datefmt="%Y-%m-%d %H:%M:%S",
+        level=os.environ.get("LOGLEVEL", "INFO").upper(),
+        stream=sys.stdout,
+    )
+    from utils import compute_mask_indices
+    from decoder import TransformerDecoder
+
+else:
+    from .hubert_pretraining import (
+        AVHubertPretrainingConfig,
+        AVHubertPretrainingTask,
+    )
+    from .resnet import ResEncoder
+    from .utils import compute_mask_indices
+    from .decoder import TransformerDecoder
+
+from omegaconf import II
+
+logger = logging.getLogger(__name__)
+
+EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
+MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(
+    ["static", "uniform", "normal", "poisson"]
+)
+# LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer", "trf_adp"])
+
+
+@dataclass
+class AVHubertConfig(FairseqDataclass):
+    label_rate: int = II("task.label_rate")
+    input_modality: str = II("task.input_modality")
+    extractor_mode: EXTRACTOR_MODE_CHOICES = field(
+        default="default",
+        metadata={
+            "help": "mode for feature extractor. default has a single group "
+            "norm with d groups in the first conv block, whereas layer_norm "
+            "has layer norms in every block (meant to use with normalize=True)"
+        },
+    )
+    encoder_layers: int = field(
+        default=12, metadata={"help": "num encoder layers in the transformer"}
+    )
+    encoder_embed_dim: int = field(
+        default=768, metadata={"help": "encoder embedding dimension"}
+    )
+    encoder_ffn_embed_dim: int = field(
+        default=3072, metadata={"help": "encoder embedding dimension for FFN"}
+    )
+    encoder_attention_heads: int = field(
+        default=12, metadata={"help": "num encoder attention heads"}
+    )
+    activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
+        default="gelu", metadata={"help": "activation function to use"}
+    )
+
+    # dropouts
+    dropout: float = field(
+        default=0.1,
+        metadata={"help": "dropout probability for the transformer"},
+    )
+    attention_dropout: float = field(
+        default=0.1,
+        metadata={"help": "dropout probability for attention weights"},
+    )
+    activation_dropout: float = field(
+        default=0.0,
+        metadata={"help": "dropout probability after activation in FFN"},
+    )
+    encoder_layerdrop: float = field(
+        default=0.0,
+        metadata={"help": "probability of dropping a tarnsformer layer"},
+    )
+    dropout_input: float = field(
+        default=0.0,
+        metadata={"help": "dropout to apply to the input (after feat extr)"},
+    )
+    dropout_features: float = field(
+        default=0.0,
+        metadata={
+            "help": "dropout to apply to the features (after feat extr)"
+        },
+    )
+
+    final_dim: int = field(
+        default=0,
+        metadata={
+            "help": "project final representations and targets to this many "
+            "dimensions. set to encoder_embed_dim is <= 0"
+        },
+    )
+    untie_final_proj: bool = field(
+        default=False,
+        metadata={"help": "use separate projection for each target"},
+    )
+    layer_norm_first: bool = field(
+        default=False,
+        metadata={"help": "apply layernorm first in the transformer"},
+    )
+    conv_feature_layers: str = field(
+        default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
+        metadata={
+            "help": "string describing convolutional feature extraction "
+            "layers in form of a python list that contains "
+            "[(dim, kernel_size, stride), ...]"
+        },
+    )
+    conv_bias: bool = field(
+        default=False, metadata={"help": "include bias in conv encoder"}
+    )
+    logit_temp: float = field(
+        default=0.1, metadata={"help": "temperature to divide logits by"}
+    )
+    target_glu: bool = field(
+        default=False, metadata={"help": "adds projection + glu to targets"}
+    )
+    feature_grad_mult: float = field(
+        default=1.0,
+        metadata={"help": "multiply feature extractor var grads by this"},
+    )
+
+    # masking
+    mask_length_audio: int = field(default=10, metadata={"help": "mask length"})
+    mask_prob_audio: float = field(
+        default=0.65,
+        metadata={"help": "probability of replacing a token with mask"},
+    )
+    mask_length_image: int = field(default=10, metadata={"help": "mask length"})
+    mask_prob_image: float = field(
+        default=0.65,
+        metadata={"help": "probability of replacing a token with mask"},
+    )
+    mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
+        default="static", metadata={"help": "how to choose mask length"}
+    )
+    mask_other: float = field(
+        default=0,
+        metadata={
+            "help": "secondary mask argument "
+            "(used for more complex distributions), "
+            "see help in compute_mask_indicesh"
+        },
+    )
+    no_mask_overlap: bool = field(
+        default=False, metadata={"help": "whether to allow masks to overlap"}
+    )
+    mask_min_space: int = field(
+        default=1,
+        metadata={
+            "help": "min space between spans (if no overlap is enabled)"
+        },
+    )
+
+    # channel masking
+    mask_channel_length: int = field(
+        default=10,
+        metadata={"help": "length of the mask for features (channels)"},
+    )
+    mask_channel_prob: float = field(
+        default=0.0,
+        metadata={"help": "probability of replacing a feature with 0"},
+    )
+    mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
+        default="static",
+        metadata={"help": "how to choose mask length for channel masking"},
+    )
+    mask_channel_other: float = field(
+        default=0,
+        metadata={
+            "help": "secondary mask argument "
+            "(used for more complex distributions), "
+            "see help in compute_mask_indicesh"
+        },
+    )
+    no_mask_channel_overlap: bool = field(
+        default=False,
+        metadata={"help": "whether to allow channel masks to overlap"},
+    )
+    mask_channel_min_space: int = field(
+        default=1,
+        metadata={
+            "help": "min space between spans (if no overlap is enabled)"
+        },
+    )
+
+    # positional embeddings
+    conv_pos: int = field(
+        default=128,
+        metadata={
+            "help": "number of filters for convolutional positional embeddings"
+        },
+    )
+    conv_pos_groups: int = field(
+        default=16,
+        metadata={
+            "help": "number of groups for convolutional positional embedding"
+        },
+    )
+
+    latent_temp: Tuple[float, float, float] = field(
+        default=(2, 0.5, 0.999995),
+        metadata={"help": "legacy (to be removed)"},
+    )
+
+    # loss computation
+    skip_masked: bool = field(
+        default=False,
+        metadata={"help": "skip computing losses over masked frames"},
+    )
+    skip_nomask: bool = field(
+        default=False,
+        metadata={"help": "skip computing losses over unmasked frames"},
+    )
+    resnet_relu_type: str = field(default='prelu', metadata={"help": 'relu type for resnet'})
+    resnet_weights: Optional[str] = field(default=None, metadata={"help": 'resnet weights'})
+    sim_type: str = field(default='cosine', metadata={"help": 'similarity type'})
+
+    sub_encoder_layers: int = field(default=0, metadata={'help': 'number of transformer layers for single modality'})
+    audio_feat_dim: int = field(default=-1, metadata={'help': 'audio feature dimension'})
+    modality_dropout: float = field(default=0, metadata={'help': 'drop one modality'})
+    audio_dropout: float = field(default=0, metadata={'help': 'drop audio feature'})
+    modality_fuse: str = field(default='concat', metadata={'help': 'fusing two modalities: add,concat'})
+    selection_type : str = field(default='same_other_seq', metadata={'help': 'type of selectig images, same_other_seq: replace masked span with span from another sequence, same_seq: repace masked span with span of the same sequence'})
+    masking_type : str = field(default='input', metadata={'help': 'input or feature masking'})
+
+    decoder_embed_dim: int = field(
+        default=768, metadata={"help": "decoder embedding dimension"}
+    )
+    decoder_ffn_embed_dim: int = field(
+        default=3072, metadata={"help": "decoder embedding dimension for FFN"}
+    )
+    decoder_layers: int = field(
+        default=6, metadata={"help": "num of decoder layers"}
+    )
+    decoder_layerdrop: float = field(
+        default=0.0, metadata={"help": "decoder layerdrop chance"}
+    )
+    decoder_attention_heads: int = field(
+        default=4, metadata={"help": "num decoder attention heads"}
+    )
+    decoder_learned_pos: bool = field(
+        default=False,
+        metadata={"help": "use learned positional embeddings in the decoder"},
+    )
+    decoder_normalize_before: bool = field(
+        default=False,
+        metadata={"help": "apply layernorm before each decoder block"},
+    )
+    no_token_positional_embeddings: bool = field(
+        default=False,
+        metadata={
+            "help": "if set, disables positional embeddings "
+            "(outside self attention)"
+        },
+    )
+    decoder_dropout: float = field(
+        default=0.1, metadata={"help": "dropout probability in the decoder"}
+    )
+    decoder_attention_dropout: float = field(
+        default=0.1,
+        metadata={
+            "help": "dropout probability for attention weights "
+            "inside the decoder"
+        },
+    )
+    decoder_activation_dropout: float = field(
+        default=0.0,
+        metadata={
+            "help": "dropout probability after activation in FFN "
+            "inside the decoder"
+        },
+    )
+    max_target_positions: int = field(
+        default=2048, metadata={"help": "max target positions"}
+    )
+    share_decoder_input_output_embed: bool = field(
+        default=False,
+        metadata={"help": "share decoder input and output embeddings"},
+    )
+    no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'})
+
+    # # new fairseq
+    # required_seq_len_multiple: int = field(
+    #     default=1,
+    #     metadata={
+    #         "help": "pad the input to encoder such that the sequence length is divisible by multiple"
+    #     },
+    # )
+
+    # layer_type: LAYER_TYPE_CHOICES = field(
+    #     default="transformer", metadata={"help": "layer type in encoder"}
+    # )
+
+class SubModel(nn.Module):
+    def __init__(self, resnet=None, input_dim=None, cfg=None):
+        super().__init__()
+        self.resnet = resnet
+        self.proj = nn.Linear(input_dim, cfg.encoder_embed_dim)
+        self.encoder = TransformerEncoder(cfg) if cfg.encoder_layers > 0 else None
+
+    def forward(self, x): #torch.Size([1, 1, 106, 112, 112])
+        if self.resnet is not None:
+            x = self.resnet(x)  #torch.Size([1, 512, 106])  #torch.Size([12, 26, 314])
+        x = self.proj(x.transpose(1, 2))   #audio是 Linear(in_features=104, out_features=1024, bias=True) 太他妈扯了吧
+        if self.encoder is not None:
+            x = self.encoder(x)[0].transpose(1, 2)
+        else: #
+            x = x.transpose(1, 2)
+        return x #torch.Size([1, 1024, 106])
+
+@register_model("av_hubert", dataclass=AVHubertConfig)
+class AVHubertModel(BaseFairseqModel):
+    def __init__(
+        self,
+        cfg: AVHubertConfig,
+        task_cfg: AVHubertPretrainingConfig,
+        dictionaries: List[Dictionary],
+        **kwargs
+    ) -> None:
+        super().__init__()
+        logger.info(f"HubertModel Config: {cfg}")
+
+        feature_ds_rate = 1
+        self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
+        sub_cfg = deepcopy(cfg)
+        sub_cfg.encoder_layers = sub_cfg.sub_encoder_layers
+        resnet = ResEncoder(relu_type=cfg.resnet_relu_type, weights=cfg.resnet_weights)
+        self.feature_extractor_audio = SubModel(resnet=None, input_dim=cfg.audio_feat_dim, cfg=sub_cfg)
+        self.feature_extractor_video = SubModel(resnet=resnet, input_dim=resnet.backend_out, cfg=sub_cfg)
+        self.modality_dropout, self.audio_dropout = cfg.modality_dropout, cfg.audio_dropout
+        self.modality_fuse = cfg.modality_fuse
+        self.encoder_embed_dim = cfg.encoder_embed_dim
+        if self.modality_fuse == 'concat':
+            self.embed = cfg.encoder_embed_dim * 2
+        elif self.modality_fuse == 'add':
+            self.embed = cfg.encoder_embed_dim
+        self.post_extract_proj = (
+            nn.Linear(self.embed, cfg.encoder_embed_dim)
+            if self.embed != cfg.encoder_embed_dim
+            else None
+        )
+
+        self.mask_prob_image, self.mask_prob_audio = cfg.mask_prob_image, cfg.mask_prob_audio
+        self.mask_selection = cfg.mask_selection
+        self.mask_other = cfg.mask_other
+        self.mask_length_image, self.mask_length_audio = cfg.mask_length_image, cfg.mask_length_audio
+        self.no_mask_overlap = cfg.no_mask_overlap
+        self.mask_min_space = cfg.mask_min_space
+
+        self.mask_channel_prob = cfg.mask_channel_prob
+        self.mask_channel_selection = cfg.mask_channel_selection
+        self.mask_channel_other = cfg.mask_channel_other
+        self.mask_channel_length = cfg.mask_channel_length
+        self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
+        self.mask_channel_min_space = cfg.mask_channel_min_space
+
+        self.dropout_input = nn.Dropout(cfg.dropout_input)
+        self.dropout_features = nn.Dropout(cfg.dropout_features)
+
+        self.feature_grad_mult = cfg.feature_grad_mult
+        self.logit_temp = cfg.logit_temp
+        self.skip_masked = cfg.skip_masked
+        self.skip_nomask = cfg.skip_nomask
+        self.sim_type = cfg.sim_type
+        self.selection_type = cfg.selection_type
+        self.masking_type = cfg.masking_type
+
+        final_dim = (
+            cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
+        )
+
+        self.mask_emb = nn.Parameter(
+            torch.FloatTensor(cfg.audio_feat_dim).uniform_() if self.masking_type == 'input' else torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
+        )
+
+        self.encoder = TransformerEncoder(cfg)
+        self.layer_norm = LayerNorm(self.embed)
+
+        self.target_glu = None
+        if cfg.target_glu:
+            self.target_glu = nn.Sequential(
+                nn.Linear(final_dim, final_dim * 2), nn.GLU()
+            )
+
+        self.untie_final_proj = cfg.untie_final_proj
+        if self.untie_final_proj:
+            self.final_proj = nn.Linear(
+                cfg.encoder_embed_dim, final_dim * len(dictionaries)
+            )
+        else:
+            self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
+
+        # modules below are not needed during fine-tuning
+        if any([d is None for d in dictionaries]):
+            logger.info(
+                "cannot find dictionary. assume will be used for fine-tuning"
+            )
+        else:
+            self.num_classes = [len(d) for d in dictionaries]
+            self.label_embs_concat = nn.Parameter(
+                torch.FloatTensor(sum(self.num_classes), final_dim)
+            )
+            nn.init.uniform_(self.label_embs_concat)
+
+    def upgrade_state_dict_named(self, state_dict, name):
+        """Upgrade a (possibly old) state dict for new versions of fairseq."""
+
+        super().upgrade_state_dict_named(state_dict, name)
+        return state_dict
+
+    @classmethod
+    def build_model(cls, cfg: AVHubertConfig, task: AVHubertPretrainingTask):
+        """Build a new model instance."""
+
+        kwargs = {}
+        model = AVHubertModel(cfg, task.cfg, task.dictionaries, **kwargs)
+        return model
+
+    def apply_input_mask(self, x, padding_mask, target_list):
+        B, C, T = x.shape[:3]
+        is_audio = True if len(x.shape) == 3 else False
+        if is_audio:
+            mask_prob, mask_length = self.mask_prob_audio, self.mask_length_audio
+        else:
+            mask_prob, mask_length = self.mask_prob_image, self.mask_length_image
+        if mask_prob > 0:
+
+            mask_indices, starts, ends, batch_indexes = compute_mask_indices(
+                (B, T),
+                padding_mask,
+                mask_prob,
+                mask_length,
+                self.mask_selection,
+                self.mask_other,
+                min_masks=2,
+                no_overlap=self.no_mask_overlap,
+                min_space=self.mask_min_space,
+            )
+            mask_indices_np = mask_indices
+            mask_indices = torch.from_numpy(mask_indices).to(x.device)
+            x = x.transpose(1, 2).contiguous() # [B, T, C, H, W]
+            if B == 1:
+                x[mask_indices] = 0
+            elif is_audio:
+                x[mask_indices] = self.mask_emb
+            elif self.selection_type == 'same_other_seq':
+                perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B
+                x_perm = x[perm]
+                x[mask_indices] = x_perm[mask_indices]
+            elif self.selection_type == 'same_seq':
+                batch_indexes_, other_indexes = [], []
+                for batch_index, start, end in zip(batch_indexes, starts, ends):
+                    length = end-start
+                    other_start = np.setdiff1d(np.arange(T), np.arange(max(0, start-length), end))
+                    if len(other_start) > 0:
+                        other_start = np.random.choice(other_start, size=1)
+                    else:
+                        other_start = 0
+                    other_end = other_start + length
+                    other_indexes.append(np.arange(other_start, other_end).clip(max=T-1))
+                    batch_indexes_.append(np.zeros([length], dtype=np.int64)+batch_index)
+                batch_indexes, other_indexes = np.concatenate(batch_indexes_), np.concatenate(other_indexes)
+                x[mask_indices] = x[batch_indexes, other_indexes]
+
+            x = x.transpose(1, 2).contiguous()
+        else:
+            mask_indices = None
+
+        if self.mask_channel_prob > 0:
+            logger.info(f"No mask channel prob for input masking")
+        return x, mask_indices
+
+    def apply_feature_mask(self, x, padding_mask, target_list):
+        B, T, C = x.shape
+        assert self.mask_prob_audio == self.mask_prob_image and self.mask_length_audio == self.mask_length_image, f"masking prob/length for image/audio be same for feature masking"
+        mask_prob, mask_length = self.mask_prob_audio, self.mask_length_image
+        if mask_prob > 0:
+            mask_indices, _, _, _ = compute_mask_indices(
+                (B, T),
+                padding_mask,
+                mask_prob,
+                mask_length,
+                self.mask_selection,
+                self.mask_other,
+                min_masks=2,
+                no_overlap=self.no_mask_overlap,
+                min_space=self.mask_min_space,
+            )
+            mask_indices = torch.from_numpy(mask_indices).to(x.device)
+            x[mask_indices] = self.mask_emb
+        else:
+            mask_indices = None
+
+        if self.mask_channel_prob > 0:
+            mask_channel_indices, _, _, _ = compute_mask_indices(
+                (B, C),
+                None,
+                self.mask_channel_prob,
+                self.mask_channel_length,
+                self.mask_channel_selection,
+                self.mask_channel_other,
+                no_overlap=self.no_mask_channel_overlap,
+                min_space=self.mask_channel_min_space,
+            )
+            mask_channel_indices = (
+                torch.from_numpy(mask_channel_indices)
+                .to(x.device)
+                .unsqueeze(1)
+                .expand(-1, T, -1)
+            )
+            x[mask_channel_indices] = 0
+
+        return x, mask_indices
+
+    def forward_features(self, source: torch.Tensor, modality: str) -> torch.Tensor:
+        extractor = eval(f"self.feature_extractor_{modality}")
+        if self.feature_grad_mult > 0:
+            features = extractor(source)
+            if self.feature_grad_mult != 1.0:
+                features = GradMultiply.apply(features, self.feature_grad_mult)
+        else:
+            with torch.no_grad():
+                features = extractor(source)
+        return features
+
+    def forward_targets(
+            self, features: torch.Tensor, mask_indices: torch.Tensor, target_list: List[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Trim features to ensure labels exist and then get aligned labels
+        feat_tsz = features.size(2)
+        targ_tsz = min([t.size(1) for t in target_list])
+        if self.feat2tar_ratio * feat_tsz > targ_tsz:
+            feat_tsz = int(targ_tsz / self.feat2tar_ratio)
+            features = features[..., :feat_tsz]
+            if mask_indices is not None:
+                mask_indices = mask_indices[..., :feat_tsz]
+        target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio
+        target_list = [t[:, target_inds.long()] for t in target_list]
+        return features, mask_indices, target_list
+
+    def forward_padding_mask(
+        self, features: torch.Tensor, padding_mask: torch.Tensor,
+    ) -> torch.Tensor:
+        extra = padding_mask.size(1) % features.size(1)
+        if extra > 0:
+            padding_mask = padding_mask[:, :-extra]
+        padding_mask = padding_mask.view(
+            padding_mask.size(0), features.size(1), -1
+        )
+        padding_mask = padding_mask.all(-1)
+        return padding_mask
+
+    def compute_logits(self, feats, emb_mat):
+        # feats: [B, T, F], emb_mat: [V, F]
+        if self.sim_type == 'dot':
+            logits = torch.matmul(feats, emb_mat.transpose(0, 1))
+        elif self.sim_type == 'cosine':
+            batch_size, timesteps, emb_dim = feats.size()
+            feats_ = feats.view(-1, emb_dim)
+            nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1) # [B*T, V]
+            denom = (feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1) * (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0) # [B*T, V]
+            logits = (nom/denom.clamp(min=1e-6)).view(batch_size, timesteps, -1)
+        else:
+            raise NotImplementedError
+        logits = logits / self.logit_temp
+        return logits
+
+    def forward(
+        self,
+        source: torch.Tensor,
+        target_list: Optional[List[torch.Tensor]] = None,
+        padding_mask: Optional[torch.Tensor] = None,
+        mask: bool = True,
+        features_only: bool = False,
+        output_layer: Optional[int] = None
+    ) -> Dict[str, torch.Tensor]:
+        """output layer is 1-based"""
+        src_audio, src_video = source['audio'], source['video']
+        if mask and self.masking_type == 'input':
+            src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list)
+            src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list)
+            mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video)
+        else:
+            src_audio, src_video, mask_indices = src_audio, src_video, None
+
+        features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
+        features_video = self.forward_features(src_video, modality='video')
+        modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random()
+        if self.training:
+            if modality_drop_prob < self.modality_dropout:
+                if audio_drop_prob < self.audio_dropout:
+                    features_audio = 0 * features_audio
+                else:
+                    features_video = 0 * features_video
+        if self.modality_fuse == 'concat':
+            features = torch.cat([features_audio, features_video], dim=1)
+        elif self.modality_fuse == 'add':
+            features = features_audio + features_video
+        if target_list is not None:
+            features, mask_indices, target_list = self.forward_targets(features, mask_indices, target_list)
+
+        features_pen = features.float().pow(2).mean()
+
+        features = features.transpose(1, 2)
+        features = self.layer_norm(features)
+
+        if padding_mask is not None:
+            padding_mask = self.forward_padding_mask(features, padding_mask)
+
+        if self.post_extract_proj is not None:
+            features = self.post_extract_proj(features)
+
+        features = self.dropout_input(features)
+        if self.masking_type == 'feature' and mask:
+            x, mask_indices = self.apply_feature_mask(features, padding_mask, target_list)
+        else:
+            x = features
+
+        # feature: (B, T, D), float
+        # target: (B, T), long
+        # x: (B, T, D), float
+        # padding_mask: (B, T), bool
+        # mask_indices: (B, T), bool
+        x, _ = self.encoder(
+            x,
+            padding_mask=padding_mask,
+            layer=None if output_layer is None else output_layer - 1
+        )
+
+        if features_only:
+            return {"x": x, "padding_mask": padding_mask, "features": features}
+
+        label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
+        proj_x = self.final_proj(x)
+        if self.untie_final_proj:
+            proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1)
+        else:
+            proj_x_list = [proj_x for _ in self.num_classes]
+        logit_list = [self.compute_logits(proj, emb).view(-1, num_class) for proj, emb, num_class in zip(proj_x_list, label_embs_list, self.num_classes)] # [[B*T, V]]
+        mask, unmask = torch.logical_and(mask_indices, ~padding_mask).view(-1), torch.logical_and(~mask_indices, ~padding_mask).view(-1) # [B*T]
+        logit_m_list, logit_u_list = [logit[mask] for logit in logit_list], [logit[unmask] for logit in logit_list]
+        target_m_list, target_u_list = [target.view(-1)[mask].long() for target in target_list], [target.view(-1)[unmask].long() for target in target_list]
+        result = {
+            "logit_m_list": logit_m_list,
+            "logit_u_list": logit_u_list,
+            "target_m_list": target_m_list,
+            "target_u_list": target_u_list,
+            "padding_mask": padding_mask,
+            "features_pen": features_pen,
+        }
+        return result
+
+    def extract_features(
+        self,
+        source: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+        mask: bool = False,
+        ret_conv: bool = False,
+        output_layer: Optional[int] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        res = self.forward(
+            source,
+            padding_mask=padding_mask,
+            mask=mask,
+            features_only=True,
+            output_layer=output_layer,
+        )
+        feature = res["features"] if ret_conv else res["x"]
+        return feature, res["padding_mask"]
+
+    def extract_finetune(self, source, padding_mask=None, mask=False, ret_conv=False, output_layer=None):
+        src_audio, src_video = source['audio'], source['video']  #torch.Size([1, 1, 106, 112, 112])
+        if mask and self.masking_type == 'input':
+            src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list=None)
+            src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list=None)
+            mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video) # mask_indices not used in fine-tuning
+        else: #
+            src_audio, src_video, mask_indices = src_audio, src_video, None
+
+        if src_audio is not None and src_video is None:
+            features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]
+            features_video = features_audio.new_zeros(features_audio.size(0), self.encoder_embed_dim, features_audio.size(-1))
+        elif src_audio is None and src_video is not None:
+            features_video = self.forward_features(src_video, modality='video')
+            features_audio = features_video.new_zeros(features_video.size(0), self.encoder_embed_dim, features_video.size(-1))  #全0!
+        elif src_audio is not None and src_video is not None:
+            features_video = self.forward_features(src_video, modality='video') #torch.Size([1, 1024, 106])  #scr torch.Size([12, 1, 314, 88, 88])
+            features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T]  #torch.Size([12, 26, 314])
+
+        if self.modality_fuse == 'concat': #
+            features = torch.cat([features_audio, features_video], dim=1)  #torch.Size([1, 2048, 106])
+        elif self.modality_fuse == 'add':
+            features = features_audio + features_video
+        features_pen = features.float().pow(2).mean()
+
+        features = features.transpose(1, 2)
+        features = self.layer_norm(features)
+        unmasked_features = features.clone()
+
+        if padding_mask is not None:  #features:torch.Size([1, 106, 2048])
+            padding_mask = self.forward_padding_mask(features, padding_mask) #torch.Size([4, 154])
+
+        if self.post_extract_proj is not None:
+            features = self.post_extract_proj(features) #torch.Size([1, 106, 1024])
+
+        features = self.dropout_input(features)
+        unmasked_features = self.dropout_features(unmasked_features)
+        x = features
+        mask_indices = None
+
+        # feature: (B, T, D), float
+        # target: (B, T), long
+        # x: (B, T, D), float
+        # padding_mask: (B, T), bool
+        # mask_indices: (B, T), bool
+        x, _ = self.encoder(
+            x,
+            padding_mask=padding_mask,
+            layer=None if output_layer is None else output_layer - 1
+        )
+
+        return x, padding_mask  #torch.Size([1, 106, 1024]), None
+
+
+    def get_extra_losses(self, net_output):
+        extra_losses = []
+        names = []
+        if "features_pen" in net_output:
+            extra_losses.append(net_output["features_pen"])
+            names.append("features_pen")
+
+        return extra_losses, names
+
+    def remove_pretraining_modules(self):
+        self.target_glu = None
+        self.final_proj = None
+
+    def get_logits(self, net_output, is_masked=True):
+        raise NotImplementedError
+
+    def get_targets(self, net_output, is_masked=True):
+        raise NotImplementedError
+
+    def compute_nce(self, x, pos, negs):
+        neg_is_pos = (pos == negs).all(-1)
+        pos = pos.unsqueeze(0)
+        targets = torch.cat([pos, negs], dim=0)
+
+        logits = torch.cosine_similarity(
+            x.float(), targets.float(), dim=-1
+        ).type_as(x)
+        logits /= self.logit_temp
+        if neg_is_pos.any():
+            logits[1:][neg_is_pos] = float("-inf")
+        logits = logits.transpose(0, 1)  # (num_x, num_cls+1)
+        return logits
diff --git a/slam_llm/models/avhubert/hubert_asr.py b/slam_llm/models/avhubert/hubert_asr.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df02900d8dfea733e9f43db1a5861098f60779f
--- /dev/null
+++ b/slam_llm/models/avhubert/hubert_asr.py
@@ -0,0 +1,523 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import sys,logging
+import contextlib
+import tempfile
+from argparse import Namespace
+from typing import Any, Optional
+
+import torch
+import torch.nn as nn
+from dataclasses import dataclass, field
+from fairseq import checkpoint_utils, tasks, utils
+from fairseq.dataclass import FairseqDataclass
+from fairseq.dataclass.utils import convert_namespace_to_omegaconf
+from fairseq.models import BaseFairseqModel, FairseqEncoder, FairseqEncoderDecoderModel, register_model
+from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES
+from fairseq.tasks import FairseqTask
+from omegaconf import II, MISSING
+
+DBG=True if len(sys.argv) == 1 else False
+
+if DBG:
+    from hubert import AVHubertModel
+    from decoder import TransformerDecoder
+else:
+    from .hubert import AVHubertModel
+    from .decoder import TransformerDecoder
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class AVHubertAsrConfig(FairseqDataclass):
+    w2v_path: str = field(
+        default=MISSING, metadata={"help": "path to hubert model"}
+    )
+    no_pretrained_weights: bool = field(
+        default=False,
+        metadata={"help": "if true, does not load pretrained weights"},
+    )
+    dropout_input: float = field(
+        default=0.0,
+        metadata={"help": "dropout to apply to the input (after feat extr)"},
+    )
+    final_dropout: float = field(
+        default=0.0,
+        metadata={
+            "help": "dropout after transformer and before final projection"
+        },
+    )
+    dropout: float = field(
+        default=0.0,
+        metadata={"help": "dropout probability inside hubert model"},
+    )
+    attention_dropout: float = field(
+        default=0.0,
+        metadata={
+            "help": "dropout probability for attention weights "
+            "inside hubert model"
+        },
+    )
+    activation_dropout: float = field(
+        default=0.0,
+        metadata={
+            "help": "dropout probability after activation in FFN "
+            "inside hubert model"
+        },
+    )
+
+    # masking
+    apply_mask: bool = field(
+        default=False, metadata={"help": "apply masking during fine-tuning"}
+    )
+    mask_length: int = field(
+        default=10, metadata={"help": "repeat the mask indices multiple times"}
+    )
+    mask_prob: float = field(
+        default=0.5,
+        metadata={
+            "help": "probability of replacing a token with mask "
+            "(normalized by length)"
+        },
+    )
+    mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
+        default="static", metadata={"help": "how to choose masks"}
+    )
+    mask_other: float = field(
+        default=0,
+        metadata={
+            "help": "secondary mask argument "
+            "(used for more complex distributions), "
+            "see help in compute_mask_indices"
+        },
+    )
+    no_mask_overlap: bool = field(
+        default=False, metadata={"help": "whether to allow masks to overlap"}
+    )
+
+    # channel masking
+    mask_channel_length: int = field(
+        default=10,
+        metadata={"help": "length of the mask for features (channels)"},
+    )
+    mask_channel_prob: float = field(
+        default=0.0,
+        metadata={"help": "probability of replacing a feature with 0"},
+    )
+    mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
+        default="static",
+        metadata={"help": "how to choose mask length for channel masking"},
+    )
+    mask_channel_other: float = field(
+        default=0,
+        metadata={
+            "help": "secondary mask argument "
+            "(used for more complex distributions), "
+            "see help in compute_mask_indices"
+        },
+    )
+    no_mask_channel_overlap: bool = field(
+        default=False,
+        metadata={"help": "whether to allow channel masks to overlap"},
+    )
+    freeze_finetune_updates: int = field(
+        default=0,
+        metadata={"help": "dont finetune hubert for this many updates"},
+    )
+    feature_grad_mult: float = field(
+        default=0.0,
+        metadata={"help": "reset feature grad mult in hubert to this"},
+    )
+    layerdrop: float = field(
+        default=0.0,
+        metadata={"help": "probability of dropping a layer in hubert"},
+    )
+    normalize: bool = II("task.normalize")
+    data: str = II("task.data")
+
+    # this holds the loaded hubert args
+    w2v_args: Any = None
+
+
+@dataclass
+class AVHubertCtcConfig(AVHubertAsrConfig):
+    pass
+
+
+@register_model("av_hubert_ctc", dataclass=AVHubertCtcConfig)
+class AVHubertCtc(BaseFairseqModel):
+    def __init__(self, cfg: AVHubertCtcConfig, w2v_encoder: BaseFairseqModel):
+        super().__init__()
+        self.cfg = cfg
+        self.w2v_encoder = w2v_encoder
+
+    def upgrade_state_dict_named(self, state_dict, name):
+        super().upgrade_state_dict_named(state_dict, name)
+        return state_dict
+
+    @classmethod
+    def build_model(cls, cfg: AVHubertCtcConfig, task: FairseqTask):
+        """Build a new model instance."""
+        w2v_encoder = HubertEncoder(cfg, task.target_dictionary)
+        return cls(cfg, w2v_encoder)
+
+    def get_normalized_probs(self, net_output, log_probs):
+        """Get normalized probabilities (or log probs) from a net's output."""
+
+        logits = net_output["encoder_out"]
+        if log_probs:
+            return utils.log_softmax(logits.float(), dim=-1)
+        else:
+            return utils.softmax(logits.float(), dim=-1)
+
+    def get_logits(self, net_output):
+        logits = net_output["encoder_out"]
+        padding = net_output["encoder_padding_mask"]
+        if padding is not None and padding.any():
+            padding = padding.T
+            logits[padding][..., 0] = 0
+            logits[padding][..., 1:] = float("-inf")
+
+        return logits
+
+    def forward(self, **kwargs):
+        x = self.w2v_encoder(**kwargs)
+        return x
+
+
+@dataclass
+class AVHubertSeq2SeqConfig(AVHubertAsrConfig):
+    decoder_embed_dim: int = field(
+        default=768, metadata={"help": "decoder embedding dimension"}
+    )
+    decoder_ffn_embed_dim: int = field(
+        default=3072, metadata={"help": "decoder embedding dimension for FFN"}
+    )
+    decoder_layers: int = field(
+        default=6, metadata={"help": "num of decoder layers"}
+    )
+    decoder_layerdrop: float = field(
+        default=0.0, metadata={"help": "decoder layerdrop chance"}
+    )
+    decoder_attention_heads: int = field(
+        default=4, metadata={"help": "num decoder attention heads"}
+    )
+    decoder_learned_pos: bool = field(
+        default=False,
+        metadata={"help": "use learned positional embeddings in the decoder"},
+    )
+    decoder_normalize_before: bool = field(
+        default=False,
+        metadata={"help": "apply layernorm before each decoder block"},
+    )
+    no_token_positional_embeddings: bool = field(
+        default=False,
+        metadata={
+            "help": "if set, disables positional embeddings "
+            "(outside self attention)"
+        },
+    )
+    decoder_dropout: float = field(
+        default=0.0, metadata={"help": "dropout probability in the decoder"}
+    )
+    decoder_attention_dropout: float = field(
+        default=0.0,
+        metadata={
+            "help": "dropout probability for attention weights "
+            "inside the decoder"
+        },
+    )
+    decoder_activation_dropout: float = field(
+        default=0.0,
+        metadata={
+            "help": "dropout probability after activation in FFN "
+            "inside the decoder"
+        },
+    )
+    max_target_positions: int = field(
+        default=2048, metadata={"help": "max target positions"}
+    )
+    share_decoder_input_output_embed: bool = field(
+        default=False,
+        metadata={"help": "share decoder input and output embeddings"},
+    )
+    no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'})
+
+class HubertEncoder(FairseqEncoder):
+    def __init__(self, cfg: AVHubertAsrConfig, tgt_dict=None):
+        self.apply_mask = cfg.apply_mask
+
+        arg_overrides = {
+            "dropout": cfg.dropout,
+            "activation_dropout": cfg.activation_dropout,
+            "dropout_input": cfg.dropout_input,
+            "attention_dropout": cfg.attention_dropout,
+            "mask_length": cfg.mask_length,
+            "mask_prob": cfg.mask_prob,
+            "mask_selection": cfg.mask_selection,
+            "mask_other": cfg.mask_other,
+            "no_mask_overlap": cfg.no_mask_overlap,
+            "mask_channel_length": cfg.mask_channel_length,
+            "mask_channel_prob": cfg.mask_channel_prob,
+            "mask_channel_selection": cfg.mask_channel_selection,
+            "mask_channel_other": cfg.mask_channel_other,
+            "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
+            "encoder_layerdrop": cfg.layerdrop,
+            "feature_grad_mult": cfg.feature_grad_mult,
+        }
+
+        if cfg.w2v_args is None:
+            state = checkpoint_utils.load_checkpoint_to_cpu(
+                cfg.w2v_path, arg_overrides
+            )
+            w2v_args = state.get("cfg", None)
+            if w2v_args is None:
+                w2v_args = convert_namespace_to_omegaconf(state["args"])
+            cfg.w2v_args = w2v_args
+        else:
+            state = None
+            w2v_args = cfg.w2v_args
+            if isinstance(w2v_args, Namespace):
+                cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
+                    w2v_args
+                )
+
+        assert cfg.normalize == w2v_args.task.normalize, (
+            "Fine-tuning works best when data normalization is the same. "
+            "Please check that --normalize is set or unset for "
+            "both pre-training and here"
+        )
+
+        w2v_args.task.data = cfg.data
+
+        task = tasks.setup_task(w2v_args.task)
+        model = task.build_model(w2v_args.model)
+
+        if state is not None and not cfg.no_pretrained_weights:
+            # set strict=False because we omit some modules
+            model.load_state_dict(state["model"], strict=False)
+
+        model.remove_pretraining_modules()
+
+        super().__init__(task.source_dictionary)
+
+        d = model.encoder.embedding_dim
+
+        self.w2v_model = model
+
+        self.final_dropout = nn.Dropout(cfg.final_dropout)
+        self.freeze_finetune_updates = cfg.freeze_finetune_updates
+        self.num_updates = 0
+
+        if tgt_dict is not None:
+            self.proj = Linear(d, len(tgt_dict))
+        elif getattr(cfg, "decoder_embed_dim", d) != d:
+            self.proj = Linear(d, cfg.decoder_embed_dim)
+        else:
+            self.proj = None
+
+    def set_num_updates(self, num_updates):
+        """Set the number of parameters updates."""
+        super().set_num_updates(num_updates)
+        self.num_updates = num_updates
+
+    def forward(self, source, padding_mask, tbc=True, **kwargs):
+
+        w2v_args = {
+            "source": source,
+            "padding_mask": padding_mask,
+            "mask": self.apply_mask and self.training,
+        }
+        ft = self.freeze_finetune_updates <= self.num_updates
+
+        with torch.no_grad() if not ft else contextlib.ExitStack():
+            x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
+
+            if tbc:
+                # B x T x C -> T x B x C
+                x = x.transpose(0, 1)
+
+        x = self.final_dropout(x)
+
+        if self.proj:
+            x = self.proj(x)
+
+        return {
+            "encoder_out": x,  # T x B x C
+            "encoder_padding_mask": padding_mask,  # B x T
+            "padding_mask": padding_mask,
+        }
+
+    def reorder_encoder_out(self, encoder_out, new_order):
+        if encoder_out["encoder_out"] is not None:
+            encoder_out["encoder_out"] = encoder_out[
+                "encoder_out"
+            ].index_select(1, new_order)
+        if encoder_out["encoder_padding_mask"] is not None:
+            encoder_out["encoder_padding_mask"] = encoder_out[
+                "encoder_padding_mask"
+            ].index_select(0, new_order)
+        return encoder_out
+
+    def max_positions(self):
+        """Maximum input length supported by the encoder."""
+        return None
+
+    def upgrade_state_dict_named(self, state_dict, name):
+        return state_dict
+
+
+class HubertEncoderWrapper(FairseqEncoder):
+    def __init__(self, w2v_model):
+        super().__init__(None)
+        self.w2v_model = w2v_model
+
+    def forward(self, source, padding_mask, **kwargs):
+        w2v_args = {
+            "source": source,
+            "padding_mask": padding_mask,
+        }
+
+        x, padding_mask = self.w2v_model.extract_finetune(**w2v_args)
+        # B x T x C -> T x B x C
+        x = x.transpose(0, 1) #torch.Size([106, 1, 1024])
+
+        return {
+            "encoder_out": x,  # T x B x C
+            "encoder_padding_mask": padding_mask,  # B x T
+            "padding_mask": padding_mask
+        }
+
+    def reorder_encoder_out(self, encoder_out, new_order):
+        if encoder_out["encoder_out"] is not None:
+            encoder_out["encoder_out"] = encoder_out[
+                "encoder_out"
+            ].index_select(1, new_order)
+        if encoder_out["encoder_padding_mask"] is not None:
+            encoder_out["encoder_padding_mask"] = encoder_out[
+                "encoder_padding_mask"
+            ].index_select(0, new_order)
+        if encoder_out["padding_mask"] is not None:
+            encoder_out["padding_mask"] = encoder_out[
+                "padding_mask"
+            ].index_select(0, new_order)
+        return encoder_out
+
+@register_model("av_hubert_seq2seq", dataclass=AVHubertSeq2SeqConfig)
+class AVHubertSeq2Seq(FairseqEncoderDecoderModel):
+    def __init__(self, encoder, decoder, tgt_dict, cfg):
+        super().__init__(encoder, decoder)
+        self.cfg = cfg
+        self.freeze_finetune_updates = cfg.freeze_finetune_updates
+
+    @classmethod
+    def build_model(cls, cfg, task):
+        """Build a new model instance."""
+
+        arg_overrides = {
+            "dropout": cfg.dropout,
+            "activation_dropout": cfg.activation_dropout,
+            "dropout_input": cfg.dropout_input,
+            "attention_dropout": cfg.attention_dropout,
+            "mask_length": cfg.mask_length,
+            "mask_prob": cfg.mask_prob,
+            "mask_selection": cfg.mask_selection,
+            "mask_other": cfg.mask_other,
+            "no_mask_overlap": cfg.no_mask_overlap,
+            "mask_channel_length": cfg.mask_channel_length,
+            "mask_channel_prob": cfg.mask_channel_prob,
+            "mask_channel_selection": cfg.mask_channel_selection,
+            "mask_channel_other": cfg.mask_channel_other,
+            "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
+            "encoder_layerdrop": cfg.layerdrop,
+            "feature_grad_mult": cfg.feature_grad_mult,
+        }
+
+        if cfg.w2v_args is None:
+            state = checkpoint_utils.load_checkpoint_to_cpu(
+                cfg.w2v_path, arg_overrides
+            )
+            w2v_args = state.get("cfg", None)
+            if w2v_args is None:
+                w2v_args = convert_namespace_to_omegaconf(state["args"])
+            cfg.w2v_args = w2v_args
+        else:
+            state = None
+            w2v_args = cfg.w2v_args
+            if isinstance(w2v_args, Namespace):
+                cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
+                    w2v_args
+                )
+
+        assert cfg.normalize == w2v_args.task.normalize, (
+            "Fine-tuning works best when data normalization is the same. "
+            "Please check that --normalize is set or unset for "
+            "both pre-training and here"
+        )
+
+        w2v_args.task.data = cfg.data
+
+        task_pretrain = tasks.setup_task(w2v_args.task)
+        if state is not None:
+            task_pretrain.load_state_dict(state['task_state'])
+ 
+        encoder_ = task_pretrain.build_model(w2v_args.model)
+
+        encoder = HubertEncoderWrapper(encoder_)
+        if state is not None and not cfg.no_pretrained_weights:
+            # set strict=False because we omit some modules
+            del state['model']['mask_emb']
+            encoder.w2v_model.load_state_dict(state["model"], strict=False)
+
+        encoder.w2v_model.remove_pretraining_modules()
+
+        src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
+
+        def build_embedding(dictionary, embed_dim):
+            num_embeddings = len(dictionary)
+            padding_idx = dictionary.pad()
+            emb = Embedding(num_embeddings, embed_dim, padding_idx=padding_idx)
+            return emb
+
+        decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim)
+        decoder = TransformerDecoder(cfg, tgt_dict, decoder_embed_tokens)
+
+        return AVHubertSeq2Seq(encoder, decoder, tgt_dict, cfg)
+
+
+    def forward(self, **kwargs):
+        # ft = self.freeze_finetune_updates <= self.num_updates
+        # with torch.no_grad() if not ft else contextlib.ExitStack():
+        #     output = self.encoder(**kwargs)
+        with torch.no_grad():
+            output = self.encoder(**kwargs) #encoder_out,encoder_padding_mask,padding_mask
+        # decoder_out = self.decoder(prev_output_tokens=kwargs['prev_output_tokens'], encoder_out=output)
+        return output
+
+    def upgrade_state_dict_named(self, state_dict, name):
+        super().upgrade_state_dict_named(state_dict, name)
+        return state_dict
+
+    def set_num_updates(self, num_updates):
+        """Set the number of parameters updates."""
+        super().set_num_updates(num_updates)
+        self.num_updates = num_updates
+
+def Embedding(num_embeddings, embedding_dim, padding_idx):
+    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
+    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
+    nn.init.constant_(m.weight[padding_idx], 0)
+    return m
+
+
+def Linear(in_features, out_features, bias=True):
+    m = nn.Linear(in_features, out_features, bias)
+    nn.init.xavier_uniform_(m.weight)
+    if bias:
+        nn.init.constant_(m.bias, 0.0)
+    return m
diff --git a/slam_llm/models/avhubert/hubert_criterion.py b/slam_llm/models/avhubert/hubert_criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc7262288958455de4189df82020f28f80d58e34
--- /dev/null
+++ b/slam_llm/models/avhubert/hubert_criterion.py
@@ -0,0 +1,169 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import re
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+from fairseq import metrics, utils
+from fairseq.criterions import FairseqCriterion, register_criterion
+from fairseq.dataclass import FairseqDataclass
+
+
+@dataclass
+class AVHubertCriterionConfig(FairseqDataclass):
+    pred_masked_weight: float = field(
+        default=1.0,
+        metadata={"help": "weight for predictive loss for masked frames"},
+    )
+    pred_nomask_weight: float = field(
+        default=0.0,
+        metadata={"help": "weight for predictive loss for unmasked frames"},
+    )
+    loss_weights: Optional[List[float]] = field(
+        default=None,
+        metadata={"help": "weights for additional loss terms (not first one)"},
+    )
+    log_keys: List[str] = field(
+        default_factory=lambda: [],
+        metadata={"help": "output keys to log"},
+    )
+
+
+@register_criterion("av_hubert", dataclass=AVHubertCriterionConfig)
+class AVHubertCriterion(FairseqCriterion):
+    def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None):
+        super().__init__(task)
+        self.pred_masked_weight = pred_masked_weight
+        self.pred_nomask_weight = pred_nomask_weight
+        self.loss_weights = loss_weights
+        self.log_keys = [] if log_keys is None else log_keys
+
+    def forward(self, model, sample, reduce=True, log_pred=False):
+        """Compute the loss for the given sample.
+        Returns a tuple with three elements:
+        1) the loss
+        2) the sample size, which is used as the denominator for the gradient
+        3) logging outputs to display while training
+        """
+        net_output = model(target_list=sample["target_list"], **sample["net_input"])
+        loss = 0.
+        sample_size = 0
+        logging_output = {}
+        reduction = "sum" if reduce else "none"
+
+        loss_m_list = []
+        logp_m_list, targ_m_list = net_output['logit_m_list'], net_output['target_m_list']
+        for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
+            loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
+            loss_m_list.append(loss_m)
+            logging_output[f"loss_m_{i}"] = loss_m.detach().item()
+        if self.pred_masked_weight > 0:
+            loss += self.pred_masked_weight * sum(loss_m_list)
+            sample_size += targ_m_list[0].numel()
+
+        loss_u_list = []
+        logp_u_list, targ_u_list = net_output['logit_u_list'], net_output['target_u_list']
+        for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
+            loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
+            loss_u_list.append(loss_u)
+            logging_output[f"loss_u_{i}"] = loss_u.detach().item()
+        if self.pred_nomask_weight > 0:
+            loss += self.pred_nomask_weight * sum(loss_u_list)
+            sample_size += targ_u_list[0].numel()
+
+        if self.loss_weights is not None:
+            assert hasattr(model, "get_extra_losses")
+            extra_losses, names = model.get_extra_losses(net_output)
+            if torch.is_tensor(extra_losses):
+                extra_losses = [extra_losses]
+                names = [names]
+            if len(self.loss_weights) == 1 and len(extra_losses) != 1:
+                self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
+            assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
+            for p, n, coef in zip(extra_losses, names, self.loss_weights):
+                if coef != 0 and p is not None:
+                    p = coef * p.float() * sample_size
+                    loss += p
+                    logging_output[f"loss_{n}"] = p.item()
+
+        logging_output = {
+            "loss": loss.item() if reduce else loss,
+            "ntokens": sample_size,
+            "nsentences": sample["id"].numel(),
+            "sample_size": sample_size,
+            **logging_output,
+        }
+
+        for lk in self.log_keys:
+            if lk in net_output:
+                logging_output[lk] = float((net_output[lk]))
+
+        with torch.no_grad():
+            for i, logp_m in enumerate(logp_m_list):
+                # corr_m, count_m = compute_correct(logp_m)
+                if logp_m.numel() == 0:
+                    corr_m, count_m = 0, 0
+                else:
+                    corr_m, count_m = (logp_m.argmax(dim=-1)==targ_m_list[i]).sum().item(), len(targ_m_list[i])
+                logging_output[f"correct_m_{i}"] = corr_m
+                logging_output[f"count_m_{i}"] = count_m
+
+            for i, logp_u in enumerate(logp_u_list):
+                if logp_u.numel() == 0:
+                    corr_u, count_u = 0, 0
+                else:
+                    corr_u, count_u = (logp_u.argmax(dim=-1)==targ_u_list[i]).sum().item(), len(targ_u_list[i])
+                logging_output[f"correct_u_{i}"] = corr_u
+                logging_output[f"count_u_{i}"] = count_u
+
+        return loss, sample_size, logging_output
+
+    @staticmethod
+    def reduce_metrics(logging_outputs) -> None:
+        """Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
+        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
+        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
+        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
+
+        metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
+        if sample_size != ntokens:
+            metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
+            metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
+        else:
+            metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
+
+        counts = {}
+        for lk in logging_outputs[0].keys():
+            if lk.startswith("count_"):
+                val = sum(log[lk] for log in logging_outputs)
+                metrics.log_scalar(lk, val)
+                counts[lk] = val
+
+        for lk in logging_outputs[0].keys():
+            if lk.startswith("loss_"):
+                val = sum(log[lk] for log in logging_outputs)
+                metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
+            elif lk.startswith("correct_"):
+                val = sum(log[lk] for log in logging_outputs)
+                metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
+
+    @staticmethod
+    def aggregate_logging_outputs(logging_outputs):
+        """Aggregate logging outputs from data parallel training."""
+        raise NotImplementedError()
+
+    @staticmethod
+    def logging_outputs_can_be_summed() -> bool:
+        """
+        Whether the logging outputs returned by `forward` can be summed
+        across workers prior to calling `reduce_metrics`. Setting this
+        to True will improves distributed training speed.
+        """
+        return False
diff --git a/slam_llm/models/avhubert/hubert_dataset.py b/slam_llm/models/avhubert/hubert_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..292ebdec9e409ec29c2fa800727e6b935c85a5be
--- /dev/null
+++ b/slam_llm/models/avhubert/hubert_dataset.py
@@ -0,0 +1,529 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import itertools
+import logging
+import os
+import sys
+import time
+from typing import Any, List, Optional, Union
+
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+from fairseq.data import data_utils
+from fairseq.data.fairseq_dataset import FairseqDataset
+from python_speech_features import logfbank
+from scipy.io import wavfile
+
+DBG=True if len(sys.argv) == 1 else False
+
+if DBG:
+    import utils as custom_utils
+    logging.basicConfig(
+        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+        datefmt="%Y-%m-%d %H:%M:%S",
+        level=os.environ.get("LOGLEVEL", "DEBUG").upper(),
+        stream=sys.stdout,
+    )
+else:
+    from . import utils as custom_utils
+
+logger = logging.getLogger(__name__)
+
+
+def load_audio_visual(manifest_path, max_keep, min_keep, frame_rate, label_paths, label_rates, tol=0.1):
+    def is_audio_label_aligned(audio_dur, label_durs):
+        return all([abs(audio_dur - label_dur)<tol for label_dur in label_durs])
+
+    n_long, n_short, n_unaligned = 0, 0, 0
+    names, inds, sizes = [], [], []
+    dur_from_label_list = []
+    is_seq_label = any([x==-1 for x in label_rates])
+    for label_path, label_rate in zip(label_paths, label_rates):
+        label_lengths = [len(line.rstrip().split())/label_rate for line in open(label_path).readlines()]
+        dur_from_label_list.append(label_lengths)
+    dur_from_label_list = list(zip(*dur_from_label_list))
+
+    with open(manifest_path) as f:
+        root = f.readline().strip()
+        for ind, line in enumerate(f):
+            items = line.strip().split("\t")
+            sz = int(items[-2]) # 
+            if min_keep is not None and sz < min_keep:
+                n_short += 1
+            elif max_keep is not None and sz > max_keep:
+                n_long += 1
+            elif (not is_seq_label) and (not is_audio_label_aligned(sz/frame_rate, dur_from_label_list[ind])):
+                n_unaligned += 1
+            else:
+                video_path = items[1]
+                audio_path = items[2]
+                audio_id = items[0]
+                names.append((video_path, audio_path+':'+audio_id))
+                inds.append(ind)
+                sizes.append(sz)
+    tot = ind + 1
+    logger.info(
+        (
+            f"max_keep={max_keep}, min_keep={min_keep}, "
+            f"loaded {len(names)}, skipped {n_short} short and {n_long} long and {n_unaligned} unaligned, "
+            f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
+        )
+    )
+    return root, names, inds, tot, sizes
+
+def load_label(label_path, inds, tot):
+    with open(label_path) as f:
+        labels = [line.rstrip() for line in f]
+        assert (
+            len(labels) == tot
+        ), f"number of labels does not match ({len(labels)} != {tot})"
+        labels = [labels[i] for i in inds]
+    return labels
+
+
+def load_label_offset(label_path, inds, tot):
+    with open(label_path) as f:
+        code_lengths = [len(line.encode("utf-8")) for line in f]
+        assert (
+            len(code_lengths) == tot
+        ), f"number of labels does not match ({len(code_lengths)} != {tot})"
+        offsets = list(itertools.accumulate([0] + code_lengths))
+        offsets = [(offsets[i], offsets[i + 1]) for i in inds]
+    return offsets
+
+
+def verify_label_lengths(
+    audio_sizes,
+    audio_rate,
+    label_path,
+    label_rate,
+    inds,
+    tot,
+    tol=0.1,  # tolerance in seconds
+):
+    if label_rate < 0:
+        logger.info(f"{label_path} is sequence label. skipped")
+        return
+
+    with open(label_path) as f:
+        lengths = [len(line.rstrip().split()) for line in f]
+        assert len(lengths) == tot
+        lengths = [lengths[i] for i in inds]
+    num_invalid = 0
+    for i, ind in enumerate(inds):
+        dur_from_audio = audio_sizes[i] / audio_rate
+        dur_from_label = lengths[i] / label_rate
+        if abs(dur_from_audio - dur_from_label) > tol:
+            logger.warning(
+                (
+                    f"audio and label duration differ too much "
+                    f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
+                    f"in line {ind+1} of {label_path}. Check if `label_rate` "
+                    f"is correctly set (currently {label_rate}). "
+                    f"num. of samples = {audio_sizes[i]}; "
+                    f"label length = {lengths[i]}"
+                )
+            )
+            num_invalid += 1
+    if num_invalid > 0:
+        logger.warning(
+            f"total {num_invalid} (audio, label) pairs with mismatched lengths"
+        )
+
+
+class AVHubertDataset(FairseqDataset):
+    def __init__(
+            self,
+            manifest_path: str,
+            sample_rate: float,
+            label_paths: List[str],
+            label_rates: Union[List[float], float],  # -1 for sequence labels
+            pad_list: List[str],
+            eos_list: List[str],
+            label_processors: Optional[List[Any]] = None,
+            max_keep_sample_size: Optional[int] = None,
+            min_keep_sample_size: Optional[int] = None,
+            max_sample_size: Optional[int] = None,
+            shuffle: bool = True,
+            pad_audio: bool = False,
+            normalize: bool = False,
+            store_labels: bool = True,
+            random_crop: bool = False,
+            single_target: bool = False,
+            stack_order_audio: int=1,
+            skip_verify: bool=False,
+            image_mean: float=0,
+            image_std: float=1,
+            image_crop_size: int=88,
+            image_aug: bool=False,
+            modalities: Optional[List[str]]=None,
+            is_s2s=False,
+            noise_fn=None,
+            noise_prob=0,
+            noise_snr=0,
+            noise_num=1
+    ):
+        self.label_rates = (
+            [label_rates for _ in range(len(label_paths))]
+            if isinstance(label_rates, int)
+            else label_rates
+        )
+        self.modalities = set(modalities)
+        self.audio_root, self.names, inds, tot, self.sizes = load_audio_visual(manifest_path, max_keep_sample_size, min_keep_sample_size, frame_rate=sample_rate, label_paths=label_paths, label_rates=self.label_rates)
+        self.sample_rate = sample_rate
+        self.stack_order_audio = stack_order_audio
+        self.shuffle = shuffle
+        self.random_crop = random_crop
+
+        self.num_labels = len(label_paths)
+        self.pad_list = pad_list
+        self.eos_list = eos_list
+        self.label_processors = label_processors
+        self.single_target = single_target
+        self.store_labels = store_labels
+        self.is_s2s = is_s2s
+        self.noise_wav, self.noise_prob, self.noise_snr, self.noise_num = [ln.strip() for ln in open(noise_fn).readlines()] if noise_fn is not None else [], noise_prob, noise_snr, noise_num
+
+        assert self.single_target == (self.label_rates[0] == -1), f"single target should be equivalent to sequence label (label_rate==-1)"
+        if store_labels:
+            self.label_list = [load_label(p, inds, tot) for p in label_paths]
+        else:
+            self.label_paths = label_paths
+            self.label_offsets_list = [
+                load_label_offset(p, inds, tot) for p in label_paths
+            ]
+        assert (
+            label_processors is None
+            or len(label_processors) == self.num_labels
+        )
+        if not skip_verify:
+            for label_path, label_rate in zip(label_paths, self.label_rates):
+                verify_label_lengths(self.sizes, self.sample_rate, label_path, label_rate, inds, tot)
+        else:
+            logger.info(f"Skip label alignment verifying")
+
+        self.max_sample_size = (
+            max_sample_size if max_sample_size is not None else sys.maxsize
+        )
+        self.pad_audio = pad_audio
+        self.normalize = normalize
+        if image_aug:
+            self.transform = custom_utils.Compose([
+                custom_utils.Normalize( 0.0,255.0 ),
+                custom_utils.RandomCrop((image_crop_size, image_crop_size)),
+                custom_utils.HorizontalFlip(0.5),
+                custom_utils.Normalize(image_mean, image_std) ])
+        else:
+            self.transform = custom_utils.Compose([
+                custom_utils.Normalize( 0.0,255.0 ),
+                custom_utils.CenterCrop((image_crop_size, image_crop_size)),
+                custom_utils.Normalize(image_mean, image_std) ])
+        logger.info(f"image transform: {self.transform}")
+
+        logger.info(
+            f"pad_audio={pad_audio}, random_crop={random_crop}, "
+            f"normalize={normalize}, max_sample_size={self.max_sample_size}, "
+            f"seqs2seq data={self.is_s2s},")
+        logger.info(
+            f"Noise wav: {noise_fn}->{len(self.noise_wav)} wav, Prob: {self.noise_prob}, SNR: {self.noise_snr}, Number of mixture: {self.noise_num}"
+        )
+
+    def get_label(self, index, label_idx):
+        if self.store_labels:
+            label = self.label_list[label_idx][index]
+        else:
+            with open(self.label_paths[label_idx]) as f:
+                offset_s, offset_e = self.label_offsets_list[label_idx][index]
+                f.seek(offset_s)
+                label = f.read(offset_e - offset_s)
+
+        if self.label_processors is not None:
+            label = self.label_processors[label_idx](label)
+        return label
+
+    def get_labels(self, index):
+        return [self.get_label(index, i) for i in range(self.num_labels)]
+
+    def load_feature(self, mix_name):
+        """
+        Load image and audio feature
+        Returns:
+        video_feats: numpy.ndarray of shape [T, H, W, 1], audio_feats: numpy.ndarray of shape [T, F]
+        """
+        def stacker(feats, stack_order):
+            """
+            Concatenating consecutive audio frames
+            Args:
+            feats - numpy.ndarray of shape [T, F]
+            stack_order - int (number of neighboring frames to concatenate
+            Returns:
+            feats - numpy.ndarray of shape [T', F']
+            """
+            feat_dim = feats.shape[1]
+            if len(feats) % stack_order != 0:
+                res = stack_order - len(feats) % stack_order
+                res = np.zeros([res, feat_dim]).astype(feats.dtype)
+                feats = np.concatenate([feats, res], axis=0)
+            feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order*feat_dim)
+            return feats
+        video_fn, audio_fn = mix_name
+        if 'video' in self.modalities:
+            video_feats = self.load_video(video_fn) # [T, H, W, 1]
+        else:
+            video_feats = None
+        if 'audio' in self.modalities:
+            audio_fn = audio_fn.split(':')[0]
+            sample_rate, wav_data = wavfile.read(audio_fn)
+            assert sample_rate == 16_000 and len(wav_data.shape) == 1
+            if np.random.rand() < self.noise_prob:
+                wav_data = self.add_noise(wav_data)
+            audio_feats = logfbank(wav_data, samplerate=sample_rate).astype(np.float32) # [T, F]
+            audio_feats = stacker(audio_feats, self.stack_order_audio) # [T/stack_order_audio, F*stack_order_audio]
+        else:
+            audio_feats = None
+        if audio_feats is not None and video_feats is not None:
+            diff = len(audio_feats) - len(video_feats)
+            if diff < 0:
+                audio_feats = np.concatenate([audio_feats, np.zeros([-diff, audio_feats.shape[-1]], dtype=audio_feats.dtype)])
+            elif diff > 0:
+                audio_feats = audio_feats[:-diff]
+        return video_feats, audio_feats
+
+    def load_video(self, audio_name):
+        feats = custom_utils.load_video(os.path.join(self.audio_root, audio_name))
+        feats = self.transform(feats)
+        feats = np.expand_dims(feats, axis=-1)
+        return feats
+
+    def select_noise(self):
+        rand_indexes = np.random.randint(0, len(self.noise_wav), size=self.noise_num)
+        noise_wav = []
+        for x in rand_indexes:
+            noise_wav.append(wavfile.read(self.noise_wav[x])[1].astype(np.float32))
+        if self.noise_num == 1:
+            return noise_wav[0]
+        else:
+            min_len = min([len(x) for x in noise_wav])
+            noise_wav = [x[:min_len] for x in noise_wav]
+            noise_wav = np.floor(np.stack(noise_wav).mean(axis=0))
+            return noise_wav
+
+    def add_noise(self, clean_wav):
+        clean_wav = clean_wav.astype(np.float32)
+        noise_wav = self.select_noise()
+        if type(self.noise_snr) == int or type(self.noise_snr) == float:
+            snr = self.noise_snr
+        elif type(self.noise_snr) == tuple:
+            snr = np.random.randint(self.noise_snr[0], self.noise_snr[1]+1)
+        clean_rms = np.sqrt(np.mean(np.square(clean_wav), axis=-1))
+        if len(clean_wav) > len(noise_wav):
+            ratio = int(np.ceil(len(clean_wav)/len(noise_wav)))
+            noise_wav = np.concatenate([noise_wav for _ in range(ratio)])
+        if len(clean_wav) < len(noise_wav):
+            start = 0
+            noise_wav = noise_wav[start: start + len(clean_wav)]
+        noise_rms = np.sqrt(np.mean(np.square(noise_wav), axis=-1))
+        adjusted_noise_rms = clean_rms / (10**(snr/20))
+        adjusted_noise_wav = noise_wav * (adjusted_noise_rms / noise_rms)
+        mixed = clean_wav + adjusted_noise_wav
+
+        #Avoid clipping noise
+        max_int16 = np.iinfo(np.int16).max
+        min_int16 = np.iinfo(np.int16).min
+        if mixed.max(axis=0) > max_int16 or mixed.min(axis=0) < min_int16:
+            if mixed.max(axis=0) >= abs(mixed.min(axis=0)): 
+                reduction_rate = max_int16 / mixed.max(axis=0)
+            else :
+                reduction_rate = min_int16 / mixed.min(axis=0)
+            mixed = mixed * (reduction_rate)
+        mixed = mixed.astype(np.int16)
+        return mixed
+
+    def __getitem__(self, index):
+        video_feats, audio_feats = self.load_feature(self.names[index])
+        audio_feats, video_feats = torch.from_numpy(audio_feats.astype(np.float32)) if audio_feats is not None else None, torch.from_numpy(video_feats.astype(np.float32)) if video_feats is not None else None
+        if self.normalize and 'audio' in self.modalities:
+            with torch.no_grad():
+                audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:])
+        labels = self.get_labels(index)
+        fid = self.names[index][1].split(':')[1]
+        return {"id": index, 'fid': fid, "video_source": video_feats, 'audio_source': audio_feats, "label_list": labels}
+
+    def __len__(self):
+        return len(self.sizes)
+
+    def crop_to_max_size(self, wav, target_size, start=None):
+        size = len(wav)
+        diff = size - target_size
+        if diff <= 0:
+            return wav, 0
+        # longer utterances
+        if start is None:
+            start, end = 0, target_size
+            if self.random_crop:
+                start = np.random.randint(0, diff + 1)
+                end = size - diff + start
+        else:
+            end = start + target_size
+        return wav[start:end], start
+
+    def collater(self, samples):
+        samples = [s for s in samples if s["id"] is not None]
+        if len(samples) == 0:
+            return {}
+
+        audio_source, video_source = [s["audio_source"] for s in samples], [s["video_source"] for s in samples]
+        if audio_source[0] is None:
+            audio_source = None
+        if video_source[0] is None:
+            video_source = None
+        if audio_source is not None:
+            audio_sizes = [len(s) for s in audio_source]
+        else:
+            audio_sizes = [len(s) for s in video_source]
+        if self.pad_audio:
+            audio_size = min(max(audio_sizes), self.max_sample_size)
+        else:
+            audio_size = min(min(audio_sizes), self.max_sample_size)
+        if audio_source is not None:
+            collated_audios, padding_mask, audio_starts = self.collater_audio(audio_source, audio_size)
+        else:
+            collated_audios, audio_starts = None, None
+        if video_source is not None:
+            collated_videos, padding_mask, audio_starts = self.collater_audio(video_source, audio_size, audio_starts)
+        else:
+            collated_videos = None
+        targets_by_label = [
+            [s["label_list"][i] for s in samples]
+            for i in range(self.num_labels)
+        ]
+        targets_list, lengths_list, ntokens_list = self.collater_label(
+            targets_by_label, audio_size, audio_starts
+        )
+        source = {"audio": collated_audios, "video": collated_videos}
+        net_input = {"source": source, "padding_mask": padding_mask}
+        batch = {
+            "id": torch.LongTensor([s["id"] for s in samples]),
+            "net_input": net_input,
+            "utt_id": [s['fid'] for s in samples]
+        }
+
+        if self.single_target:
+            batch["target_lengths"] = lengths_list[0]
+            batch["ntokens"] = ntokens_list[0]
+            if self.is_s2s:
+                batch['target'], net_input['prev_output_tokens'] = targets_list[0][0], targets_list[0][1]
+            else:
+                batch["target"] = targets_list[0]
+        else:
+            batch["target_lengths_list"] = lengths_list
+            batch["ntokens_list"] = ntokens_list
+            batch["target_list"] = targets_list
+        return batch
+
+    def collater_audio(self, audios, audio_size, audio_starts=None):
+        audio_feat_shape = list(audios[0].shape[1:])
+        collated_audios = audios[0].new_zeros([len(audios), audio_size]+audio_feat_shape)
+        padding_mask = (
+            torch.BoolTensor(len(audios), audio_size).fill_(False) # 
+        )
+        start_known = audio_starts is not None
+        audio_starts = [0 for _ in audios] if not start_known else audio_starts
+        for i, audio in enumerate(audios):
+            diff = len(audio) - audio_size
+            if diff == 0:
+                collated_audios[i] = audio
+            elif diff < 0:
+                assert self.pad_audio
+                collated_audios[i] = torch.cat(
+                    [audio, audio.new_full([-diff]+audio_feat_shape, 0.0)]
+                )
+                padding_mask[i, diff:] = True
+            else:
+                collated_audios[i], audio_starts[i] = self.crop_to_max_size(
+                    audio, audio_size, audio_starts[i] if start_known else None
+                )
+        if len(audios[0].shape) == 2:
+            collated_audios = collated_audios.transpose(1, 2) # [B, T, F] -> [B, F, T]
+        else:
+            collated_audios = collated_audios.permute((0, 4, 1, 2, 3)).contiguous() # [B, T, H, W, C] -> [B, C, T, H, W]
+        return collated_audios, padding_mask, audio_starts
+
+    def collater_frm_label(
+        self, targets, audio_size, audio_starts, label_rate, pad
+    ):
+        assert label_rate > 0
+        s2f = label_rate / self.sample_rate # num label per sample
+        frm_starts = [int(round(s * s2f)) for s in audio_starts]
+        frm_size = int(round(audio_size * s2f))
+        if not self.pad_audio:
+            rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
+            frm_size = min(frm_size, *rem_size)
+        targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)]
+        logger.debug(f"audio_starts={audio_starts}")
+        logger.debug(f"frame_starts={frm_starts}")
+        logger.debug(f"frame_size={frm_size}")
+
+        lengths = torch.LongTensor([len(t) for t in targets])
+        ntokens = lengths.sum().item()
+        targets = data_utils.collate_tokens(
+            targets, pad_idx=pad, left_pad=False
+        )
+        return targets, lengths, ntokens
+
+    def collater_seq_label(self, targets, pad):
+        lengths = torch.LongTensor([len(t) for t in targets])
+        ntokens = lengths.sum().item()
+        targets = data_utils.collate_tokens(
+            targets, pad_idx=pad, left_pad=False
+        )
+        return targets, lengths, ntokens
+
+    def collater_seq_label_s2s(self, targets, pad):
+        lengths = torch.LongTensor([len(t) for t in targets])
+        ntokens = lengths.sum().item()
+        pad, eos = self.label_processors[0].dictionary.pad(), self.label_processors[0].dictionary.eos()
+        targets_ = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False)
+        prev_output_tokens = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False, move_eos_to_beginning=True)
+        return (targets_, prev_output_tokens), lengths, ntokens
+
+    def collater_label(self, targets_by_label, audio_size, audio_starts):
+        targets_list, lengths_list, ntokens_list = [], [], []
+        itr = zip(targets_by_label, self.label_rates, self.pad_list)
+        for targets, label_rate, pad in itr:
+            if label_rate == -1:
+                if self.is_s2s:
+                    targets, lengths, ntokens = self.collater_seq_label_s2s(targets, pad)
+                else:
+                    targets, lengths, ntokens = self.collater_seq_label(targets, pad)
+            else:
+                targets, lengths, ntokens = self.collater_frm_label(
+                    targets, audio_size, audio_starts, label_rate, pad
+                )
+            targets_list.append(targets)
+            lengths_list.append(lengths)
+            ntokens_list.append(ntokens)
+        return targets_list, lengths_list, ntokens_list
+
+    def num_tokens(self, index):
+        return self.size(index)
+
+    def size(self, index):
+        if self.pad_audio:
+            return self.sizes[index]
+        return min(self.sizes[index], self.max_sample_size)
+
+    def ordered_indices(self):
+        if self.shuffle:
+            order = [np.random.permutation(len(self))]
+        else:
+            order = [np.arange(len(self))]
+
+        order.append(self.sizes)
+        return np.lexsort(order)[::-1]
diff --git a/slam_llm/models/avhubert/hubert_pretraining.py b/slam_llm/models/avhubert/hubert_pretraining.py
new file mode 100644
index 0000000000000000000000000000000000000000..5829e6eadbd1da1ae66242bf4478c93e2161c32e
--- /dev/null
+++ b/slam_llm/models/avhubert/hubert_pretraining.py
@@ -0,0 +1,401 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import os, glob
+import sys
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+
+from dataclasses import dataclass, field
+from fairseq import metrics, search
+from fairseq.data import Dictionary, encoders
+from fairseq.dataclass.configs import FairseqDataclass
+from fairseq.tasks import register_task
+from fairseq.tasks.fairseq_task import FairseqTask
+from omegaconf import MISSING, II
+import numpy as np
+from argparse import Namespace
+
+DBG=True if len(sys.argv) == 1 else False
+
+if DBG:
+    from hubert_dataset import AVHubertDataset
+    from sequence_generator import SequenceGenerator
+else:
+    from .hubert_dataset import AVHubertDataset
+    from .sequence_generator import SequenceGenerator
+
+logger = logging.getLogger(__name__)
+
+
+class LabelEncoder(object):
+    def __init__(self, dictionary: Dictionary) -> None:
+        self.dictionary = dictionary
+
+    def __call__(self, label: str) -> List[str]:
+        return self.dictionary.encode_line(
+            label, append_eos=False, add_if_not_exist=False,
+        )
+
+class LabelEncoderS2SToken(object):
+    def __init__(self, dictionary: Dictionary, bpe_tokenizer) -> None:
+        self.bpe_tokenizer = bpe_tokenizer
+        self.dictionary = dictionary
+
+    def __call__(self, label: str) -> List[str]:
+        label = self.bpe_tokenizer.encode(label.lower())
+        return self.dictionary.encode_line(
+            label, append_eos=True, add_if_not_exist=False,
+        ).long()
+
+    def decode(self, tok, symbols_ignore=None):
+        tok = self.dictionary.string(tok, extra_symbols_to_ignore=symbols_ignore)
+        if self.bpe_tokenizer:
+            tok = self.bpe_tokenizer.decode(tok)
+        return tok
+
+@dataclass
+class AVHubertPretrainingConfig(FairseqDataclass):
+    input_modality: str = II("task.input_modality") #??
+    data: str = field(
+        default=MISSING, metadata={"help": "path to data directory"}
+    )
+    labels: List[str] = field(
+        default_factory=lambda: ["ltr"],
+        metadata={
+            "help": (
+                "extension of the label files to load, frame-level labels for"
+                " pre-training, and sequence-level label for fine-tuning"
+            )
+        },
+    )
+    label_dir: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "if set, looks for labels in this directory instead",
+        },
+    )
+    label_rate: int = field(
+        default=-1,
+        metadata={"help": "label frame rate. -1 for sequence label"},
+    )
+
+    sample_rate: int = field(
+        default=16_000,
+        metadata={
+            "help": "target sample rate. audio files will be up/down "
+            "sampled to this rate"
+        },
+    )
+    normalize: bool = field(
+        default=False,
+        metadata={
+            "help": "if set, normalizes input to have 0 mean and unit variance"
+        },
+    )
+    enable_padding: bool = field(
+        default=False,
+        metadata={"help": "pad shorter samples instead of cropping"},
+    )
+    max_sample_size: Optional[int] = field(
+        default=None,
+        metadata={"help": "max sample size to keep in training"},
+    )
+    min_sample_size: Optional[int] = field(
+        default=None,
+        metadata={"help": "min sample size to keep in training"},
+    )
+    max_trim_sample_size: Optional[int] = field(
+        default=II("task.max_sample_size"),
+        metadata={"help": "max sample size to trim to for batching"},
+    )
+    single_target: Optional[bool] = field(
+        default=False,
+        metadata={
+            "help": "if set, AddTargetDatasets outputs same keys "
+            "as AddTargetDataset"
+        },
+    )
+    random_crop: Optional[bool] = field(
+        default=True,
+        metadata={"help": "always crop from the beginning if false"},
+    )
+    pad_audio: Optional[bool] = field(
+        default=False,
+        metadata={"help": "pad audio to the longest one in the batch if true"},
+    )
+    pdb: Optional[bool] = field(
+        default=False,
+        metadata={"help": "pdb"},
+    )
+    stack_order_audio: int = field(
+        default=1,
+        metadata={"help": "concatenate n consecutive audio frames for one step"},
+    )
+    skip_verify: Optional[bool] = field(
+        default=False,
+        metadata={"help": "skip verifying label-audio alignment"},
+    )
+    image_aug: bool = field(default=False, metadata={'help': 'image data augmentation'})
+    image_crop_size: int = field(
+        default=88, metadata={"help": "image ROI size"})
+    image_mean: float = field(
+        default=0.421, metadata={"help": "image mean"})
+    image_std: float = field(
+        default=0.165, metadata={"help": "image std"})
+    modalities: Optional[List[str]] = field(default_factory=lambda: ["audio", "video"], metadata={'help': 'modalities to load'})
+    is_s2s: bool=field(default=False, metadata={'help': 'seq2seq fine-tuning only'})
+    tokenizer_bpe_name: Optional[str] = field(default=None, metadata={'help': 'tokenizer model name'})
+    tokenizer_bpe_model: Optional[str] = field(default=None, metadata={'help': 'tokenizer model path'})
+    noise_wav: Optional[str] = field(default=None, metadata={'help': 'manifest of noise wav files (one wav file path per line)'})
+    noise_prob: float = field(default=0, metadata={'help': 'noise probability'})
+    noise_snr: Optional[str] = field(default='0', metadata={'help': 'noise SNR in audio'})
+    noise_num: int = field(default=1, metadata={'help': 'number of noise wav files to mix'})
+    fine_tuning: bool = field(default=False, metadata={"help": "set to true if fine-tuning AV-Hubert"})
+
+@register_task("av_hubert_pretraining", dataclass=AVHubertPretrainingConfig)
+class AVHubertPretrainingTask(FairseqTask):
+
+    cfg: AVHubertPretrainingConfig
+
+    def __init__(
+        self,
+        cfg: AVHubertPretrainingConfig,
+    ) -> None:
+        super().__init__(cfg)
+
+        logger.info(f"current directory is {os.getcwd()}")
+        logger.info(f"AVHubertPretrainingTask Config {cfg}")
+
+        self.fine_tuning = cfg.fine_tuning
+        if cfg.fine_tuning:
+            self.state.add_factory("target_dictionary", self.load_dictionaries)
+            if cfg.is_s2s:
+                self.state.add_factory("s2s_tokenizer", self.load_tokenizer)
+        else:
+            self.state.add_factory("dictionaries", self.load_dictionaries)
+
+        self.blank_symbol = "<s>"
+
+    @property
+    def source_dictionary(self) -> Optional[Dictionary]:
+        return None # self._source_dictionary
+
+    @property
+    def target_dictionary(self) -> Optional[Dictionary]:
+        return self.state.target_dictionary # self._target_dictionary
+
+    @property
+    def dictionaries(self) -> List[Dictionary]:
+        return self.state.dictionaries
+
+    def load_dictionaries(self):
+        label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir
+        dictionaries = [
+            Dictionary.load(f"{label_dir}/dict.{label}.txt")
+            for label in self.cfg.labels
+        ]
+        return dictionaries[0] if self.cfg.fine_tuning else dictionaries
+
+    def load_tokenizer(self):
+        bpe_args = Namespace(**{'bpe': self.cfg.tokenizer_bpe_name, f"{self.cfg.tokenizer_bpe_name}_model": self.cfg.tokenizer_bpe_model})
+        bpe_tokenizer = encoders.build_bpe(bpe_args)
+        return bpe_tokenizer
+
+    @property
+    def s2s_tokenizer(self):
+        return self.state.s2s_tokenizer
+
+    @classmethod
+    def setup_task(
+        cls, cfg: AVHubertPretrainingConfig, **kwargs
+    ) -> "AVHubertPretrainingTask":
+        if cfg.pdb:
+            import pdb
+            pdb.set_trace()
+        return cls(cfg)
+
+    def get_label_dir(self) -> str:
+        if self.cfg.label_dir is None:
+            return self.cfg.data
+        return self.cfg.label_dir
+
+    def load_dataset(self, split: str, **kwargs) -> None:
+        manifest = f"{self.cfg.data}/{split}.tsv"
+        dictionaries = [self.target_dictionary] if self.fine_tuning else self.dictionaries
+        pad_list = [dictionary.pad() for dictionary in dictionaries]
+        eos_list = [dictionary.eos() for dictionary in dictionaries]
+        if not self.cfg.is_s2s:
+            procs = [LabelEncoder(dictionary) for dictionary in dictionaries]
+        else:
+            logger.info(f"Using tokenizer")
+            bpe_tokenizer = self.s2s_tokenizer
+            procs = [LabelEncoderS2SToken(dictionary, bpe_tokenizer) for dictionary in dictionaries]
+        paths = [
+            f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels
+        ]
+        image_aug = self.cfg.image_aug if split == 'train' else False
+        noise_fn, noise_snr = f"{self.cfg.noise_wav}/{split}.tsv" if self.cfg.noise_wav is not None else None, eval(self.cfg.noise_snr)
+        noise_num = self.cfg.noise_num # 
+        self.datasets[split] = AVHubertDataset(
+            manifest,
+            sample_rate=self.cfg.sample_rate,
+            label_paths=paths,
+            label_rates=self.cfg.label_rate,
+            pad_list=pad_list,
+            eos_list=eos_list,
+            label_processors=procs,
+            max_keep_sample_size=self.cfg.max_sample_size,
+            min_keep_sample_size=self.cfg.min_sample_size,
+            max_sample_size=self.cfg.max_trim_sample_size,
+            pad_audio=self.cfg.pad_audio,
+            normalize=self.cfg.normalize,
+            store_labels=False,
+            random_crop=self.cfg.random_crop,
+            single_target=self.cfg.single_target,
+            stack_order_audio=self.cfg.stack_order_audio,
+            skip_verify=self.cfg.skip_verify,
+            image_mean=self.cfg.image_mean,
+            image_std=self.cfg.image_std,
+            image_crop_size=self.cfg.image_crop_size,
+            image_aug=image_aug,
+            modalities=self.cfg.modalities,
+            is_s2s=self.cfg.is_s2s,
+            noise_fn=noise_fn,
+            noise_prob=self.cfg.noise_prob,
+            noise_snr=noise_snr,
+            noise_num=noise_num
+        )
+
+    def max_positions(self) -> Tuple[int, int]:
+        return (sys.maxsize, sys.maxsize)
+
+    def filter_indices_by_size(
+        self, indices: np.array, *args, **kwargs
+    ) -> np.array:
+        return indices
+
+    def build_generator(
+        self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
+    ):
+        """
+        Build a :class:`~fairseq.SequenceGenerator` instance for this
+        task.
+        Args:
+            models (List[~fairseq.models.FairseqModel]): ensemble of models
+            args (fairseq.dataclass.configs.GenerationConfig):
+                configuration object (dataclass) for generation
+            extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
+                through to SequenceGenerator
+            prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
+                If provided, this function constrains the beam search to
+                allowed tokens only at each step. The provided function
+                should take 2 arguments: the batch ID (`batch_id: int`)
+                and a unidimensional tensor of token ids (`inputs_ids:
+                torch.Tensor`). It has to return a `List[int]` with the
+                allowed tokens for the next generation step conditioned
+                on the previously generated tokens (`inputs_ids`) and
+                the batch ID (`batch_id`). This argument is useful for
+                constrained generation conditioned on the prefix, as
+                described in "Autoregressive Entity Retrieval"
+                (https://arxiv.org/abs/2010.00904) and
+                https://github.com/facebookresearch/GENRE.
+        """
+        if getattr(args, "score_reference", False):
+            from fairseq.sequence_scorer import SequenceScorer
+
+            return SequenceScorer(
+                self.target_dictionary,
+                compute_alignment=getattr(args, "print_alignment", False),
+            )
+
+        # Choose search strategy. Defaults to Beam Search.
+        sampling = getattr(args, "sampling", False)
+        sampling_topk = getattr(args, "sampling_topk", -1)
+        sampling_topp = getattr(args, "sampling_topp", -1.0)
+        diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
+        diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
+        match_source_len = getattr(args, "match_source_len", False)
+        diversity_rate = getattr(args, "diversity_rate", -1)
+        constrained = getattr(args, "constraints", False)
+        if prefix_allowed_tokens_fn is None:
+            prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
+        if (
+            sum(
+                int(cond)
+                for cond in [
+                    sampling,
+                    diverse_beam_groups > 0,
+                    match_source_len,
+                    diversity_rate > 0,
+                ]
+            )
+            > 1
+        ):
+            raise ValueError("Provided Search parameters are mutually exclusive.")
+        assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
+        assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
+
+        if sampling:
+            search_strategy = search.Sampling(
+                self.target_dictionary, sampling_topk, sampling_topp
+            )
+        elif diverse_beam_groups > 0:
+            search_strategy = search.DiverseBeamSearch(
+                self.target_dictionary, diverse_beam_groups, diverse_beam_strength
+            )
+        elif match_source_len:
+            # this is useful for tagging applications where the output
+            # length should match the input length, so we hardcode the
+            # length constraints for simplicity
+            search_strategy = search.LengthConstrainedBeamSearch(
+                self.target_dictionary,
+                min_len_a=1,
+                min_len_b=0,
+                max_len_a=1,
+                max_len_b=0,
+            )
+        elif diversity_rate > -1:
+            search_strategy = search.DiverseSiblingsSearch(
+                self.target_dictionary, diversity_rate
+            )
+        elif constrained:
+            search_strategy = search.LexicallyConstrainedBeamSearch(
+                self.target_dictionary, args.constraints
+            )
+        elif prefix_allowed_tokens_fn:
+            search_strategy = search.PrefixConstrainedBeamSearch(
+                self.target_dictionary, prefix_allowed_tokens_fn
+            )
+        else:
+            search_strategy = search.BeamSearch(self.target_dictionary)
+
+        extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
+        if seq_gen_cls is None:
+            if getattr(args, "print_alignment", False):
+                seq_gen_cls = SequenceGeneratorWithAlignment
+                extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
+            else:
+                seq_gen_cls = SequenceGenerator
+
+        return seq_gen_cls(
+            models,
+            self.target_dictionary,
+            beam_size=getattr(args, "beam", 5),
+            max_len_a=getattr(args, "max_len_a", 0),
+            max_len_b=getattr(args, "max_len_b", 200),
+            min_len=getattr(args, "min_len", 1),
+            normalize_scores=(not getattr(args, "unnormalized", False)),
+            len_penalty=getattr(args, "lenpen", 1),
+            unk_penalty=getattr(args, "unkpen", 0),
+            temperature=getattr(args, "temperature", 1.0),
+            match_source_len=getattr(args, "match_source_len", False),
+            no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
+            search_strategy=search_strategy,
+            **extra_gen_cls_kwargs,
+        )
diff --git a/slam_llm/models/avhubert/infer_s2s.py b/slam_llm/models/avhubert/infer_s2s.py
new file mode 100644
index 0000000000000000000000000000000000000000..805b66fdc976073ba61364072bbd41df7547a4b9
--- /dev/null
+++ b/slam_llm/models/avhubert/infer_s2s.py
@@ -0,0 +1,318 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import ast
+from itertools import chain
+import logging
+import math
+import os
+import sys
+import json
+import hashlib
+import editdistance
+from argparse import Namespace
+
+import numpy as np
+import torch
+from fairseq import checkpoint_utils, options, tasks, utils, distributed_utils
+from fairseq.dataclass.utils import convert_namespace_to_omegaconf
+from fairseq.logging import progress_bar
+from fairseq.logging.meters import StopwatchMeter, TimeMeter
+from fairseq.models import FairseqLanguageModel
+from omegaconf import DictConfig
+
+from pathlib import Path
+import hydra
+from hydra.core.config_store import ConfigStore
+from fairseq.dataclass.configs import (
+    CheckpointConfig,
+    CommonConfig,
+    CommonEvalConfig,
+    DatasetConfig,
+    DistributedTrainingConfig,
+    GenerationConfig,
+    FairseqDataclass,
+)
+from dataclasses import dataclass, field, is_dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+from omegaconf import OmegaConf
+
+logging.root.setLevel(logging.INFO)
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+config_path = Path(__file__).resolve().parent / "conf"
+
+@dataclass
+class OverrideConfig(FairseqDataclass):
+    noise_wav: Optional[str] = field(default=None, metadata={'help': 'noise wav file'})
+    noise_prob: float = field(default=0, metadata={'help': 'noise probability'})
+    noise_snr: float = field(default=0, metadata={'help': 'noise SNR in audio'})
+    modalities: List[str] = field(default_factory=lambda: [""], metadata={'help': 'which modality to use'})
+    data: Optional[str] = field(default=None, metadata={'help': 'path to test data directory'})
+    label_dir: Optional[str] = field(default=None, metadata={'help': 'path to test label directory'})
+
+@dataclass
+class InferConfig(FairseqDataclass):
+    task: Any = None
+    generation: GenerationConfig = GenerationConfig()
+    common: CommonConfig = CommonConfig()
+    common_eval: CommonEvalConfig = CommonEvalConfig()
+    checkpoint: CheckpointConfig = CheckpointConfig()
+    distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
+    dataset: DatasetConfig = DatasetConfig()
+    override: OverrideConfig = OverrideConfig()
+    is_ax: bool = field(
+        default=False,
+        metadata={
+            "help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume"
+        },
+    )
+
+
+def main(cfg: DictConfig):
+
+    if isinstance(cfg, Namespace):
+        cfg = convert_namespace_to_omegaconf(cfg)
+
+    assert cfg.common_eval.path is not None, "--path required for recognition!"
+    assert (
+        not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
+    ), "--sampling requires --nbest to be equal to --beam"
+
+    if cfg.common_eval.results_path is not None:
+        os.makedirs(cfg.common_eval.results_path, exist_ok=True)
+        output_path = os.path.join(cfg.common_eval.results_path, "decode.log")
+        with open(output_path, "w", buffering=1, encoding="utf-8") as h:
+            return _main(cfg, h)
+    return _main(cfg, sys.stdout)
+
+
+def get_symbols_to_strip_from_output(generator):
+    if hasattr(generator, "symbols_to_strip_from_output"):
+        return generator.symbols_to_strip_from_output
+    else:
+        return {generator.eos, generator.pad}
+
+def _main(cfg, output_file):
+    logging.basicConfig(
+        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+        datefmt="%Y-%m-%d %H:%M:%S",
+        level=os.environ.get("LOGLEVEL", "INFO").upper(),
+        stream=output_file,
+    )
+    logger = logging.getLogger("hybrid.speech_recognize")
+    if output_file is not sys.stdout:  # also print to stdout
+        logger.addHandler(logging.StreamHandler(sys.stdout))
+
+    utils.import_user_module(cfg.common)
+    models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([cfg.common_eval.path])
+    models = [model.eval().cuda() for model in models]  #!!
+    saved_cfg.task.modalities = cfg.override.modalities
+    task = tasks.setup_task(saved_cfg.task)
+
+    task.build_tokenizer(saved_cfg.tokenizer)
+    task.build_bpe(saved_cfg.bpe)
+
+    logger.info(cfg)
+
+    # Fix seed for stochastic decoding
+    if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
+        np.random.seed(cfg.common.seed)
+        utils.set_torch_seed(cfg.common.seed)
+
+    use_cuda = torch.cuda.is_available()
+
+    # Set dictionary
+    dictionary = task.target_dictionary
+
+    # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
+    task.cfg.noise_prob = cfg.override.noise_prob
+    task.cfg.noise_snr = cfg.override.noise_snr
+    task.cfg.noise_wav = cfg.override.noise_wav
+    if cfg.override.data is not None:
+        task.cfg.data = cfg.override.data
+    if cfg.override.label_dir is not None:
+        task.cfg.label_dir = cfg.override.label_dir
+    task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
+
+    lms = [None]
+
+    # Optimize ensemble for generation
+    for model in chain(models, lms):
+        if model is None:
+            continue
+        if cfg.common.fp16:
+            model.half()
+        if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
+            model.cuda()
+        model.prepare_for_inference_(cfg)
+
+    # Load dataset (possibly sharded)
+    itr = task.get_batch_iterator(
+        dataset=task.dataset(cfg.dataset.gen_subset),
+        max_tokens=cfg.dataset.max_tokens,
+        max_sentences=cfg.dataset.batch_size,
+        max_positions=utils.resolve_max_positions(
+            task.max_positions(), *[m.max_positions() for m in models]
+        ),
+        ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
+        required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
+        seed=cfg.common.seed,
+        num_shards=cfg.distributed_training.distributed_world_size,
+        shard_id=cfg.distributed_training.distributed_rank,
+        num_workers=cfg.dataset.num_workers,
+        data_buffer_size=cfg.dataset.data_buffer_size,
+    ).next_epoch_itr(shuffle=False)
+    progress = progress_bar.progress_bar(
+        itr,
+        log_format=cfg.common.log_format,
+        log_interval=cfg.common.log_interval,
+        default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
+    )
+
+    # Initialize generator
+    if cfg.generation.match_source_len:
+        logger.warning(
+            "The option match_source_len is not applicable to speech recognition. Ignoring it."
+        )
+    gen_timer = StopwatchMeter()
+    extra_gen_cls_kwargs = {
+        "lm_model": lms[0],
+        "lm_weight": cfg.generation.lm_weight,
+    }
+    cfg.generation.score_reference = False  #
+    save_attention_plot = cfg.generation.print_alignment is not None
+    cfg.generation.print_alignment = None  #
+    generator = task.build_generator(
+        models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
+    )
+
+    def decode_fn(x):
+        symbols_ignore = get_symbols_to_strip_from_output(generator)
+        symbols_ignore.add(dictionary.pad())
+        if hasattr(task.datasets[cfg.dataset.gen_subset].label_processors[0], 'decode'):
+            return task.datasets[cfg.dataset.gen_subset].label_processors[0].decode(x, symbols_ignore)
+        chars = dictionary.string(x, extra_symbols_to_ignore=symbols_ignore)
+        words = " ".join("".join(chars.split()).replace('|', ' ').split())
+        return words
+
+    num_sentences = 0
+    has_target = True
+    wps_meter = TimeMeter()
+    result_dict = {'utt_id': [], 'ref': [], 'hypo': []}
+    for sample in progress:
+        sample = utils.move_to_cuda(sample) if use_cuda else sample
+        if "net_input" not in sample:
+            continue
+
+        prefix_tokens = None
+        if cfg.generation.prefix_size > 0:
+            prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
+
+        constraints = None
+        if "constraints" in sample:
+            constraints = sample["constraints"]
+
+        gen_timer.start()
+        hypos = task.inference_step(
+            generator,
+            models,
+            sample,
+            prefix_tokens=prefix_tokens,
+            constraints=constraints,
+        )
+        num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
+        gen_timer.stop(num_generated_tokens)
+
+        for i in range(len(sample["id"])):
+            result_dict['utt_id'].append(sample['utt_id'][i])
+            ref_sent = decode_fn(sample['target'][i].int().cpu())
+            result_dict['ref'].append(ref_sent)
+            best_hypo = hypos[i][0]['tokens'].int().cpu()
+            hypo_str = decode_fn(best_hypo)
+            result_dict['hypo'].append(hypo_str)
+            logger.info(f"\nREF:{ref_sent}\nHYP:{hypo_str}\n")
+        wps_meter.update(num_generated_tokens)
+        progress.log({"wps": round(wps_meter.avg)})
+        num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
+
+    logger.info("NOTE: hypothesis and token scores are output in base 2")
+    logger.info("Recognized {:,} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format(
+        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
+
+    yaml_str = OmegaConf.to_yaml(cfg.generation)
+    fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16)
+    fid = fid % 1000000
+    result_fn = f"{cfg.common_eval.results_path}/hypo-{fid}.json"
+    json.dump(result_dict, open(result_fn, 'w'), indent=4)
+    n_err, n_total = 0, 0
+    assert len(result_dict['hypo']) == len(result_dict['ref'])
+    for hypo, ref in zip(result_dict['hypo'], result_dict['ref']):
+        hypo, ref = hypo.strip().split(), ref.strip().split()
+        n_err += editdistance.eval(hypo, ref)
+        n_total += len(ref)
+    wer = 100 * n_err / n_total
+    wer_fn = f"{cfg.common_eval.results_path}/wer.{fid}"
+    with open(wer_fn, "w") as fo:
+        fo.write(f"WER: {wer}\n")
+        fo.write(f"err / num_ref_words = {n_err} / {n_total}\n\n")
+        fo.write(f"{yaml_str}")
+    logger.info(f"WER: {wer}%")
+    return
+
+
+@hydra.main(config_path=config_path, config_name="infer")
+def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]:
+    container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
+    cfg = OmegaConf.create(container)
+    OmegaConf.set_struct(cfg, True)
+
+    if cfg.common.reset_logging:
+        reset_logging()
+
+    wer = float("inf")
+
+    try:
+        if cfg.common.profile:
+            with torch.cuda.profiler.profile():
+                with torch.autograd.profiler.emit_nvtx():
+                    distributed_utils.call_main(cfg, main)
+        else:
+            distributed_utils.call_main(cfg, main)
+
+    except BaseException as e:  # pylint: disable=broad-except
+        if not cfg.common.suppress_crashes:
+            raise
+        else:
+            logger.error("Crashed! %s", str(e))
+    return
+
+
+def cli_main() -> None:
+    try:
+        from hydra._internal.utils import (
+            get_args,
+        )  # pylint: disable=import-outside-toplevel
+
+        cfg_name = get_args().config_name or "infer"
+    except ImportError:
+        logger.warning("Failed to get config name from hydra args")
+        cfg_name = "infer"
+
+    cs = ConfigStore.instance()
+    cs.store(name=cfg_name, node=InferConfig)
+
+    for k in InferConfig.__dataclass_fields__:
+        if is_dataclass(InferConfig.__dataclass_fields__[k].type):
+            v = InferConfig.__dataclass_fields__[k].default
+            cs.store(name=k, node=v)
+
+    hydra_main()  # pylint: disable=no-value-for-parameter
+
+
+if __name__ == "__main__":
+    cli_main()
diff --git a/slam_llm/models/avhubert/resnet.py b/slam_llm/models/avhubert/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..0134e111f8892a8c314628e7ffc3543fd9d35c03
--- /dev/null
+++ b/slam_llm/models/avhubert/resnet.py
@@ -0,0 +1,169 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import math
+import torch.nn as nn
+import pdb
+
+
+logger = logging.getLogger(__name__)
+
+def conv3x3(in_planes, out_planes, stride=1):
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+
+def downsample_basic_block( inplanes, outplanes, stride ):
+    return  nn.Sequential(
+                nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(outplanes),
+            )
+
+def downsample_basic_block_v2( inplanes, outplanes, stride ):
+    return  nn.Sequential(
+                nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False),
+                nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False),
+                nn.BatchNorm2d(outplanes),
+            )
+
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type = 'relu' ):
+        super(BasicBlock, self).__init__()
+
+        assert relu_type in ['relu','prelu']
+
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes)
+
+        if relu_type == 'relu':
+            self.relu1 = nn.ReLU(inplace=True)
+            self.relu2 = nn.ReLU(inplace=True)
+        elif relu_type == 'prelu':
+            self.relu1 = nn.PReLU(num_parameters=planes)
+            self.relu2 = nn.PReLU(num_parameters=planes)
+        else:
+            raise Exception('relu type not implemented')
+
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu1(out)
+        out = self.conv2(out)
+        out = self.bn2(out)
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu2(out)
+
+        return out
+
+
+class ResNet(nn.Module):
+
+    def __init__(self, block, layers, num_classes=1000, relu_type = 'relu', gamma_zero = False, avg_pool_downsample = False):
+        self.inplanes = 64
+        self.relu_type = relu_type
+        self.gamma_zero = gamma_zero
+        self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block
+
+        super(ResNet, self).__init__()
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+        self.avgpool = nn.AdaptiveAvgPool2d(1)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+        if self.gamma_zero:
+            for m in self.modules():
+                if isinstance(m, BasicBlock ):
+                    m.bn2.weight.data.zero_()
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+
+
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = self.downsample_block( inplanes = self.inplanes, 
+                                                 outplanes = planes * block.expansion, 
+                                                 stride = stride )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample, relu_type = self.relu_type))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes, relu_type = self.relu_type))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        return x
+
+class ResEncoder(nn.Module):
+    def __init__(self, relu_type, weights):
+        super(ResEncoder, self).__init__()
+        self.frontend_nout = 64
+        self.backend_out = 512
+        frontend_relu = nn.PReLU(num_parameters=self.frontend_nout) if relu_type == 'prelu' else nn.ReLU()
+        self.frontend3D = nn.Sequential(
+            nn.Conv3d(1, self.frontend_nout, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False),
+            nn.BatchNorm3d(self.frontend_nout),
+            frontend_relu,
+            nn.MaxPool3d( kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)))
+        self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type)
+        if weights is not None:
+            logger.info(f"Load {weights} for resnet")
+            std = torch.load(weights, map_location=torch.device('cpu'))['model_state_dict']
+            frontend_std, trunk_std = OrderedDict(), OrderedDict()
+            for key, val in std.items():
+                new_key = '.'.join(key.split('.')[1:])
+                if 'frontend3D' in key:
+                    frontend_std[new_key] = val
+                if 'trunk' in key:
+                    trunk_std[new_key] = val
+            self.frontend3D.load_state_dict(frontend_std)
+            self.trunk.load_state_dict(trunk_std)
+
+    def forward(self, x):
+        B, C, T, H, W = x.size()
+        x = self.frontend3D(x)
+        Tnew = x.shape[2]
+        x = self.threeD_to_2D_tensor(x)
+        x = self.trunk(x)
+        x = x.view(B, Tnew, x.size(1))
+        x = x.transpose(1, 2).contiguous()
+        return x
+
+    def threeD_to_2D_tensor(self, x):
+        n_batch, n_channels, s_time, sx, sy = x.shape
+        x = x.transpose(1, 2).contiguous()
+        return x.reshape(n_batch*s_time, n_channels, sx, sy)
diff --git a/slam_llm/models/avhubert/sequence_generator.py b/slam_llm/models/avhubert/sequence_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..428a5765e5cb2f3d56771eac5df2e31f0484fdf6
--- /dev/null
+++ b/slam_llm/models/avhubert/sequence_generator.py
@@ -0,0 +1,985 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Dict, List, Optional
+import sys
+
+import torch
+import torch.nn as nn
+from fairseq import search, utils
+from fairseq.data import data_utils
+from fairseq.models import FairseqIncrementalDecoder
+from torch import Tensor
+from fairseq.ngram_repeat_block import NGramRepeatBlock
+
+
+class SequenceGenerator(nn.Module):
+    def __init__(
+        self,
+        models,
+        tgt_dict,
+        beam_size=1,
+        max_len_a=0,
+        max_len_b=200,
+        max_len=0,
+        min_len=1,
+        normalize_scores=True,
+        len_penalty=1.0,
+        unk_penalty=0.0,
+        temperature=1.0,
+        match_source_len=False,
+        no_repeat_ngram_size=0,
+        search_strategy=None,
+        eos=None,
+        symbols_to_strip_from_output=None,
+        lm_model=None,
+        lm_weight=1.0,
+    ):
+        """Generates translations of a given source sentence.
+
+        Args:
+            models (List[~fairseq.models.FairseqModel]): ensemble of models,
+                currently support fairseq.models.TransformerModel for scripting
+            beam_size (int, optional): beam width (default: 1)
+            max_len_a/b (int, optional): generate sequences of maximum length
+                ax + b, where x is the source length
+            max_len (int, optional): the maximum length of the generated output
+                (not including end-of-sentence)
+            min_len (int, optional): the minimum length of the generated output
+                (not including end-of-sentence)
+            normalize_scores (bool, optional): normalize scores by the length
+                of the output (default: True)
+            len_penalty (float, optional): length penalty, where <1.0 favors
+                shorter, >1.0 favors longer sentences (default: 1.0)
+            unk_penalty (float, optional): unknown word penalty, where <0
+                produces more unks, >0 produces fewer (default: 0.0)
+            temperature (float, optional): temperature, where values
+                >1.0 produce more uniform samples and values <1.0 produce
+                sharper samples (default: 1.0)
+            match_source_len (bool, optional): outputs should match the source
+                length (default: False)
+        """
+        super().__init__()
+        if isinstance(models, EnsembleModel):
+            self.model = models
+        else:
+            self.model = EnsembleModel(models)
+        self.tgt_dict = tgt_dict
+        self.pad = tgt_dict.pad()
+        self.unk = tgt_dict.unk()
+        self.eos = tgt_dict.eos() if eos is None else eos
+        self.symbols_to_strip_from_output = (
+            symbols_to_strip_from_output.union({self.eos})
+            if symbols_to_strip_from_output is not None
+            else {self.eos}
+        )
+        self.vocab_size = len(tgt_dict)
+        self.beam_size = beam_size
+        # the max beam size is the dictionary size - 1, since we never select pad
+        self.beam_size = min(beam_size, self.vocab_size - 1)
+        self.max_len_a = max_len_a
+        self.max_len_b = max_len_b
+        self.min_len = min_len
+        self.max_len = max_len or self.model.max_decoder_positions()
+
+        self.normalize_scores = normalize_scores
+        self.len_penalty = len_penalty
+        self.unk_penalty = unk_penalty
+        self.temperature = temperature
+        self.match_source_len = match_source_len
+
+        if no_repeat_ngram_size > 0:
+            self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size)
+        else:
+            self.repeat_ngram_blocker = None
+
+        assert temperature > 0, "--temperature must be greater than 0"
+
+        self.search = (
+            search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
+        )
+        # We only need to set src_lengths in LengthConstrainedBeamSearch.
+        # As a module attribute, setting it would break in multithread
+        # settings when the model is shared.
+        self.should_set_src_lengths = (
+            hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
+        )
+
+        self.model.eval()
+
+        self.lm_model = lm_model
+        self.lm_weight = lm_weight
+        if self.lm_model is not None:
+            self.lm_model.eval()
+
+    def cuda(self):
+        self.model.cuda()
+        return self
+
+    @torch.no_grad()
+    def forward(
+        self,
+        sample: Dict[str, Dict[str, Tensor]],
+        prefix_tokens: Optional[Tensor] = None,
+        bos_token: Optional[int] = None,
+    ):
+        """Generate a batch of translations.
+
+        Args:
+            sample (dict): batch
+            prefix_tokens (torch.LongTensor, optional): force decoder to begin
+                with these tokens
+            bos_token (int, optional): beginning of sentence token
+                (default: self.eos)
+        """
+        return self._generate(sample, prefix_tokens, bos_token=bos_token)
+
+    # TODO(myleott): unused, deprecate after pytorch-translate migration
+    def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
+        """Iterate over a batched dataset and yield individual translations.
+        Args:
+            cuda (bool, optional): use GPU for generation
+            timer (StopwatchMeter, optional): time generations
+        """
+        for sample in data_itr:
+            s = utils.move_to_cuda(sample) if cuda else sample
+            if "net_input" not in s:
+                continue
+            input = s["net_input"]
+            # model.forward normally channels prev_output_tokens into the decoder
+            # separately, but SequenceGenerator directly calls model.encoder
+            encoder_input = {
+                k: v for k, v in input.items() if k != "prev_output_tokens"
+            }
+            if timer is not None:
+                timer.start()
+            with torch.no_grad():
+                hypos = self.generate(encoder_input)
+            if timer is not None:
+                timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
+            for i, id in enumerate(s["id"].data):
+                # remove padding
+                src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
+                ref = (
+                    utils.strip_pad(s["target"].data[i, :], self.pad)
+                    if s["target"] is not None
+                    else None
+                )
+                yield id, src, ref, hypos[i]
+
+    @torch.no_grad()
+    def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
+        """Generate translations. Match the api of other fairseq generators.
+
+        Args:
+            models (List[~fairseq.models.FairseqModel]): ensemble of models
+            sample (dict): batch
+            prefix_tokens (torch.LongTensor, optional): force decoder to begin
+                with these tokens
+            constraints (torch.LongTensor, optional): force decoder to include
+                the list of constraints
+            bos_token (int, optional): beginning of sentence token
+                (default: self.eos)
+        """
+        return self._generate(sample, **kwargs)
+
+    def _generate(
+        self,
+        sample: Dict[str, Dict[str, Tensor]],
+        prefix_tokens: Optional[Tensor] = None,
+        constraints: Optional[Tensor] = None,
+        bos_token: Optional[int] = None,
+    ):
+        incremental_states = torch.jit.annotate(
+            List[Dict[str, Dict[str, Optional[Tensor]]]],
+            [
+                torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
+                for i in range(self.model.models_size)
+            ],
+        )
+        net_input = sample["net_input"]
+
+        if "src_tokens" in net_input:
+            src_tokens = net_input["src_tokens"]
+            # length of the source text being the character length except EndOfSentence and pad
+            src_lengths = (
+                (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
+            )
+        elif "source" in net_input:
+            src_tokens = net_input["source"]
+            src_lengths = (
+                net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
+                if net_input["padding_mask"] is not None
+                else torch.tensor(src_tokens.size(-1)).to(src_tokens)
+            )
+        elif "features" in net_input:
+            src_tokens = net_input["features"]
+            src_lengths = (
+                net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
+                if net_input["padding_mask"] is not None
+                else torch.tensor(src_tokens.size(-1)).to(src_tokens)
+            )
+        else:
+            raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
+
+        # bsz: total number of sentences in beam
+        # Note that src_tokens may have more than 2 dimensions (i.e. audio features)
+        if src_tokens['audio'] is not None:
+            bsz, src_len = src_tokens['audio'].size()[:2]
+            src_device = src_tokens['audio'].device
+        else:
+            bsz, src_len = net_input['padding_mask'].size()
+            src_device = src_tokens['video'].device
+        beam_size = self.beam_size
+        if constraints is not None and not self.search.supports_constraints:
+            raise NotImplementedError(
+                "Target-side constraints were provided, but search method doesn't support them"
+            )
+
+        # Initialize constraints, when active
+        self.search.init_constraints(constraints, beam_size)
+
+        max_len: int = -1
+        if self.match_source_len:
+            max_len = src_lengths.max().item()
+        else:
+            max_len = min(
+                int(self.max_len_a * src_len + self.max_len_b),
+                self.max_len - 1,
+            )
+        assert (
+            self.min_len <= max_len
+        ), "min_len cannot be larger than max_len, please adjust these!"
+        # compute the encoder output for each beam
+        encoder_outs = self.model.forward_encoder(net_input)
+
+        # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
+        new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
+        new_order = new_order.to(src_device).long()
+        encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
+        # ensure encoder_outs is a List.
+        assert encoder_outs is not None
+
+        # initialize buffers
+        scores = (
+            torch.zeros(bsz * beam_size, max_len + 1).to(src_device).float()
+        )  # +1 for eos; pad is never chosen for scoring
+        tokens = (
+            torch.zeros(bsz * beam_size, max_len + 2)
+            .to(src_device)
+            .long()
+            .fill_(self.pad)
+        )  # +2 for eos and pad
+        tokens[:, 0] = self.eos if bos_token is None else bos_token
+        attn: Optional[Tensor] = None
+
+        # A list that indicates candidates that should be ignored.
+        # For example, suppose we're sampling and have already finalized 2/5
+        # samples. Then cands_to_ignore would mark 2 positions as being ignored,
+        # so that we only finalize the remaining 3 samples.
+        cands_to_ignore = (
+            torch.zeros(bsz, beam_size).to(src_device).eq(-1)
+        )  # forward and backward-compatible False mask
+
+        # list of completed sentences
+        finalized = torch.jit.annotate(
+            List[List[Dict[str, Tensor]]],
+            [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
+        )  # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
+
+        # a boolean array indicating if the sentence at the index is finished or not
+        finished = [False for i in range(bsz)]
+        num_remaining_sent = bsz  # number of sentences remaining
+
+        # number of candidate hypos per step
+        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS
+
+        # offset arrays for converting between different indexing schemes
+        bbsz_offsets = (
+            (torch.arange(0, bsz) * beam_size)
+            .unsqueeze(1)
+            .type_as(tokens)
+            .to(src_device)
+        )
+        cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_device)
+
+        reorder_state: Optional[Tensor] = None
+        batch_idxs: Optional[Tensor] = None
+
+        original_batch_idxs: Optional[Tensor] = None
+        if "id" in sample and isinstance(sample["id"], Tensor):
+            original_batch_idxs = sample["id"]
+        else:
+            original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
+
+        for step in range(max_len + 1):  # one extra step for EOS marker
+            # reorder decoder internal states based on the prev choice of beams
+            if reorder_state is not None:
+                if batch_idxs is not None:
+                    # update beam indices to take into account removed sentences
+                    corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
+                        batch_idxs
+                    )
+                    reorder_state.view(-1, beam_size).add_(
+                        corr.unsqueeze(-1) * beam_size
+                    )
+                    original_batch_idxs = original_batch_idxs[batch_idxs]
+                self.model.reorder_incremental_state(incremental_states, reorder_state)
+                encoder_outs = self.model.reorder_encoder_out(
+                    encoder_outs, reorder_state
+                )
+
+            lprobs, avg_attn_scores = self.model.forward_decoder(
+                tokens[:, : step + 1],
+                encoder_outs,
+                incremental_states,
+                self.temperature,
+            )
+
+            if self.lm_model is not None:
+                lm_out = self.lm_model(tokens[:, : step + 1])
+                probs = self.lm_model.get_normalized_probs(
+                    lm_out, log_probs=True, sample=None
+                )
+                probs = probs[:, -1, :] * self.lm_weight
+                lprobs += probs
+
+            lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
+
+            lprobs[:, self.pad] = -math.inf  # never select pad
+            lprobs[:, self.unk] -= self.unk_penalty  # apply unk penalty
+
+            # handle max length constraint
+            if step >= max_len:
+                lprobs[:, : self.eos] = -math.inf
+                lprobs[:, self.eos + 1 :] = -math.inf
+
+            # handle prefix tokens (possibly with different lengths)
+            if (
+                prefix_tokens is not None
+                and step < prefix_tokens.size(1)
+                and step < max_len
+            ):
+                lprobs, tokens, scores = self._prefix_tokens(
+                    step, lprobs, scores, tokens, prefix_tokens, beam_size
+                )
+            elif step < self.min_len:
+                # minimum length constraint (does not apply if using prefix_tokens)
+                lprobs[:, self.eos] = -math.inf
+
+            # Record attention scores, only support avg_attn_scores is a Tensor
+            if avg_attn_scores is not None:
+                if attn is None:
+                    attn = torch.empty(
+                        bsz * beam_size, avg_attn_scores.size(1), max_len + 2
+                    ).to(scores)
+                attn[:, :, step + 1].copy_(avg_attn_scores)
+
+            scores = scores.type_as(lprobs)
+            eos_bbsz_idx = torch.empty(0).to(
+                tokens
+            )  # indices of hypothesis ending with eos (finished sentences)
+            eos_scores = torch.empty(0).to(
+                scores
+            )  # scores of hypothesis ending with eos (finished sentences)
+
+            if self.should_set_src_lengths:
+                self.search.set_src_lengths(src_lengths)
+
+            if self.repeat_ngram_blocker is not None:
+                lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
+
+            # Shape: (batch, cand_size)
+            cand_scores, cand_indices, cand_beams = self.search.step(
+                step,
+                lprobs.view(bsz, -1, self.vocab_size),
+                scores.view(bsz, beam_size, -1)[:, :, :step],
+                tokens[:, : step + 1],
+                original_batch_idxs,
+            )
+
+            # cand_bbsz_idx contains beam indices for the top candidate
+            # hypotheses, with a range of values: [0, bsz*beam_size),
+            # and dimensions: [bsz, cand_size]
+            cand_bbsz_idx = cand_beams.add(bbsz_offsets)
+
+            # finalize hypotheses that end in eos
+            # Shape of eos_mask: (batch size, beam size)
+            eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
+            eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
+
+            # only consider eos when it's among the top beam_size indices
+            # Now we know what beam item(s) to finish
+            # Shape: 1d list of absolute-numbered
+            eos_bbsz_idx = torch.masked_select(
+                cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
+            )
+
+            finalized_sents: List[int] = []
+            if eos_bbsz_idx.numel() > 0:
+                eos_scores = torch.masked_select(
+                    cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
+                )
+
+                finalized_sents = self.finalize_hypos(
+                    step,
+                    eos_bbsz_idx,
+                    eos_scores,
+                    tokens,
+                    scores,
+                    finalized,
+                    finished,
+                    beam_size,
+                    attn,
+                    src_lengths,
+                    max_len,
+                )
+                num_remaining_sent -= len(finalized_sents)
+
+            assert num_remaining_sent >= 0
+            if num_remaining_sent == 0:
+                break
+            if self.search.stop_on_max_len and step >= max_len:
+                break
+            assert step < max_len, f"{step} < {max_len}"
+
+            # Remove finalized sentences (ones for which {beam_size}
+            # finished hypotheses have been generated) from the batch.
+            if len(finalized_sents) > 0:
+                new_bsz = bsz - len(finalized_sents)
+
+                # construct batch_idxs which holds indices of batches to keep for the next pass
+                batch_mask = torch.ones(
+                    bsz, dtype=torch.bool, device=cand_indices.device
+                )
+                batch_mask[finalized_sents] = False
+                # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
+                batch_idxs = torch.arange(
+                    bsz, device=cand_indices.device
+                ).masked_select(batch_mask)
+
+                # Choose the subset of the hypothesized constraints that will continue
+                self.search.prune_sentences(batch_idxs)
+
+                eos_mask = eos_mask[batch_idxs]
+                cand_beams = cand_beams[batch_idxs]
+                bbsz_offsets.resize_(new_bsz, 1)
+                cand_bbsz_idx = cand_beams.add(bbsz_offsets)
+                cand_scores = cand_scores[batch_idxs]
+                cand_indices = cand_indices[batch_idxs]
+
+                if prefix_tokens is not None:
+                    prefix_tokens = prefix_tokens[batch_idxs]
+                src_lengths = src_lengths[batch_idxs]
+                cands_to_ignore = cands_to_ignore[batch_idxs]
+
+                scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
+                tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
+                if attn is not None:
+                    attn = attn.view(bsz, -1)[batch_idxs].view(
+                        new_bsz * beam_size, attn.size(1), -1
+                    )
+                bsz = new_bsz
+            else:
+                batch_idxs = None
+
+            # Set active_mask so that values > cand_size indicate eos hypos
+            # and values < cand_size indicate candidate active hypos.
+            # After, the min values per row are the top candidate active hypos
+
+            # Rewrite the operator since the element wise or is not supported in torchscript.
+
+            eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
+            active_mask = torch.add(
+                eos_mask.type_as(cand_offsets) * cand_size,
+                cand_offsets[: eos_mask.size(1)],
+            )
+
+            # get the top beam_size active hypotheses, which are just
+            # the hypos with the smallest values in active_mask.
+            # {active_hypos} indicates which {beam_size} hypotheses
+            # from the list of {2 * beam_size} candidates were
+            # selected. Shapes: (batch size, beam size)
+            new_cands_to_ignore, active_hypos = torch.topk(
+                active_mask, k=beam_size, dim=1, largest=False
+            )
+
+            # update cands_to_ignore to ignore any finalized hypos.
+            cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
+            # Make sure there is at least one active item for each sentence in the batch.
+            assert (~cands_to_ignore).any(dim=1).all()
+
+            # update cands_to_ignore to ignore any finalized hypos
+
+            # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
+            # can be selected more than once).
+            active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
+            active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
+
+            active_bbsz_idx = active_bbsz_idx.view(-1)
+            active_scores = active_scores.view(-1)
+
+            # copy tokens and scores for active hypotheses
+
+            # Set the tokens for each beam (can select the same row more than once)
+            tokens[:, : step + 1] = torch.index_select(
+                tokens[:, : step + 1], dim=0, index=active_bbsz_idx
+            )
+            # Select the next token for each of them
+            tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
+                cand_indices, dim=1, index=active_hypos
+            )
+            if step > 0:
+                scores[:, :step] = torch.index_select(
+                    scores[:, :step], dim=0, index=active_bbsz_idx
+                )
+            scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
+                cand_scores, dim=1, index=active_hypos
+            )
+
+            # Update constraints based on which candidates were selected for the next beam
+            self.search.update_constraints(active_hypos)
+
+            # copy attention for active hypotheses
+            if attn is not None:
+                attn[:, :, : step + 2] = torch.index_select(
+                    attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
+                )
+
+            # reorder incremental state in decoder
+            reorder_state = active_bbsz_idx
+
+        # sort by score descending
+        for sent in range(len(finalized)):
+            scores = torch.tensor(
+                [float(elem["score"].item()) for elem in finalized[sent]]
+            )
+            _, sorted_scores_indices = torch.sort(scores, descending=True)
+            finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
+            finalized[sent] = torch.jit.annotate(
+                List[Dict[str, Tensor]], finalized[sent]
+            )
+        return finalized
+
+    def _prefix_tokens(
+        self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
+    ):
+        """Handle prefix tokens"""
+        prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
+        prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
+        prefix_mask = prefix_toks.ne(self.pad)
+        lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs)
+        lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
+            -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
+        )
+        # if prefix includes eos, then we should make sure tokens and
+        # scores are the same across all beams
+        eos_mask = prefix_toks.eq(self.eos)
+        if eos_mask.any():
+            # validate that the first beam matches the prefix
+            first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
+                :, 0, 1 : step + 1
+            ]
+            eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
+            target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
+            assert (first_beam == target_prefix).all()
+
+            # copy tokens, scores and lprobs from the first beam to all beams
+            tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
+            scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
+            lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
+        return lprobs, tokens, scores
+
+    def replicate_first_beam(self, tensor, mask, beam_size: int):
+        tensor = tensor.view(-1, beam_size, tensor.size(-1))
+        tensor[mask] = tensor[mask][:, :1, :]
+        return tensor.view(-1, tensor.size(-1))
+
+    def finalize_hypos(
+        self,
+        step: int,
+        bbsz_idx,
+        eos_scores,
+        tokens,
+        scores,
+        finalized: List[List[Dict[str, Tensor]]],
+        finished: List[bool],
+        beam_size: int,
+        attn: Optional[Tensor],
+        src_lengths,
+        max_len: int,
+    ):
+        """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
+        A sentence is finalized when {beam_size} finished items have been collected for it.
+
+        Returns number of sentences (not beam items) being finalized.
+        These will be removed from the batch and not processed further.
+        Args:
+            bbsz_idx (Tensor):
+        """
+        assert bbsz_idx.numel() == eos_scores.numel()
+
+        # clone relevant token and attention tensors.
+        # tokens is (batch * beam, max_len). So the index_select
+        # gets the newly EOS rows, then selects cols 1..{step + 2}
+        tokens_clone = tokens.index_select(0, bbsz_idx)[
+            :, 1 : step + 2
+        ]  # skip the first index, which is EOS
+
+        tokens_clone[:, step] = self.eos
+        attn_clone = (
+            attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
+            if attn is not None
+            else None
+        )
+
+        # compute scores per token position
+        pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
+        pos_scores[:, step] = eos_scores
+        # convert from cumulative to per-position scores
+        pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
+
+        # normalize sentence-level scores
+        if self.normalize_scores:
+            eos_scores /= (step + 1) ** self.len_penalty
+
+        # cum_unfin records which sentences in the batch are finished.
+        # It helps match indexing between (a) the original sentences
+        # in the batch and (b) the current, possibly-reduced set of
+        # sentences.
+        cum_unfin: List[int] = []
+        prev = 0
+        for f in finished:
+            if f:
+                prev += 1
+            else:
+                cum_unfin.append(prev)
+
+        # The keys here are of the form "{sent}_{unfin_idx}", where
+        # "unfin_idx" is the index in the current (possibly reduced)
+        # list of sentences, and "sent" is the index in the original,
+        # unreduced batch
+        # set() is not supported in script export
+        sents_seen: Dict[str, Optional[Tensor]] = {}
+
+        # For every finished beam item
+        for i in range(bbsz_idx.size()[0]):
+            idx = bbsz_idx[i]
+            score = eos_scores[i]
+            # sentence index in the current (possibly reduced) batch
+            unfin_idx = idx // beam_size
+            # sentence index in the original (unreduced) batch
+            sent = unfin_idx + cum_unfin[unfin_idx]
+            # Cannot create dict for key type '(int, int)' in torchscript.
+            # The workaround is to cast int to string
+            seen = str(sent.item()) + "_" + str(unfin_idx.item())
+            if seen not in sents_seen:
+                sents_seen[seen] = None
+
+            if self.match_source_len and step > src_lengths[unfin_idx]:
+                score = torch.tensor(-math.inf).to(score)
+
+            # An input sentence (among those in a batch) is finished when
+            # beam_size hypotheses have been collected for it
+            if len(finalized[sent]) < beam_size:
+                if attn_clone is not None:
+                    # remove padding tokens from attn scores
+                    hypo_attn = attn_clone[i]
+                else:
+                    hypo_attn = torch.empty(0)
+
+                finalized[sent].append(
+                    {
+                        "tokens": tokens_clone[i],
+                        "score": score,
+                        "attention": hypo_attn,  # src_len x tgt_len
+                        "alignment": torch.empty(0),
+                        "positional_scores": pos_scores[i],
+                    }
+                )
+
+        newly_finished: List[int] = []
+
+        for seen in sents_seen.keys():
+            # check termination conditions for this sentence
+            sent: int = int(float(seen.split("_")[0]))
+            unfin_idx: int = int(float(seen.split("_")[1]))
+
+            if not finished[sent] and self.is_finished(
+                step, unfin_idx, max_len, len(finalized[sent]), beam_size
+            ):
+                finished[sent] = True
+                newly_finished.append(unfin_idx)
+
+        return newly_finished
+
+    def is_finished(
+        self,
+        step: int,
+        unfin_idx: int,
+        max_len: int,
+        finalized_sent_len: int,
+        beam_size: int,
+    ):
+        """
+        Check whether decoding for a sentence is finished, which
+        occurs when the list of finalized sentences has reached the
+        beam size, or when we reach the maximum length.
+        """
+        assert finalized_sent_len <= beam_size
+        if finalized_sent_len == beam_size or step == max_len:
+            return True
+        return False
+
+
+class EnsembleModel(nn.Module):
+    """A wrapper around an ensemble of models."""
+
+    def __init__(self, models):
+        super().__init__()
+        self.models_size = len(models)
+        # method '__len__' is not supported in ModuleList for torch script
+        self.single_model = models[0]
+        self.models = nn.ModuleList(models)
+
+        self.has_incremental: bool = False
+        if all(
+            hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
+            for m in models
+        ):
+            self.has_incremental = True
+
+    def forward(self):
+        pass
+
+    def has_encoder(self):
+        return hasattr(self.single_model, "encoder")
+
+    def has_incremental_states(self):
+        return self.has_incremental
+
+    def max_decoder_positions(self):
+        return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
+
+    @torch.jit.export
+    def forward_encoder(self, net_input: Dict[str, Tensor]):
+        if not self.has_encoder():
+            return None
+        return [model.encoder.forward_torchscript(net_input) for model in self.models]
+
+    @torch.jit.export
+    def forward_decoder(
+        self,
+        tokens,
+        encoder_outs: List[Dict[str, List[Tensor]]],
+        incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
+        temperature: float = 1.0,
+    ):
+        log_probs = []
+        avg_attn: Optional[Tensor] = None
+        encoder_out: Optional[Dict[str, List[Tensor]]] = None
+        for i, model in enumerate(self.models):
+            if self.has_encoder():
+                encoder_out = encoder_outs[i]
+            # decode each model
+            if self.has_incremental_states():
+                decoder_out = model.decoder.forward(
+                    tokens,
+                    encoder_out=encoder_out,
+                    incremental_state=incremental_states[i],
+                )
+            else:
+                if hasattr(model, "decoder"):
+                    decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out)
+                else:
+                    decoder_out = model.forward(tokens)
+
+            attn: Optional[Tensor] = None
+            decoder_len = len(decoder_out)
+            if decoder_len > 1 and decoder_out[1] is not None:
+                if isinstance(decoder_out[1], Tensor):
+                    attn = decoder_out[1]
+                else:
+                    attn_holder = decoder_out[1]["attn"]
+                    if isinstance(attn_holder, Tensor):
+                        attn = attn_holder
+                    elif attn_holder is not None:
+                        attn = attn_holder[0]
+                if attn is not None:
+                    attn = attn[:, -1, :]
+
+            decoder_out_tuple = (
+                decoder_out[0][:, -1:, :].div_(temperature),
+                None if decoder_len <= 1 else decoder_out[1],
+            )
+            probs = model.get_normalized_probs(
+                decoder_out_tuple, log_probs=True, sample=None
+            )
+            probs = probs[:, -1, :]
+            if self.models_size == 1:
+                return probs, attn
+
+            log_probs.append(probs)
+            if attn is not None:
+                if avg_attn is None:
+                    avg_attn = attn
+                else:
+                    avg_attn.add_(attn)
+
+        avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
+            self.models_size
+        )
+
+        if avg_attn is not None:
+            avg_attn.div_(self.models_size)
+        return avg_probs, avg_attn
+
+    @torch.jit.export
+    def reorder_encoder_out(
+        self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order
+    ):
+        """
+        Reorder encoder output according to *new_order*.
+
+        Args:
+            encoder_out: output from the ``forward()`` method
+            new_order (LongTensor): desired order
+
+        Returns:
+            *encoder_out* rearranged according to *new_order*
+        """
+        new_outs: List[Dict[str, List[Tensor]]] = []
+        if not self.has_encoder():
+            return new_outs
+        for i, model in enumerate(self.models):
+            assert encoder_outs is not None
+            new_outs.append(
+                model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
+            )
+        return new_outs
+
+    @torch.jit.export
+    def reorder_incremental_state(
+        self,
+        incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
+        new_order,
+    ):
+        if not self.has_incremental_states():
+            return
+        for i, model in enumerate(self.models):
+            model.decoder.reorder_incremental_state_scripting(
+                incremental_states[i], new_order
+            )
+
+
+class SequenceGeneratorWithAlignment(SequenceGenerator):
+    def __init__(
+        self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs
+    ):
+        """Generates translations of a given source sentence.
+
+        Produces alignments following "Jointly Learning to Align and
+        Translate with Transformer Models" (Garg et al., EMNLP 2019).
+
+        Args:
+            left_pad_target (bool, optional): Whether or not the
+                hypothesis should be left padded or not when they are
+                teacher forced for generating alignments.
+        """
+        super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs)
+        self.left_pad_target = left_pad_target
+
+        if print_alignment == "hard":
+            self.extract_alignment = utils.extract_hard_alignment
+        elif print_alignment == "soft":
+            self.extract_alignment = utils.extract_soft_alignment
+
+    @torch.no_grad()
+    def generate(self, models, sample, **kwargs):
+        finalized = super()._generate(sample, **kwargs)
+
+        src_tokens = sample["net_input"]["src_tokens"]
+        bsz = src_tokens.shape[0]
+        beam_size = self.beam_size
+        (
+            src_tokens,
+            src_lengths,
+            prev_output_tokens,
+            tgt_tokens,
+        ) = self._prepare_batch_for_alignment(sample, finalized)
+        if any(getattr(m, "full_context_alignment", False) for m in self.model.models):
+            attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens)
+        else:
+            attn = [
+                finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0)
+                for i in range(bsz * beam_size)
+            ]
+
+        if src_tokens.device != "cpu":
+            src_tokens = src_tokens.to("cpu")
+            tgt_tokens = tgt_tokens.to("cpu")
+            attn = [i.to("cpu") for i in attn]
+
+        # Process the attn matrix to extract hard alignments.
+        for i in range(bsz * beam_size):
+            alignment = self.extract_alignment(
+                attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos
+            )
+            finalized[i // beam_size][i % beam_size]["alignment"] = alignment
+        return finalized
+
+    def _prepare_batch_for_alignment(self, sample, hypothesis):
+        src_tokens = sample["net_input"]["src_tokens"]
+        bsz = src_tokens.shape[0]
+        src_tokens = (
+            src_tokens[:, None, :]
+            .expand(-1, self.beam_size, -1)
+            .contiguous()
+            .view(bsz * self.beam_size, -1)
+        )
+        src_lengths = sample["net_input"]["src_lengths"]
+        src_lengths = (
+            src_lengths[:, None]
+            .expand(-1, self.beam_size)
+            .contiguous()
+            .view(bsz * self.beam_size)
+        )
+        prev_output_tokens = data_utils.collate_tokens(
+            [beam["tokens"] for example in hypothesis for beam in example],
+            self.pad,
+            self.eos,
+            self.left_pad_target,
+            move_eos_to_beginning=True,
+        )
+        tgt_tokens = data_utils.collate_tokens(
+            [beam["tokens"] for example in hypothesis for beam in example],
+            self.pad,
+            self.eos,
+            self.left_pad_target,
+            move_eos_to_beginning=False,
+        )
+        return src_tokens, src_lengths, prev_output_tokens, tgt_tokens
+
+
+class EnsembleModelWithAlignment(EnsembleModel):
+    """A wrapper around an ensemble of models."""
+
+    def __init__(self, models):
+        super().__init__(models)
+
+    def forward_align(self, src_tokens, src_lengths, prev_output_tokens):
+        avg_attn = None
+        for model in self.models:
+            decoder_out = model(src_tokens, src_lengths, prev_output_tokens)
+            attn = decoder_out[1]["attn"][0]
+            if avg_attn is None:
+                avg_attn = attn
+            else:
+                avg_attn.add_(attn)
+        if len(self.models) > 1:
+            avg_attn.div_(len(self.models))
+        return avg_attn
diff --git a/slam_llm/models/avhubert/utils.py b/slam_llm/models/avhubert/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c9ffbc397118ba786af646238dac1550f7ec5ca
--- /dev/null
+++ b/slam_llm/models/avhubert/utils.py
@@ -0,0 +1,298 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import cv2
+import torch
+import random
+import numpy as np
+from typing import Dict, List, Optional, Tuple
+
+def load_video(path):
+    for i in range(3):
+        try:
+            cap = cv2.VideoCapture(path)
+            frames = []
+            while True:
+                ret, frame = cap.read()
+                if ret:
+                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+                    frames.append(frame)
+                else:
+                    break
+            frames = np.stack(frames)
+            return frames
+        except Exception:
+            print(f"failed loading {path} ({i} / 3)")
+            if i == 2:
+                raise ValueError(f"Unable to load {path}")
+
+
+class Compose(object):
+    """Compose several preprocess together.
+    Args:
+        preprocess (list of ``Preprocess`` objects): list of preprocess to compose.
+    """
+
+    def __init__(self, preprocess):
+        self.preprocess = preprocess
+
+    def __call__(self, sample):
+        for t in self.preprocess:
+            sample = t(sample)
+        return sample
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '('
+        for t in self.preprocess:
+            format_string += '\n'
+            format_string += '    {0}'.format(t)
+        format_string += '\n)'
+        return format_string
+
+
+class Normalize(object):
+    """Normalize a ndarray image with mean and standard deviation.
+    """
+
+    def __init__(self, mean, std):
+        self.mean = mean
+        self.std = std
+
+    def __call__(self, frames):
+        """
+        Args:
+            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+        Returns:
+            Tensor: Normalized Tensor image.
+        """
+        frames = (frames - self.mean) / self.std
+        return frames
+
+    def __repr__(self):
+        return self.__class__.__name__+'(mean={0}, std={1})'.format(self.mean, self.std)
+
+class CenterCrop(object):
+    """Crop the given image at the center
+    """
+    def __init__(self, size):
+        self.size = size
+
+    def __call__(self, frames):
+        """
+        Args:
+            img (numpy.ndarray): Images to be cropped.
+        Returns:
+            numpy.ndarray: Cropped image.
+        """
+        t, h, w = frames.shape
+        th, tw = self.size
+        delta_w = int(round((w - tw))/2.)
+        delta_h = int(round((h - th))/2.)
+        frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw]
+        return frames
+
+
+class RandomCrop(object):
+    """Crop the given image at the center
+    """
+
+    def __init__(self, size):
+        self.size = size
+
+    def __call__(self, frames):
+        """
+        Args:
+            img (numpy.ndarray): Images to be cropped.
+        Returns:
+            numpy.ndarray: Cropped image.
+        """
+        t, h, w = frames.shape
+        th, tw = self.size
+        delta_w = random.randint(0, w-tw)
+        delta_h = random.randint(0, h-th)
+        frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw]
+        return frames
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(size={0})'.format(self.size)
+
+class HorizontalFlip(object):
+    """Flip image horizontally.
+    """
+
+    def __init__(self, flip_ratio):
+        self.flip_ratio = flip_ratio
+
+    def __call__(self, frames):
+        """
+        Args:
+            img (numpy.ndarray): Images to be flipped with a probability flip_ratio
+        Returns:
+            numpy.ndarray: Cropped image.
+        """
+        t, h, w = frames.shape
+        if random.random() < self.flip_ratio:
+            for index in range(t):
+                frames[index] = cv2.flip(frames[index], 1)
+        return frames
+
+def compute_mask_indices(
+    shape: Tuple[int, int],
+    padding_mask: Optional[torch.Tensor],
+    mask_prob: float,
+    mask_length: int,
+    mask_type: str = "static",
+    mask_other: float = 0.0,
+    min_masks: int = 0,
+    no_overlap: bool = False,
+    min_space: int = 0,
+) -> np.ndarray:
+    """
+    Computes random mask spans for a given shape
+    Args:
+        shape: the the shape for which to compute masks.
+            should be of size 2 where first element is batch size and 2nd is timesteps
+        padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+        mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+            number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+            however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+        mask_type: how to compute mask lengths
+            static = fixed size
+            uniform = sample from uniform distribution [mask_other, mask_length*2]
+            normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+            poisson = sample from possion distribution with lambda = mask length
+        min_masks: minimum number of masked spans
+        no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+        min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+    """
+
+    bsz, all_sz = shape
+    mask = np.full((bsz, all_sz), False)
+
+    all_num_mask = int(
+        # add a random number for probabilistic rounding
+        mask_prob * all_sz / float(mask_length)
+        + np.random.rand()
+    )
+
+    all_num_mask = max(min_masks, all_num_mask)
+
+    mask_idcs = []
+    for i in range(bsz):
+        if padding_mask is not None:
+            sz = all_sz - padding_mask[i].long().sum().item()
+            num_mask = int(
+                # add a random number for probabilistic rounding
+                mask_prob * sz / float(mask_length)
+                + np.random.rand()
+            )
+            num_mask = max(min_masks, num_mask)
+        else:
+            sz = all_sz
+            num_mask = all_num_mask
+
+        if mask_type == "static":
+            lengths = np.full(num_mask, mask_length)
+        elif mask_type == "uniform":
+            lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+        elif mask_type == "normal":
+            lengths = np.random.normal(mask_length, mask_other, size=num_mask)
+            lengths = [max(1, int(round(x))) for x in lengths]
+        elif mask_type == "poisson":
+            lengths = np.random.poisson(mask_length, size=num_mask)
+            lengths = [int(round(x)) for x in lengths]
+        else:
+            raise Exception("unknown mask selection " + mask_type)
+
+        if sum(lengths) == 0:
+            lengths[0] = min(mask_length, sz - 1)
+
+        if no_overlap:
+            mask_idc = []
+
+            def arrange(s, e, length, keep_length):
+                span_start = np.random.randint(s, e - length)
+                mask_idc.extend(span_start + i for i in range(length))
+
+                new_parts = []
+                if span_start - s - min_space >= keep_length:
+                    new_parts.append((s, span_start - min_space + 1))
+                if e - span_start - keep_length - min_space > keep_length:
+                    new_parts.append((span_start + length + min_space, e))
+                return new_parts
+
+            parts = [(0, sz)]
+            min_length = min(lengths)
+            for length in sorted(lengths, reverse=True):
+                lens = np.fromiter(
+                    (e - s if e - s >= length + min_space else 0 for s, e in parts),
+                    np.int,
+                )
+                l_sum = np.sum(lens)
+                if l_sum == 0:
+                    break
+                probs = lens / np.sum(lens)
+                c = np.random.choice(len(parts), p=probs)
+                s, e = parts.pop(c)
+                parts.extend(arrange(s, e, length, min_length))
+            mask_idc = np.asarray(mask_idc)
+        else:
+            min_len = min(lengths)
+            if sz - min_len <= num_mask:
+                min_len = sz - num_mask - 1
+
+            mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+
+            mask_idc = np.asarray(
+                [
+                    mask_idc[j] + offset
+                    for j in range(len(mask_idc))
+                    for offset in range(lengths[j])
+                ]
+            )
+
+        mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+
+    min_len = min([len(m) for m in mask_idcs])
+    batch_indexes, starts, ends = [], [], []
+    for i, mask_idc in enumerate(mask_idcs):
+        if len(mask_idc) > min_len:
+            mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+        mask[i, mask_idc] = True
+        vals, run_starts, run_lengths = find_runs(mask[i])
+        start_indices, lengths = run_starts[vals == True], run_lengths[vals == True]
+        starts.append(start_indices)
+        ends.append(start_indices+lengths)
+        batch_indexes.append(np.zeros([len(start_indices)])+i)
+    return mask, np.concatenate(starts).astype(np.int64), np.concatenate(ends).astype(np.int64), np.concatenate(batch_indexes).astype(np.int64)
+
+def find_runs(x):
+    """Find runs of consecutive items in an array."""
+
+    # ensure array
+    x = np.asanyarray(x)
+    if x.ndim != 1:
+        raise ValueError('only 1D array supported')
+    n = x.shape[0]
+
+    # handle empty array
+    if n == 0:
+        return np.array([]), np.array([]), np.array([])
+
+    else:
+        # find run starts
+        loc_run_start = np.empty(n, dtype=bool)
+        loc_run_start[0] = True
+        np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
+        run_starts = np.nonzero(loc_run_start)[0]
+
+        # find run values
+        run_values = x[loc_run_start]
+
+        # find run lengths
+        run_lengths = np.diff(np.append(run_starts, n))
+
+        return run_values, run_starts, run_lengths
diff --git a/slam_llm/models/encoder.py b/slam_llm/models/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6810b87446edb8d96a680d273f09e7afb1e66536
--- /dev/null
+++ b/slam_llm/models/encoder.py
@@ -0,0 +1,158 @@
+import types
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from dataclasses import dataclass
+
+class WhisperWrappedEncoder:
+    
+    @classmethod
+    def load(cls, model_config):
+        
+        def extract_variable_length_features(self, x: torch.Tensor):
+            """
+            x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
+                the mel spectrogram of the audio
+            """
+            x = F.gelu(self.conv1(x))
+            x = F.gelu(self.conv2(x))
+            x = x.permute(0, 2, 1)
+
+            # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
+            # x = (x + self.positional_embedding).to(x.dtype)
+            x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
+
+            for block in self.blocks:
+                x = block(x)
+
+            x = self.ln_post(x)
+            return x
+
+        import whisper
+        encoder = whisper.load_model(name=model_config.encoder_path, device='cpu').encoder
+        encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder)
+        return encoder
+
+
+class BEATsEncoder:
+
+    @classmethod
+    def load(cls, model_config):
+        from .BEATs.BEATs import BEATs, BEATsConfig
+        checkpoint = torch.load(model_config.encoder_path)
+        cfg = BEATsConfig(checkpoint['cfg'])
+        BEATs_model = BEATs(cfg)
+        BEATs_model.load_state_dict(checkpoint['model'])
+
+        return BEATs_model
+
+
+@dataclass
+class UserDirModule:
+    user_dir: str
+    
+class EATEncoder:
+    
+    @classmethod
+    def load(cls, model_config):
+        import fairseq
+        model_path = UserDirModule(model_config.encoder_fairseq_dir)
+        fairseq.utils.import_user_module(model_path)
+        EATEncoder, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
+        EATEncoder = EATEncoder[0]
+
+        return EATEncoder
+    
+    def extract_features(self, source, padding_mask):
+        return self.model.extract_features(source, padding_mask = padding_mask, mask=False, remove_extra_tokens = False)['x']
+
+class SpatialASTEncoder:
+    @classmethod
+    def load(cls, model_config):
+        from functools import partial
+        from .SpatialAST import SpatialAST 
+        binaural_encoder = SpatialAST.BinauralEncoder(
+            num_classes=355, drop_path_rate=0.1, num_cls_tokens=3,
+            patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 
+            qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)
+        )
+
+        checkpoint = torch.load(model_config.encoder_ckpt, map_location='cpu')
+        binaural_encoder.load_state_dict(checkpoint['model'], strict=False) 
+        return binaural_encoder
+
+class WavLMEncoder(nn.Module):
+    def __init__(self, config, model):
+        super().__init__()
+        self.config = config
+        self.model = model
+
+    @classmethod
+    def load(cls, model_config):
+        from .wavlm.WavLM import WavLM, WavLMConfig
+        checkpoint = torch.load(model_config.encoder_path)
+        cfg = WavLMConfig(checkpoint['cfg'])
+        WavLM_model = WavLM(cfg)
+        WavLM_model.load_state_dict(checkpoint['model'])
+        assert model_config.normalize == cfg.normalize, "normalize flag in config and model checkpoint do not match"
+ 
+        return cls(cfg, WavLM_model)
+
+    def extract_features(self, source, padding_mask):
+        return self.model.extract_features(source, padding_mask)[0]
+
+class AVHubertEncoder:
+
+    @classmethod
+    def load(cls, model_config):
+        import fairseq
+        from .avhubert import hubert_pretraining, hubert, hubert_asr
+        models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
+        model = models[0]
+        return model
+
+class HubertEncoder:
+
+    @classmethod
+    def load(cls, model_config):
+        import fairseq
+        models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
+        model = models[0]
+        if model_config.encoder_type == "pretrain":
+            pass
+        elif model_config.encoder_type == "finetune":
+            model.w2v_encoder.proj = None
+            model.w2v_encoder.apply_mask = False
+        else:
+            assert model_config.encoder_type in ["pretrain", "finetune"], "input_type must be one of [pretrain, finetune]" 
+        return model
+
+
+class HfTextEncoder:
+
+    @classmethod
+    def load(cls, model_config):
+        from transformers import AutoModel
+        model = AutoModel.from_pretrained(model_config.encoder_path)
+        return model
+
+class MusicFMEncoder(nn.Module):
+    def __init__(self, config, model):
+        super().__init__()
+        self.config = config
+        self.model = model
+
+    @classmethod
+    def load(cls, model_config):
+        from .musicfm.model.musicfm_25hz import MusicFM25Hz
+        model = MusicFM25Hz(
+            stat_path = model_config.encoder_stat_path,
+            model_path = model_config.encoder_path,
+            w2v2_config_path = model_config.get('encoder_config_path', "facebook/wav2vec2-conformer-rope-large-960h-ft")
+        )
+        return cls(model_config, model)
+
+    def extract_features(self, source, padding_mask=None):
+        _, hidden_states = self.model.get_predictions(source)
+        out = hidden_states[self.config.encoder_layer_idx]
+        return out
diff --git a/slam_llm/models/musicfm/model/__init__.py b/slam_llm/models/musicfm/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..99a8091366e2ac7f301452706d5cfdff5e320a0d
--- /dev/null
+++ b/slam_llm/models/musicfm/model/__init__.py
@@ -0,0 +1,2 @@
+
+
diff --git a/slam_llm/models/musicfm/model/musicfm_25hz.py b/slam_llm/models/musicfm/model/musicfm_25hz.py
new file mode 100644
index 0000000000000000000000000000000000000000..9acdca859ba3822f82664a74adea7dc0f471ac75
--- /dev/null
+++ b/slam_llm/models/musicfm/model/musicfm_25hz.py
@@ -0,0 +1,253 @@
+# MIT License
+#
+# Copyright 2023 ByteDance Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
+# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+
+import json
+import random
+import torch
+from torch import nn
+from einops import rearrange
+
+from ..modules.random_quantizer import RandomProjectionQuantizer
+from ..modules.features import MelSTFT
+from ..modules.conv import Conv2dSubsampling
+
+
+class MusicFM25Hz(nn.Module):
+    """
+    MusicFM
+
+    Input: 128-band mel spectrogram
+    Frontend: 2-layer Residual convolution
+    Backend: 12-layer Conformer
+    Quantizer: a codebook for mel spectrogram
+    """
+
+    def __init__(
+        self,
+        num_codebooks=1,
+        codebook_dim=16,
+        codebook_size=4096,
+        features=["melspec_2048"],
+        hop_length=240,
+        n_mels=128,
+        conv_dim=512,
+        encoder_dim=1024,
+        encoder_depth=12,
+        mask_hop=0.4,
+        mask_prob=0.6,
+        is_flash=False,
+        stat_path="./data/fma_stats.json",
+        model_path="./data/pretrained_fma.pt",
+        w2v2_config_path="facebook/wav2vec2-conformer-rope-large-960h-ft",
+    ):
+        super(MusicFM25Hz, self).__init__()
+
+        # global variables
+        self.hop_length = hop_length
+        self.mask_hop = mask_hop
+        self.mask_prob = mask_prob
+        self.num_codebooks = num_codebooks
+        self.codebook_size = codebook_size
+        self.features = features
+
+        # load feature mean / std stats
+        with open(stat_path, "r") as f:
+            self.stat = json.load(f)
+
+        # feature extractor
+        self.preprocessor_melspec_2048 = MelSTFT(
+            n_fft=2048, hop_length=hop_length, is_db=True
+        )
+
+        # random quantizer
+        seed = 142
+        for feature in self.features:
+            for i in range(num_codebooks):
+                setattr(
+                    self,
+                    f"quantizer_{feature}_{i}",
+                    RandomProjectionQuantizer(
+                        n_mels * 4, codebook_dim, codebook_size, seed=seed + i
+                    ),
+                )
+
+        # two residual convolution layers + one projection layer
+        self.conv = Conv2dSubsampling(
+            1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels
+        )
+
+        # Conformer
+        if is_flash:
+            from modules.flash_conformer import (
+                Wav2Vec2ConformerEncoder,
+                Wav2Vec2ConformerConfig,
+            )
+        else:
+            from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
+                Wav2Vec2ConformerEncoder,
+                Wav2Vec2ConformerConfig,
+            )
+        config = Wav2Vec2ConformerConfig.from_pretrained(
+            w2v2_config_path
+        )
+        config.num_hidden_layers = encoder_depth
+        config.hidden_size = encoder_dim
+
+        self.conformer = Wav2Vec2ConformerEncoder(config)
+
+        # projection
+        self.linear = nn.Linear(encoder_dim, codebook_size)
+
+        # loss function
+        self.loss = nn.CrossEntropyLoss()
+
+        # cls token (used for sequence classification)
+        random.seed(seed)
+        self.cls_token = nn.Parameter(torch.randn(encoder_dim))
+
+        # load model
+        if model_path:
+            S = torch.load(model_path)["state_dict"]
+            SS = {k[6:]: v for k, v in S.items()}
+            self.load_state_dict(SS, strict=True)
+
+    def masking(self, x):
+        """random masking of 400ms with given probability"""
+        mx = x.clone()
+        b, t = mx.shape
+        len_masking_raw = int(24000 * self.mask_hop)
+        len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)
+
+        # get random mask indices
+        start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
+        time_domain_masked_indices = torch.nonzero(
+            start_indices.repeat_interleave(len_masking_raw, dim=1)
+        )
+        token_domain_masked_indices = torch.nonzero(
+            start_indices.repeat_interleave(len_masking_token, dim=1)
+        )
+
+        # mask with random values
+        masking_noise = (
+            torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
+        )  # 0 mean 0.1 std
+        mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
+
+        return mx, token_domain_masked_indices
+
+    @torch.no_grad()
+    def preprocessing(self, x, features):
+        """extract classic audio features"""
+        # check precision
+        if x.dtype == torch.float16:
+            precision = 16
+        else:
+            precision = 32
+
+        out = {}
+        for key in features:
+            layer = getattr(self, "preprocessor_%s" % key)
+            out[key] = layer.float()(x.float())[..., :-1]
+            if precision == 16:
+                out[key] = out[key].half()
+        return out
+
+    def encoder(self, x):
+        """2-layer conv + w2v-conformer"""
+        x = self.conv(x)
+        out = self.conformer(x, output_hidden_states=True)
+        hidden_emb = out["hidden_states"]
+        last_emb = out["last_hidden_state"]
+        logits = self.linear(last_emb)
+        logits = {
+            key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size]
+            for i, key in enumerate(self.features)
+        }
+        return logits, hidden_emb
+
+    @torch.no_grad()
+    def normalize(self, x):
+        """normalize the input audio to have zero mean unit variance"""
+        for key in x.keys():
+            x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
+        return x
+
+    @torch.no_grad()
+    def rearrange(self, x):
+        """rearrange the batch to flatten every 4 steps"""
+        for key in x.keys():
+            if key == "chromagram":
+                x[key] = rearrange(x[key], "b f t -> b t f")
+            else:
+                x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4)
+        return x
+
+    @torch.no_grad()
+    def tokenize(self, x):
+        out = {}
+        for key in x.keys():
+            layer = getattr(self, "quantizer_%s" % key)
+            out[key] = layer(x[key])
+        return out
+
+    def get_targets(self, x):
+        x = self.preprocessing(x, features=self.features)
+        x = self.normalize(x)
+        x = self.rearrange(x)
+        target_tokens = self.tokenize(x)
+        return target_tokens
+
+    def get_predictions(self, x):
+        # preprocessing
+        x = self.preprocessing(x, features=["melspec_2048"])
+        x = self.normalize(x)
+
+        # encoding
+        logits, hidden_emb = self.encoder(x["melspec_2048"])
+
+        return logits, hidden_emb
+
+    def get_latent(self, x, layer_ix=12):
+        _, hidden_states = self.get_predictions(x)
+        emb = hidden_states[layer_ix]
+        return emb
+
+    def get_loss(self, logits, target_tokens, masked_indices):
+        losses = {}
+        accuracies = {}
+        for key in logits.keys():
+            masked_logits = logits[key][tuple(masked_indices.t())]
+            masked_tokens = target_tokens[key][tuple(masked_indices.t())]
+            losses[key] = self.loss(masked_logits, masked_tokens)
+            accuracies[key] = (
+                torch.sum(masked_logits.argmax(-1) == masked_tokens)
+                / masked_tokens.numel()
+            )
+        return losses, accuracies
+
+    def forward(self, x):
+        # get target feature tokens
+        target_tokens = self.get_targets(x)
+
+        # masking
+        x, masked_indices = self.masking(x)
+
+        # forward
+        logits, hidden_emb = self.get_predictions(x)
+
+        # get loss
+        losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
+
+        return logits, hidden_emb, losses, accuracies
diff --git a/slam_llm/models/musicfm/modules/__init__.py b/slam_llm/models/musicfm/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..99a8091366e2ac7f301452706d5cfdff5e320a0d
--- /dev/null
+++ b/slam_llm/models/musicfm/modules/__init__.py
@@ -0,0 +1,2 @@
+
+
diff --git a/slam_llm/models/musicfm/modules/conv.py b/slam_llm/models/musicfm/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc3a83cfe0c3dfb2aaba46483615d3f3c75d7565
--- /dev/null
+++ b/slam_llm/models/musicfm/modules/conv.py
@@ -0,0 +1,82 @@
+# MIT License
+#
+# Copyright 2023 ByteDance Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
+# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+
+from torch import nn
+from einops import rearrange
+
+
+class Res2dModule(nn.Module):
+    def __init__(self, idim, odim, stride=(2, 2)):
+        super(Res2dModule, self).__init__()
+        self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
+        self.bn1 = nn.BatchNorm2d(odim)
+        self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
+        self.bn2 = nn.BatchNorm2d(odim)
+        self.relu = nn.ReLU()
+
+        # residual
+        self.diff = False
+        if (idim != odim) or (stride[0] > 1):
+            self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
+            self.bn3 = nn.BatchNorm2d(odim)
+            self.diff = True
+
+    def forward(self, x):
+        out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
+        if self.diff:
+            x = self.bn3(self.conv3(x))
+        out = x + out
+        out = self.relu(out)
+        return out
+
+
+class Conv2dSubsampling(nn.Module):
+    """Convolutional 2D subsampling (to 1/4 length).
+
+    Args:
+        idim (int): Input dimension.
+        hdim (int): Hidden dimension.
+        odim (int): Output dimension.
+        strides (list): Sizes of strides.
+        n_bands (int): Number of frequency bands.
+    """
+
+    def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
+        """Construct an Conv2dSubsampling object."""
+        super(Conv2dSubsampling, self).__init__()
+
+        self.conv = nn.Sequential(
+            Res2dModule(idim, hdim, (2, strides[0])),
+            Res2dModule(hdim, hdim, (2, strides[1])),
+        )
+        self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
+
+    def forward(self, x):
+        """Subsample x.
+
+        Args:
+            x (torch.Tensor): Input tensor (#batch, idim, time).
+
+        Returns:
+            torch.Tensor: Subsampled tensor (#batch, time', odim),
+                where time' = time // 4.
+        """
+
+        if x.dim() == 3:
+            x = x.unsqueeze(1)  # (b, c, f, t)
+        x = self.conv(x)
+        x = rearrange(x, "b c f t -> b t (c f)")
+        x = self.linear(x)
+        return x
diff --git a/slam_llm/models/musicfm/modules/features.py b/slam_llm/models/musicfm/modules/features.py
new file mode 100644
index 0000000000000000000000000000000000000000..d29f859d32d628497751d02419b3e5f749d2030a
--- /dev/null
+++ b/slam_llm/models/musicfm/modules/features.py
@@ -0,0 +1,45 @@
+# MIT License
+#
+# Copyright 2023 ByteDance Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
+# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+
+import torchaudio
+from torch import nn
+
+
+class MelSTFT(nn.Module):
+    def __init__(
+        self,
+        sample_rate=24000,
+        n_fft=2048,
+        hop_length=240,
+        n_mels=128,
+        is_db=False,
+    ):
+        super(MelSTFT, self).__init__()
+
+        # spectrogram
+        self.mel_stft = torchaudio.transforms.MelSpectrogram(
+            sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
+        )
+
+        # amplitude to decibel
+        self.is_db = is_db
+        if is_db:
+            self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
+
+    def forward(self, waveform):
+        if self.is_db:
+            return self.amplitude_to_db(self.mel_stft(waveform))
+        else:
+            return self.mel_stft(waveform)
diff --git a/slam_llm/models/musicfm/modules/flash_conformer.py b/slam_llm/models/musicfm/modules/flash_conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..689e2c1c1c705a69ed987ed7507e02fc58e36341
--- /dev/null
+++ b/slam_llm/models/musicfm/modules/flash_conformer.py
@@ -0,0 +1,2114 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Wav2Vec2-Conformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from torch.nn import functional as F
+
+from transformers.activations import ACT2FN
+from transformers.deepspeed import is_deepspeed_zero3_enabled
+from transformers.modeling_outputs import (
+    BaseModelOutput,
+    CausalLMOutput,
+    SequenceClassifierOutput,
+    TokenClassifierOutput,
+    Wav2Vec2BaseModelOutput,
+    XVectorOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (
+    ModelOutput,
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    replace_return_docstrings,
+)
+from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+# General docstring
+_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 64.21
+
+
+WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
+    "facebook/wav2vec2-conformer-rel-pos-large",
+    # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
+]
+
+
+@dataclass
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
+    """
+    Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
+
+    Args:
+        loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+            Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
+            paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
+        projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+            Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
+            projected quantized states.
+        projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+            Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
+            target vectors for contrastive loss.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+            shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+            The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
+        diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+            The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    projected_states: torch.FloatTensor = None
+    projected_quantized_states: torch.FloatTensor = None
+    codevector_perplexity: torch.FloatTensor = None
+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+    attentions: Optional[Tuple[torch.FloatTensor]] = None
+    contrastive_loss: Optional[torch.FloatTensor] = None
+    diversity_loss: Optional[torch.FloatTensor] = None
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
+def _compute_mask_indices(
+    shape: Tuple[int, int],
+    mask_prob: float,
+    mask_length: int,
+    attention_mask: Optional[torch.LongTensor] = None,
+    min_masks: int = 0,
+) -> np.ndarray:
+    """
+    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+    ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+    CPU as part of the preprocessing during training.
+
+    Args:
+        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+               the first element is the batch size and the second element is the length of the axis to span.
+        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+                    independently generated mask spans of length `mask_length` is computed by
+                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+                    actual percentage will be smaller.
+        mask_length: size of the mask
+        min_masks: minimum number of masked spans
+        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+                        each batch dimension.
+    """
+    batch_size, sequence_length = shape
+
+    if mask_length < 1:
+        raise ValueError("`mask_length` has to be bigger than 0.")
+
+    if mask_length > sequence_length:
+        raise ValueError(
+            f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+            f" and `sequence_length`: {sequence_length}`"
+        )
+
+    # epsilon is used for probabilistic rounding
+    epsilon = np.random.rand(1).item()
+
+    def compute_num_masked_span(input_length):
+        """Given input length, compute how many spans should be masked"""
+        num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+        num_masked_span = max(num_masked_span, min_masks)
+
+        # make sure num masked span <= sequence_length
+        if num_masked_span * mask_length > sequence_length:
+            num_masked_span = sequence_length // mask_length
+
+        # make sure num_masked span is also <= input_length - (mask_length - 1)
+        if input_length - (mask_length - 1) < num_masked_span:
+            num_masked_span = max(input_length - (mask_length - 1), 0)
+
+        return num_masked_span
+
+    # compute number of masked spans in batch
+    input_lengths = (
+        attention_mask.sum(-1).detach().tolist()
+        if attention_mask is not None
+        else [sequence_length for _ in range(batch_size)]
+    )
+
+    # SpecAugment mask to fill
+    spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+    spec_aug_mask_idxs = []
+
+    max_num_masked_span = compute_num_masked_span(sequence_length)
+
+    if max_num_masked_span == 0:
+        return spec_aug_mask
+
+    for input_length in input_lengths:
+        # compute num of masked spans for this input
+        num_masked_span = compute_num_masked_span(input_length)
+
+        # get random indices to mask
+        spec_aug_mask_idx = np.random.choice(
+            np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+        )
+
+        # pick first sampled index that will serve as a dummy index to pad vector
+        # to ensure same dimension for all batches due to probabilistic rounding
+        # Picking first sample just pads those vectors twice.
+        if len(spec_aug_mask_idx) == 0:
+            # this case can only happen if `input_length` is strictly smaller then
+            # `sequence_length` in which case the last token has to be a padding
+            # token which we can use as a dummy mask id
+            dummy_mask_idx = sequence_length - 1
+        else:
+            dummy_mask_idx = spec_aug_mask_idx[0]
+
+        spec_aug_mask_idx = np.concatenate(
+            [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+        )
+        spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+    spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+    # expand masked indices to masked spans
+    spec_aug_mask_idxs = np.broadcast_to(
+        spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+    )
+    spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+    # add offset to the starting indexes so that indexes now create a span
+    offsets = np.arange(mask_length)[None, None, :]
+    offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+        batch_size, max_num_masked_span * mask_length
+    )
+    spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+    # ensure that we cannot have indices larger than sequence_length
+    if spec_aug_mask_idxs.max() > sequence_length - 1:
+        spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+    # scatter indices to mask
+    np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+    return spec_aug_mask
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
+def _sample_negative_indices(
+    features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
+):
+    """
+    Sample `num_negatives` vectors from feature vectors.
+    """
+    batch_size, sequence_length = features_shape
+
+    # generate indices of the positive vectors themselves, repeat them `num_negatives` times
+    sequence_length_range = np.arange(sequence_length)
+
+    # get `num_negatives` random vector indices from the same utterance
+    sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
+
+    mask_time_indices = (
+        mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
+    )
+
+    for batch_idx in range(batch_size):
+        high = mask_time_indices[batch_idx].sum() - 1
+        mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
+
+        feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
+        sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
+        # avoid sampling the same positive vector, but keep the distribution uniform
+        sampled_indices[sampled_indices >= feature_indices] += 1
+
+        # remap to actual indices
+        sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
+
+        # correct for batch size
+        sampled_negative_indices[batch_idx] += batch_idx * sequence_length
+
+    return sampled_negative_indices
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
+    def __init__(self, config, layer_id=0):
+        super().__init__()
+        self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+        self.out_conv_dim = config.conv_dim[layer_id]
+
+        self.conv = nn.Conv1d(
+            self.in_conv_dim,
+            self.out_conv_dim,
+            kernel_size=config.conv_kernel[layer_id],
+            stride=config.conv_stride[layer_id],
+            bias=config.conv_bias,
+        )
+        self.activation = ACT2FN[config.feat_extract_activation]
+
+    def forward(self, hidden_states):
+        hidden_states = self.conv(hidden_states)
+        hidden_states = self.activation(hidden_states)
+        return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
+    def __init__(self, config, layer_id=0):
+        super().__init__()
+        self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+        self.out_conv_dim = config.conv_dim[layer_id]
+
+        self.conv = nn.Conv1d(
+            self.in_conv_dim,
+            self.out_conv_dim,
+            kernel_size=config.conv_kernel[layer_id],
+            stride=config.conv_stride[layer_id],
+            bias=config.conv_bias,
+        )
+        self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+        self.activation = ACT2FN[config.feat_extract_activation]
+
+    def forward(self, hidden_states):
+        hidden_states = self.conv(hidden_states)
+
+        hidden_states = hidden_states.transpose(-2, -1)
+        hidden_states = self.layer_norm(hidden_states)
+        hidden_states = hidden_states.transpose(-2, -1)
+
+        hidden_states = self.activation(hidden_states)
+        return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
+    def __init__(self, config, layer_id=0):
+        super().__init__()
+        self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+        self.out_conv_dim = config.conv_dim[layer_id]
+
+        self.conv = nn.Conv1d(
+            self.in_conv_dim,
+            self.out_conv_dim,
+            kernel_size=config.conv_kernel[layer_id],
+            stride=config.conv_stride[layer_id],
+            bias=config.conv_bias,
+        )
+        self.activation = ACT2FN[config.feat_extract_activation]
+
+        self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
+
+    def forward(self, hidden_states):
+        hidden_states = self.conv(hidden_states)
+        hidden_states = self.layer_norm(hidden_states)
+        hidden_states = self.activation(hidden_states)
+        return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.conv = nn.Conv1d(
+            config.hidden_size,
+            config.hidden_size,
+            kernel_size=config.num_conv_pos_embeddings,
+            padding=config.num_conv_pos_embeddings // 2,
+            groups=config.num_conv_pos_embedding_groups,
+        )
+
+        if is_deepspeed_zero3_enabled():
+            import deepspeed
+
+            with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
+                self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
+            deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
+            deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
+        else:
+            self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
+
+        self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
+        self.activation = ACT2FN[config.feat_extract_activation]
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.transpose(1, 2)
+
+        hidden_states = self.conv(hidden_states)
+        hidden_states = self.padding(hidden_states)
+        hidden_states = self.activation(hidden_states)
+
+        hidden_states = hidden_states.transpose(1, 2)
+        return hidden_states
+
+
+class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
+    """Rotary positional embedding
+    Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        dim = config.hidden_size // config.num_attention_heads
+        base = config.rotary_embedding_base
+
+        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+        self.register_buffer("inv_freq", inv_freq)
+        self.cached_sequence_length = None
+        self.cached_rotary_positional_embedding = None
+
+    def forward(self, hidden_states):
+        sequence_length = hidden_states.shape[1]
+
+        if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
+            return self.cached_rotary_positional_embedding
+
+        self.cached_sequence_length = sequence_length
+        time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
+        freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
+        embeddings = torch.cat((freqs, freqs), dim=-1)
+
+        cos_embeddings = embeddings.cos()[:, None, None, :]
+        sin_embeddings = embeddings.sin()[:, None, None, :]
+        self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
+        return self.cached_rotary_positional_embedding
+
+
+class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
+    """Relative positional encoding module."""
+
+    def __init__(self, config):
+        super().__init__()
+        self.max_len = config.max_source_positions
+        self.d_model = config.hidden_size
+        self.pe = None
+        self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
+
+    def extend_pe(self, x):
+        # Reset the positional encodings
+        if self.pe is not None:
+            # self.pe contains both positive and negative parts
+            # the length of self.pe is 2 * input_len - 1
+            if self.pe.size(1) >= x.size(1) * 2 - 1:
+                if self.pe.dtype != x.dtype or self.pe.device != x.device:
+                    self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+                return
+        # Suppose `i` is the position of query vector and `j` is the
+        # position of key vector. We use positive relative positions when keys
+        # are to the left (i>j) and negative relative positions otherwise (i<j).
+        pe_positive = torch.zeros(x.size(1), self.d_model)
+        pe_negative = torch.zeros(x.size(1), self.d_model)
+        position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)
+        )
+        pe_positive[:, 0::2] = torch.sin(position * div_term)
+        pe_positive[:, 1::2] = torch.cos(position * div_term)
+        pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
+        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
+
+        # Reverse the order of positive indices and concat both positive and
+        # negative indices. This is used to support the shifting trick
+        # as in https://arxiv.org/abs/1901.02860
+        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+        pe_negative = pe_negative[1:].unsqueeze(0)
+        pe = torch.cat([pe_positive, pe_negative], dim=1)
+        self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+    def forward(self, hidden_states: torch.Tensor):
+        self.extend_pe(hidden_states)
+        start_idx = self.pe.size(1) // 2 - hidden_states.size(1) + 1
+        end_idx = self.pe.size(1) // 2 + hidden_states.size(1)
+        relative_position_embeddings = self.pe[:, start_idx:end_idx]
+
+        return relative_position_embeddings
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerSamePadLayer(nn.Module):
+    def __init__(self, num_conv_pos_embeddings):
+        super().__init__()
+        self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
+
+    def forward(self, hidden_states):
+        if self.num_pad_remove > 0:
+            hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+        return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeatureEncoder(nn.Module):
+    """Construct the features from raw audio waveform"""
+
+    def __init__(self, config):
+        super().__init__()
+
+        if config.feat_extract_norm == "group":
+            conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
+                Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
+                for i in range(config.num_feat_extract_layers - 1)
+            ]
+        elif config.feat_extract_norm == "layer":
+            conv_layers = [
+                Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
+            ]
+        else:
+            raise ValueError(
+                f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
+            )
+        self.conv_layers = nn.ModuleList(conv_layers)
+        self.gradient_checkpointing = False
+        self._requires_grad = True
+
+    def _freeze_parameters(self):
+        for param in self.parameters():
+            param.requires_grad = False
+        self._requires_grad = False
+
+    def forward(self, input_values):
+        hidden_states = input_values[:, None]
+
+        # make sure hidden_states require grad for gradient_checkpointing
+        if self._requires_grad and self.training:
+            hidden_states.requires_grad = True
+
+        for conv_layer in self.conv_layers:
+            if self._requires_grad and self.gradient_checkpointing and self.training:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(conv_layer),
+                    hidden_states,
+                )
+            else:
+                hidden_states = conv_layer(hidden_states)
+
+        return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeatureProjection(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
+        self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+        self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+    def forward(self, hidden_states):
+        # non-projected hidden states are needed for quantization
+        norm_hidden_states = self.layer_norm(hidden_states)
+        hidden_states = self.projection(norm_hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        return hidden_states, norm_hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeedForward(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.intermediate_dropout = nn.Dropout(config.activation_dropout)
+
+        self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+        self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.output_dropout = nn.Dropout(config.hidden_dropout)
+
+    def forward(self, hidden_states):
+        hidden_states = self.intermediate_dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        hidden_states = self.intermediate_dropout(hidden_states)
+
+        hidden_states = self.output_dense(hidden_states)
+        hidden_states = self.output_dropout(hidden_states)
+        return hidden_states
+
+
+class Wav2Vec2ConformerConvolutionModule(nn.Module):
+    """Convolution block used in the conformer block"""
+
+    def __init__(self, config):
+        super().__init__()
+        if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
+            raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
+        self.layer_norm = nn.LayerNorm(config.hidden_size)
+        self.pointwise_conv1 = torch.nn.Conv1d(
+            config.hidden_size,
+            2 * config.hidden_size,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=False,
+        )
+        self.glu = torch.nn.GLU(dim=1)
+        self.depthwise_conv = torch.nn.Conv1d(
+            config.hidden_size,
+            config.hidden_size,
+            config.conv_depthwise_kernel_size,
+            stride=1,
+            padding=(config.conv_depthwise_kernel_size - 1) // 2,
+            groups=config.hidden_size,
+            bias=False,
+        )
+        self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
+        self.activation = ACT2FN[config.hidden_act]
+        self.pointwise_conv2 = torch.nn.Conv1d(
+            config.hidden_size,
+            config.hidden_size,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=False,
+        )
+        self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
+
+    def forward(self, hidden_states):
+        hidden_states = self.layer_norm(hidden_states)
+        # exchange the temporal dimension and the feature dimension
+        hidden_states = hidden_states.transpose(1, 2)
+
+        # GLU mechanism
+        # => (batch, 2*channel, dim)
+        hidden_states = self.pointwise_conv1(hidden_states)
+        # => (batch, channel, dim)
+        hidden_states = self.glu(hidden_states)
+
+        # 1D Depthwise Conv
+        hidden_states = self.depthwise_conv(hidden_states)
+        hidden_states = self.batch_norm(hidden_states)
+        hidden_states = self.activation(hidden_states)
+
+        hidden_states = self.pointwise_conv2(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = hidden_states.transpose(1, 2)
+        return hidden_states
+
+
+class Wav2Vec2ConformerSelfAttention(nn.Module):
+    """Construct an Wav2Vec2ConformerSelfAttention object.
+    Can be enhanced with rotary or relative position embeddings.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+
+        self.head_size = config.hidden_size // config.num_attention_heads
+        self.num_heads = config.num_attention_heads
+        self.position_embeddings_type = config.position_embeddings_type
+
+        self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
+        self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
+        self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
+        self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
+
+        self.dropout = nn.Dropout(p=config.attention_dropout)
+        self.dropout_p = config.attention_dropout
+
+        self.is_causal = config.is_causal
+
+        if self.position_embeddings_type == "relative":
+            # linear transformation for positional encoding
+            self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+            # these two learnable bias are used in matrix c and matrix d
+            # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+            self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+            self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        relative_position_embeddings: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        # self-attention mechanism
+        batch_size, sequence_length, hidden_size = hidden_states.size()
+
+        # make sure query/key states can be != value states
+        query_key_states = hidden_states
+        value_states = hidden_states
+
+        if self.position_embeddings_type == "rotary":
+            if relative_position_embeddings is None:
+                raise ValueError(
+                    "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
+                )
+            query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
+
+        # project query_key_states and value_states
+        query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+        key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+        value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
+
+        # => (batch, head, time1, d_k)
+        query = query.transpose(1, 2)
+        key = key.transpose(1, 2)
+        value = value.transpose(1, 2)
+
+        with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
+            hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
+        probs = None
+
+        # # apply attention_mask if necessary
+        # if attention_mask is not None:
+        #     scores = scores + attention_mask
+
+        # # => (batch, head, time1, time2)
+        # probs = torch.softmax(scores, dim=-1)
+        # probs = self.dropout(probs)
+
+        # # => (batch, head, time1, d_k)
+        # hidden_states = torch.matmul(probs, value)
+
+        # => (batch, time1, hidden_size)
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
+        hidden_states = self.linear_out(hidden_states)
+
+        return hidden_states, probs
+
+    def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
+        batch_size, sequence_length, hidden_size = hidden_states.size()
+        hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
+
+        cos = relative_position_embeddings[0, :sequence_length, ...]
+        sin = relative_position_embeddings[1, :sequence_length, ...]
+
+        # rotate hidden_states with rotary embeddings
+        hidden_states = hidden_states.transpose(0, 1)
+        rotated_states_begin = hidden_states[..., : self.head_size // 2]
+        rotated_states_end = hidden_states[..., self.head_size // 2 :]
+        rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
+        hidden_states = (hidden_states * cos) + (rotated_states * sin)
+        hidden_states = hidden_states.transpose(0, 1)
+
+        hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
+
+        return hidden_states
+
+    def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
+        # 1. project positional embeddings
+        # => (batch, head, 2*time1-1, d_k)
+        proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
+        proj_relative_position_embeddings = proj_relative_position_embeddings.view(
+            relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
+        )
+        proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
+        proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
+
+        # 2. Add bias to query
+        # => (batch, head, time1, d_k)
+        query = query.transpose(1, 2)
+        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+        # 3. attention score: first compute matrix a and matrix c
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        # => (batch, head, time1, time2)
+        scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+        # 4. then compute matrix b and matrix d
+        # => (batch, head, time1, 2*time1-1)
+        scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
+
+        # 5. shift matrix b and matrix d
+        zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
+        scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
+        scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
+        scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
+        scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
+        scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
+
+        # 6. sum matrices
+        # => (batch, head, time1, time2)
+        scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
+
+        return scores
+
+
+class Wav2Vec2ConformerEncoderLayer(nn.Module):
+    """Conformer block based on https://arxiv.org/abs/2005.08100."""
+
+    def __init__(self, config):
+        super().__init__()
+        embed_dim = config.hidden_size
+        dropout = config.attention_dropout
+
+        # Feed-forward 1
+        self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
+        self.ffn1 = Wav2Vec2ConformerFeedForward(config)
+
+        # Self-Attention
+        self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
+        self.self_attn_dropout = torch.nn.Dropout(dropout)
+        self.self_attn = Wav2Vec2ConformerSelfAttention(config)
+
+        # Conformer Convolution
+        self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
+
+        # Feed-forward 2
+        self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
+        self.ffn2 = Wav2Vec2ConformerFeedForward(config)
+        self.final_layer_norm = nn.LayerNorm(embed_dim)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask: Optional[torch.Tensor] = None,
+        relative_position_embeddings: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ):
+        hidden_states = hidden_states
+
+        # 1. Feed-Forward 1 layer
+        residual = hidden_states
+        hidden_states = self.ffn1_layer_norm(hidden_states)
+        hidden_states = self.ffn1(hidden_states)
+        hidden_states = hidden_states * 0.5 + residual
+        residual = hidden_states
+
+        # 2. Self-Attention layer
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+        hidden_states, attn_weigts = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            relative_position_embeddings=relative_position_embeddings,
+            output_attentions=output_attentions,
+        )
+        hidden_states = self.self_attn_dropout(hidden_states)
+        hidden_states = hidden_states + residual
+
+        # 3. Convolutional Layer
+        residual = hidden_states
+        hidden_states = self.conv_module(hidden_states)
+        hidden_states = residual + hidden_states
+
+        # 4. Feed-Forward 2 Layer
+        residual = hidden_states
+        hidden_states = self.ffn2_layer_norm(hidden_states)
+        hidden_states = self.ffn2(hidden_states)
+        hidden_states = hidden_states * 0.5 + residual
+        hidden_states = self.final_layer_norm(hidden_states)
+
+        return hidden_states, attn_weigts
+
+
+class Wav2Vec2ConformerEncoder(nn.Module):
+    def __init__(self, config, is_causal=False):
+        super().__init__()
+        config.is_causal = is_causal
+        self.config = config
+
+        if config.position_embeddings_type == "relative":
+            self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
+        elif config.position_embeddings_type == "rotary":
+            self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
+        else:
+            self.embed_positions = None
+
+        self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
+        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout)
+        self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        output_attentions=False,
+        output_hidden_states=False,
+        return_dict=True,
+    ):
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+
+        if attention_mask is not None:
+            # make sure padded tokens output 0
+            hidden_states[~attention_mask] = 0.0
+
+            # extend attention_mask
+            attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+            attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
+            attention_mask = attention_mask.expand(
+                attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
+            )
+
+        hidden_states = self.dropout(hidden_states)
+
+        if self.embed_positions is not None:
+            relative_position_embeddings = self.embed_positions(hidden_states)
+        else:
+            relative_position_embeddings = None
+
+        deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
+
+        for i, layer in enumerate(self.layers):
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            dropout_probability = np.random.uniform(0, 1)
+
+            skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
+            if not skip_the_layer or deepspeed_zero3_is_enabled:
+                # under deepspeed zero3 all gpus must run in sync
+                if self.gradient_checkpointing and self.training:
+                    # create gradient checkpointing function
+                    def create_custom_forward(module):
+                        def custom_forward(*inputs):
+                            return module(*inputs, output_attentions)
+
+                        return custom_forward
+
+                    layer_outputs = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(layer),
+                        hidden_states,
+                        attention_mask,
+                        relative_position_embeddings,
+                    )
+                else:
+                    layer_outputs = layer(
+                        hidden_states,
+                        attention_mask=attention_mask,
+                        relative_position_embeddings=relative_position_embeddings,
+                        output_attentions=output_attentions,
+                    )
+                hidden_states = layer_outputs[0]
+
+            if skip_the_layer:
+                layer_outputs = (None, None)
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+        hidden_states = self.layer_norm(hidden_states)
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+        return BaseModelOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+        )
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
+    """
+    Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
+    GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        self.num_groups = config.num_codevector_groups
+        self.num_vars = config.num_codevectors_per_group
+
+        if config.codevector_dim % self.num_groups != 0:
+            raise ValueError(
+                f"`config.codevector_dim {config.codevector_dim} must be divisible "
+                f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
+            )
+
+        # storage for codebook variables (codewords)
+        self.codevectors = nn.Parameter(
+            torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
+        )
+        self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
+
+        # can be decayed for training
+        self.temperature = 2
+
+    @staticmethod
+    def _compute_perplexity(probs, mask=None):
+        if mask is not None:
+            mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
+            probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
+            marginal_probs = probs.sum(dim=0) / mask.sum()
+        else:
+            marginal_probs = probs.mean(dim=0)
+
+        perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
+        return perplexity
+
+    def forward(self, hidden_states, mask_time_indices=None):
+        batch_size, sequence_length, hidden_size = hidden_states.shape
+
+        # project to codevector dim
+        hidden_states = self.weight_proj(hidden_states)
+        hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
+
+        if self.training:
+            # sample code vector probs via gumbel in differentiateable way
+            codevector_probs = nn.functional.gumbel_softmax(
+                hidden_states.float(), tau=self.temperature, hard=True
+            ).type_as(hidden_states)
+
+            # compute perplexity
+            codevector_soft_dist = torch.softmax(
+                hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
+            )
+            perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
+        else:
+            # take argmax in non-differentiable way
+            # comptute hard codevector distribution (one hot)
+            codevector_idx = hidden_states.argmax(dim=-1)
+            codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
+                -1, codevector_idx.view(-1, 1), 1.0
+            )
+            codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
+
+            perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
+
+        codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
+        # use probs to retrieve codevectors
+        codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
+        codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
+        codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
+
+        return codevectors, perplexity
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerAdapter(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+
+        # feature dim might need to be down-projected
+        if config.output_hidden_size != config.hidden_size:
+            self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
+            self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
+        else:
+            self.proj = self.proj_layer_norm = None
+
+        self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
+        self.layerdrop = config.layerdrop
+
+    def forward(self, hidden_states):
+        # down project hidden_states if necessary
+        if self.proj is not None and self.proj_layer_norm is not None:
+            hidden_states = self.proj(hidden_states)
+            hidden_states = self.proj_layer_norm(hidden_states)
+
+        hidden_states = hidden_states.transpose(1, 2)
+
+        for layer in self.layers:
+            layerdrop_prob = np.random.random()
+            if not self.training or (layerdrop_prob > self.layerdrop):
+                hidden_states = layer(hidden_states)
+
+        hidden_states = hidden_states.transpose(1, 2)
+        return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerAdapterLayer(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.conv = nn.Conv1d(
+            config.output_hidden_size,
+            2 * config.output_hidden_size,
+            config.adapter_kernel_size,
+            stride=config.adapter_stride,
+            padding=1,
+        )
+
+    def forward(self, hidden_states):
+        hidden_states = self.conv(hidden_states)
+        hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+        return hidden_states
+
+
+class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = Wav2Vec2ConformerConfig
+    base_model_prefix = "wav2vec2_conformer"
+    main_input_name = "input_values"
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+    supports_gradient_checkpointing = True
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
+        if isinstance(module, Wav2Vec2ConformerForPreTraining):
+            module.project_hid.reset_parameters()
+            module.project_q.reset_parameters()
+            module.project_hid._is_hf_initialized = True
+            module.project_q._is_hf_initialized = True
+        # gumbel softmax requires special init
+        elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
+            module.weight_proj.weight.data.normal_(mean=0.0, std=1)
+            module.weight_proj.bias.data.zero_()
+            nn.init.uniform_(module.codevectors)
+        elif isinstance(module, Wav2Vec2ConformerSelfAttention):
+            if hasattr(module, "pos_bias_u"):
+                nn.init.xavier_uniform_(module.pos_bias_u)
+            if hasattr(module, "pos_bias_v"):
+                nn.init.xavier_uniform_(module.pos_bias_v)
+        elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
+            nn.init.normal_(
+                module.conv.weight,
+                mean=0,
+                std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
+            )
+            nn.init.constant_(module.conv.bias, 0)
+        elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
+            k = math.sqrt(1 / module.projection.in_features)
+            nn.init.uniform_(module.projection.weight, a=-k, b=k)
+            nn.init.uniform_(module.projection.bias, a=-k, b=k)
+        elif isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, nn.Conv1d):
+            nn.init.kaiming_normal_(module.weight)
+
+            if module.bias is not None:
+                k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+                nn.init.uniform_(module.bias, a=-k, b=k)
+
+    def _get_feat_extract_output_lengths(
+        self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
+    ):
+        """
+        Computes the output length of the convolutional layers
+        """
+
+        add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
+
+        def _conv_out_length(input_length, kernel_size, stride):
+            # 1D convolutional layer output length formula taken
+            # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+            return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+        for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+            input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+        if add_adapter:
+            for _ in range(self.config.num_adapter_layers):
+                input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
+
+        return input_lengths
+
+    def _get_feature_vector_attention_mask(
+        self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
+    ):
+        # Effectively attention_mask.sum(-1), but not inplace to be able to run
+        # on inference mode.
+        non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
+
+        output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
+        output_lengths = output_lengths.to(torch.long)
+
+        batch_size = attention_mask.shape[0]
+
+        attention_mask = torch.zeros(
+            (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+        )
+        # these two operations makes sure that all values before the output lengths idxs are attended to
+        attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+        attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+        return attention_mask
+
+    def _set_gradient_checkpointing(self, module, value=False):
+        if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
+            module.gradient_checkpointing = value
+
+
+WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
+    Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
+    Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
+    Auli.
+
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving etc.).
+
+    This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
+    regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
+
+    Parameters:
+        config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
+    Args:
+        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+            into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
+            soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
+            conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
+        attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+            1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+
+            <Tip warning={true}>
+
+            `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
+            True`. For all models whose processor has `config.return_attention_mask == False`, such as
+            [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
+            `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
+            such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
+            that these models also yield slightly different results depending on whether `input_values` is padded or
+            not.
+
+            </Tip>
+
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
+    WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
+    def __init__(self, config: Wav2Vec2ConformerConfig):
+        super().__init__(config)
+        self.config = config
+        self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
+        self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
+
+        # model only needs masking vector if mask prob is > 0.0
+        if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+            self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
+
+        self.encoder = Wav2Vec2ConformerEncoder(config)
+
+        self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
+    def freeze_feature_encoder(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        self.feature_extractor._freeze_parameters()
+
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
+    def _mask_hidden_states(
+        self,
+        hidden_states: torch.FloatTensor,
+        mask_time_indices: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+    ):
+        """
+        Masks extracted features along time axis and/or along feature axis according to
+        [SpecAugment](https://arxiv.org/abs/1904.08779).
+        """
+
+        # `config.apply_spec_augment` can set masking to False
+        if not getattr(self.config, "apply_spec_augment", True):
+            return hidden_states
+
+        # generate indices & apply SpecAugment along time axis
+        batch_size, sequence_length, hidden_size = hidden_states.size()
+
+        if mask_time_indices is not None:
+            # apply SpecAugment along time axis with given mask_time_indices
+            hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+        elif self.config.mask_time_prob > 0 and self.training:
+            mask_time_indices = _compute_mask_indices(
+                (batch_size, sequence_length),
+                mask_prob=self.config.mask_time_prob,
+                mask_length=self.config.mask_time_length,
+                attention_mask=attention_mask,
+                min_masks=self.config.mask_time_min_masks,
+            )
+            mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
+            hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+
+        if self.config.mask_feature_prob > 0 and self.training:
+            # generate indices & apply SpecAugment along feature axis
+            mask_feature_indices = _compute_mask_indices(
+                (batch_size, hidden_size),
+                mask_prob=self.config.mask_feature_prob,
+                mask_length=self.config.mask_feature_length,
+                min_masks=self.config.mask_feature_min_masks,
+            )
+            mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
+            mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
+            hidden_states[mask_feature_indices] = 0
+
+        return hidden_states
+
+    @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=Wav2Vec2BaseModelOutput,
+        config_class=_CONFIG_FOR_DOC,
+        modality="audio",
+        expected_output=_EXPECTED_OUTPUT_SHAPE,
+    )
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
+    def forward(
+        self,
+        input_values: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        mask_time_indices: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        extract_features = self.feature_extractor(input_values)
+        extract_features = extract_features.transpose(1, 2)
+
+        if attention_mask is not None:
+            # compute reduced attention_mask corresponding to feature vectors
+            attention_mask = self._get_feature_vector_attention_mask(
+                extract_features.shape[1], attention_mask, add_adapter=False
+            )
+
+        hidden_states, extract_features = self.feature_projection(extract_features)
+        hidden_states = self._mask_hidden_states(
+            hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+        )
+
+        encoder_outputs = self.encoder(
+            hidden_states,
+            attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        hidden_states = encoder_outputs[0]
+
+        if self.adapter is not None:
+            hidden_states = self.adapter(hidden_states)
+
+        if not return_dict:
+            return (hidden_states, extract_features) + encoder_outputs[1:]
+
+        return Wav2Vec2BaseModelOutput(
+            last_hidden_state=hidden_states,
+            extract_features=extract_features,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+
+@add_start_docstrings(
+    """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
+)
+class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+    def __init__(self, config: Wav2Vec2ConformerConfig):
+        super().__init__(config)
+        self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+        self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
+
+        self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
+
+        self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
+        self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
+    def set_gumbel_temperature(self, temperature: int):
+        """
+        Set the Gumbel softmax temperature to a given value. Only necessary for training
+        """
+        self.quantizer.temperature = temperature
+
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+    def freeze_feature_encoder(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+    @staticmethod
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
+    def compute_contrastive_logits(
+        target_features: torch.FloatTensor,
+        negative_features: torch.FloatTensor,
+        predicted_features: torch.FloatTensor,
+        temperature: int = 0.1,
+    ):
+        """
+        Compute logits for contrastive loss based using cosine similarity as the distance measure between
+        `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
+        """
+        target_features = torch.cat([target_features, negative_features], dim=0)
+
+        logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
+            target_features
+        )
+
+        # apply temperature
+        logits = logits / temperature
+        return logits
+
+    @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
+    def forward(
+        self,
+        input_values: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        mask_time_indices: Optional[torch.BoolTensor] = None,
+        sampled_negative_indices: Optional[torch.BoolTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
+        r"""
+        mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
+            masked extracted features in *config.proj_codevector_dim* space.
+        sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
+            Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
+            Required input for pre-training.
+
+        Returns:
+
+        Example:
+
+        ```python
+        >>> import torch
+        >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
+        >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
+        ...     _compute_mask_indices,
+        ...     _sample_negative_indices,
+        ... )
+        >>> from datasets import load_dataset
+
+        >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+        >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+
+        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+        >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values  # Batch size 1
+
+        >>> # compute masked indices
+        >>> batch_size, raw_sequence_length = input_values.shape
+        >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
+        >>> mask_time_indices = _compute_mask_indices(
+        ...     shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
+        ... )
+        >>> sampled_negative_indices = _sample_negative_indices(
+        ...     features_shape=(batch_size, sequence_length),
+        ...     num_negatives=model.config.num_negatives,
+        ...     mask_time_indices=mask_time_indices,
+        ... )
+        >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
+        >>> sampled_negative_indices = torch.tensor(
+        ...     data=sampled_negative_indices, device=input_values.device, dtype=torch.long
+        ... )
+
+        >>> with torch.no_grad():
+        ...     outputs = model(input_values, mask_time_indices=mask_time_indices)
+
+        >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
+        >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
+
+        >>> # show that cosine similarity is much higher than random
+        >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
+        tensor(True)
+
+        >>> # for contrastive loss training model should be put into train mode
+        >>> model = model.train()
+        >>> loss = model(
+        ...     input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
+        ... ).loss
+        ```"""
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if mask_time_indices is not None:
+            mask_time_indices = mask_time_indices.to(torch.bool)
+
+        outputs = self.wav2vec2_conformer(
+            input_values,
+            attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            mask_time_indices=mask_time_indices,
+            return_dict=return_dict,
+        )
+
+        # 1. project all transformed features (including masked) to final vq dim
+        transformer_features = self.project_hid(outputs[0])
+
+        # 2. quantize all (unmasked) extracted features and project to final vq dim
+        extract_features = self.dropout_features(outputs[1])
+
+        if attention_mask is not None:
+            # compute reduced attention_mask correponding to feature vectors
+            attention_mask = self._get_feature_vector_attention_mask(
+                extract_features.shape[1], attention_mask, add_adapter=False
+            )
+
+        quantized_features, codevector_perplexity = self.quantizer(
+            extract_features, mask_time_indices=mask_time_indices
+        )
+        quantized_features = self.project_q(quantized_features)
+
+        loss = contrastive_loss = diversity_loss = None
+        if sampled_negative_indices is not None:
+            batch_size, sequence_length, hidden_size = quantized_features.shape
+
+            # for training, we sample negatives
+            # 3. sample K negatives (distractors) quantized states for contrastive loss
+            # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
+            # sample negative quantized vectors BTC => (BxT)C
+            negative_quantized_features = quantized_features.view(-1, hidden_size)[
+                sampled_negative_indices.long().view(-1)
+            ]
+            negative_quantized_features = negative_quantized_features.view(
+                batch_size, sequence_length, -1, hidden_size
+            ).permute(2, 0, 1, 3)
+
+            # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
+            # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
+            logits = self.compute_contrastive_logits(
+                quantized_features[None, :],
+                negative_quantized_features,
+                transformer_features,
+                self.config.contrastive_logits_temperature,
+            )
+
+            # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
+            # its cosine similarity will be masked
+            neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
+
+            if neg_is_pos.any():
+                logits[1:][neg_is_pos] = float("-inf")
+
+            # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
+            # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
+            logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
+            target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
+
+            contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
+            # 7. compute diversity loss: \mathbf{L}_d
+            num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
+            diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
+
+            # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
+            loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
+
+        if not return_dict:
+            if loss is not None:
+                return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+            return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+
+        return Wav2Vec2ConformerForPreTrainingOutput(
+            loss=loss,
+            projected_states=transformer_features,
+            projected_quantized_states=quantized_features,
+            codevector_perplexity=codevector_perplexity,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            contrastive_loss=contrastive_loss,
+            diversity_loss=diversity_loss,
+        )
+
+
+@add_start_docstrings(
+    """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+    WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+        self.dropout = nn.Dropout(config.final_dropout)
+
+        if config.vocab_size is None:
+            raise ValueError(
+                f"You are trying to instantiate {self.__class__} with a configuration that "
+                "does not define the vocabulary size of the language model head. Please "
+                "instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+                "or define `vocab_size` of your model's configuration."
+            )
+        output_hidden_size = (
+            config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+        )
+        self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+    def freeze_feature_encoder(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+    @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=CausalLMOutput,
+        config_class=_CONFIG_FOR_DOC,
+        expected_output=_CTC_EXPECTED_OUTPUT,
+        expected_loss=_CTC_EXPECTED_LOSS,
+    )
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+    def forward(
+        self,
+        input_values: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        labels: Optional[torch.Tensor] = None,
+    ) -> Union[Tuple, CausalLMOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+            config.vocab_size - 1]`.
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.wav2vec2_conformer(
+            input_values,
+            attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        hidden_states = outputs[0]
+        hidden_states = self.dropout(hidden_states)
+
+        logits = self.lm_head(hidden_states)
+
+        loss = None
+        if labels is not None:
+            if labels.max() >= self.config.vocab_size:
+                raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+            # retrieve loss input_lengths from attention_mask
+            attention_mask = (
+                attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
+            )
+            input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+
+            # assuming that padded tokens are filled with -100
+            # when not being attended to
+            labels_mask = labels >= 0
+            target_lengths = labels_mask.sum(-1)
+            flattened_targets = labels.masked_select(labels_mask)
+
+            # ctc_loss doesn't support fp16
+            log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+            with torch.backends.cudnn.flags(enabled=False):
+                loss = nn.functional.ctc_loss(
+                    log_probs,
+                    flattened_targets,
+                    input_lengths,
+                    target_lengths,
+                    blank=self.config.pad_token_id,
+                    reduction=self.config.ctc_loss_reduction,
+                    zero_infinity=self.config.ctc_zero_infinity,
+                )
+
+        if not return_dict:
+            output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+            return ((loss,) + output) if loss is not None else output
+
+        return CausalLMOutput(
+            loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+        )
+
+
+@add_start_docstrings(
+    """
+    Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
+    tasks like SUPERB Keyword Spotting.
+    """,
+    WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+    def __init__(self, config):
+        super().__init__(config)
+
+        if hasattr(config, "add_adapter") and config.add_adapter:
+            raise ValueError(
+                "Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
+            )
+        self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+        num_layers = config.num_hidden_layers + 1  # transformer layers + input embeddings
+        if config.use_weighted_layer_sum:
+            self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+        self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
+        self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+    def freeze_feature_encoder(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+    def freeze_base_model(self):
+        """
+        Calling this function will disable the gradient computation for the base model so that its parameters will not
+        be updated during training. Only the classification head will be updated.
+        """
+        for param in self.wav2vec2_conformer.parameters():
+            param.requires_grad = False
+
+    @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=SequenceClassifierOutput,
+        config_class=_CONFIG_FOR_DOC,
+        modality="audio",
+    )
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+    def forward(
+        self,
+        input_values: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        labels: Optional[torch.Tensor] = None,
+    ) -> Union[Tuple, SequenceClassifierOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+        outputs = self.wav2vec2_conformer(
+            input_values,
+            attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        if self.config.use_weighted_layer_sum:
+            hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+            hidden_states = torch.stack(hidden_states, dim=1)
+            norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+            hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+        else:
+            hidden_states = outputs[0]
+
+        hidden_states = self.projector(hidden_states)
+        if attention_mask is None:
+            pooled_output = hidden_states.mean(dim=1)
+        else:
+            padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
+            hidden_states[~padding_mask] = 0.0
+            pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
+
+        logits = self.classifier(pooled_output)
+
+        loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+
+        if not return_dict:
+            output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+            return ((loss,) + output) if loss is not None else output
+
+        return SequenceClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@add_start_docstrings(
+    """
+    Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
+    """,
+    WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+    def __init__(self, config):
+        super().__init__(config)
+
+        if hasattr(config, "add_adapter") and config.add_adapter:
+            raise ValueError(
+                "Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
+            )
+        self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+        num_layers = config.num_hidden_layers + 1  # transformer layers + input embeddings
+        if config.use_weighted_layer_sum:
+            self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+        self.num_labels = config.num_labels
+
+        self.init_weights()
+
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+    def freeze_feature_encoder(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
+    def freeze_base_model(self):
+        """
+        Calling this function will disable the gradient computation for the base model so that its parameters will not
+        be updated during training. Only the classification head will be updated.
+        """
+        for param in self.wav2vec2_conformer.parameters():
+            param.requires_grad = False
+
+    @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=TokenClassifierOutput,
+        config_class=_CONFIG_FOR_DOC,
+        modality="audio",
+    )
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
+    def forward(
+        self,
+        input_values: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        labels: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, TokenClassifierOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+        outputs = self.wav2vec2_conformer(
+            input_values,
+            attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        if self.config.use_weighted_layer_sum:
+            hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+            hidden_states = torch.stack(hidden_states, dim=1)
+            norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+            hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+        else:
+            hidden_states = outputs[0]
+
+        logits = self.classifier(hidden_states)
+
+        loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
+        if not return_dict:
+            output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+            return output
+
+        return TokenClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
+class AMSoftmaxLoss(nn.Module):
+    def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
+        super(AMSoftmaxLoss, self).__init__()
+        self.scale = scale
+        self.margin = margin
+        self.num_labels = num_labels
+        self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
+        self.loss = nn.CrossEntropyLoss()
+
+    def forward(self, hidden_states, labels):
+        labels = labels.flatten()
+        weight = nn.functional.normalize(self.weight, dim=0)
+        hidden_states = nn.functional.normalize(hidden_states, dim=1)
+        cos_theta = torch.mm(hidden_states, weight)
+        psi = cos_theta - self.margin
+
+        onehot = nn.functional.one_hot(labels, self.num_labels)
+        logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
+        loss = self.loss(logits, labels)
+
+        return loss
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
+class TDNNLayer(nn.Module):
+    def __init__(self, config, layer_id=0):
+        super().__init__()
+        self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
+        self.out_conv_dim = config.tdnn_dim[layer_id]
+        self.kernel_size = config.tdnn_kernel[layer_id]
+        self.dilation = config.tdnn_dilation[layer_id]
+
+        self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
+        self.activation = nn.ReLU()
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.unsqueeze(1)
+        hidden_states = nn.functional.unfold(
+            hidden_states,
+            (self.kernel_size, self.in_conv_dim),
+            stride=(1, self.in_conv_dim),
+            dilation=(self.dilation, 1),
+        )
+        hidden_states = hidden_states.transpose(1, 2)
+        hidden_states = self.kernel(hidden_states)
+
+        hidden_states = self.activation(hidden_states)
+        return hidden_states
+
+
+@add_start_docstrings(
+    """
+    Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+    """,
+    WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+        num_layers = config.num_hidden_layers + 1  # transformer layers + input embeddings
+        if config.use_weighted_layer_sum:
+            self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+        self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
+
+        tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
+        self.tdnn = nn.ModuleList(tdnn_layers)
+
+        self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
+        self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
+
+        self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
+
+        self.init_weights()
+
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+    def freeze_feature_encoder(self):
+        """
+        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+        not be updated during training.
+        """
+        self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
+    def freeze_base_model(self):
+        """
+        Calling this function will disable the gradient computation for the base model so that its parameters will not
+        be updated during training. Only the classification head will be updated.
+        """
+        for param in self.wav2vec2_conformer.parameters():
+            param.requires_grad = False
+
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
+    def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+        """
+        Computes the output length of the TDNN layers
+        """
+
+        def _conv_out_length(input_length, kernel_size, stride):
+            # 1D convolutional layer output length formula taken
+            # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+            return (input_length - kernel_size) // stride + 1
+
+        for kernel_size in self.config.tdnn_kernel:
+            input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
+
+        return input_lengths
+
+    @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=XVectorOutput,
+        config_class=_CONFIG_FOR_DOC,
+        modality="audio",
+    )
+    # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+    def forward(
+        self,
+        input_values: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        labels: Optional[torch.Tensor] = None,
+    ) -> Union[Tuple, XVectorOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+        outputs = self.wav2vec2_conformer(
+            input_values,
+            attention_mask=attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        if self.config.use_weighted_layer_sum:
+            hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+            hidden_states = torch.stack(hidden_states, dim=1)
+            norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+            hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+        else:
+            hidden_states = outputs[0]
+
+        hidden_states = self.projector(hidden_states)
+
+        for tdnn_layer in self.tdnn:
+            hidden_states = tdnn_layer(hidden_states)
+
+        # Statistic Pooling
+        if attention_mask is None:
+            mean_features = hidden_states.mean(dim=1)
+            std_features = hidden_states.std(dim=1)
+        else:
+            feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
+            tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
+            mean_features = []
+            std_features = []
+            for i, length in enumerate(tdnn_output_lengths):
+                mean_features.append(hidden_states[i, :length].mean(dim=0))
+                std_features.append(hidden_states[i, :length].std(dim=0))
+            mean_features = torch.stack(mean_features)
+            std_features = torch.stack(std_features)
+        statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
+
+        output_embeddings = self.feature_extractor(statistic_pooling)
+        logits = self.classifier(output_embeddings)
+
+        loss = None
+        if labels is not None:
+            loss = self.objective(logits, labels)
+
+        if not return_dict:
+            output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
+            return ((loss,) + output) if loss is not None else output
+
+        return XVectorOutput(
+            loss=loss,
+            logits=logits,
+            embeddings=output_embeddings,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
diff --git a/slam_llm/models/musicfm/modules/random_quantizer.py b/slam_llm/models/musicfm/modules/random_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..257164a951b2137e377369605abb74a2557bd965
--- /dev/null
+++ b/slam_llm/models/musicfm/modules/random_quantizer.py
@@ -0,0 +1,83 @@
+# MIT License
+#
+# Copyright 2023 ByteDance Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
+# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+
+import torch
+from torch import nn, einsum
+from einops import rearrange
+
+
+class RandomProjectionQuantizer(nn.Module):
+    """
+    Random projection and codebook lookup module
+
+    Some code is borrowed from:
+     https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
+    But I did normalization using pre-computed global mean & variance instead of using layer norm.
+    """
+
+    def __init__(
+        self,
+        input_dim,
+        codebook_dim,
+        codebook_size,
+        seed=142,
+    ):
+        super().__init__()
+
+        # random seed
+        torch.manual_seed(seed)
+
+        # randomly initialized projection
+        random_projection = torch.empty(input_dim, codebook_dim)
+        nn.init.xavier_normal_(random_projection)
+        self.register_buffer("random_projection", random_projection)
+
+        # randomly initialized codebook
+        codebook = torch.empty(codebook_size, codebook_dim)
+        nn.init.normal_(codebook)
+        self.register_buffer("codebook", codebook)
+
+    def codebook_lookup(self, x):
+        # reshape
+        b = x.shape[0]
+        x = rearrange(x, "b n e -> (b n) e")
+
+        # L2 normalization
+        normalized_x = nn.functional.normalize(x, dim=1, p=2)
+        normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
+
+        # compute distances
+        distances = torch.cdist(normalized_codebook, normalized_x)
+
+        # get nearest
+        nearest_indices = torch.argmin(distances, dim=0)
+
+        # reshape
+        xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
+
+        return xq
+
+    @torch.no_grad()
+    def forward(self, x):
+        # always eval
+        self.eval()
+
+        # random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
+        x = einsum("b n d, d e -> b n e", x, self.random_projection)
+
+        # codebook lookup
+        xq = self.codebook_lookup(x)
+
+        return xq
diff --git a/slam_llm/models/projector.py b/slam_llm/models/projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..69825c58c037de5b116dfafa9d86fe437b3c1cc5
--- /dev/null
+++ b/slam_llm/models/projector.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+
+
+class EncoderProjectorConcat(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.k = config.encoder_projector_ds_rate
+        self.encoder_dim = config.encoder_dim
+        self.llm_dim = config.llm_dim
+        self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048)
+        self.relu = nn.ReLU()
+        self.linear2 = nn.Linear(2048, config.llm_dim)
+
+    def forward(self, x):
+        batch_size, seq_len, dim = x.size()
+        num_frames_to_discard = seq_len % self.k
+        if num_frames_to_discard > 0:
+            x = x[:, :-num_frames_to_discard, :]
+        seq_len = x.size(1)
+        
+        x = x.contiguous()
+        x = x.view(batch_size, seq_len // self.k, dim * self.k)
+        x = self.linear1(x)
+        x = self.relu(x)
+        x = self.linear2(x)
+        return x
+
+class EncoderProjectorCov1d(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.k = config.encoder_projector_ds_rate
+        self.encoder_dim = config.encoder_dim
+        self.llm_dim = config.llm_dim
+        self.conv1d = nn.Conv1d(in_channels=self.encoder_dim, out_channels=self.encoder_dim, kernel_size=self.k, stride=self.k, padding=0)
+        self.linear1 = nn.Linear(self.encoder_dim, 2048)
+        self.relu1 = nn.ReLU()
+        self.linear2 = nn.Linear(2048, self.llm_dim)
+        self.relu2 = nn.ReLU()
+    
+    def forward(self, x):
+        x = x.transpose(1, 2)
+        x = self.conv1d(x)
+        x = x.transpose(1, 2)
+        x = self.relu1(x)
+        x = self.linear1(x)
+        x = self.relu2(x)
+        x = self.linear2(x)
+        return x
+
+class EncoderProjectorQFormer(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.encoder_dim = config.encoder_dim
+        self.llm_dim = config.llm_dim
+        from transformers import Blip2QFormerConfig, Blip2QFormerModel
+        configuration = Blip2QFormerConfig()
+        configuration.encoder_hidden_size = self.encoder_dim
+        configuration.num_hidden_layers = 8
+
+        self.query_len = 64
+        self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))
+        self.query.data.normal_(mean=0.0, std=1.0)
+        self.qformer = Blip2QFormerModel(configuration)
+
+        self.linear = nn.Linear(configuration.hidden_size, self.llm_dim)
+        self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
+
+    def forward(self, x, atts):
+        query = self.query.expand(x.shape[0], -1, -1)
+        
+        query_output = self.qformer(
+            query_embeds=query,
+            encoder_hidden_states=x,
+            encoder_attention_mask=atts,
+            return_dict=True,
+        )
+        
+        query_proj = self.norm(self.linear(query_output.last_hidden_state))
+        
+        return query_proj
\ No newline at end of file
diff --git a/slam_llm/models/slam_model.py b/slam_llm/models/slam_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3d7e032e53f7c2d3f3e741142b360f3cb8710ba
--- /dev/null
+++ b/slam_llm/models/slam_model.py
@@ -0,0 +1,443 @@
+import os
+import types
+import torch
+import soundfile as sf
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+from typing import List, Optional, Tuple, Union
+from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
+from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
+
+from slam_llm.utils.config_utils import generate_peft_config
+from slam_llm.utils.train_utils import print_module_size, print_model_size
+from peft import PeftModel, PeftConfig
+from torch.nn import CrossEntropyLoss
+from slam_llm.utils.metric import compute_accuracy
+
+import logging
+logger = logging.getLogger(__name__)
+
+def model_factory(train_config, model_config, **kwargs):
+    # return necessary components for training
+    tokenizer = setup_tokenizer(train_config, model_config, **kwargs)
+
+    encoder = setup_encoder(train_config, model_config, **kwargs)
+
+    # llm
+    llm = setup_llm(train_config, model_config, **kwargs)
+
+    # projector
+    encoder_projector = setup_encoder_projector(
+        train_config, model_config, **kwargs
+    )
+    model = slam_model(
+        encoder,
+        llm,
+        encoder_projector,
+        tokenizer,
+        train_config,
+        model_config,
+        **kwargs,
+    )
+
+    ckpt_path = kwargs.get("ckpt_path", None) #FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft)
+    if ckpt_path is not None:
+            logger.info("loading other parts from: {}".format(ckpt_path))
+            ckpt_dict = torch.load(ckpt_path, map_location="cpu")
+            model.load_state_dict(ckpt_dict, strict=False)
+
+    print_model_size(model, train_config, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
+    return model, tokenizer
+
+
+def setup_tokenizer(train_config, model_config, **kwargs):
+    # Load the tokenizer and add special tokens
+    if "vallex" in model_config.llm_name.lower():
+        return None  
+    elif "mupt" in model_config.llm_name.lower():
+        tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path,
+                                            trust_remote_code=True,
+                                            use_fast=False)
+    else:
+        tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path)
+        tokenizer.pad_token_id = tokenizer.eos_token_id
+    return tokenizer
+
+
+def setup_encoder(train_config, model_config, **kwargs):
+    encoder_list = model_config.encoder_name.split(",") if model_config.encoder_name else []
+    if len(encoder_list) == 0:
+        return None
+    if len(encoder_list) == 1:
+        encoder_name = encoder_list[0]
+        if encoder_name == "whisper" or encoder_name == "qwen-audio":
+            from slam_llm.models.encoder import WhisperWrappedEncoder
+            encoder = WhisperWrappedEncoder.load(model_config)
+        if encoder_name == "beats": 
+            from slam_llm.models.encoder import BEATsEncoder
+            encoder = BEATsEncoder.load(model_config)
+        if encoder_name == "eat":
+            from slam_llm.models.encoder import EATEncoder
+            encoder = EATEncoder.load(model_config)
+        if encoder_name == "SpatialAST":
+            from slam_llm.models.encoder import SpatialASTEncoder
+            encoder = SpatialASTEncoder.load(model_config)
+        if encoder_name == "wavlm":
+            from slam_llm.models.encoder import WavLMEncoder
+            encoder = WavLMEncoder.load(model_config)
+        if encoder_name == "av_hubert":
+            from slam_llm.models.encoder import AVHubertEncoder
+            encoder = AVHubertEncoder.load(model_config)
+        if encoder_name == "hubert":
+            from slam_llm.models.encoder import HubertEncoder
+            encoder = HubertEncoder.load(model_config)
+        if encoder_name == "musicfm":
+            from slam_llm.models.encoder import MusicFMEncoder
+            encoder = MusicFMEncoder.load(model_config)
+
+        if "llama" in encoder_name.lower():
+            from slam_llm.models.encoder import HfTextEncoder
+            encoder = HfTextEncoder.load(model_config)
+    print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
+
+    if train_config.freeze_encoder:
+        for name, param in encoder.named_parameters(): 
+            param.requires_grad = False
+        encoder.eval()
+    print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
+
+    return encoder
+
+def setup_llm(train_config, model_config, **kwargs):
+    from pkg_resources import packaging
+    use_cache = False if train_config.enable_fsdp or train_config.enable_ddp else None
+    if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.low_cpu_fsdp:
+        """
+        for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
+        this avoids cpu oom when loading large models like llama 70B, in which case
+        model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
+        overhead and currently requires latest nightly.
+        """
+        # v = packaging.version.parse(torch.__version__)
+        # verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
+        # if not verify_latest_nightly:
+        #     raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
+        #                     "please install latest nightly.")
+        rank = int(os.environ["RANK"])
+        if rank == 0:
+            if "vallex" in model_config.llm_name.lower():
+                from src.slam_llm.models.vallex.vallex_config import VallexConfig
+                from src.slam_llm.models.vallex.vallex_model import VALLE
+                vallex_config = VallexConfig(
+                    **model_config
+                )
+                model = VALLE(vallex_config)
+            elif "aya" in model_config.llm_name.lower():
+                model = AutoModelForSeq2SeqLM.from_pretrained(
+                    model_config.llm_path,
+                    load_in_8bit=True if train_config.quantization else None,
+                    device_map="auto" if train_config.quantization else None,
+                    use_cache=use_cache,
+                )
+            else:
+                model = AutoModelForCausalLM.from_pretrained(
+                    model_config.llm_path,
+                    load_in_8bit=True if train_config.quantization else None,
+                    device_map="auto" if train_config.quantization else None,
+                    use_cache=use_cache,
+                )
+        else:
+            llama_config = AutoConfig.from_pretrained(model_config.llm_path)
+            llama_config.use_cache = use_cache
+            # with torch.device("meta"):
+            if "aya" in model_config.llm_name.lower():
+                model = AutoModelForSeq2SeqLM(llama_config)
+            else:
+                model = AutoModelForCausalLM(llama_config) #(FIX:MZY): torch 2.0.1 does not support `meta`
+
+    else:
+        if "vallex" in model_config.llm_name.lower():
+            from src.slam_llm.models.vallex.vallex_config import VallexConfig
+            from src.slam_llm.models.vallex.vallex_model import VALLE
+            vallex_config = VallexConfig(
+                **model_config
+            )
+            model = VALLE(vallex_config)
+        elif "aya" in model_config.llm_name.lower():
+            model = AutoModelForSeq2SeqLM.from_pretrained(
+                model_config.llm_path,
+                load_in_8bit=True if train_config.quantization else None,
+                device_map="auto" if train_config.quantization else None,
+                use_cache=use_cache,
+            )
+        else:
+            model = AutoModelForCausalLM.from_pretrained(
+                model_config.llm_path,
+                load_in_8bit=True if train_config.quantization else None,
+                device_map="auto" if train_config.quantization else None,
+                use_cache=use_cache,
+            )
+    if (train_config.enable_fsdp or train_config.enable_ddp) and train_config.use_fast_kernels:
+        """
+        For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
+        using of Flash Attention or Xformer memory-efficient kernels
+        based on the hardware being used. This would speed up fine-tuning.
+        """
+        try:
+            from optimum.bettertransformer import BetterTransformer
+            model = BetterTransformer.transform(model)
+        except ImportError:
+            logger.warning("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
+
+    print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
+
+    # Prepare the model for int8 training if quantization is enabled
+    if train_config.quantization:
+        model = prepare_model_for_kbit_training(model)
+
+    if train_config.freeze_llm: # TODO:to test offical `freeze_layers` and `num_freeze_layers`
+        for name, param in model.named_parameters(): 
+            param.requires_grad = False
+        model.eval()
+        
+    if kwargs.get("peft_ckpt", None): # (FIX:MZY):reload will get wrong results when decoding
+        logger.info("loading peft_ckpt from: {}".format(kwargs.get("peft_ckpt")))
+        model = PeftModel.from_pretrained(model=model, model_id=kwargs.get("peft_ckpt"), is_trainable=True)
+        model.print_trainable_parameters()
+    elif train_config.use_peft:
+        logger.info("setup peft...")
+        peft_config = generate_peft_config(train_config)
+        model = get_peft_model(model, peft_config)
+        model.print_trainable_parameters()
+
+    print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
+    return model
+
+def setup_encoder_projector(train_config, model_config, **kwargs):
+    if model_config.encoder_projector == "linear":
+        from slam_llm.models.projector import EncoderProjectorConcat
+        encoder_projector = EncoderProjectorConcat(model_config)
+    elif model_config.encoder_projector == "cov1d-linear":
+        from slam_llm.models.projector import EncoderProjectorCov1d
+        encoder_projector = EncoderProjectorCov1d(model_config)
+    elif model_config.encoder_projector == "q-former":
+        from slam_llm.models.projector import EncoderProjectorQFormer
+        encoder_projector = EncoderProjectorQFormer(model_config)
+    else:
+        return None
+    print_module_size(encoder_projector, model_config.encoder_projector, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
+    return encoder_projector
+
+
+class slam_model(nn.Module):
+    def __init__(
+        self,
+        encoder: nn.Module,
+        llm: nn.Module,
+        encoder_projector: nn.Module,
+        tokenizer, 
+        train_config, 
+        model_config, 
+        **kwargs
+    ):
+        super().__init__()
+        # modality encoder 
+        self.encoder = encoder
+
+        # llm
+        self.llm = llm
+
+        # projector
+        self.encoder_projector = encoder_projector
+
+        # tokenizer
+        self.tokenizer = tokenizer
+        self.metric = kwargs.get("metric", "acc")
+
+        self.train_config = train_config
+        self.model_config = model_config
+
+        if train_config.get("enable_deepspeed", False):
+            def new_forward(self, input):
+                output = F.layer_norm(
+                    input.float(),
+                    self.normalized_shape,
+                    self.weight.float() if self.weight is not None else None,
+                    self.bias.float() if self.bias is not None else None,
+                    self.eps,
+                )
+                return output.type_as(input)
+            for item in self.modules():
+                if isinstance(item, nn.LayerNorm):
+                    item.forward = types.MethodType(new_forward, item)
+
+
+
+    def forward(self,
+                input_ids: torch.LongTensor = None,
+                attention_mask: Optional[torch.Tensor] = None,
+                position_ids: Optional[torch.LongTensor] = None,
+                past_key_values: Optional[List[torch.FloatTensor]] = None,
+                inputs_embeds: Optional[torch.FloatTensor] = None,
+                labels: Optional[torch.LongTensor] = None,
+                use_cache: Optional[bool] = None,
+                output_attentions: Optional[bool] = None,
+                output_hidden_states: Optional[bool] = None,
+                return_dict: Optional[bool] = None,
+                **kwargs,
+                ):
+        audio_mel = kwargs.get("audio_mel", None)
+        audio_mel_mask = kwargs.get("audio_mel_mask", None)
+        audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper
+
+        audio = kwargs.get("audio", None)
+        audio_mask = kwargs.get("audio_mask", None)
+        visual = kwargs.get("visual", None)
+        visual_mask = kwargs.get("visual_mask", None)
+
+
+        # for text encoder
+        instruct_ids = kwargs.get("instruct_ids", None)
+        instruct_mask = kwargs.get("instruct_mask", None)
+
+        modality_mask = kwargs.get("modality_mask", None)
+        
+        zh_data = kwargs.get("zh", None)
+        en_data = kwargs.get("en", None)
+
+        encoder_outs = None
+        if audio_mel is not None or audio is not None or visual is not None:
+            if self.train_config.freeze_encoder: # freeze encoder
+                self.encoder.eval()
+
+            if self.model_config.encoder_name == "whisper":
+                encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim
+            if self.model_config.encoder_name == "beats":
+                encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim
+            if self.model_config.encoder_name == "eat":
+                encoder_outs = self.encoder.model.extract_features(audio_mel.unsqueeze(dim=1), padding_mask = None, mask=False, remove_extra_tokens = False)['x']
+            if self.model_config.encoder_name == "SpatialAST":
+                encoder_outs = self.encoder(audio) # output: [bs, seq_len=3+512, dim=768]
+            if self.model_config.encoder_name == "wavlm":
+                encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask
+            if self.model_config.encoder_name == "hubert":
+                results = self.encoder(source = audio, padding_mask = 1-audio_mask)
+                if self.model_config.encoder_type == "pretrain":
+                    encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"]
+                if self.model_config.encoder_type == "finetune":
+                    encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"]
+                    encoder_outs = encoder_outs.transpose(0, 1)
+            if self.model_config.encoder_name == "av_hubert":
+                results = self.encoder(source={'video':visual, 'audio':audio}, padding_mask=visual_mask) # bs*seq*dim  
+                encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"]
+                encoder_outs = encoder_outs.transpose(0, 1)
+                audio_mel_post_mask = (~audio_mel_post_mask).float()
+            if self.model_config.encoder_name == 'musicfm':
+                encoder_outs = self.encoder.extract_features(audio, padding_mask = None) # MusicFM doesn't support padding mask 
+            if self.encoder is None:
+                encoder_outs = audio_mel if audio_mel is not None else audio
+
+            if self.model_config.encoder_projector == "q-former":
+                encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
+            if self.model_config.encoder_projector == "linear":
+                encoder_outs = self.encoder_projector(encoder_outs)
+            if self.model_config.encoder_projector == "cov1d-linear": 
+                encoder_outs = self.encoder_projector(encoder_outs) 
+
+        if instruct_ids is not None:
+            if self.encoder is not None:
+                encoder_outs = self.encoder(input_ids=instruct_ids, attention_mask=instruct_mask).last_hidden_state
+
+            if self.model_config.encoder_projector == "q-former":
+                encoder_outs = self.encoder_projector(encoder_outs, instruct_mask)
+            if self.model_config.encoder_projector == "linear":
+                encoder_outs = self.encoder_projector(encoder_outs)
+
+        if input_ids is not None:
+            input_ids[input_ids == -1] = 0
+            if isinstance(self.llm, T5ForConditionalGeneration):
+                inputs_embeds = self.llm.shared(input_ids)
+            else:
+                if hasattr(self.llm.model, "embed_tokens"):
+                    inputs_embeds = self.llm.model.embed_tokens(input_ids)
+                elif hasattr(self.llm.model.model, "embed_tokens"):
+                    inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
+                else:
+                    inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
+
+        if modality_mask is not None:
+            modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1)
+            modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist()
+
+            encoder_outs_pad = torch.zeros_like(inputs_embeds)
+            for i in range(encoder_outs.shape[0]):
+                encoder_outs_pad[
+                    i, modality_mask_start_indices[i]:modality_mask_start_indices[i]+modality_lengths[i]
+                ] = encoder_outs[i][:modality_lengths[i]]
+            
+            inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None])
+
+        if kwargs.get("inference_mode", False):
+            return inputs_embeds, attention_mask
+
+        if zh_data is not None and en_data is not None:
+            model_outputs, acc = self.llm(zh=zh_data, en=en_data)
+        else:
+            model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
+            acc = -1
+            if self.metric:
+                with torch.no_grad():
+                    preds = torch.argmax(model_outputs.logits, -1)
+                    acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100)
+
+        return model_outputs, acc
+    
+    @torch.no_grad()
+    def generate(self,
+                input_ids: torch.LongTensor = None,
+                attention_mask: Optional[torch.Tensor] = None,
+                position_ids: Optional[torch.LongTensor] = None,
+                past_key_values: Optional[List[torch.FloatTensor]] = None,
+                inputs_embeds: Optional[torch.FloatTensor] = None,
+                labels: Optional[torch.LongTensor] = None,
+                use_cache: Optional[bool] = None,
+                output_attentions: Optional[bool] = None,
+                output_hidden_states: Optional[bool] = None,
+                return_dict: Optional[bool] = None,
+                **kwargs,
+                ):
+        kwargs["inference_mode"] = True
+
+        inputs_embeds, attention_mask = self.forward(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            labels=labels,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            **kwargs,
+        )
+
+        model_outputs = self.llm.generate(
+            inputs_embeds=inputs_embeds,
+            # max_length=kwargs.get("max_length", 200),
+            max_new_tokens=kwargs.get("max_new_tokens", 200),
+            num_beams=kwargs.get("num_beams", 4),
+            do_sample=kwargs.get("do_sample", False),
+            min_length=kwargs.get("min_length", 1),
+            top_p=kwargs.get("top_p", 1.0),
+            repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+            length_penalty=kwargs.get("length_penalty", 1.0),
+            temperature=kwargs.get("temperature", 1.0),
+            attention_mask=attention_mask,
+            bos_token_id=self.tokenizer.bos_token_id,
+            eos_token_id=self.tokenizer.eos_token_id,
+            pad_token_id=self.tokenizer.pad_token_id
+        )
+
+        return model_outputs
diff --git a/slam_llm/models/vallex/__init__.py b/slam_llm/models/vallex/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/slam_llm/models/vallex/activation.py b/slam_llm/models/vallex/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6f3e55e87b32b1045a5323c720a93a9efefe13a
--- /dev/null
+++ b/slam_llm/models/vallex/activation.py
@@ -0,0 +1,179 @@
+from typing import Optional, Tuple, List
+import math
+
+import torch
+from torch import Tensor
+from torch.nn import Linear, Module
+from torch.nn import functional as F
+from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
+
+
+class MultiheadAttention(Module):
+    __constants__ = ["batch_first"]
+    bias_k: Optional[torch.Tensor]
+    bias_v: Optional[torch.Tensor]
+
+    def __init__(
+            self,
+            embed_dim,
+            num_heads,
+            dropout=0.0,
+            bias=True,
+            add_bias_kv=False,
+            add_zero_attn=False,
+            kdim=None,
+            vdim=None,
+            batch_first=False,
+            linear1_cls=Linear,
+            linear2_cls=Linear,
+            device=None,
+            dtype=None,
+    ) -> None:
+        factory_kwargs = {"device": device, "dtype": dtype}
+        super(MultiheadAttention, self).__init__()
+        self.embed_dim = embed_dim
+        self.kdim = kdim if kdim is not None else embed_dim
+        self.vdim = vdim if vdim is not None else embed_dim
+        self._qkv_same_embed_dim = False
+
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.batch_first = batch_first
+        self.head_dim = embed_dim // num_heads
+        self.num_heads = num_heads
+        assert (
+                self.head_dim * num_heads == self.embed_dim
+        ), "embed_dim must be divisible by num_heads"
+
+        self.k_proj = Linear(self.kdim, embed_dim)
+        self.v_proj = Linear(self.kdim, embed_dim)
+        self.q_proj = Linear(self.kdim, embed_dim)
+        
+        self.out_proj = NonDynamicallyQuantizableLinear(
+            embed_dim, embed_dim, bias=bias, **factory_kwargs
+        )
+
+        self.add_zero_attn = add_zero_attn
+        self.scaling = self.head_dim**-0.5
+
+    def __setstate__(self, state):
+        # Support loading old MultiheadAttention checkpoints generated by v1.1.0
+        if "_qkv_same_embed_dim" not in state:
+            state["_qkv_same_embed_dim"] = True
+
+        super(MultiheadAttention, self).__setstate__(state)
+
+    def forward(
+            self,
+            query: Tensor,
+            key: Tensor,
+            value: Tensor,
+            key_padding_mask: Optional[Tensor] = None,
+            need_weights: bool = True,
+            attn_mask: Optional[Tensor] = None,
+            average_attn_weights: bool = True,
+    ) -> Tuple[Tensor, Optional[Tensor]]:
+
+        # T,B,C
+        B, T, C = query.size()
+        
+        q = self.q_proj(query)
+        k = self.k_proj(key)
+        v = self.v_proj(value)
+        q *= self.scaling
+        
+        k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)  # (B, nh, T, hs)
+        q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)  # (B, nh, T, hs)
+        v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)  # (B, nh, T, hs)
+        
+        attn_weights = q @ k.transpose(-2, -1) # B, nh, T, T
+        
+        if attn_mask is not None:
+            # attn_mask is inf
+            # attn_mask = attn_mask.unsqueeze(0)
+            # attn_weights += attn_mask
+            if torch.is_floating_point(attn_mask):
+                # print(attn_weights.size(), attn_mask.size())
+                attn_weights += attn_mask.unsqueeze(0).unsqueeze(1)
+            else:
+                attn_weights = attn_weights.masked_fill(attn_mask, float('-inf'))
+
+        if key_padding_mask is not None:
+            # don't attend to padding symbols
+            attn_weights = attn_weights.view(B, self.num_heads, T, T)
+            attn_weights = attn_weights.masked_fill(
+                key_padding_mask.unsqueeze(1)
+                .unsqueeze(2)
+                .to(torch.bool),
+                float("-inf"),
+            )
+        attn_weights_float = F.softmax(attn_weights, dim=-1)
+        attn = attn_weights_float @ v
+        
+        y = attn.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side
+        y = self.out_proj(y)
+        return y, attn_weights
+    
+    def infer(self, 
+              x: Tensor,
+              key_padding_mask: Optional[Tensor] = None,
+              need_weights: bool = True,
+              attn_mask: Optional[Tensor] = None,
+              average_attn_weights: bool = True,
+              past_kv = None,
+              use_cache = False):
+        
+        # print("debug:"+str(x.size()))
+        
+        B, T, C = x.size()
+        
+        q = self.q_proj(x)
+        k = self.k_proj(x)
+        v = self.v_proj(x)
+        q *= self.scaling
+        
+        # k = k.view(T, B*self.num_heads, self.head_dim).transpose(0, 1)  # (B, nh, T, hs)
+        # q = q.view(T, B*self.num_heads, self.head_dim).transpose(0, 1)  # (B, nh, T, hs)
+        # v = v.view(T, B*self.num_heads, self.head_dim).transpose(0, 1)  # (B, nh, T, hs)
+        
+        k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)  # (B, nh, T, hs)
+        q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)  # (B, nh, T, hs)
+        v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)  # (B, nh, T, hs)
+        
+        if past_kv is not None:
+            past_key = past_kv[0]
+            past_value = past_kv[1]
+            k = torch.cat((past_key, k), dim=-2)
+            v = torch.cat((past_value, v), dim=-2)
+        
+        FULL_T = k.shape[-2]
+        
+        if use_cache is True:
+            present = (k, v)
+        else:
+            present = None
+        
+        # print(q.size(), k.size())
+        attn_weights = q @ k.transpose(-2, -1)
+        # print(attn_mask.size())
+        attn_weights = attn_weights.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
+        
+        # if key_padding_mask is not None:
+        #     # don't attend to padding symbols
+        #     attn_weights = attn_weights.view(B, self.num_heads, T, T)
+        #     attn_weights = attn_weights.view(B, -1, self.num_heads, T, T)
+        #     attn_weights = attn_weights.masked_fill(
+        #         key_padding_mask.unsqueeze(1)
+        #         .unsqueeze(2)
+        #         .unsqueeze(3)
+        #         .to(torch.bool),
+        #         float("-inf"),
+        #     )
+        attn_weights_float = F.softmax(attn_weights, dim=-1, )
+        # attn_weights = attn_weights_float.type_as(attn_weights)
+        # attn = torch.bmm(attn_weights, v)
+        attn = attn_weights_float @ v
+        
+        y = attn.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side
+        y = self.out_proj(y)
+        return (y, present)
\ No newline at end of file
diff --git a/slam_llm/models/vallex/scaling.py b/slam_llm/models/vallex/scaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..a34de303ce471279cdb78003f5409927c6c941a6
--- /dev/null
+++ b/slam_llm/models/vallex/scaling.py
@@ -0,0 +1,1404 @@
+# Copyright    2022  Xiaomi Corp.        (authors: Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import collections
+import logging
+import random
+import math
+from functools import reduce
+from itertools import repeat
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import Embedding as ScaledEmbedding
+
+class Transpose(nn.Identity):
+    """(N, T, D) -> (N, D, T)"""
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        return input.transpose(1, 2)
+
+class ActivationBalancerFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(
+        ctx,
+        x: Tensor,
+        scale_factor: Tensor,
+        sign_factor: Optional[Tensor],
+        channel_dim: int,
+    ) -> Tensor:
+        if channel_dim < 0:
+            channel_dim += x.ndim
+        ctx.channel_dim = channel_dim
+        xgt0 = x > 0
+        if sign_factor is None:
+            ctx.save_for_backward(xgt0, scale_factor)
+        else:
+            ctx.save_for_backward(xgt0, scale_factor, sign_factor)
+        return x
+
+    @staticmethod
+    def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
+        if len(ctx.saved_tensors) == 3:
+            xgt0, scale_factor, sign_factor = ctx.saved_tensors
+            for _ in range(ctx.channel_dim, x_grad.ndim - 1):
+                scale_factor = scale_factor.unsqueeze(-1)
+                sign_factor = sign_factor.unsqueeze(-1)
+            factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
+        else:
+            xgt0, scale_factor = ctx.saved_tensors
+            for _ in range(ctx.channel_dim, x_grad.ndim - 1):
+                scale_factor = scale_factor.unsqueeze(-1)
+            factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
+        neg_delta_grad = x_grad.abs() * factor
+        return (
+            x_grad - neg_delta_grad,
+            None,
+            None,
+            None,
+        )
+
+
+def _compute_scale_factor(
+    x: Tensor,
+    channel_dim: int,
+    min_abs: float,
+    max_abs: float,
+    gain_factor: float,
+    max_factor: float,
+) -> Tensor:
+    if channel_dim < 0:
+        channel_dim += x.ndim
+    sum_dims = [d for d in range(x.ndim) if d != channel_dim]
+    x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
+
+    if min_abs == 0.0:
+        below_threshold = 0.0
+    else:
+        # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
+        # x_abs)_mean , min_abs.
+        below_threshold = (
+            (min_abs - x_abs_mean) * (gain_factor / min_abs)
+        ).clamp(min=0, max=max_factor)
+
+    above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
+        min=0, max=max_factor
+    )
+
+    return below_threshold - above_threshold
+
+
+def _compute_sign_factor(
+    x: Tensor,
+    channel_dim: int,
+    min_positive: float,
+    max_positive: float,
+    gain_factor: float,
+    max_factor: float,
+) -> Tensor:
+    if channel_dim < 0:
+        channel_dim += x.ndim
+    sum_dims = [d for d in range(x.ndim) if d != channel_dim]
+    proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
+    if min_positive == 0.0:
+        factor1 = 0.0
+    else:
+        # 0 if proportion_positive >= min_positive, else can be
+        # as large as max_factor.
+        factor1 = (
+            (min_positive - proportion_positive) * (gain_factor / min_positive)
+        ).clamp_(min=0, max=max_factor)
+
+    if max_positive == 1.0:
+        factor2 = 0.0
+    else:
+        # 0 if self.proportion_positive <= max_positive, else can be
+        # as large as -max_factor.
+        factor2 = (
+            (proportion_positive - max_positive)
+            * (gain_factor / (1.0 - max_positive))
+        ).clamp_(min=0, max=max_factor)
+    sign_factor = factor1 - factor2
+    # require min_positive != 0 or max_positive != 1:
+    assert not isinstance(sign_factor, float)
+    return sign_factor
+
+
+class ActivationScaleBalancerFunction(torch.autograd.Function):
+    """
+    This object is used in class ActivationBalancer when the user specified
+    min_positive=0, max_positive=1, so there are no constraints on the signs
+    of the activations and only the absolute value has a constraint.
+    """
+
+    @staticmethod
+    def forward(
+        ctx,
+        x: Tensor,
+        sign_factor: Tensor,
+        scale_factor: Tensor,
+        channel_dim: int,
+    ) -> Tensor:
+        if channel_dim < 0:
+            channel_dim += x.ndim
+        ctx.channel_dim = channel_dim
+        xgt0 = x > 0
+        ctx.save_for_backward(xgt0, sign_factor, scale_factor)
+        return x
+
+    @staticmethod
+    def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
+        xgt0, sign_factor, scale_factor = ctx.saved_tensors
+        for _ in range(ctx.channel_dim, x_grad.ndim - 1):
+            sign_factor = sign_factor.unsqueeze(-1)
+            scale_factor = scale_factor.unsqueeze(-1)
+
+        factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
+        neg_delta_grad = x_grad.abs() * factor
+        return (
+            x_grad - neg_delta_grad,
+            None,
+            None,
+            None,
+        )
+
+
+class RandomClampFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(
+        ctx,
+        x: Tensor,
+        min: Optional[float],
+        max: Optional[float],
+        prob: float,
+        reflect: float,
+    ) -> Tensor:
+        x_clamped = torch.clamp(x, min=min, max=max)
+        mask = torch.rand_like(x) < prob
+        ans = torch.where(mask, x_clamped, x)
+        if x.requires_grad:
+            ctx.save_for_backward(ans == x)
+            ctx.reflect = reflect
+        if reflect != 0.0:
+            ans = ans * (1.0 + reflect) - (x * reflect)
+        return ans
+
+    @staticmethod
+    def backward(
+        ctx, ans_grad: Tensor
+    ) -> Tuple[Tensor, None, None, None, None]:
+        (is_same,) = ctx.saved_tensors
+        x_grad = ans_grad * is_same.to(ans_grad.dtype)
+        reflect = ctx.reflect
+        if reflect != 0.0:
+            x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
+        return x_grad, None, None, None, None
+
+
+def random_clamp(
+    x: Tensor,
+    min: Optional[float] = None,
+    max: Optional[float] = None,
+    prob: float = 0.5,
+    reflect: float = 0.0,
+):
+    return RandomClampFunction.apply(x, min, max, prob, reflect)
+
+
+def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
+    """
+    A randomized way of casting a floating point value to half precision.
+    """
+    if x.dtype == torch.float16:
+        return x
+    x_abs = x.abs()
+    is_too_small = x_abs < min_abs
+    # for elements where is_too_small is true, random_val will contain +-min_abs with
+    # probability (x.abs() / min_abs), and 0.0 otherwise.  [so this preserves expectations,
+    # for those elements].
+    random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
+    return torch.where(is_too_small, random_val, x).to(torch.float16)
+
+
+class RandomGradFunction(torch.autograd.Function):
+    """
+    Does nothing in forward pass; in backward pass, gets rid of very small grads using
+    randomized approach that preserves expectations (intended to reduce roundoff).
+    """
+
+    @staticmethod
+    def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
+        ctx.min_abs = min_abs
+        return x
+
+    @staticmethod
+    def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
+        if ans_grad.dtype == torch.float16:
+            return (
+                random_cast_to_half(
+                    ans_grad.to(torch.float32), min_abs=ctx.min_abs
+                ),
+                None,
+            )
+        else:
+            return ans_grad, None
+
+
+class RandomGrad(torch.nn.Module):
+    """
+    Gets rid of very small gradients using an expectation-preserving method, intended to increase
+    accuracy of training when using amp (automatic mixed precision)
+    """
+
+    def __init__(self, min_abs: float = 5.0e-06):
+        super(RandomGrad, self).__init__()
+        self.min_abs = min_abs
+
+    def forward(self, x: Tensor):
+        if (
+            torch.jit.is_scripting()
+            or not self.training
+            or torch.jit.is_tracing()
+        ):
+            return x
+        else:
+            return RandomGradFunction.apply(x, self.min_abs)
+
+
+class SoftmaxFunction(torch.autograd.Function):
+    """
+    Tries to handle half-precision derivatives in a randomized way that should
+    be more accurate for training than the default behavior.
+    """
+
+    @staticmethod
+    def forward(ctx, x: Tensor, dim: int):
+        ans = x.softmax(dim=dim)
+        # if x dtype is float16, x.softmax() returns a float32 because
+        # (presumably) that op does not support float16, and autocast
+        # is enabled.
+        if torch.is_autocast_enabled():
+            ans = ans.to(torch.float16)
+        ctx.save_for_backward(ans)
+        ctx.x_dtype = x.dtype
+        ctx.dim = dim
+        return ans
+
+    @staticmethod
+    def backward(ctx, ans_grad: Tensor):
+        (ans,) = ctx.saved_tensors
+        with torch.cuda.amp.autocast(enabled=False):
+            ans_grad = ans_grad.to(torch.float32)
+            ans = ans.to(torch.float32)
+            x_grad = ans_grad * ans
+            x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
+            return x_grad, None
+
+
+def softmax(x: Tensor, dim: int):
+    if torch.jit.is_scripting() or torch.jit.is_tracing():
+        return x.softmax(dim)
+
+    return SoftmaxFunction.apply(x, dim)
+
+
+class MaxEigLimiterFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(
+        ctx,
+        x: Tensor,
+        coeffs: Tensor,
+        direction: Tensor,
+        channel_dim: int,
+        grad_scale: float,
+    ) -> Tensor:
+        ctx.channel_dim = channel_dim
+        ctx.grad_scale = grad_scale
+        ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
+        return x
+
+    @staticmethod
+    def backward(ctx, x_grad, *args):
+        with torch.enable_grad():
+            (x_orig, coeffs, new_direction) = ctx.saved_tensors
+            x_orig.requires_grad = True
+            num_channels = x_orig.shape[ctx.channel_dim]
+            x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
+            new_direction.requires_grad = False
+            x = x - x.mean(dim=0)
+            x_var = (x ** 2).mean()
+            x_residual = x - coeffs * new_direction
+            x_residual_var = (x_residual ** 2).mean()
+            # `variance_proportion` is the proportion of the variance accounted for
+            # by the top eigen-direction.  This is to be minimized.
+            variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
+            variance_proportion.backward()
+        x_orig_grad = x_orig.grad
+        x_extra_grad = (
+            x_orig.grad
+            * ctx.grad_scale
+            * x_grad.norm()
+            / (x_orig_grad.norm() + 1.0e-20)
+        )
+        return x_grad + x_extra_grad.detach(), None, None, None, None
+
+
+class BasicNorm(torch.nn.Module):
+    """
+    This is intended to be a simpler, and hopefully cheaper, replacement for
+    LayerNorm.  The observation this is based on, is that Transformer-type
+    networks, especially with pre-norm, sometimes seem to set one of the
+    feature dimensions to a large constant value (e.g. 50), which "defeats"
+    the LayerNorm because the output magnitude is then not strongly dependent
+    on the other (useful) features.  Presumably the weight and bias of the
+    LayerNorm are required to allow it to do this.
+
+    So the idea is to introduce this large constant value as an explicit
+    parameter, that takes the role of the "eps" in LayerNorm, so the network
+    doesn't have to do this trick.  We make the "eps" learnable.
+
+    Args:
+       num_channels: the number of channels, e.g. 512.
+      channel_dim: the axis/dimension corresponding to the channel,
+        interprted as an offset from the input's ndim if negative.
+        shis is NOT the num_channels; it should typically be one of
+        {-2, -1, 0, 1, 2, 3}.
+       eps: the initial "epsilon" that we add as ballast in:
+             scale = ((input_vec**2).mean() + epsilon)**-0.5
+          Note: our epsilon is actually large, but we keep the name
+          to indicate the connection with conventional LayerNorm.
+       learn_eps: if true, we learn epsilon; if false, we keep it
+         at the initial value.
+    eps_min: float
+    eps_max: float
+    """
+
+    def __init__(
+        self,
+        num_channels: int,
+        channel_dim: int = -1,  # CAUTION: see documentation.
+        eps: float = 0.25,
+        learn_eps: bool = True,
+        eps_min: float = -3.0,
+        eps_max: float = 3.0,
+    ) -> None:
+        super(BasicNorm, self).__init__()
+        self.num_channels = num_channels
+        self.channel_dim = channel_dim
+        if learn_eps:
+            self.eps = nn.Parameter(torch.tensor(eps).log().detach())
+        else:
+            self.register_buffer("eps", torch.tensor(eps).log().detach())
+        self.eps_min = eps_min
+        self.eps_max = eps_max
+
+    def forward(self, x: Tensor) -> Tensor:
+        assert x.shape[self.channel_dim] == self.num_channels
+        eps = self.eps
+        if self.training and random.random() < 0.25:
+            # with probability 0.25, in training mode, clamp eps between the min
+            # and max; this will encourage it to learn parameters within the
+            # allowed range by making parameters that are outside the allowed
+            # range noisy.
+
+            # gradients to allow the parameter to get back into the allowed
+            # region if it happens to exit it.
+            eps = eps.clamp(min=self.eps_min, max=self.eps_max)
+        scales = (
+            torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
+        ) ** -0.5
+        return x * scales
+
+
+def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
+    """
+    Behaves like a constructor of a modified version of nn.Linear
+    that gives an easy way to set the default initial parameter scale.
+
+    Args:
+        Accepts the standard args and kwargs that nn.Linear accepts
+        e.g. in_features, out_features, bias=False.
+
+        initial_scale: you can override this if you want to increase
+           or decrease the initial magnitude of the module's output
+           (affects the initialization of weight_scale and bias_scale).
+           Another option, if you want to do something like this, is
+           to re-initialize the parameters.
+    """
+    ans = nn.Linear(*args, **kwargs)
+    with torch.no_grad():
+        ans.weight[:] *= initial_scale
+        if ans.bias is not None:
+            torch.nn.init.uniform_(
+                ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
+            )
+    return ans
+
+
+def ScaledConv1d(
+    *args,
+    initial_scale: float = 1.0,
+    kernel_size: int = 3,
+    padding: str = "same",
+    **kwargs,
+) -> nn.Conv1d:
+    """
+    Behaves like a constructor of a modified version of nn.Conv1d
+    that gives an easy way to set the default initial parameter scale.
+
+    Args:
+        Accepts the standard args and kwargs that nn.Linear accepts
+        e.g. in_features, out_features, bias=False.
+
+        initial_scale: you can override this if you want to increase
+           or decrease the initial magnitude of the module's output
+           (affects the initialization of weight_scale and bias_scale).
+           Another option, if you want to do something like this, is
+           to re-initialize the parameters.
+    """
+    ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
+    with torch.no_grad():
+        ans.weight[:] *= initial_scale
+        if ans.bias is not None:
+            torch.nn.init.uniform_(
+                ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
+            )
+    return ans
+
+
+def TransposeScaledConv1d(
+    *args,
+    initial_scale: float = 1.0,
+    kernel_size: int = 3,
+    padding: str = "same",
+    **kwargs,
+) -> nn.Sequential:
+    """
+    Transpose -> ScaledConv1d
+    """
+    return nn.Sequential(
+        Transpose(),
+        ScaledConv1d(
+            *args,
+            initial_scale=initial_scale,
+            kernel_size=kernel_size,
+            padding=padding,
+            **kwargs,
+        ),
+    )
+
+
+def ScaledConv1dTranspose(
+    *args,
+    initial_scale: float = 1.0,
+    kernel_size: int = 3,
+    padding: str = "same",
+    **kwargs,
+) -> nn.Sequential:
+    """
+    Transpose -> ScaledConv1d
+    """
+    return nn.Sequential(
+        ScaledConv1d(
+            *args,
+            initial_scale=initial_scale,
+            kernel_size=kernel_size,
+            padding=padding,
+            **kwargs,
+        ),
+        Transpose(),
+    )
+
+
+def TransposeConv1d(
+    *args, kernel_size: int = 3, padding: str = "same", **kwargs
+) -> nn.Sequential:
+    """
+    Transpose -> Conv1d
+    """
+    return nn.Sequential(
+        Transpose(),
+        nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
+    )
+
+
+def Conv1dTranspose(
+    *args, kernel_size: int = 3, padding: str = "same", **kwargs
+) -> nn.Sequential:
+    """
+    ScaledConv1d -> Transpose
+    """
+    return nn.Sequential(
+        nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
+        Transpose(),
+    )
+
+
+class SRLinear(nn.Linear):
+    """https://arxiv.org/abs/2303.06296
+    Stabilizing Transformer Training by Preventing Attention Entropy Collapse
+    """
+
+    def __init__(self, in_features, out_features, bias=True, **kwargs):
+        super().__init__(in_features, out_features, bias=bias, **kwargs)
+        self.register_buffer(
+            "u", nn.functional.normalize(torch.randn(in_features), dim=0)
+        )
+        with torch.no_grad():
+            sigma = self.get_sigma()
+        self.register_buffer("spectral_norm", sigma)
+        self.sigma = nn.Parameter(torch.ones(1))
+
+    def get_sigma(self):
+        with torch.no_grad():
+            u = self.u
+            v = self.weight.mv(u)
+            v = nn.functional.normalize(v, dim=0)
+            u = self.weight.T.mv(v)
+            u = nn.functional.normalize(u, dim=0)
+            self.u.data.copy_(u)
+        return torch.einsum("c,cd,d->", v, self.weight, u)
+
+    def get_weight(self):
+        sigma = self.get_sigma()
+        if self.training:
+            self.spectral_norm.data.copy_(sigma)
+        weight = (self.sigma / sigma) * self.weight
+        return weight
+
+    def forward(self, x):
+        return nn.functional.linear(x, self.get_weight(), self.bias)
+
+
+class SRConv1d(SRLinear):
+    def __init__(
+        self,
+        in_features,
+        out_features,
+        kernel_size,
+        stride: int = 1,
+        padding: str = "same",
+        bias: bool = True,
+        **kwargs,
+    ):
+        in_features = in_features * kernel_size
+        super().__init__(in_features, out_features, bias=bias, **kwargs)
+        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.padding = padding
+
+    def forward(self, x):
+        in_features = self.in_features // self.kernel_size
+        weight = self.get_weight().view(
+            self.out_features, in_features, self.kernel_size
+        )
+        return nn.functional.conv1d(
+            x, weight, bias=self.bias, stride=self.stride, padding=self.padding
+        )
+
+
+def TransposeSRConv1d(
+    *args, kernel_size: int = 3, padding: str = "same", **kwargs
+) -> nn.Sequential:
+    """
+    Transpose -> SRConv1d
+    """
+    return nn.Sequential(
+        Transpose(),
+        SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
+    )
+
+
+def SRConv1dTranspose(
+    *args, kernel_size: int = 3, padding: str = "same", **kwargs
+) -> nn.Sequential:
+    """
+    SRConv1d -> Transpose
+    """
+    return nn.Sequential(
+        SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
+        Transpose(),
+    )
+
+
+class ActivationBalancer(torch.nn.Module):
+    """
+    Modifies the backpropped derivatives of a function to try to encourage, for
+    each channel, that it is positive at least a proportion `threshold` of the
+    time.  It does this by multiplying negative derivative values by up to
+    (1+max_factor), and positive derivative values by up to (1-max_factor),
+    interpolated from 1 at the threshold to those extremal values when none
+    of the inputs are positive.
+
+    Args:
+           num_channels: the number of channels
+           channel_dim: the dimension/axis corresponding to the channel, e.g.
+               -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
+           min_positive: the minimum, per channel, of the proportion of the time
+               that (x > 0), below which we start to modify the derivatives.
+           max_positive: the maximum, per channel, of the proportion of the time
+               that (x > 0), above which we start to modify the derivatives.
+           max_factor: the maximum factor by which we modify the derivatives for
+              either the sign constraint or the magnitude constraint;
+              e.g. with max_factor=0.02, the the derivatives would be multiplied by
+              values in the range [0.98..1.02].
+           sign_gain_factor: determines the 'gain' with which we increase the
+              change in gradient once the constraints on min_positive and max_positive
+              are violated.
+           scale_gain_factor: determines the 'gain' with which we increase the
+              change in gradient once the constraints on min_abs and max_abs
+              are violated.
+           min_abs:  the minimum average-absolute-value difference from the mean
+               value per channel, which we allow, before we start to modify
+               the derivatives to prevent this.
+           max_abs:  the maximum average-absolute-value difference from the mean
+               value per channel, which we allow, before we start to modify
+               the derivatives to prevent this.
+          min_prob: determines the minimum probability with which we modify the
+             gradients for the {min,max}_positive and {min,max}_abs constraints,
+             on each forward().  This is done randomly to prevent all layers
+             from doing it at the same time.  Early in training we may use
+             higher probabilities than this; it will decay to this value.
+    """
+
+    def __init__(
+        self,
+        num_channels: int,
+        channel_dim: int,
+        min_positive: float = 0.05,
+        max_positive: float = 0.95,
+        max_factor: float = 0.04,
+        sign_gain_factor: float = 0.01,
+        scale_gain_factor: float = 0.02,
+        min_abs: float = 0.2,
+        max_abs: float = 100.0,
+        min_prob: float = 0.1,
+    ):
+        super(ActivationBalancer, self).__init__()
+        self.num_channels = num_channels
+        self.channel_dim = channel_dim
+        self.min_positive = min_positive
+        self.max_positive = max_positive
+        self.max_factor = max_factor
+        self.min_abs = min_abs
+        self.max_abs = max_abs
+        self.min_prob = min_prob
+        self.sign_gain_factor = sign_gain_factor
+        self.scale_gain_factor = scale_gain_factor
+
+        # count measures how many times the forward() function has been called.
+        # We occasionally sync this to a tensor called `count`, that exists to
+        # make sure it is synced to disk when we load and save the model.
+        self.cpu_count = 0
+        self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
+
+    def forward(self, x: Tensor) -> Tensor:
+        if (
+            torch.jit.is_scripting()
+            or not x.requires_grad
+            or torch.jit.is_tracing()
+        ):
+            return _no_op(x)
+
+        count = self.cpu_count
+        self.cpu_count += 1
+
+        if random.random() < 0.01:
+            # Occasionally sync self.cpu_count with self.count.
+            # count affects the decay of 'prob'.  don't do this on every iter,
+            # because syncing with the GPU is slow.
+            self.cpu_count = max(self.cpu_count, self.count.item())
+            self.count.fill_(self.cpu_count)
+
+        # the prob of doing some work exponentially decreases from 0.5 till it hits
+        # a floor at min_prob (==0.1, by default)
+        prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
+
+        if random.random() < prob:
+            sign_gain_factor = 0.5
+            if self.min_positive != 0.0 or self.max_positive != 1.0:
+                sign_factor = _compute_sign_factor(
+                    x,
+                    self.channel_dim,
+                    self.min_positive,
+                    self.max_positive,
+                    gain_factor=self.sign_gain_factor / prob,
+                    max_factor=self.max_factor,
+                )
+            else:
+                sign_factor = None
+
+            scale_factor = _compute_scale_factor(
+                x.detach(),
+                self.channel_dim,
+                min_abs=self.min_abs,
+                max_abs=self.max_abs,
+                gain_factor=self.scale_gain_factor / prob,
+                max_factor=self.max_factor,
+            )
+            return ActivationBalancerFunction.apply(
+                x,
+                scale_factor,
+                sign_factor,
+                self.channel_dim,
+            )
+        else:
+            return _no_op(x)
+
+
+def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
+    """
+    Returns x unmodified, but in backprop will put a penalty for the excess of
+    the absolute values of elements of x over the limit "limit".  E.g. if
+    limit == 10.0, then if x has any values over 10 it will get a penalty.
+
+    Caution: the value of this penalty will be affected by grad scaling used
+    in automatic mixed precision training.  For this reasons we use this,
+    it shouldn't really matter, or may even be helpful; we just use this
+    to disallow really implausible values of scores to be given to softmax.
+    """
+    x_sign = x.sign()
+    over_limit = (x.abs() - limit) > 0
+    # The following is a memory efficient way to penalize the absolute values of
+    # x that's over the limit.  (The memory efficiency comes when you think
+    # about which items torch needs to cache for the autograd, and which ones it
+    # can throw away).  The numerical value of aux_loss as computed here will
+    # actually be larger than it should be, by limit * over_limit.sum(), but it
+    # has the same derivative as the real aux_loss which is penalty * (x.abs() -
+    # limit).relu().
+    aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
+    # note: we don't do sum() here on aux)_loss, but it's as if we had done
+    # sum() due to how with_loss() works.
+    x = with_loss(x, aux_loss)
+    # you must use x for something, or this will be ineffective.
+    return x
+
+
+def _diag(x: Tensor):  # like .diag(), but works for tensors with 3 dims.
+    if x.ndim == 2:
+        return x.diag()
+    else:
+        (batch, dim, dim) = x.shape
+        x = x.reshape(batch, dim * dim)
+        x = x[:, :: dim + 1]
+        assert x.shape == (batch, dim)
+        return x
+
+
+def _whitening_metric(x: Tensor, num_groups: int):
+    """
+    Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
+    of the centered feature covariance are the same within each group's covariance matrix
+    and also between groups.
+    Args:
+        x: a Tensor of shape (*, num_channels)
+     num_groups:  the number of groups of channels, a number >=1 that divides num_channels
+    Returns:
+        Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
+    greater than 1.0 otherwise.
+    """
+    assert x.dtype != torch.float16
+    x = x.reshape(-1, x.shape[-1])
+    (num_frames, num_channels) = x.shape
+    assert num_channels % num_groups == 0
+    channels_per_group = num_channels // num_groups
+    x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
+    # x now has shape (num_groups, num_frames, channels_per_group)
+    # subtract the mean so we use the centered, not uncentered, covariance.
+    # My experience has been that when we "mess with the gradients" like this,
+    # it's better not do anything that tries to move the mean around, because
+    # that can easily cause instability.
+    x = x - x.mean(dim=1, keepdim=True)
+    # x_covar: (num_groups, channels_per_group, channels_per_group)
+    x_covar = torch.matmul(x.transpose(1, 2), x)
+    x_covar_mean_diag = _diag(x_covar).mean()
+    # the following expression is what we'd get if we took the matrix product
+    # of each covariance and measured the mean of its trace, i.e.
+    # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
+    x_covarsq_mean_diag = (x_covar ** 2).sum() / (
+        num_groups * channels_per_group
+    )
+    # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
+    metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
+    return metric
+
+
+class WhiteningPenaltyFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(
+        ctx,
+        x: Tensor,
+        num_groups: int,
+        whitening_limit: float,
+        grad_scale: float,
+    ) -> Tensor:
+        ctx.save_for_backward(x)
+        ctx.num_groups = num_groups
+        ctx.whitening_limit = whitening_limit
+        ctx.grad_scale = grad_scale
+        return x
+
+    @staticmethod
+    def backward(ctx, x_grad: Tensor):
+        (x_orig,) = ctx.saved_tensors
+        with torch.enable_grad():
+            with torch.cuda.amp.autocast(enabled=False):
+                x_detached = x_orig.to(torch.float32).detach()
+                x_detached.requires_grad = True
+
+                metric = _whitening_metric(x_detached, ctx.num_groups)
+
+                if random.random() < 0.005 or __name__ == "__main__":
+                    logging.info(
+                        f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
+                        f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
+                    )
+
+                (metric - ctx.whitening_limit).relu().backward()
+                penalty_grad = x_detached.grad
+                scale = ctx.grad_scale * (
+                    x_grad.to(torch.float32).norm()
+                    / (penalty_grad.norm() + 1.0e-20)
+                )
+                penalty_grad = penalty_grad * scale
+        return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
+
+
+class Whiten(nn.Module):
+    def __init__(
+        self,
+        num_groups: int,
+        whitening_limit: float,
+        prob: Union[float, Tuple[float, float]],
+        grad_scale: float,
+    ):
+        """
+        Args:
+          num_groups: the number of groups to divide the channel dim into before
+            whitening.  We will attempt to make the feature covariance
+            within each group, after mean subtraction, as "white" as possible,
+            while having the same trace across all groups.
+         whitening_limit: a value greater than 1.0, that dictates how much
+           freedom we have to violate the constraints.  1.0 would mean perfectly
+           white, with exactly the same trace across groups; larger values
+           give more freedom.  E.g. 2.0.
+         prob: the probability with which we apply the gradient modification
+           (also affects the grad scale).  May be supplied as a float,
+           or as a pair (min_prob, max_prob)
+
+          grad_scale: determines the scale on the gradient term from this object,
+            relative to the rest of the gradient on the attention weights.
+            E.g. 0.02 (you may want to use smaller values than this if prob is large)
+        """
+        super(Whiten, self).__init__()
+        assert num_groups >= 1
+        assert whitening_limit >= 1
+        assert grad_scale >= 0
+        self.num_groups = num_groups
+        self.whitening_limit = whitening_limit
+        if isinstance(prob, float):
+            assert 0 < prob <= 1
+            self.prob = prob
+        else:
+            (self.min_prob, self.max_prob) = prob
+            assert 0 < self.min_prob < self.max_prob <= 1
+            self.prob = self.max_prob
+
+        self.grad_scale = grad_scale
+
+    def forward(self, x: Tensor) -> Tensor:
+        """
+        In the forward pass, this function just returns the input unmodified.
+        In the backward pass, it will modify the gradients to ensure that the
+        distribution in each group has close to (lambda times I) as the covariance
+        after mean subtraction, with the same lambda across groups.
+        For whitening_limit > 1, there will be more freedom to violate this
+        constraint.
+
+        Args:
+           x: the input of shape (*, num_channels)
+
+        Returns:
+            x, unmodified.   You should make sure
+        you use the returned value, or the graph will be freed
+        and nothing will happen in backprop.
+        """
+        if (
+            not x.requires_grad
+            or random.random() > self.prob
+            or self.grad_scale == 0
+        ):
+            return _no_op(x)
+        else:
+            if hasattr(self, "min_prob") and random.random() < 0.25:
+                # occasionally switch between min_prob and max_prob, based on whether
+                # we are above or below the threshold.
+                if (
+                    _whitening_metric(x.to(torch.float32), self.num_groups)
+                    > self.whitening_limit
+                ):
+                    # there would be a change to the grad.
+                    self.prob = self.max_prob
+                else:
+                    self.prob = self.min_prob
+
+            return WhiteningPenaltyFunction.apply(
+                x, self.num_groups, self.whitening_limit, self.grad_scale
+            )
+
+
+class WithLoss(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x: Tensor, y: Tensor):
+        ctx.y_shape = y.shape
+        return x
+
+    @staticmethod
+    def backward(ctx, ans_grad: Tensor):
+        return ans_grad, torch.ones(
+            ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
+        )
+
+
+def with_loss(x, y):
+    if torch.jit.is_scripting() or torch.jit.is_tracing():
+        return x
+    # returns x but adds y.sum() to the loss function.
+    return WithLoss.apply(x, y)
+
+
+def _no_op(x: Tensor) -> Tensor:
+    if torch.jit.is_scripting() or torch.jit.is_tracing():
+        return x
+    else:
+        # a no-op function that will have a node in the autograd graph,
+        # to avoid certain bugs relating to backward hooks
+        return x.chunk(1, dim=-1)[0]
+
+
+class Identity(torch.nn.Module):
+    def __init__(self):
+        super(Identity, self).__init__()
+
+    def forward(self, x):
+        return _no_op(x)
+
+
+class MaxEig(torch.nn.Module):
+    """
+    Modifies the backpropped derivatives of a function to try to discourage
+    that any given direction in activation space accounts for more than
+    a specified proportion of the covariance (e.g. 0.2).
+
+
+    Args:
+           num_channels: the number of channels
+           channel_dim: the dimension/axis corresponding to the channel, e.g.
+               -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
+           max_var_per_eig:  the maximum proportion of the variance of the
+               features/channels, after mean subtraction, that can come from
+               any given eigenvalue.
+           min_prob: the minimum probability with which we apply this during any invocation
+               of forward(), assuming last time we applied the constraint it was
+               not active; supplied for speed.
+           scale: determines the scale with which we modify the gradients, relative
+               to the existing / unmodified gradients
+    """
+
+    def __init__(
+        self,
+        num_channels: int,
+        channel_dim: int,
+        max_var_per_eig: float = 0.2,
+        min_prob: float = 0.01,
+        scale: float = 0.01,
+    ):
+        super(MaxEig, self).__init__()
+        self.num_channels = num_channels
+        self.channel_dim = channel_dim
+        self.scale = scale
+        assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
+        self.max_var_per_eig = max_var_per_eig
+
+        # we figure out the dominant direction using the power method: starting with
+        # a random vector, keep multiplying by the covariance and renormalizing.
+        with torch.no_grad():
+            # arbitrary.. would use randn() but want to leave the rest of the model's
+            # random parameters unchanged for comparison
+            direction = torch.arange(num_channels).to(torch.float)
+            direction = direction / direction.norm()
+            self.register_buffer("max_eig_direction", direction)
+
+        self.min_prob = min_prob
+        # cur_prob is the current probability we'll use to apply the ActivationBalancer.
+        # We'll regress this towards prob, each time we try to apply it and it is not
+        # active.
+        self.cur_prob = 1.0
+
+    def forward(self, x: Tensor) -> Tensor:
+        if (
+            torch.jit.is_scripting()
+            or self.max_var_per_eig <= 0
+            or random.random() > self.cur_prob
+            or torch.jit.is_tracing()
+        ):
+            return _no_op(x)
+
+        with torch.cuda.amp.autocast(enabled=False):
+            eps = 1.0e-20
+            orig_x = x
+            x = x.to(torch.float32)
+            with torch.no_grad():
+                x = x.transpose(self.channel_dim, -1).reshape(
+                    -1, self.num_channels
+                )
+                x = x - x.mean(dim=0)
+                new_direction, coeffs = self._find_direction_coeffs(
+                    x, self.max_eig_direction
+                )
+                x_var = (x ** 2).mean()
+                x_residual = x - coeffs * new_direction
+                x_residual_var = (x_residual ** 2).mean()
+
+                # `variance_proportion` is the proportion of the variance accounted for
+                # by the top eigen-direction.
+                variance_proportion = (x_var - x_residual_var) / (
+                    x_var + 1.0e-20
+                )
+
+                # ensure new direction is nonzero even if x == 0, by including `direction`.
+                self._set_direction(
+                    0.1 * self.max_eig_direction + new_direction
+                )
+
+            if random.random() < 0.01 or __name__ == "__main__":
+                logging.info(
+                    f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
+                )
+
+            if variance_proportion >= self.max_var_per_eig:
+                # The constraint is active.  Note, we should quite rarely
+                # reach here, only near the beginning of training if we are
+                # starting to diverge, should this constraint be active.
+                cur_prob = self.cur_prob
+                self.cur_prob = (
+                    1.0  # next time, do the update with probability 1.0.
+                )
+                return MaxEigLimiterFunction.apply(
+                    orig_x, coeffs, new_direction, self.channel_dim, self.scale
+                )
+            else:
+                # let self.cur_prob exponentially approach self.min_prob, as
+                # long as the constraint is inactive.
+                self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
+                return orig_x
+
+    def _set_direction(self, direction: Tensor):
+        """
+        Sets self.max_eig_direction to a normalized version of `direction`
+        """
+        direction = direction.detach()
+        direction = direction / direction.norm()
+        direction_sum = direction.sum().item()
+        if direction_sum - direction_sum == 0:  # no inf/nan
+            self.max_eig_direction[:] = direction
+        else:
+            logging.info(
+                f"Warning: sum of direction in MaxEig is {direction_sum}, "
+                "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
+            )
+
+    def _find_direction_coeffs(
+        self, x: Tensor, prev_direction: Tensor
+    ) -> Tuple[Tensor, Tensor, Tensor]:
+        """
+            Figure out (an approximation to) the proportion of the variance of a set of
+            feature vectors that can be attributed to the top eigen-direction.
+            Args:
+             x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
+          prev_direction:  a Tensor of shape (num_channels,), that is our previous estimate
+                   of the top eigen-direction, or a random direction if this is the first
+                   iteration.  Does not have to be normalized, but should be nonzero.
+
+        Returns: (cur_direction, coeffs), where:
+             cur_direction: a Tensor of shape (num_channels,) that is the current
+                estimate of the top eigen-direction.
+             coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
+                approximately minimizes, (x - coeffs * cur_direction).norm()
+        """
+        (num_frames, num_channels) = x.shape
+        assert num_channels > 1 and num_frames > 1
+        assert prev_direction.shape == (num_channels,)
+        # `coeffs` are the coefficients of `prev_direction` in x.
+        # actually represent the coeffs up to a constant positive factor.
+        coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
+        cur_direction = (x * coeffs).sum(dim=0) / (
+            (coeffs ** 2).sum() + 1.0e-20
+        )
+        return cur_direction, coeffs
+
+
+class DoubleSwishFunction(torch.autograd.Function):
+    """
+      double_swish(x) = x * torch.sigmoid(x-1)
+    This is a definition, originally motivated by its close numerical
+    similarity to swish(swish(x)), where swish(x) =  x * sigmoid(x).
+
+    Memory-efficient derivative computation:
+     double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
+     double_swish'(x) = d/dx double_swish(x) =  x * s'(x) + x' * s(x) = x * s'(x) + s(x).
+     Now, s'(x) = s(x) * (1-s(x)).
+     double_swish'(x) =  x * s'(x) + s(x).
+                      =  x * s(x) * (1-s(x)) + s(x).
+                     = double_swish(x) * (1-s(x)) + s(x)
+     ... so we just need to remember s(x) but not x itself.
+    """
+
+    @staticmethod
+    def forward(ctx, x: Tensor) -> Tensor:
+        requires_grad = x.requires_grad
+        x_dtype = x.dtype
+        if x.dtype == torch.float16:
+            x = x.to(torch.float32)
+
+        s = torch.sigmoid(x - 1.0)
+        y = x * s
+
+        if requires_grad:
+            deriv = y * (1 - s) + s
+            # notes on derivative of x * sigmoid(x - 1):
+            # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
+            # min \simeq -0.043638.  Take floor as -0.043637 so it's a lower bund
+            # max \simeq 1.1990.   Take ceil to be 1.2 so it's an upper bound.
+            # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
+            # floors), should be expectation-preserving.
+            floor = -0.043637
+            ceil = 1.2
+            d_scaled = (deriv - floor) * (
+                255.0 / (ceil - floor)
+            ) + torch.rand_like(deriv)
+            if __name__ == "__main__":
+                # for self-testing only.
+                assert d_scaled.min() >= 0.0
+                assert d_scaled.max() < 256.0
+            d_int = d_scaled.to(torch.uint8)
+            ctx.save_for_backward(d_int)
+        if x.dtype == torch.float16 or torch.is_autocast_enabled():
+            y = y.to(torch.float16)
+        return y
+
+    @staticmethod
+    def backward(ctx, y_grad: Tensor) -> Tensor:
+        (d,) = ctx.saved_tensors
+        # the same constants as used in forward pass.
+        floor = -0.043637
+        ceil = 1.2
+        d = d * ((ceil - floor) / 255.0) + floor
+        return y_grad * d
+
+
+class DoubleSwish(torch.nn.Module):
+    def forward(self, x: Tensor) -> Tensor:
+        """Return double-swish activation function which is an approximation to Swish(Swish(x)),
+        that we approximate closely with x * sigmoid(x-1).
+        """
+        if torch.jit.is_scripting() or torch.jit.is_tracing():
+            return x * torch.sigmoid(x - 1.0)
+        return DoubleSwishFunction.apply(x)
+
+
+def BalancedDoubleSwish(
+    d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
+) -> nn.Sequential:
+    """
+    ActivationBalancer -> DoubleSwish
+    """
+    balancer = ActivationBalancer(
+        d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
+    )
+    return nn.Sequential(
+        balancer,
+        DoubleSwish(),
+    )
+
+
+def _test_max_eig():
+    for proportion in [0.1, 0.5, 10.0]:
+        logging.info(f"proportion = {proportion}")
+        x = torch.randn(100, 128)
+        direction = torch.randn(128)
+        coeffs = torch.randn(100, 1)
+        x += proportion * direction * coeffs
+
+        x.requires_grad = True
+
+        num_channels = 128
+        m = MaxEig(
+            num_channels, 1, 0.5, scale=0.1  # channel_dim  # max_var_per_eig
+        )  # grad_scale
+
+        for _ in range(4):
+            y = m(x)
+
+        y_grad = torch.randn_like(x)
+        y.backward(gradient=y_grad)
+
+        if proportion < 0.2:
+            assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
+        elif proportion > 1.0:
+            assert not torch.allclose(x.grad, y_grad)
+
+
+def _test_whiten():
+    for proportion in [0.1, 0.5, 10.0]:
+        logging.info(f"_test_whiten(): proportion = {proportion}")
+        x = torch.randn(100, 128)
+        direction = torch.randn(128)
+        coeffs = torch.randn(100, 1)
+        x += proportion * direction * coeffs
+
+        x.requires_grad = True
+
+        num_channels = 128
+        m = Whiten(
+            1, 5.0, prob=1.0, grad_scale=0.1  # num_groups  # whitening_limit,
+        )  # grad_scale
+
+        for _ in range(4):
+            y = m(x)
+
+        y_grad = torch.randn_like(x)
+        y.backward(gradient=y_grad)
+
+        if proportion < 0.2:
+            assert torch.allclose(x.grad, y_grad)
+        elif proportion > 1.0:
+            assert not torch.allclose(x.grad, y_grad)
+
+
+def _test_activation_balancer_sign():
+    probs = torch.arange(0, 1, 0.01)
+    N = 1000
+    x = 1.0 * (
+        (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
+    )
+    x = x.detach()
+    x.requires_grad = True
+    m = ActivationBalancer(
+        probs.numel(),
+        channel_dim=0,
+        min_positive=0.05,
+        max_positive=0.95,
+        max_factor=0.2,
+        min_abs=0.0,
+    )
+
+    y_grad = torch.sign(torch.randn(probs.numel(), N))
+
+    y = m(x)
+    y.backward(gradient=y_grad)
+    print("_test_activation_balancer_sign: x = ", x)
+    print("_test_activation_balancer_sign: y grad = ", y_grad)
+    print("_test_activation_balancer_sign: x grad = ", x.grad)
+
+
+def _test_activation_balancer_magnitude():
+    magnitudes = torch.arange(0, 1, 0.01)
+    N = 1000
+    x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
+        -1
+    )
+    x = x.detach()
+    x.requires_grad = True
+    m = ActivationBalancer(
+        magnitudes.numel(),
+        channel_dim=0,
+        min_positive=0.0,
+        max_positive=1.0,
+        max_factor=0.2,
+        min_abs=0.2,
+        max_abs=0.8,
+        min_prob=1.0,
+    )
+
+    y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
+
+    y = m(x)
+    y.backward(gradient=y_grad)
+    print("_test_activation_balancer_magnitude: x = ", x)
+    print("_test_activation_balancer_magnitude: y grad = ", y_grad)
+    print("_test_activation_balancer_magnitude: x grad = ", x.grad)
+
+
+def _test_basic_norm():
+    num_channels = 128
+    m = BasicNorm(num_channels=num_channels, channel_dim=1)
+
+    x = torch.randn(500, num_channels)
+
+    y = m(x)
+
+    assert y.shape == x.shape
+    x_rms = (x ** 2).mean().sqrt()
+    y_rms = (y ** 2).mean().sqrt()
+    print("x rms = ", x_rms)
+    print("y rms = ", y_rms)
+    assert y_rms < x_rms
+    assert y_rms > 0.5 * x_rms
+
+
+def _test_double_swish_deriv():
+    x = torch.randn(10, 12, dtype=torch.double) * 3.0
+    x.requires_grad = True
+    m = DoubleSwish()
+
+    tol = (1.2 - (-0.043637)) / 255.0
+    torch.autograd.gradcheck(m, x, atol=tol)
+
+    # for self-test.
+    x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+    x.requires_grad = True
+    y = m(x)
+
+
+def _test_softmax():
+    a = torch.randn(2, 10, dtype=torch.float64)
+    b = a.clone()
+    a.requires_grad = True
+    b.requires_grad = True
+    a.softmax(dim=1)[:, 0].sum().backward()
+    print("a grad = ", a.grad)
+    softmax(b, dim=1)[:, 0].sum().backward()
+    print("b grad = ", b.grad)
+    assert torch.allclose(a.grad, b.grad)
+
+
+if __name__ == "__main__":
+    logging.getLogger().setLevel(logging.INFO)
+    torch.set_num_threads(1)
+    torch.set_num_interop_threads(1)
+    _test_softmax()
+    _test_whiten()
+    _test_max_eig()
+    _test_activation_balancer_sign()
+    _test_activation_balancer_magnitude()
+    _test_basic_norm()
+    _test_double_swish_deriv()
diff --git a/slam_llm/models/vallex/transformers.py b/slam_llm/models/vallex/transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..183b65ca3190146fb195a8c590a76db91ad49ea2
--- /dev/null
+++ b/slam_llm/models/vallex/transformers.py
@@ -0,0 +1,613 @@
+import copy
+import numbers
+from functools import partial
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor, nn
+from torch.nn import functional as F
+
+from .activation import MultiheadAttention
+from .scaling import BasicNorm as _BasicNorm
+
+_shape_t = Union[int, List[int], torch.Size]
+
+
+class LayerNorm(nn.Module):
+    __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
+    normalized_shape: Tuple[int, ...]
+    eps: float
+    elementwise_affine: bool
+
+    def __init__(
+        self,
+        normalized_shape: _shape_t,
+        eps: float = 1e-5,
+        elementwise_affine: bool = True,
+        device=None,
+        dtype=None,
+    ) -> None:
+        factory_kwargs = {"device": device, "dtype": dtype}
+        super(LayerNorm, self).__init__()
+        if isinstance(normalized_shape, numbers.Integral):
+            # mypy error: incompatible types in assignment
+            normalized_shape = (normalized_shape,)  # type: ignore[assignment]
+        self.normalized_shape = tuple(normalized_shape)  # type: ignore[arg-type]
+        self.eps = eps
+        self.elementwise_affine = elementwise_affine
+        if self.elementwise_affine:
+            self.weight = nn.Parameter(
+                torch.empty(self.normalized_shape, **factory_kwargs)
+            )
+            self.bias = nn.Parameter(
+                torch.empty(self.normalized_shape, **factory_kwargs)
+            )
+        else:
+            self.register_parameter("weight", None)
+            self.register_parameter("bias", None)
+
+        self.reset_parameters()
+
+    def reset_parameters(self) -> None:
+        if self.elementwise_affine:
+            nn.init.ones_(self.weight)
+            nn.init.zeros_(self.bias)
+
+    def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
+        if isinstance(input, tuple):
+            input, embedding = input
+            return (
+                F.layer_norm(
+                    input,
+                    self.normalized_shape,
+                    self.weight,
+                    self.bias,
+                    self.eps,
+                ),
+                embedding,
+            )
+
+        assert embedding is None
+        return F.layer_norm(
+            input, self.normalized_shape, self.weight, self.bias, self.eps
+        )
+
+    def extra_repr(self) -> str:
+        return (
+            "{normalized_shape}, eps={eps}, "
+            "elementwise_affine={elementwise_affine}".format(**self.__dict__)
+        )
+
+
+class AdaptiveLayerNorm(nn.Module):
+    r"""Adaptive Layer Normalization"""
+
+    def __init__(self, d_model, norm) -> None:
+        super(AdaptiveLayerNorm, self).__init__()
+        self.project_layer = nn.Linear(d_model, 2 * d_model)
+        self.norm = norm
+        self.d_model = d_model
+        self.eps = self.norm.eps
+
+    def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
+        if isinstance(input, tuple):
+            input, embedding = input
+            weight, bias = torch.split(
+                self.project_layer(embedding),
+                split_size_or_sections=self.d_model,
+                dim=-1,
+            )
+            return (weight * self.norm(input) + bias, embedding)
+
+        weight, bias = torch.split(
+            self.project_layer(embedding),
+            split_size_or_sections=self.d_model,
+            dim=-1,
+        )
+        return weight * self.norm(input) + bias
+
+
+class BasicNorm(_BasicNorm):
+    def __init__(
+        self,
+        d_model: int,
+        eps: float = 1e-5,
+        device=None,
+        dtype=None,
+    ):
+        super(BasicNorm, self).__init__(d_model, eps=eps)
+
+    def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
+        if isinstance(input, tuple):
+            input, embedding = input
+            return (
+                super(BasicNorm, self).forward(input),
+                embedding,
+            )
+
+        assert embedding is None
+        return super(BasicNorm, self).forward(input)
+
+
+class BalancedBasicNorm(nn.Module):
+    def __init__(
+        self,
+        d_model: int,
+        eps: float = 1e-5,
+        device=None,
+        dtype=None,
+    ):
+        super(BalancedBasicNorm, self).__init__()
+        self.balancer = ActivationBalancer(
+            d_model,
+            channel_dim=-1,
+            min_positive=0.45,
+            max_positive=0.55,
+            max_abs=6.0,
+        )
+        self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
+
+    def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
+        if isinstance(input, tuple):
+            input, embedding = input
+            return self.norm((self.balancer(input), embedding))
+
+        assert embedding is None
+        return self.norm(self.balancer(input))
+
+
+class IdentityNorm(nn.Module):
+    def __init__(
+        self,
+        d_model: int,
+        eps: float = 1e-5,
+        device=None,
+        dtype=None,
+    ) -> None:
+        super(IdentityNorm, self).__init__()
+
+    def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
+        if isinstance(input, tuple):
+            return input
+
+        assert embedding is None
+        return input
+
+
+class TransformerEncoderLayer(nn.Module):
+    __constants__ = ["batch_first", "norm_first"]
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+        batch_first: bool = False,
+        norm_first: bool = False,
+        device=None,
+        dtype=None,
+        linear1_self_attention_cls: nn.Module = nn.Linear,
+        linear2_self_attention_cls: nn.Module = nn.Linear,
+        linear1_feedforward_cls: nn.Module = nn.Linear,
+        linear2_feedforward_cls: nn.Module = nn.Linear,
+        layer_norm_cls: nn.Module = LayerNorm,
+        layer_norm_eps: float = 1e-5,
+        adaptive_layer_norm=False,
+    ) -> None:
+        factory_kwargs = {"device": device, "dtype": dtype}
+        super(TransformerEncoderLayer, self).__init__()
+        self.self_attn = MultiheadAttention(
+            d_model,
+            nhead,
+            dropout=dropout,
+            batch_first=batch_first,
+            linear1_cls=linear1_self_attention_cls,
+            linear2_cls=linear2_self_attention_cls,
+            **factory_kwargs,
+        )
+
+        # Implementation of Feedforward model
+        
+        self.dropout = nn.Dropout(dropout)
+
+        self.norm_first = norm_first
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+
+        # Legacy string support for activation function.
+        if isinstance(activation, str):
+            activation = _get_activation_fn(activation)
+        elif isinstance(activation, partial):
+            activation = activation(d_model)
+
+        self.activation = activation
+
+        norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
+        if layer_norm_cls == IdentityNorm:
+            norm2 = BalancedBasicNorm(
+                d_model, eps=layer_norm_eps, **factory_kwargs
+            )
+        else:
+            norm2 = layer_norm_cls(
+                d_model, eps=layer_norm_eps, **factory_kwargs
+            )
+
+        self.norm1 = norm1
+        self.linear1 = linear1_feedforward_cls(
+            d_model, dim_feedforward, **factory_kwargs
+        )
+        self.linear2 = linear2_feedforward_cls(
+            dim_feedforward, d_model, **factory_kwargs
+        )
+        self.norm2 = norm2
+        
+
+    def __setstate__(self, state):
+        super(TransformerEncoderLayer, self).__setstate__(state)
+        if not hasattr(self, "activation"):
+            self.activation = F.relu
+
+    def forward(
+        self,
+        src: Tensor,
+        src_mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        r"""Pass the input through the encoder layer.
+
+        Args:
+            src: the sequence to the encoder layer (required).
+            src_mask: the mask for the src sequence (optional).
+            src_key_padding_mask: the mask for the src keys per batch (optional).
+
+        Shape:
+            see the docs in Transformer class.
+        """
+        x, stage_embedding = src, None
+        is_src_tuple = False
+        if isinstance(src, tuple):
+            x, stage_embedding = src
+            is_src_tuple = True
+
+        if src_key_padding_mask is not None:
+            _skpm_dtype = src_key_padding_mask.dtype
+            if _skpm_dtype != torch.bool and not torch.is_floating_point(
+                src_key_padding_mask
+            ):
+                raise AssertionError(
+                    "only bool and floating types of key_padding_mask are supported"
+                )
+
+        if self.norm_first:
+            x = x + self._sa_block(
+                self.norm1(x, stage_embedding),
+                src_mask,
+                src_key_padding_mask,
+            )
+            
+            x = x + self._ff_block(self.norm2(x, stage_embedding))
+        else:
+            x = self.norm1(
+                x + self._sa_block(x, src_mask, src_key_padding_mask),
+                stage_embedding,
+            )
+            x = self.norm2(x + self._ff_block(x), stage_embedding)
+
+        if is_src_tuple:
+            return (x, stage_embedding)
+        return x
+
+    def infer(
+        self,
+        src: Tensor,
+        src_mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+        past_kv: Optional[Tensor] = None,
+        use_cache: bool = False,
+    ):
+        x, stage_embedding = src, None
+        is_src_tuple = False
+        if isinstance(src, tuple):
+            x, stage_embedding = src
+            is_src_tuple = True
+
+        if src_key_padding_mask is not None:
+            _skpm_dtype = src_key_padding_mask.dtype
+            if _skpm_dtype != torch.bool and not torch.is_floating_point(
+                src_key_padding_mask
+            ):
+                raise AssertionError(
+                    "only bool and floating types of key_padding_mask are supported"
+                )
+
+        if self.norm_first:
+            x_attn_out, kv = self.self_attn.infer(
+                self.norm1(x, stage_embedding),
+                attn_mask=src_mask,
+                key_padding_mask=src_key_padding_mask,
+                need_weights=False,
+                past_kv=past_kv,
+                use_cache=use_cache,
+            )
+            x = x + x_attn_out
+            x = x + self._ff_block(self.norm2(x, stage_embedding))
+
+        if is_src_tuple:
+            return (x, stage_embedding)
+        return (x, kv)
+
+    # self-attention block
+    def _sa_block(
+        self,
+        x: Tensor,
+        attn_mask: Optional[Tensor],
+        key_padding_mask: Optional[Tensor],
+    ) -> Tensor:
+        x = self.self_attn(
+            x,
+            x,
+            x,
+            attn_mask=attn_mask,
+            key_padding_mask=key_padding_mask,
+            need_weights=False,
+        )[0]
+        return self.dropout1(x)
+
+    # feed forward block
+    def _ff_block(self, x: Tensor) -> Tensor:
+        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+        return self.dropout2(x)
+
+
+class TransformerEncoder(nn.Module):
+    __constants__ = ["norm"]
+
+    def __init__(self, encoder_layer, num_layers, norm=None):
+        super(TransformerEncoder, self).__init__()
+        self.layers = _get_clones(encoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+
+    def forward(
+        self,
+        src: Tensor,
+        mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+        return_layer_states: bool = False,
+    ) -> Tensor:
+        output = src
+        for i, mod in enumerate(self.layers):
+            output = mod(
+                output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
+            )
+            # print(i, output.mean())
+        if self.norm is not None:
+            output = self.norm(output)
+
+        return output
+
+    def infer(
+        self,
+        src: Tensor,
+        mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+        return_layer_states: bool = False,
+        past_kv: Optional[Tensor] = None,
+        use_cache: bool = False,
+    ):
+        if past_kv is None:
+            past_length = 0
+            past_kv = tuple([None] * self.num_layers)
+        else:
+            past_length = past_kv[0][0].size(-2)
+        new_kv = () if use_cache else None
+        output = src
+        for i, (mod, past_layer_kv) in enumerate(zip(self.layers, past_kv)):
+            output, kv = mod.infer(
+                output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache
+            )
+            # print(i, output.mean())
+            if use_cache:
+                new_kv = new_kv + (kv,)
+
+        if self.norm is not None:
+            output = self.norm(output)
+
+        return output, new_kv
+
+
+class TransformerDecoderLayer(nn.Module):
+    __constants__ = ["batch_first", "norm_first"]
+
+    def __init__(
+        self,
+        d_model: int,
+        nhead: int,
+        dim_feedforward: int = 2048,
+        dropout: float = 0.1,
+        activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+        linear1_self_attention_cls: nn.Module = nn.Linear,
+        linear2_self_attention_cls: nn.Module = nn.Linear,
+        linear1_feedforward_cls: nn.Module = nn.Linear,
+        linear2_feedforward_cls: nn.Module = nn.Linear,
+        batch_first: bool = False,
+        norm_first: bool = False,
+        device=None,
+        dtype=None,
+        layer_norm_cls: nn.Module = LayerNorm,
+        layer_norm_eps: float = 1e-5,
+        adaptive_layer_norm=False,
+    ) -> None:
+        factory_kwargs = {"device": device, "dtype": dtype}
+        super(TransformerDecoderLayer, self).__init__()
+        self.self_attn = MultiheadAttention(
+            d_model,
+            nhead,
+            dropout=dropout,
+            batch_first=batch_first,
+            linear1_cls=linear1_self_attention_cls,
+            linear2_cls=linear2_self_attention_cls,
+            **factory_kwargs,
+        )
+        self.multihead_attn = MultiheadAttention(
+            d_model,
+            nhead,
+            dropout=dropout,
+            batch_first=batch_first,
+            linear1_cls=linear1_self_attention_cls,
+            linear2_cls=linear2_self_attention_cls,
+            **factory_kwargs,
+        )
+        # Implementation of Feedforward model
+        self.linear1 = linear1_feedforward_cls(
+            d_model, dim_feedforward, **factory_kwargs
+        )
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = linear2_feedforward_cls(
+            dim_feedforward, d_model, **factory_kwargs
+        )
+
+        self.norm_first = norm_first
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.dropout3 = nn.Dropout(dropout)
+
+        # Legacy string support for activation function.
+        if isinstance(activation, str):
+            self.activation = _get_activation_fn(activation)
+        elif isinstance(activation, partial):
+            self.activation = activation(d_model)
+        else:
+            self.activation = activation
+
+        if adaptive_layer_norm:
+            norm1 = layer_norm_cls(
+                d_model, eps=layer_norm_eps, **factory_kwargs
+            )
+            norm2 = layer_norm_cls(
+                d_model, eps=layer_norm_eps, **factory_kwargs
+            )
+            norm3 = layer_norm_cls(
+                d_model, eps=layer_norm_eps, **factory_kwargs
+            )
+
+            self.norm1 = AdaptiveLayerNorm(d_model, norm1)
+            self.norm2 = AdaptiveLayerNorm(d_model, norm2)
+            self.norm3 = AdaptiveLayerNorm(d_model, norm3)
+        else:
+            self.norm1 = layer_norm_cls(
+                d_model, eps=layer_norm_eps, **factory_kwargs
+            )
+            self.norm2 = layer_norm_cls(
+                d_model, eps=layer_norm_eps, **factory_kwargs
+            )
+            if layer_norm_cls == IdentityNorm:
+                self.norm3 = BalancedBasicNorm(
+                    d_model, eps=layer_norm_eps, **factory_kwargs
+                )
+            else:
+                self.norm3 = layer_norm_cls(
+                    d_model, eps=layer_norm_eps, **factory_kwargs
+                )
+
+    def forward(
+        self,
+        tgt: Tensor,
+        memory: Tensor,
+        tgt_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+        tgt_key_padding_mask: Optional[Tensor] = None,
+        memory_key_padding_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        tgt_is_tuple = False
+        if isinstance(tgt, tuple):
+            x, stage_embedding = tgt
+            tgt_is_tuple = True
+        else:
+            x, stage_embedding = tgt, None
+
+        if self.norm_first:
+            x = x + self._sa_block(
+                self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
+            )
+            x = x + self._mha_block(
+                self.norm2(x, stage_embedding),
+                memory,
+                memory_mask,
+                memory_key_padding_mask,
+            )
+            x = x + self._ff_block(self.norm3(x, stage_embedding))
+        else:
+            x = self.norm1(
+                x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
+                stage_embedding,
+            )
+            x = self.norm2(
+                x
+                + self._mha_block(
+                    x, memory, memory_mask, memory_key_padding_mask
+                ),
+                stage_embedding,
+            )
+            x = self.norm3(x + self._ff_block(x), stage_embedding)
+
+        if tgt_is_tuple:
+            return (x, stage_embedding)
+        return x
+
+    # self-attention block
+    def _sa_block(
+        self,
+        x: Tensor,
+        attn_mask: Optional[Tensor],
+        key_padding_mask: Optional[Tensor],
+    ) -> Tensor:
+        x = self.self_attn(
+            x,
+            x,
+            x,
+            attn_mask=attn_mask,
+            key_padding_mask=key_padding_mask,
+            need_weights=False,
+        )[0]
+        return self.dropout1(x)
+
+    # multihead attention block
+    def _mha_block(
+        self,
+        x: Tensor,
+        mem: Tensor,
+        attn_mask: Optional[Tensor],
+        key_padding_mask: Optional[Tensor],
+    ) -> Tensor:
+        x = self.multihead_attn(
+            x,
+            mem,
+            mem,
+            attn_mask=attn_mask,
+            key_padding_mask=key_padding_mask,
+            need_weights=False,
+        )[0]
+        return self.dropout2(x)
+
+    # feed forward block
+    def _ff_block(self, x: Tensor) -> Tensor:
+        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+        return self.dropout3(x)
+
+
+def _get_clones(module, N):
+    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
+    if activation == "relu":
+        return F.relu
+    elif activation == "gelu":
+        return F.gelu
+
+    raise RuntimeError(
+        "activation should be relu/gelu, not {}".format(activation)
+    )
diff --git a/slam_llm/models/vallex/vallex_config.py b/slam_llm/models/vallex/vallex_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..242218f2f4559b74ba90a76c31bfce6bbac1cdc7
--- /dev/null
+++ b/slam_llm/models/vallex/vallex_config.py
@@ -0,0 +1,56 @@
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+from fairseq.data import Dictionary
+from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
+
+logger = logging.get_logger(__name__)
+
+
+
+class VallexConfig(PretrainedConfig):
+    
+    model_type = "vallex"
+    
+    def __init__(self,
+            n_layer=24,
+            n_head=16,
+            n_dim=1024,
+            prefix_mode=1,
+            num_quantizers=8,
+            sample_rate=24000,
+            ar_at_dict="",
+            ar_st_dict="",
+            nar_at_dict="",
+            nar_st_dict="",
+            nar_scale_factor=1.0,
+            prepend_bos=True,
+            norm_first=True,
+            eps=0.0,
+            only_ar=False,
+            only_nar=False,
+            **kwargs
+        ):
+        self.n_layer = n_layer
+        self.n_head = n_head
+        self.n_dim = n_dim
+        self.prefix_mode = prefix_mode
+        self.num_quantizers = num_quantizers
+        self.sample_rate = sample_rate
+        self.nar_scale_factor = nar_scale_factor
+        self.prepend_bos = prepend_bos
+        self.norm_first = norm_first
+        
+        self.ar_at_dict = ar_at_dict
+        self.ar_st_dict = ar_st_dict
+        self.nar_at_dict = nar_at_dict
+        self.nar_st_dict = nar_st_dict
+        self.eps = eps
+        self.only_ar = only_ar
+        self.only_nar = only_nar
+        
+        super().__init__(
+            **kwargs
+        )
+        
+
+AutoConfig.register("vallex", VallexConfig)
\ No newline at end of file
diff --git a/slam_llm/models/vallex/vallex_model.py b/slam_llm/models/vallex/vallex_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7a57426e0f0ffb40a89b1b39b323284c7638b15
--- /dev/null
+++ b/slam_llm/models/vallex/vallex_model.py
@@ -0,0 +1,772 @@
+import random
+from typing import Dict, Iterator, List, Tuple, Union
+from fairseq import utils
+import numpy as np
+import torch
+import math
+import torch.nn as nn
+import torch.nn.functional as F
+from fairseq.data import Dictionary
+from src.slam_llm.models.vallex.transformers import (
+    LayerNorm,
+    TransformerEncoder,
+    TransformerEncoderLayer,
+)
+from src.slam_llm.models.vallex.vallex_config import VallexConfig
+from transformers.modeling_utils import PreTrainedModel
+from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
+from dataclasses import dataclass
+
+@dataclass
+class ModelOutput:
+    logits: torch.Tensor
+    loss: torch.Tensor
+    acc: torch.Tensor
+
+def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True, scale=1, prob_mask=None):
+    if target.dim() == lprobs.dim() - 1:
+        target = target.unsqueeze(-1)
+    if prob_mask is not None:
+        lprobs = lprobs.masked_fill(prob_mask, 0.0)
+        n_class = (1-prob_mask.float()).sum()
+    else:
+        n_class = lprobs.size(-1)
+    nll_loss = -lprobs.gather(dim=-1, index=target) 
+    # nll_loss = nll_loss * scale
+    smooth_loss = -lprobs.sum(dim=-1, keepdim=True) * scale 
+    if ignore_index is not None:
+        pad_mask = target.eq(ignore_index) 
+        nll_loss.masked_fill_(pad_mask, 0.0)
+        smooth_loss.masked_fill_(pad_mask, 0.0)
+        pad_mask_float = (1 - pad_mask.to(torch.float)).sum()
+    else:
+        nll_loss = nll_loss.squeeze(-1)
+        smooth_loss = smooth_loss.squeeze(-1)
+    if reduce:
+        nll_loss = nll_loss.sum()
+        smooth_loss = smooth_loss.sum()
+    eps_i = epsilon / (n_class - 1) 
+    loss = (1.0 - epsilon - eps_i) * nll_loss + \
+        eps_i * smooth_loss 
+    return loss / pad_mask_float, nll_loss / pad_mask_float
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+    def __init__(self, embedding_dim, padding_idx, init_size=1024):
+        super().__init__()
+        self.embedding_dim = embedding_dim
+        self.padding_idx = padding_idx if padding_idx is not None else 0
+        self.weights = SinusoidalPositionalEmbedding.get_embedding(
+            init_size, embedding_dim, padding_idx
+        )
+        self.onnx_trace = False
+        self.register_buffer("_float_tensor", torch.FloatTensor(1))
+        self.max_positions = int(1e5)
+
+    def prepare_for_onnx_export_(self):
+        self.onnx_trace = True
+
+    @staticmethod
+    def get_embedding(
+        num_embeddings: int, embedding_dim: int, padding_idx = None
+    ):
+        half_dim = embedding_dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
+            1
+        ) * emb.unsqueeze(0)
+        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
+            num_embeddings, -1
+        )
+        if embedding_dim % 2 == 1:
+            # zero pad
+            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+        if padding_idx is not None:
+            emb[padding_idx, :] = 0
+        return emb
+
+    def forward(
+        self,
+        input,
+        incremental_state = None,
+        timestep = None,
+        positions = None,
+    ):
+        bspair = torch.onnx.operators.shape_as_tensor(input)
+        bsz, seq_len = bspair[0], bspair[1]
+        max_pos = self.padding_idx + 1 + seq_len
+        if self.weights is None or max_pos > self.weights.size(0):
+            # recompute/expand embeddings if needed
+            self.weights = SinusoidalPositionalEmbedding.get_embedding(
+                max_pos, self.embedding_dim, self.padding_idx
+            )
+        self.weights = self.weights.to(self._float_tensor)
+
+        if incremental_state is not None:
+            # positions is the same for every token when decoding a single step
+            pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+            if self.onnx_trace:
+                return (
+                    self.weights.index_select(index=self.padding_idx + pos, dim=0)
+                    .unsqueeze(1)
+                    .repeat(bsz, 1, 1)
+                )
+            return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+        positions = utils.make_positions(
+            input, self.padding_idx, onnx_trace=self.onnx_trace
+        )
+        if self.onnx_trace:
+            flat_embeddings = self.weights.detach().index_select(0, positions.view(-1))
+            embedding_shape = torch.cat(
+                (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long))
+            )
+            embeddings = torch.onnx.operators.reshape_from_tensor_shape(
+                flat_embeddings, embedding_shape
+            )
+            return embeddings
+        return (
+            self.weights.index_select(0, positions.view(-1))
+            .view(bsz, seq_len, -1)
+            .detach()
+        )
+
+
+class Transpose(nn.Identity):
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        return input.transpose(1, 2)
+
+
+class VALLF(PreTrainedModel):
+    config_class = VallexConfig
+    
+    def __init__(
+        self,
+        config: VallexConfig
+    ):
+        super().__init__(config)
+        
+        self.ar_at_dict = Dictionary.load(self.config.ar_at_dict)
+        self.ar_st_dict = Dictionary.load(self.config.ar_st_dict)
+        self.nar_at_dict = Dictionary.load(self.config.nar_at_dict)
+        self.nar_st_dict = Dictionary.load(self.config.nar_st_dict)
+        
+        self.ar_at_dict.tts_flag = self.ar_at_dict.add_symbol("<TTS>")
+        self.ar_st_dict.asr_flag = self.ar_st_dict.add_symbol("<ASR>")
+        self.ar_st_dict.mt_flag = self.ar_st_dict.add_symbol("<MT>")
+        
+        self.padding_idx = self.ar_at_dict.pad()
+        self.config = config
+        d_model = self.config.n_dim
+        nar_scale_factor = self.config.nar_scale_factor
+        prepend_bos = self.config.prepend_bos
+        
+        norm_first = self.config.norm_first
+        num_layers = self.config.n_layer
+        self.NUM_AUDIO_TOKENS = self.ar_at_dict.eos()
+        
+        nar_d_model = int(d_model * nar_scale_factor)
+
+        self.ar_text_embedding = nn.Embedding(len(self.ar_st_dict), d_model, self.ar_st_dict.pad())  # W_x
+        if config.only_ar:
+            pass
+        else:
+            self.nar_text_embedding = nn.Embedding(len(self.nar_st_dict), d_model, self.nar_st_dict.pad())
+
+        # ID self.NUM_AUDIO_TOKENS     -> PAD
+        # ID self.NUM_AUDIO_TOKENS + 1 -> BOS
+        self.ar_audio_prepend_bos = prepend_bos
+        self.ar_audio_embedding = EncodecDecoderLstm(
+            dictionary=self.ar_at_dict, emb_dim=d_model
+        )
+
+        self.ar_text_prenet = nn.Identity()
+        self.ar_audio_prenet = nn.Identity()
+
+        self.ar_text_position = SinusoidalPositionalEmbedding(
+            d_model,
+            padding_idx=self.ar_at_dict.pad(),
+            init_size=1024+self.ar_at_dict.pad()+1
+        )
+        self.ar_audio_position = SinusoidalPositionalEmbedding(
+            d_model,
+            padding_idx=self.ar_at_dict.pad(),
+            init_size=1024+self.ar_at_dict.pad()+1
+        )
+
+        self.ar_decoder = TransformerEncoder(
+            TransformerEncoderLayer(
+                d_model,
+                self.config.n_head,
+                dim_feedforward=d_model * 4,
+                dropout=0.1,
+                batch_first=True,
+                norm_first=norm_first,
+            ),
+            num_layers=num_layers,
+            norm=LayerNorm(d_model) if norm_first else None,
+        )
+        self.ar_predict_layer = nn.Linear(
+            d_model, len(self.ar_at_dict), bias=False
+        )
+
+        self.rng = random.Random(0)
+        self.num_heads = self.config.n_head
+        self.prefix_mode = self.config.prefix_mode
+        self.num_quantizers = self.config.num_quantizers
+
+        assert self.num_quantizers >= 1
+        if config.only_ar:
+            pass
+        else:
+            if self.num_quantizers > 1:
+                self.nar_audio_embeddings = NATEncodecDecoderLstm(
+                    codecs=[0, 1, 2, 3, 4, 5, 6, 7], dictionary=self.nar_at_dict, emb_dim=d_model
+                )  # W_a
+
+                self.nar_text_prenet = nn.Identity()
+                self.nar_audio_prenet = nn.Identity()
+
+                self.nar_text_position = SinusoidalPositionalEmbedding(
+                    d_model,
+                    padding_idx=self.nar_at_dict.pad(),
+                    init_size=1024+self.nar_at_dict.pad()+1
+                )
+                self.nar_audio_position = SinusoidalPositionalEmbedding(
+                    d_model,
+                    padding_idx=self.nar_at_dict.pad(),
+                    init_size=1024+self.nar_at_dict.pad()+1
+                )
+
+                self.nar_decoder = TransformerEncoder(
+                    TransformerEncoderLayer(
+                        nar_d_model,
+                        int(self.num_heads * nar_scale_factor),
+                        dim_feedforward=nar_d_model * 4,
+                        dropout=0.1,
+                        batch_first=True,
+                        norm_first=norm_first,
+                        adaptive_layer_norm=True,
+                    ),
+                    num_layers=int(num_layers * nar_scale_factor),
+                    norm=nn.LayerNorm(nar_d_model)
+                    if norm_first
+                    else None,
+                )
+                self.nar_predict_layers = nn.ModuleList(
+                    [
+                        nn.Linear(nar_d_model, len(self.nar_at_dict), bias=False)
+                        for i in range(self.num_quantizers)
+                    ]
+                )
+                self.nar_stage_embeddings = None
+
+    def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
+        assert stage > 0
+        if stage == 1:
+            for name, param in self.named_parameters():
+                if name.startswith("ar_"):
+                    print(f" AR parameter: {name}")
+                    yield param
+
+        if stage == 2:
+            for name, param in self.named_parameters():
+                if name.startswith("nar_"):
+                    print(f"NAR parameter: {name}")
+                    yield param
+
+    def stage_named_parameters(
+        self, stage: int = 1
+    ) -> Iterator[Tuple[str, nn.Parameter]]:
+        assert stage > 0
+        if stage == 1:
+            for pair in self.named_parameters():
+                if pair[0].startswith("ar_"):
+                    yield pair
+
+        if stage == 2:
+            for pair in self.named_parameters():
+                if pair[0].startswith("nar_"):
+                    yield pair
+
+    def pad_y_eos(self, y, y_mask_int, eos_id):
+        targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
+            y_mask_int, (0, 1), value=1
+        )
+        # inputs, targets
+        if self.ar_audio_prepend_bos:
+            return (
+                F.pad(targets[:, :-1], (1, 0), value=self.NUM_AUDIO_TOKENS + 1),
+                targets,
+            )
+
+        return targets[:, :-1], targets[:, 1:]
+
+class VALLE(VALLF):
+    config_class = VallexConfig
+    
+    def __init__(
+        self,
+        config: VallexConfig,
+        **kwargs,
+    ):
+        super(VALLE, self).__init__(
+            config,
+            **kwargs,
+        )
+        print(config)
+        self.config = config
+        d_model = self.config.n_dim
+        self.eps = config.eps
+        
+        self.language_ID = {
+            'en': 0,
+            'zh': 1,
+        }
+        self.ar_language_embedding = nn.Embedding(3, d_model, padding_idx=2) 
+        self.nar_language_embedding = nn.Embedding(3, d_model, padding_idx=2) 
+        self.embed_scale = 32.0
+        
+    def forward(
+        self,
+        zh,
+        en
+    ):
+        """
+        "zh": {
+            "st_tokens": zh_st,
+            "at_tokens_wbos": zh_prev_at,
+            "at_tokens_tgt": zh_tgt_at,
+            "self_atten_mask": zh_self_atten_mask,
+            "padding_mask": zh_padding_mask,
+            "langid": zh_id.long()
+        },
+        "en": {
+            "st_tokens": en_st,
+            "at_tokens_wbos": en_prev_at,
+            "at_tokens_tgt": en_tgt_at,
+            "self_atten_mask": en_self_atten_mask,
+            "padding_mask": en_padding_mask,
+            "langid": en_id.long()
+        }
+        """
+        flag = (np.random.randint(low=0, high=1000) % 2 == 0) # zh or en
+        if flag:
+            data = zh
+        else:
+            data = en
+        
+        st_tokens = data["st_tokens"]
+        at_tokens_wbos = data["at_tokens_wbos"]
+        at_tokens_tgt = data["at_tokens_tgt"]
+        self_atten_mask = data["self_atten_mask"]
+        padding_mask = data["padding_mask"]
+        langid = data["langid"]
+        
+        st_len = st_tokens.size(1)
+        st_emb = self.embed_scale * self.ar_text_embedding(st_tokens)
+        src_lang_emb = self.embed_scale * self.ar_language_embedding(langid)
+        st_emb += src_lang_emb
+        st_pos = self.ar_text_position(st_tokens)
+        st_emb += st_pos
+        
+        at_emb, _ = self.ar_audio_embedding(at_tokens_wbos, None)
+        at_emb = self.embed_scale * at_emb
+        tgt_lang_emb = self.embed_scale * self.ar_language_embedding(langid)
+        at_emb += tgt_lang_emb
+        at_pos = self.ar_audio_position(at_tokens_wbos)
+        at_emb += at_pos
+        
+        x = torch.concat([st_emb, at_emb], dim=1)
+        
+        x = self.ar_decoder(
+            x,
+            mask=self_atten_mask,
+            src_key_padding_mask=padding_mask
+        )
+        x = self.ar_predict_layer(x)
+        x = x[:, st_len:, :]
+        loss, nll_loss, lprob, right_rate = self.calculate_loss(
+            x, at_tokens_tgt
+        )
+        return ModelOutput(logits=lprob, loss=loss, acc=right_rate), right_rate
+
+    def calculate_loss(self, encoder_out, target, reduce=True, scale=1.0, prob_mask=None, acc=True):
+        lprob = self.get_normalized_probs(encoder_out, log_probs=True)
+        with torch.no_grad():
+            mask = target.ne(self.padding_idx)
+            n_correct = torch.sum(
+                lprob.argmax(-1).masked_select(mask).eq(target.masked_select(mask))
+            )
+            total = torch.sum(mask)
+            right_rate = n_correct * 100.0 / total
+        
+        lprob, target = lprob.view(-1, lprob.size(-1)), target.view(-1)
+        loss, nll_loss = label_smoothed_nll_loss(
+            lprob,
+            target,
+            self.eps,
+            ignore_index=self.padding_idx,
+            reduce=reduce,
+            scale=scale,
+            prob_mask=prob_mask
+        )
+        
+        return loss, nll_loss, lprob, right_rate
+    
+    def get_normalized_probs(self, encoder_out, log_probs, sample=None):
+        if torch.is_tensor(encoder_out):
+            logits = encoder_out.float()
+            if log_probs:
+                return F.log_softmax(logits, dim=-1)
+            else:
+                return F.softmax(logits, dim=-1)
+            
+    
+    def inference_24L(
+        self,
+        x: torch.Tensor,
+        x_lens: torch.Tensor,
+        y: torch.Tensor,
+        enroll_x_lens: torch.Tensor,
+        top_k: int = -100,
+        temperature: float = 1.0,
+        prompt_language: str = None,
+        text_language: str = None,
+        best_of: int = 1,
+        length_penalty: float = 1.0,
+        return_worst: bool = False,
+        at_eos: int = -1
+    ) -> torch.Tensor:
+        assert x.ndim == 2, x.shape
+        assert x_lens.ndim == 1, x_lens.shape
+        assert y.ndim == 3, y.shape
+        assert y.shape[0] == 1, y.shape
+
+        assert torch.all(x_lens > 0)
+        self.NUM_AUDIO_TOKENS = at_eos
+        text = x
+        x = self.embed_scale * self.ar_text_embedding(text)
+        prompt_language_id = prompt_language.to(x.device)
+        text_language_id = text_language.to(x.device)
+        src_lang_emb = self.embed_scale * self.ar_language_embedding(prompt_language_id)
+        tgt_lang_emb = self.embed_scale * self.ar_language_embedding(text_language_id)
+        x[:, :enroll_x_lens, :] += src_lang_emb
+        x[:, enroll_x_lens:, :] += tgt_lang_emb
+        x = self.ar_text_prenet(x)
+        x_pos = self.ar_text_position(text)
+
+        text_len = x_lens.max()
+        prompts = y
+        prefix_len = y.shape[1]
+
+        # AR Decoder
+        # TODO: Managing decoder steps avoid repetitive computation
+        y = prompts[..., 0]
+        if self.ar_audio_prepend_bos:
+            y = F.pad(y, (1, 0), value=self.ar_at_dict.tts_flag)
+
+        x_len = x_lens.max()
+        x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
+
+        kv_cache = None
+        use_kv_caching = True
+
+        sum_logprobs = torch.zeros(best_of, device=y.device)  # implement batch decoding here
+        x = x.repeat(best_of, 1, 1)
+        y = y.repeat(best_of, 1)
+        lstm_h = None
+        first_ar = True
+        while True:
+            if first_ar:
+                y_emb, lstm_h = self.ar_audio_embedding(y, lstm_h)
+                y_emb = y_emb * self.embed_scale
+                y_emb = self.ar_audio_prenet(y_emb)
+                y_pos = self.ar_audio_position(y)
+                y_emb[:, :prefix_len] = y_emb[:, :prefix_len] + src_lang_emb
+                y_emb[:, prefix_len:] = y_emb[:, prefix_len:] + tgt_lang_emb
+                xy_pos_token = torch.concat([x_pos+x, y_pos+y_emb], dim=1)
+                first_ar = False
+            else:
+                y_emb_cur, lstm_h = self.ar_audio_embedding(y[:, -1:], lstm_h)
+                y_emb_cur = y_emb_cur * self.embed_scale
+                y_emb_cur = self.ar_audio_prenet(y_emb_cur)
+                y_pos_cur = self.ar_audio_position(y)[:, -1:]
+                y_emb_cur = y_emb_cur + src_lang_emb
+                y_emb_cur = y_emb_cur + tgt_lang_emb
+                xy_pos_token = torch.concat([xy_pos_token, y_pos_cur+y_emb_cur], dim=1)
+            # print(xy_pos_token.size())
+
+            y_len = y.shape[1]
+            x_attn_mask_pad = F.pad(
+                x_attn_mask,
+                (0, y_len),
+                value=True,
+            )
+            y_attn_mask = F.pad(
+                torch.triu(
+                    torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
+                ),
+                (x_len, 0),
+                value=False,
+            )
+            xy_attn_mask = torch.concat(
+                [x_attn_mask_pad, y_attn_mask], dim=0
+            ).to(y.device)
+
+
+            if use_kv_caching and kv_cache is not None:
+                xy_pos = xy_pos_token[:, [-1]]
+                xy_attn_mask = xy_attn_mask[:, [-1]]
+            else:
+                xy_pos = xy_pos_token
+
+            xy_dec, kv_cache = self.ar_decoder.infer(
+                xy_pos,
+                mask=xy_attn_mask,
+                past_kv=kv_cache,
+                use_cache=use_kv_caching,
+            )
+
+            logits = self.ar_predict_layer(xy_dec[:, -1])
+            samples, current_logprobs = topk_sampling(
+                logits, top_k=top_k, top_p=1, temperature=temperature
+            )
+            sum_logprobs += current_logprobs * (y[:, -1] != self.NUM_AUDIO_TOKENS)
+            samples[y[:, -1] == self.NUM_AUDIO_TOKENS] = self.NUM_AUDIO_TOKENS
+            completed = (samples[:, -1] == self.NUM_AUDIO_TOKENS).all()
+            if (
+                completed
+                or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 32
+            ):  
+                if prompts.shape[1] == y.shape[1]:
+                    raise SyntaxError(
+                        "well trained model shouldn't reach here."
+                    )
+                lengths = torch.sum(y != self.NUM_AUDIO_TOKENS, dim=1)
+                avg_logprobs = sum_logprobs / lengths ** length_penalty
+                # choose the best beam according to sum_logprobs
+                best_beam = y[torch.argmax(avg_logprobs), :]
+                worst_beam = y[torch.argmin(avg_logprobs), :]
+                # strip all eos tokens
+                best_beam = best_beam[best_beam != self.NUM_AUDIO_TOKENS]
+                worst_beam = worst_beam[worst_beam != self.NUM_AUDIO_TOKENS]
+                if return_worst:
+                    y = worst_beam.unsqueeze(0)
+                else:
+                    y = best_beam.unsqueeze(0)
+                print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
+                break
+
+            y = torch.concat([y, samples], dim=1)
+
+        codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
+        if self.num_quantizers == 1:
+            return torch.stack(codes, dim=-1)
+
+        if self.prefix_mode in [2, 4]:  # Exclude enrolled_phonemes
+            enrolled_len = enroll_x_lens.max().item()
+            # SOS + Synthesis Text + EOS
+            text = torch.concat(
+                [
+                    text[:, :1],
+                    text[:, enrolled_len - 1 :],
+                ],
+                dim=1,
+            )
+            text_len = text_len - (enrolled_len - 2)
+            assert text.shape[0] == 1
+
+        x = self.embed_scale * self.nar_text_embedding(text)
+        # Add language embedding
+        prompt_language_id = prompt_language.to(x.device)
+        text_language_id = text_language.to(x.device)
+        src_lang_emb = self.embed_scale * self.nar_language_embedding(prompt_language_id)
+        tgt_lang_emb = self.embed_scale * self.nar_language_embedding(text_language_id)
+        x[:, :enroll_x_lens, :] += src_lang_emb
+        x[:, enroll_x_lens:, :] += tgt_lang_emb
+        x = self.nar_text_prenet(x)
+        x_pos = self.nar_text_position(text)
+
+        if self.prefix_mode == 0:
+            for i, predict_layer in enumerate(
+                self.nar_predict_layers
+            ):
+                y_pos = self.nar_audio_prenet(y_emb)
+                y_pos = self.nar_audio_position(y_pos)
+                xy_pos = torch.concat([x, y_pos], dim=1)
+
+                xy_dec, _ = self.nar_decoder(
+                    (xy_pos, self.nar_stage_embeddings[i].weight)
+                )
+                logits = predict_layer(xy_dec[:, text_len + prefix_len :])
+
+                samples = torch.argmax(logits, dim=-1)
+                codes.append(samples)
+
+                if i < self.num_quantizers - 2:
+                    y_emb[:, :prefix_len] += self.embed_scale * self.nar_audio_embeddings(
+                        prompts[..., i + 1]
+                    )[0]
+                    y_emb[:, prefix_len:] += self.embed_scale * self.nar_audio_embeddings(samples)[0]
+        else:
+            y_pos = self.nar_audio_position(y[:, int(self.ar_audio_prepend_bos):])
+            
+            ref_at_emb = self.embed_scale * self.nar_audio_embeddings(prompts)[0] + src_lang_emb
+            est_at = y[:, prefix_len+int(self.ar_audio_prepend_bos):].unsqueeze(-1)
+            # 
+            for i in range(1, 8):
+                y_emb, _ = self.nar_audio_embeddings(est_at)
+                y_emb = self.embed_scale * y_emb + tgt_lang_emb
+                
+                y_emb = torch.concat([ref_at_emb, y_emb], dim=1)
+                xy_pos = torch.concat([x+x_pos, y_emb+y_pos], dim=1)
+
+                xy_dec = self.nar_decoder(
+                    xy_pos
+                )
+                logits = self.nar_predict_layers[i-1](xy_dec[:, text_len + prefix_len :])
+                # print(logits.size(), xy_pos.size(), xy_dec.size())
+                samples = torch.argmax(logits, dim=-1)
+                est_at = torch.concat([est_at, samples.unsqueeze(-1)], dim=-1)
+                codes.append(samples)
+
+        assert len(codes) == self.num_quantizers
+        return torch.stack(codes, dim=-1)
+            
+def top_k_top_p_filtering(
+    logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
+):
+    if top_k > 0:
+        top_k = min(
+            max(top_k, min_tokens_to_keep), logits.size(-1)
+        )  # Safety check
+        # Remove all tokens with a probability less than the last token of the top-k
+        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+        logits[indices_to_remove] = filter_value
+
+    if top_p < 1.0:
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+        cumulative_probs = torch.cumsum(
+            F.softmax(sorted_logits, dim=-1), dim=-1
+        )
+
+        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
+        sorted_indices_to_remove = cumulative_probs > top_p
+        if min_tokens_to_keep > 1:
+            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
+            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
+        # Shift the indices to the right to keep also the first token above the threshold
+        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
+            ..., :-1
+        ].clone()
+        sorted_indices_to_remove[..., 0] = 0
+
+        # scatter sorted tensors to original indexing
+        indices_to_remove = sorted_indices_to_remove.scatter(
+            1, sorted_indices, sorted_indices_to_remove
+        )
+        logits[indices_to_remove] = filter_value
+    return logits
+
+
+def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
+    if temperature != 1.0:
+        logits = logits / temperature
+    # Top-p/top-k filtering
+    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
+    # Sample
+    token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
+    logprobs = F.log_softmax(logits.float(), dim=-1)
+    current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
+    return token, current_logprobs
+
+class SLSTM(nn.Module):
+    def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True, bidirectional=False):
+        super().__init__()
+        self.skip = skip
+        self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=bidirectional)            
+        if bidirectional:
+            self.out_fc = nn.Linear(dimension*2, dimension)
+        else:
+            self.out_fc = None
+
+    def forward(self, x, hidden=None):
+        x = x.permute(2, 0, 1)
+        y, hidden = self.lstm(x, hidden)
+        if self.out_fc is not None:
+            y = self.out_fc(y)
+        if self.skip:
+            y = y + x
+        y = y.permute(1, 2, 0)
+        return y, hidden
+    
+class EncodecDecoderLstm(nn.Module):
+    def __init__(self, dictionary, emb_dim, 
+                 out_dim=None,
+                 num_layers=3, lstm_skip=True, lstm_bidire=False,
+                 activation_param={'alpha': 1.0}, **kwargs):
+        super().__init__()
+        
+        # Identity()
+        if out_dim is None:
+            out_dim = emb_dim
+        self.slstm = SLSTM(dimension=out_dim, num_layers=num_layers, skip=lstm_skip, bidirectional=lstm_bidire)
+        self.elu = nn.ELU(**activation_param)
+        self.embedding_dim = emb_dim
+        self.padding_idx = dictionary.pad()
+        self.emb = nn.Embedding(len(dictionary), emb_dim, dictionary.pad_index)
+    
+    def forward(self, x, hidden=None):
+        """
+        Args:
+            x (_type_): B,T,D
+        """
+        # print(x.size())
+        quantized_out = self.emb(x)
+        out, hidden = self.slstm(quantized_out.permute(0,2,1), hidden)
+        out = self.elu(out)
+        return out.permute(0,2,1), hidden
+
+class NATEncodecDecoderLstm(nn.Module):
+    def __init__(self, codecs, dictionary, emb_dim, 
+                 out_dim=None,
+                 num_layers=3, lstm_skip=True, lstm_bidire=False,
+                 activation_param={'alpha': 1.0}, **kwargs):
+        super().__init__()
+        
+        # Identity()
+        if out_dim is None:
+            out_dim = emb_dim
+        self.slstm = SLSTM(dimension=out_dim, num_layers=num_layers, skip=lstm_skip, bidirectional=lstm_bidire)
+        self.elu = nn.ELU(**activation_param)
+        self.codecs = codecs
+        self.embedding_dim = emb_dim
+        self.padding_idx = dictionary.pad()
+        self.emb_list = nn.ModuleList(
+            [nn.Embedding(len(dictionary), emb_dim, dictionary.pad_index) for i in range(len(self.codecs))]
+        )
+    
+    def forward(self, x, hidden=None):
+        """
+        Args:
+            x (_type_): B,T,D
+        """
+        if len(x.size()) == 2:
+            x = x.unsqueeze(-1)
+        
+        if x.size(2) != len(self.codecs) and x.size(1) == len(self.codecs):
+            x = x.permute(0, 2, 1)
+        
+        quantized_out = 0
+        for i in range(x.size(2)):
+            quantized = self.emb_list[i](x[: , :, i])
+            quantized_out = quantized_out + quantized
+        # quantized_out = quantized_out / len(self.codecs)
+        
+        out, hidden = self.slstm(quantized_out.permute(0,2,1), hidden)
+        out = self.elu(out)
+        return out.permute(0,2,1), hidden
+
+AutoModel.register(VallexConfig, VALLE)
\ No newline at end of file
diff --git a/slam_llm/models/wavlm/WavLM.py b/slam_llm/models/wavlm/WavLM.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0134c660c0bc13bd35ee16867ef6bd71ce1d362
--- /dev/null
+++ b/slam_llm/models/wavlm/WavLM.py
@@ -0,0 +1,743 @@
+# --------------------------------------------------------
+# WavLM: Large-Scale Self-Supervised  Pre-training  for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
+# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+import logging
+from typing import List, Optional, Tuple
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import LayerNorm
+from .modules import (
+    Fp32GroupNorm,
+    Fp32LayerNorm,
+    GradMultiply,
+    MultiheadAttention,
+    SamePad,
+    init_bert_params,
+    get_activation_fn,
+    TransposeLast,
+    GLU_Linear,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def compute_mask_indices(
+    shape: Tuple[int, int],
+    padding_mask: Optional[torch.Tensor],
+    mask_prob: float,
+    mask_length: int,
+    mask_type: str = "static",
+    mask_other: float = 0.0,
+    min_masks: int = 0,
+    no_overlap: bool = False,
+    min_space: int = 0,
+) -> np.ndarray:
+    """
+    Computes random mask spans for a given shape
+
+    Args:
+        shape: the the shape for which to compute masks.
+            should be of size 2 where first element is batch size and 2nd is timesteps
+        padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+        mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+            number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+            however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+        mask_type: how to compute mask lengths
+            static = fixed size
+            uniform = sample from uniform distribution [mask_other, mask_length*2]
+            normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+            poisson = sample from possion distribution with lambda = mask length
+        min_masks: minimum number of masked spans
+        no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+        min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+    """
+
+    bsz, all_sz = shape
+    mask = np.full((bsz, all_sz), False)
+
+    all_num_mask = int(
+        # add a random number for probabilistic rounding
+        mask_prob * all_sz / float(mask_length)
+        + np.random.rand()
+    )
+
+    all_num_mask = max(min_masks, all_num_mask)
+
+    mask_idcs = []
+    for i in range(bsz):
+        if padding_mask is not None:
+            sz = all_sz - padding_mask[i].long().sum().item()
+            num_mask = int(
+                # add a random number for probabilistic rounding
+                mask_prob * sz / float(mask_length)
+                + np.random.rand()
+            )
+            num_mask = max(min_masks, num_mask)
+        else:
+            sz = all_sz
+            num_mask = all_num_mask
+
+        if mask_type == "static":
+            lengths = np.full(num_mask, mask_length)
+        elif mask_type == "uniform":
+            lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+        elif mask_type == "normal":
+            lengths = np.random.normal(mask_length, mask_other, size=num_mask)
+            lengths = [max(1, int(round(x))) for x in lengths]
+        elif mask_type == "poisson":
+            lengths = np.random.poisson(mask_length, size=num_mask)
+            lengths = [int(round(x)) for x in lengths]
+        else:
+            raise Exception("unknown mask selection " + mask_type)
+
+        if sum(lengths) == 0:
+            lengths[0] = min(mask_length, sz - 1)
+
+        if no_overlap:
+            mask_idc = []
+
+            def arrange(s, e, length, keep_length):
+                span_start = np.random.randint(s, e - length)
+                mask_idc.extend(span_start + i for i in range(length))
+
+                new_parts = []
+                if span_start - s - min_space >= keep_length:
+                    new_parts.append((s, span_start - min_space + 1))
+                if e - span_start - keep_length - min_space > keep_length:
+                    new_parts.append((span_start + length + min_space, e))
+                return new_parts
+
+            parts = [(0, sz)]
+            min_length = min(lengths)
+            for length in sorted(lengths, reverse=True):
+                lens = np.fromiter(
+                    (e - s if e - s >= length + min_space else 0 for s, e in parts),
+                    np.int,
+                )
+                l_sum = np.sum(lens)
+                if l_sum == 0:
+                    break
+                probs = lens / np.sum(lens)
+                c = np.random.choice(len(parts), p=probs)
+                s, e = parts.pop(c)
+                parts.extend(arrange(s, e, length, min_length))
+            mask_idc = np.asarray(mask_idc)
+        else:
+            min_len = min(lengths)
+            if sz - min_len <= num_mask:
+                min_len = sz - num_mask - 1
+
+            mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+
+            mask_idc = np.asarray(
+                [
+                    mask_idc[j] + offset
+                    for j in range(len(mask_idc))
+                    for offset in range(lengths[j])
+                ]
+            )
+
+        mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+
+    min_len = min([len(m) for m in mask_idcs])
+    for i, mask_idc in enumerate(mask_idcs):
+        if len(mask_idc) > min_len:
+            mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+        mask[i, mask_idc] = True
+
+    return mask
+
+
+class WavLMConfig:
+    def __init__(self, cfg=None):
+        self.extractor_mode: str = "default"     # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
+        self.encoder_layers: int = 12     # num encoder layers in the transformer
+
+        self.encoder_embed_dim: int = 768     # encoder embedding dimension
+        self.encoder_ffn_embed_dim: int = 3072     # encoder embedding dimension for FFN
+        self.encoder_attention_heads: int = 12     # num encoder attention heads
+        self.activation_fn: str = "gelu"     # activation function to use
+
+        self.layer_norm_first: bool = False     # apply layernorm first in the transformer
+        self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2"     # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
+        self.conv_bias: bool = False     # include bias in conv encoder
+        self.feature_grad_mult: float = 1.0     # multiply feature extractor var grads by this
+
+        self.normalize: bool = False  # normalize input to have 0 mean and unit variance during training
+
+        # dropouts
+        self.dropout: float = 0.1     # dropout probability for the transformer
+        self.attention_dropout: float = 0.1     # dropout probability for attention weights
+        self.activation_dropout: float = 0.0     # dropout probability after activation in FFN
+        self.encoder_layerdrop: float = 0.0     # probability of dropping a tarnsformer layer
+        self.dropout_input: float = 0.0     # dropout to apply to the input (after feat extr)
+        self.dropout_features: float = 0.0     # dropout to apply to the features (after feat extr)
+
+        # masking
+        self.mask_length: int = 10     # mask length
+        self.mask_prob: float = 0.65     # probability of replacing a token with mask
+        self.mask_selection: str = "static"     # how to choose mask length
+        self.mask_other: float = 0     # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
+        self.no_mask_overlap: bool = False     # whether to allow masks to overlap
+        self.mask_min_space: int = 1     # min space between spans (if no overlap is enabled)
+
+        # channel masking
+        self.mask_channel_length: int = 10     # length of the mask for features (channels)
+        self.mask_channel_prob: float = 0.0     # probability of replacing a feature with 0
+        self.mask_channel_selection: str = "static"     # how to choose mask length for channel masking
+        self.mask_channel_other: float = 0     # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
+        self.no_mask_channel_overlap: bool = False     # whether to allow channel masks to overlap
+        self.mask_channel_min_space: int = 1     # min space between spans (if no overlap is enabled)
+
+        # positional embeddings
+        self.conv_pos: int = 128     # number of filters for convolutional positional embeddings
+        self.conv_pos_groups: int = 16     # number of groups for convolutional positional embedding
+
+        # relative position embedding
+        self.relative_position_embedding: bool = False     # apply relative position embedding
+        self.num_buckets: int = 320     # number of buckets for relative position embedding
+        self.max_distance: int = 1280     # maximum distance for relative position embedding
+        self.gru_rel_pos: bool = False     # apply gated relative position embedding
+
+        if cfg is not None:
+            self.update(cfg)
+
+    def update(self, cfg: dict):
+        self.__dict__.update(cfg)
+
+
+class WavLM(nn.Module):
+    def __init__(
+        self,
+        cfg: WavLMConfig,
+    ) -> None:
+        super().__init__()
+        logger.info(f"WavLM Config: {cfg.__dict__}")
+
+        self.cfg = cfg
+        feature_enc_layers = eval(cfg.conv_feature_layers)
+        self.embed = feature_enc_layers[-1][0]
+
+        self.feature_extractor = ConvFeatureExtractionModel(
+            conv_layers=feature_enc_layers,
+            dropout=0.0,
+            mode=cfg.extractor_mode,
+            conv_bias=cfg.conv_bias,
+        )
+
+        self.post_extract_proj = (
+            nn.Linear(self.embed, cfg.encoder_embed_dim)
+            if self.embed != cfg.encoder_embed_dim
+            else None
+        )
+
+        self.mask_prob = cfg.mask_prob
+        self.mask_selection = cfg.mask_selection
+        self.mask_other = cfg.mask_other
+        self.mask_length = cfg.mask_length
+        self.no_mask_overlap = cfg.no_mask_overlap
+        self.mask_min_space = cfg.mask_min_space
+
+        self.mask_channel_prob = cfg.mask_channel_prob
+        self.mask_channel_selection = cfg.mask_channel_selection
+        self.mask_channel_other = cfg.mask_channel_other
+        self.mask_channel_length = cfg.mask_channel_length
+        self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
+        self.mask_channel_min_space = cfg.mask_channel_min_space
+
+        self.dropout_input = nn.Dropout(cfg.dropout_input)
+        self.dropout_features = nn.Dropout(cfg.dropout_features)
+
+        self.feature_grad_mult = cfg.feature_grad_mult
+
+        self.mask_emb = nn.Parameter(
+            torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
+        )
+
+        self.encoder = TransformerEncoder(cfg)
+        self.layer_norm = LayerNorm(self.embed)
+
+    def apply_mask(self, x, padding_mask):
+        B, T, C = x.shape
+        if self.mask_prob > 0:
+            mask_indices = compute_mask_indices(
+                (B, T),
+                padding_mask,
+                self.mask_prob,
+                self.mask_length,
+                self.mask_selection,
+                self.mask_other,
+                min_masks=2,
+                no_overlap=self.no_mask_overlap,
+                min_space=self.mask_min_space,
+            )
+            mask_indices = torch.from_numpy(mask_indices).to(x.device)
+            x[mask_indices] = self.mask_emb
+        else:
+            mask_indices = None
+
+        if self.mask_channel_prob > 0:
+            mask_channel_indices = compute_mask_indices(
+                (B, C),
+                None,
+                self.mask_channel_prob,
+                self.mask_channel_length,
+                self.mask_channel_selection,
+                self.mask_channel_other,
+                no_overlap=self.no_mask_channel_overlap,
+                min_space=self.mask_channel_min_space,
+            )
+            mask_channel_indices = (
+                torch.from_numpy(mask_channel_indices)
+                .to(x.device)
+                .unsqueeze(1)
+                .expand(-1, T, -1)
+            )
+            x[mask_channel_indices] = 0
+
+        return x, mask_indices
+
+    def forward_padding_mask(
+            self, features: torch.Tensor, padding_mask: torch.Tensor,
+    ) -> torch.Tensor:
+        extra = padding_mask.size(1) % features.size(1)
+        if extra > 0:
+            padding_mask = padding_mask[:, :-extra]
+        padding_mask = padding_mask.view(
+            padding_mask.size(0), features.size(1), -1
+        )
+        padding_mask = padding_mask.all(-1)
+        return padding_mask
+
+    def extract_features(
+        self,
+        source: torch.Tensor,
+        padding_mask: Optional[torch.Tensor] = None,
+        mask: bool = False,
+        ret_conv: bool = False,
+        output_layer: Optional[int] = None,
+        ret_layer_results: bool = False,
+    ):
+
+        if self.feature_grad_mult > 0:
+            features = self.feature_extractor(source)
+            if self.feature_grad_mult != 1.0:
+                features = GradMultiply.apply(features, self.feature_grad_mult)
+        else:
+            with torch.no_grad():
+                features = self.feature_extractor(source)
+
+        features = features.transpose(1, 2)
+        features = self.layer_norm(features)
+
+        if padding_mask is not None:
+            padding_mask = self.forward_padding_mask(features, padding_mask)
+
+        if self.post_extract_proj is not None:
+            features = self.post_extract_proj(features)
+
+        features = self.dropout_input(features)
+
+        if mask:
+            x, mask_indices = self.apply_mask(
+                features, padding_mask
+            )
+        else:
+            x = features
+
+        # feature: (B, T, D), float
+        # target: (B, T), long
+        # x: (B, T, D), float
+        # padding_mask: (B, T), bool
+        # mask_indices: (B, T), bool
+        x, layer_results = self.encoder(
+            x,
+            padding_mask=padding_mask,
+            layer=None if output_layer is None else output_layer - 1
+        )
+
+        res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
+
+        feature = res["features"] if ret_conv else res["x"]
+        if ret_layer_results:
+            feature = (feature, res["layer_results"])
+        return feature, res["padding_mask"]
+
+
+class ConvFeatureExtractionModel(nn.Module):
+    def __init__(
+            self,
+            conv_layers: List[Tuple[int, int, int]],
+            dropout: float = 0.0,
+            mode: str = "default",
+            conv_bias: bool = False,
+            conv_type: str = "default"
+    ):
+        super().__init__()
+
+        assert mode in {"default", "layer_norm"}
+
+        def block(
+                n_in,
+                n_out,
+                k,
+                stride,
+                is_layer_norm=False,
+                is_group_norm=False,
+                conv_bias=False,
+        ):
+            def make_conv():
+                conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
+                nn.init.kaiming_normal_(conv.weight)
+                return conv
+
+            assert (
+                           is_layer_norm and is_group_norm
+                   ) == False, "layer norm and group norm are exclusive"
+
+            if is_layer_norm:
+                return nn.Sequential(
+                    make_conv(),
+                    nn.Dropout(p=dropout),
+                    nn.Sequential(
+                        TransposeLast(),
+                        Fp32LayerNorm(dim, elementwise_affine=True),
+                        TransposeLast(),
+                    ),
+                    nn.GELU(),
+                )
+            elif is_group_norm:
+                return nn.Sequential(
+                    make_conv(),
+                    nn.Dropout(p=dropout),
+                    Fp32GroupNorm(dim, dim, affine=True),
+                    nn.GELU(),
+                )
+            else:
+                return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
+
+        self.conv_type = conv_type
+        if self.conv_type == "default":
+            in_d = 1
+            self.conv_layers = nn.ModuleList()
+            for i, cl in enumerate(conv_layers):
+                assert len(cl) == 3, "invalid conv definition: " + str(cl)
+                (dim, k, stride) = cl
+
+                self.conv_layers.append(
+                    block(
+                        in_d,
+                        dim,
+                        k,
+                        stride,
+                        is_layer_norm=mode == "layer_norm",
+                        is_group_norm=mode == "default" and i == 0,
+                        conv_bias=conv_bias,
+                    )
+                )
+                in_d = dim
+        elif self.conv_type == "conv2d":
+            in_d = 1
+            self.conv_layers = nn.ModuleList()
+            for i, cl in enumerate(conv_layers):
+                assert len(cl) == 3
+                (dim, k, stride) = cl
+
+                self.conv_layers.append(
+                    torch.nn.Conv2d(in_d, dim, k, stride)
+                )
+                self.conv_layers.append(torch.nn.ReLU())
+                in_d = dim
+        elif self.conv_type == "custom":
+            in_d = 1
+            idim = 80
+            self.conv_layers = nn.ModuleList()
+            for i, cl in enumerate(conv_layers):
+                assert len(cl) == 3
+                (dim, k, stride) = cl
+                self.conv_layers.append(
+                    torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
+                )
+                self.conv_layers.append(
+                    torch.nn.LayerNorm([dim, idim])
+                )
+                self.conv_layers.append(torch.nn.ReLU())
+                in_d = dim
+                if (i + 1) % 2 == 0:
+                    self.conv_layers.append(
+                        torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
+                    )
+                    idim = int(math.ceil(idim / 2))
+        else:
+            pass
+
+    def forward(self, x, mask=None):
+
+        # BxT -> BxCxT
+        x = x.unsqueeze(1)
+        if self.conv_type == "custom":
+            for conv in self.conv_layers:
+                if isinstance(conv, nn.LayerNorm):
+                    x = x.transpose(1, 2)
+                    x = conv(x).transpose(1, 2)
+                else:
+                    x = conv(x)
+            x = x.transpose(2, 3).contiguous()
+            x = x.view(x.size(0), -1, x.size(-1))
+        else:
+            for conv in self.conv_layers:
+                x = conv(x)
+            if self.conv_type == "conv2d":
+                b, c, t, f = x.size()
+                x = x.transpose(2, 3).contiguous().view(b, c * f, t)
+        return x
+
+
+class TransformerEncoder(nn.Module):
+    def __init__(self, args):
+        super().__init__()
+
+        self.dropout = args.dropout
+        self.embedding_dim = args.encoder_embed_dim
+
+        self.pos_conv = nn.Conv1d(
+            self.embedding_dim,
+            self.embedding_dim,
+            kernel_size=args.conv_pos,
+            padding=args.conv_pos // 2,
+            groups=args.conv_pos_groups,
+        )
+        dropout = 0
+        std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
+        nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
+        nn.init.constant_(self.pos_conv.bias, 0)
+
+        self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
+        self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
+
+        if hasattr(args, "relative_position_embedding"):
+            self.relative_position_embedding = args.relative_position_embedding
+            self.num_buckets = args.num_buckets
+            self.max_distance = args.max_distance
+        else:
+            self.relative_position_embedding = False
+            self.num_buckets = 0
+            self.max_distance = 0
+
+        self.layers = nn.ModuleList(
+            [
+                TransformerSentenceEncoderLayer(
+                    embedding_dim=self.embedding_dim,
+                    ffn_embedding_dim=args.encoder_ffn_embed_dim,
+                    num_attention_heads=args.encoder_attention_heads,
+                    dropout=self.dropout,
+                    attention_dropout=args.attention_dropout,
+                    activation_dropout=args.activation_dropout,
+                    activation_fn=args.activation_fn,
+                    layer_norm_first=args.layer_norm_first,
+                    has_relative_attention_bias=(self.relative_position_embedding and i == 0),
+                    num_buckets=self.num_buckets,
+                    max_distance=self.max_distance,
+                    gru_rel_pos=args.gru_rel_pos,
+                )
+                for i in range(args.encoder_layers)
+            ]
+        )
+
+        self.layer_norm_first = args.layer_norm_first
+        self.layer_norm = LayerNorm(self.embedding_dim)
+        self.layerdrop = args.encoder_layerdrop
+
+        self.apply(init_bert_params)
+
+    def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
+        x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
+
+        if self.layer_norm_first and layer is None:
+            x = self.layer_norm(x)
+
+        return x, layer_results
+
+    def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
+
+        if padding_mask is not None:
+            x[padding_mask] = 0
+
+        x_conv = self.pos_conv(x.transpose(1, 2))
+        x_conv = x_conv.transpose(1, 2)
+        x = x + x_conv
+
+        if not self.layer_norm_first:
+            x = self.layer_norm(x)
+
+        x = F.dropout(x, p=self.dropout, training=self.training)
+
+        # B x T x C -> T x B x C
+        x = x.transpose(0, 1)
+
+        layer_results = []
+        z = None
+        if tgt_layer is not None:
+            layer_results.append((x, z))
+        r = None
+        pos_bias = None
+        for i, layer in enumerate(self.layers):
+            dropout_probability = np.random.random()
+            if not self.training or (dropout_probability > self.layerdrop):
+                x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
+                                       self_attn_mask=streaming_mask, pos_bias=pos_bias)
+            if tgt_layer is not None:
+                layer_results.append((x, z))
+            if i == tgt_layer:
+                r = x
+                break
+
+        if r is not None:
+            x = r
+
+        # T x B x C -> B x T x C
+        x = x.transpose(0, 1)
+
+        return x, layer_results
+
+
+class TransformerSentenceEncoderLayer(nn.Module):
+    """
+    Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
+    models.
+    """
+
+    def __init__(
+            self,
+            embedding_dim: float = 768,
+            ffn_embedding_dim: float = 3072,
+            num_attention_heads: float = 8,
+            dropout: float = 0.1,
+            attention_dropout: float = 0.1,
+            activation_dropout: float = 0.1,
+            activation_fn: str = "relu",
+            layer_norm_first: bool = False,
+            has_relative_attention_bias: bool = False,
+            num_buckets: int = 0,
+            max_distance: int = 0,
+            rescale_init: bool = False,
+            gru_rel_pos: bool = False,
+    ) -> None:
+
+        super().__init__()
+        # Initialize parameters
+        self.embedding_dim = embedding_dim
+        self.dropout = dropout
+        self.activation_dropout = activation_dropout
+
+        # Initialize blocks
+        self.activation_name = activation_fn
+        self.activation_fn = get_activation_fn(activation_fn)
+        self.self_attn = MultiheadAttention(
+            self.embedding_dim,
+            num_attention_heads,
+            dropout=attention_dropout,
+            self_attention=True,
+            has_relative_attention_bias=has_relative_attention_bias,
+            num_buckets=num_buckets,
+            max_distance=max_distance,
+            rescale_init=rescale_init,
+            gru_rel_pos=gru_rel_pos,
+        )
+
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(self.activation_dropout)
+        self.dropout3 = nn.Dropout(dropout)
+
+        self.layer_norm_first = layer_norm_first
+
+        # layer norm associated with the self attention layer
+        self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
+
+        if self.activation_name == "glu":
+            self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
+        else:
+            self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
+        self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
+
+        # layer norm associated with the position wise feed-forward NN
+        self.final_layer_norm = LayerNorm(self.embedding_dim)
+
+    def forward(
+            self,
+            x: torch.Tensor,
+            self_attn_mask: torch.Tensor = None,
+            self_attn_padding_mask: torch.Tensor = None,
+            need_weights: bool = False,
+            pos_bias=None
+    ):
+        """
+        LayerNorm is applied either before or after the self-attention/ffn
+        modules similar to the original Transformer imlementation.
+        """
+        residual = x
+
+        if self.layer_norm_first:
+            x = self.self_attn_layer_norm(x)
+            x, attn, pos_bias = self.self_attn(
+                query=x,
+                key=x,
+                value=x,
+                key_padding_mask=self_attn_padding_mask,
+                need_weights=False,
+                attn_mask=self_attn_mask,
+                position_bias=pos_bias
+            )
+            x = self.dropout1(x)
+            x = residual + x
+
+            residual = x
+            x = self.final_layer_norm(x)
+            if self.activation_name == "glu":
+                x = self.fc1(x)
+            else:
+                x = self.activation_fn(self.fc1(x))
+            x = self.dropout2(x)
+            x = self.fc2(x)
+            x = self.dropout3(x)
+            x = residual + x
+        else:
+            x, attn, pos_bias = self.self_attn(
+                query=x,
+                key=x,
+                value=x,
+                key_padding_mask=self_attn_padding_mask,
+                need_weights=need_weights,
+                attn_mask=self_attn_mask,
+                position_bias=pos_bias
+            )
+
+            x = self.dropout1(x)
+            x = residual + x
+
+            x = self.self_attn_layer_norm(x)
+
+            residual = x
+            if self.activation_name == "glu":
+                x = self.fc1(x)
+            else:
+                x = self.activation_fn(self.fc1(x))
+            x = self.dropout2(x)
+            x = self.fc2(x)
+            x = self.dropout3(x)
+            x = residual + x
+            x = self.final_layer_norm(x)
+
+        return x, attn, pos_bias
+
diff --git a/slam_llm/models/wavlm/modules.py b/slam_llm/models/wavlm/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f19ab8a48665d5fa5db87371ba2b8e524fcfbf7
--- /dev/null
+++ b/slam_llm/models/wavlm/modules.py
@@ -0,0 +1,827 @@
+# --------------------------------------------------------
+# WavLM: Large-Scale Self-Supervised  Pre-training  for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
+# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+import warnings
+from typing import Dict, Optional, Tuple
+import torch
+from torch import Tensor, nn
+from torch.nn import Parameter
+import torch.nn.functional as F
+
+
+class TransposeLast(nn.Module):
+    def __init__(self, deconstruct_idx=None):
+        super().__init__()
+        self.deconstruct_idx = deconstruct_idx
+
+    def forward(self, x):
+        if self.deconstruct_idx is not None:
+            x = x[self.deconstruct_idx]
+        return x.transpose(-2, -1)
+
+
+class Fp32LayerNorm(nn.LayerNorm):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def forward(self, input):
+        output = F.layer_norm(
+            input.float(),
+            self.normalized_shape,
+            self.weight.float() if self.weight is not None else None,
+            self.bias.float() if self.bias is not None else None,
+            self.eps,
+        )
+        return output.type_as(input)
+
+
+class Fp32GroupNorm(nn.GroupNorm):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def forward(self, input):
+        output = F.group_norm(
+            input.float(),
+            self.num_groups,
+            self.weight.float() if self.weight is not None else None,
+            self.bias.float() if self.bias is not None else None,
+            self.eps,
+        )
+        return output.type_as(input)
+
+
+class GradMultiply(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, x, scale):
+        ctx.scale = scale
+        res = x.new(x)
+        return res
+
+    @staticmethod
+    def backward(ctx, grad):
+        return grad * ctx.scale, None
+
+
+class SamePad(nn.Module):
+    def __init__(self, kernel_size, causal=False):
+        super().__init__()
+        if causal:
+            self.remove = kernel_size - 1
+        else:
+            self.remove = 1 if kernel_size % 2 == 0 else 0
+
+    def forward(self, x):
+        if self.remove > 0:
+            x = x[:, :, : -self.remove]
+        return x
+
+
+class Swish(nn.Module):
+    """Swish function
+    """
+
+    def __init__(self):
+        """Construct an MultiHeadedAttention object."""
+        super(Swish, self).__init__()
+        self.act = torch.nn.Sigmoid()
+
+    def forward(self, x):
+        return x * self.act(x)
+
+
+class GLU_Linear(nn.Module):
+    def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
+        super(GLU_Linear, self).__init__()
+
+        self.glu_type = glu_type
+        self.output_dim = output_dim
+
+        if glu_type == "sigmoid":
+            self.glu_act = torch.nn.Sigmoid()
+        elif glu_type == "swish":
+            self.glu_act = Swish()
+        elif glu_type == "relu":
+            self.glu_act = torch.nn.ReLU()
+        elif glu_type == "gelu":
+            self.glu_act = torch.nn.GELU()
+
+        if bias_in_glu:
+            self.linear = nn.Linear(input_dim, output_dim * 2, True)
+        else:
+            self.linear = nn.Linear(input_dim, output_dim * 2, False)
+
+    def forward(self, x):
+        # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
+        x = self.linear(x)
+
+        if self.glu_type == "bilinear":
+            x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
+        else:
+            x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
+
+        return x
+
+
+def gelu_accurate(x):
+    if not hasattr(gelu_accurate, "_a"):
+        gelu_accurate._a = math.sqrt(2 / math.pi)
+    return (
+        0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
+    )
+
+
+def gelu(x: torch.Tensor) -> torch.Tensor:
+    return torch.nn.functional.gelu(x.float()).type_as(x)
+
+
+def get_activation_fn(activation: str):
+    """Returns the activation function corresponding to `activation`"""
+
+    if activation == "relu":
+        return F.relu
+    elif activation == "gelu":
+        return gelu
+    elif activation == "gelu_fast":
+        warnings.warn(
+            "--activation-fn=gelu_fast has been renamed to gelu_accurate"
+        )
+        return gelu_accurate
+    elif activation == "gelu_accurate":
+        return gelu_accurate
+    elif activation == "tanh":
+        return torch.tanh
+    elif activation == "linear":
+        return lambda x: x
+    elif activation == "glu":
+        return lambda x: x
+    else:
+        raise RuntimeError("--activation-fn {} not supported".format(activation))
+
+
+def init_bert_params(module):
+    """
+    Initialize the weights specific to the BERT Model.
+    This overrides the default initializations depending on the specified arguments.
+        1. If normal_init_linear_weights is set then weights of linear
+           layer will be initialized using the normal distribution and
+           bais will be set to the specified value.
+        2. If normal_init_embed_weights is set then weights of embedding
+           layer will be initialized using the normal distribution.
+        3. If normal_init_proj_weights is set then weights of
+           in_project_weight for MultiHeadAttention initialized using
+           the normal distribution (to be validated).
+    """
+
+    def normal_(data):
+        # with FSDP, module params will be on CUDA, so we cast them back to CPU
+        # so that the RNG is consistent with and without FSDP
+        data.copy_(
+            data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
+        )
+
+    if isinstance(module, nn.Linear):
+        normal_(module.weight.data)
+        if module.bias is not None:
+            module.bias.data.zero_()
+    if isinstance(module, nn.Embedding):
+        normal_(module.weight.data)
+        if module.padding_idx is not None:
+            module.weight.data[module.padding_idx].zero_()
+    if isinstance(module, MultiheadAttention):
+        normal_(module.q_proj.weight.data)
+        normal_(module.k_proj.weight.data)
+        normal_(module.v_proj.weight.data)
+
+
+def quant_noise(module, p, block_size):
+    """
+    Wraps modules and applies quantization noise to the weights for
+    subsequent quantization with Iterative Product Quantization as
+    described in "Training with Quantization Noise for Extreme Model Compression"
+
+    Args:
+        - module: nn.Module
+        - p: amount of Quantization Noise
+        - block_size: size of the blocks for subsequent quantization with iPQ
+
+    Remarks:
+        - Module weights must have the right sizes wrt the block size
+        - Only Linear, Embedding and Conv2d modules are supported for the moment
+        - For more detail on how to quantize by blocks with convolutional weights,
+          see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
+        - We implement the simplest form of noise here as stated in the paper
+          which consists in randomly dropping blocks
+    """
+
+    # if no quantization noise, don't register hook
+    if p <= 0:
+        return module
+
+    # supported modules
+    assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
+
+    # test whether module.weight has the right sizes wrt block_size
+    is_conv = module.weight.ndim == 4
+
+    # 2D matrix
+    if not is_conv:
+        assert (
+            module.weight.size(1) % block_size == 0
+        ), "Input features must be a multiple of block sizes"
+
+    # 4D matrix
+    else:
+        # 1x1 convolutions
+        if module.kernel_size == (1, 1):
+            assert (
+                module.in_channels % block_size == 0
+            ), "Input channels must be a multiple of block sizes"
+        # regular convolutions
+        else:
+            k = module.kernel_size[0] * module.kernel_size[1]
+            assert k % block_size == 0, "Kernel size must be a multiple of block size"
+
+    def _forward_pre_hook(mod, input):
+        # no noise for evaluation
+        if mod.training:
+            if not is_conv:
+                # gather weight and sizes
+                weight = mod.weight
+                in_features = weight.size(1)
+                out_features = weight.size(0)
+
+                # split weight matrix into blocks and randomly drop selected blocks
+                mask = torch.zeros(
+                    in_features // block_size * out_features, device=weight.device
+                )
+                mask.bernoulli_(p)
+                mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
+
+            else:
+                # gather weight and sizes
+                weight = mod.weight
+                in_channels = mod.in_channels
+                out_channels = mod.out_channels
+
+                # split weight matrix into blocks and randomly drop selected blocks
+                if mod.kernel_size == (1, 1):
+                    mask = torch.zeros(
+                        int(in_channels // block_size * out_channels),
+                        device=weight.device,
+                    )
+                    mask.bernoulli_(p)
+                    mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
+                else:
+                    mask = torch.zeros(
+                        weight.size(0), weight.size(1), device=weight.device
+                    )
+                    mask.bernoulli_(p)
+                    mask = (
+                        mask.unsqueeze(2)
+                        .unsqueeze(3)
+                        .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+                    )
+
+            # scale weights and apply mask
+            mask = mask.to(
+                torch.bool
+            )  # x.bool() is not currently supported in TorchScript
+            s = 1 / (1 - p)
+            mod.weight.data = s * weight.masked_fill(mask, 0)
+
+    module.register_forward_pre_hook(_forward_pre_hook)
+    return module
+
+
+class MultiheadAttention(nn.Module):
+    """Multi-headed attention.
+
+    See "Attention Is All You Need" for more details.
+    """
+
+    def __init__(
+            self,
+            embed_dim,
+            num_heads,
+            kdim=None,
+            vdim=None,
+            dropout=0.0,
+            bias=True,
+            add_bias_kv=False,
+            add_zero_attn=False,
+            self_attention=False,
+            encoder_decoder_attention=False,
+            q_noise=0.0,
+            qn_block_size=8,
+            has_relative_attention_bias=False,
+            num_buckets=32,
+            max_distance=128,
+            gru_rel_pos=False,
+            rescale_init=False,
+    ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.kdim = kdim if kdim is not None else embed_dim
+        self.vdim = vdim if vdim is not None else embed_dim
+        self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+        self.num_heads = num_heads
+        self.dropout_module = nn.Dropout(dropout)
+
+        self.has_relative_attention_bias = has_relative_attention_bias
+        self.num_buckets = num_buckets
+        self.max_distance = max_distance
+        if self.has_relative_attention_bias:
+            self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
+
+        self.head_dim = embed_dim // num_heads
+        self.q_head_dim = self.head_dim
+        self.k_head_dim = self.head_dim
+        assert (
+                self.head_dim * num_heads == self.embed_dim
+        ), "embed_dim must be divisible by num_heads"
+        self.scaling = self.head_dim ** -0.5
+
+        self.self_attention = self_attention
+        self.encoder_decoder_attention = encoder_decoder_attention
+
+        assert not self.self_attention or self.qkv_same_dim, (
+            "Self-attention requires query, key and " "value to be of the same size"
+        )
+
+        k_bias = True
+        if rescale_init:
+            k_bias = False
+
+        k_embed_dim = embed_dim
+        q_embed_dim = embed_dim
+
+        self.k_proj = quant_noise(
+            nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
+        )
+        self.v_proj = quant_noise(
+            nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
+        )
+        self.q_proj = quant_noise(
+            nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
+        )
+
+        self.out_proj = quant_noise(
+            nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
+        )
+
+        if add_bias_kv:
+            self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+            self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+        else:
+            self.bias_k = self.bias_v = None
+
+        self.add_zero_attn = add_zero_attn
+
+        self.gru_rel_pos = gru_rel_pos
+        if self.gru_rel_pos:
+            self.grep_linear = nn.Linear(self.q_head_dim, 8)
+            self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        if self.qkv_same_dim:
+            # Empirically observed the convergence to be much better with
+            # the scaled initialization
+            nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
+            nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
+            nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
+        else:
+            nn.init.xavier_uniform_(self.k_proj.weight)
+            nn.init.xavier_uniform_(self.v_proj.weight)
+            nn.init.xavier_uniform_(self.q_proj.weight)
+
+        nn.init.xavier_uniform_(self.out_proj.weight)
+        if self.out_proj.bias is not None:
+            nn.init.constant_(self.out_proj.bias, 0.0)
+        if self.bias_k is not None:
+            nn.init.xavier_normal_(self.bias_k)
+        if self.bias_v is not None:
+            nn.init.xavier_normal_(self.bias_v)
+        if self.has_relative_attention_bias:
+            nn.init.xavier_normal_(self.relative_attention_bias.weight)
+
+    def _relative_positions_bucket(self, relative_positions, bidirectional=True):
+        num_buckets = self.num_buckets
+        max_distance = self.max_distance
+        relative_buckets = 0
+
+        if bidirectional:
+            num_buckets = num_buckets // 2
+            relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
+            relative_positions = torch.abs(relative_positions)
+        else:
+            relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
+
+        max_exact = num_buckets // 2
+        is_small = relative_positions < max_exact
+
+        relative_postion_if_large = max_exact + (
+                torch.log(relative_positions.float() / max_exact)
+                / math.log(max_distance / max_exact)
+                * (num_buckets - max_exact)
+        ).to(torch.long)
+        relative_postion_if_large = torch.min(
+            relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
+        )
+
+        relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
+        return relative_buckets
+
+    def compute_bias(self, query_length, key_length):
+        context_position = torch.arange(query_length, dtype=torch.long)[:, None]
+        memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
+        relative_position = memory_position - context_position
+        relative_position_bucket = self._relative_positions_bucket(
+            relative_position,
+            bidirectional=True
+        )
+        relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
+        values = self.relative_attention_bias(relative_position_bucket)
+        values = values.permute([2, 0, 1])
+        return values
+
+    def forward(
+            self,
+            query,
+            key: Optional[Tensor],
+            value: Optional[Tensor],
+            key_padding_mask: Optional[Tensor] = None,
+            incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+            need_weights: bool = True,
+            static_kv: bool = False,
+            attn_mask: Optional[Tensor] = None,
+            before_softmax: bool = False,
+            need_head_weights: bool = False,
+            position_bias: Optional[Tensor] = None
+    ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+        """Input shape: Time x Batch x Channel
+
+        Args:
+            key_padding_mask (ByteTensor, optional): mask to exclude
+                keys that are pads, of shape `(batch, src_len)`, where
+                padding elements are indicated by 1s.
+            need_weights (bool, optional): return the attention weights,
+                averaged over heads (default: False).
+            attn_mask (ByteTensor, optional): typically used to
+                implement causal attention, where the mask prevents the
+                attention from looking forward in time (default: None).
+            before_softmax (bool, optional): return the raw attention
+                weights and values before the attention softmax.
+            need_head_weights (bool, optional): return the attention
+                weights for each head. Implies *need_weights*. Default:
+                return the average attention weights over all heads.
+        """
+        if need_head_weights:
+            need_weights = True
+
+        is_tpu = query.device.type == "xla"
+
+        tgt_len, bsz, embed_dim = query.size()
+        src_len = tgt_len
+        assert embed_dim == self.embed_dim
+        assert list(query.size()) == [tgt_len, bsz, embed_dim]
+        if key is not None:
+            src_len, key_bsz, _ = key.size()
+            if not torch.jit.is_scripting():
+                assert key_bsz == bsz
+                assert value is not None
+                assert src_len, bsz == value.shape[:2]
+
+        if self.has_relative_attention_bias and position_bias is None:
+            position_bias = self.compute_bias(tgt_len, src_len)
+            position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
+
+        if (
+                not is_tpu  # don't use PyTorch version on TPUs
+                and incremental_state is None
+                and not static_kv
+                # A workaround for quantization to work. Otherwise JIT compilation
+                # treats bias in linear module as method.
+                and not torch.jit.is_scripting()
+                and self.q_head_dim == self.head_dim
+        ):
+            assert key is not None and value is not None
+            assert attn_mask is None
+
+            attn_mask_rel_pos = None
+            if position_bias is not None:
+                attn_mask_rel_pos = position_bias
+                if self.gru_rel_pos:
+                    query_layer = query.transpose(0, 1)
+                    new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
+                    query_layer = query_layer.view(*new_x_shape)
+                    query_layer = query_layer.permute(0, 2, 1, 3)
+                    _B, _H, _L, __ = query_layer.size()
+
+                    gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
+                        _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
+                    gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
+                    attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
+
+                attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
+            k_proj_bias = self.k_proj.bias
+            if k_proj_bias is None:
+                k_proj_bias = torch.zeros_like(self.q_proj.bias)
+
+            x, attn = F.multi_head_attention_forward(
+                query,
+                key,
+                value,
+                self.embed_dim,
+                self.num_heads,
+                torch.empty([0]),
+                torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
+                self.bias_k,
+                self.bias_v,
+                self.add_zero_attn,
+                self.dropout_module.p,
+                self.out_proj.weight,
+                self.out_proj.bias,
+                self.training,
+                # self.training or self.dropout_module.apply_during_inference,
+                key_padding_mask,
+                need_weights,
+                attn_mask_rel_pos,
+                use_separate_proj_weight=True,
+                q_proj_weight=self.q_proj.weight,
+                k_proj_weight=self.k_proj.weight,
+                v_proj_weight=self.v_proj.weight,
+            )
+            return x, attn, position_bias
+
+        if incremental_state is not None:
+            saved_state = self._get_input_buffer(incremental_state)
+            if saved_state is not None and "prev_key" in saved_state:
+                # previous time steps are cached - no need to recompute
+                # key and value if they are static
+                if static_kv:
+                    assert self.encoder_decoder_attention and not self.self_attention
+                    key = value = None
+        else:
+            saved_state = None
+
+        if self.self_attention:
+            q = self.q_proj(query)
+            k = self.k_proj(query)
+            v = self.v_proj(query)
+        elif self.encoder_decoder_attention:
+            # encoder-decoder attention
+            q = self.q_proj(query)
+            if key is None:
+                assert value is None
+                k = v = None
+            else:
+                k = self.k_proj(key)
+                v = self.v_proj(key)
+
+        else:
+            assert key is not None and value is not None
+            q = self.q_proj(query)
+            k = self.k_proj(key)
+            v = self.v_proj(value)
+        q *= self.scaling
+
+        if self.bias_k is not None:
+            assert self.bias_v is not None
+            k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+            v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+            if attn_mask is not None:
+                attn_mask = torch.cat(
+                    [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
+                )
+            if key_padding_mask is not None:
+                key_padding_mask = torch.cat(
+                    [
+                        key_padding_mask,
+                        key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
+                    ],
+                    dim=1,
+                )
+
+        q = (
+            q.contiguous()
+                .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
+                .transpose(0, 1)
+        )
+        if k is not None:
+            k = (
+                k.contiguous()
+                    .view(-1, bsz * self.num_heads, self.k_head_dim)
+                    .transpose(0, 1)
+            )
+        if v is not None:
+            v = (
+                v.contiguous()
+                    .view(-1, bsz * self.num_heads, self.head_dim)
+                    .transpose(0, 1)
+            )
+
+        if saved_state is not None:
+            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+            if "prev_key" in saved_state:
+                _prev_key = saved_state["prev_key"]
+                assert _prev_key is not None
+                prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
+                if static_kv:
+                    k = prev_key
+                else:
+                    assert k is not None
+                    k = torch.cat([prev_key, k], dim=1)
+                src_len = k.size(1)
+            if "prev_value" in saved_state:
+                _prev_value = saved_state["prev_value"]
+                assert _prev_value is not None
+                prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
+                if static_kv:
+                    v = prev_value
+                else:
+                    assert v is not None
+                    v = torch.cat([prev_value, v], dim=1)
+            prev_key_padding_mask: Optional[Tensor] = None
+            if "prev_key_padding_mask" in saved_state:
+                prev_key_padding_mask = saved_state["prev_key_padding_mask"]
+            assert k is not None and v is not None
+            key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
+                key_padding_mask=key_padding_mask,
+                prev_key_padding_mask=prev_key_padding_mask,
+                batch_size=bsz,
+                src_len=k.size(1),
+                static_kv=static_kv,
+            )
+
+            saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
+            saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
+            saved_state["prev_key_padding_mask"] = key_padding_mask
+            # In this branch incremental_state is never None
+            assert incremental_state is not None
+            incremental_state = self._set_input_buffer(incremental_state, saved_state)
+        assert k is not None
+        assert k.size(1) == src_len
+
+        # This is part of a workaround to get around fork/join parallelism
+        # not supporting Optional types.
+        if key_padding_mask is not None and key_padding_mask.dim() == 0:
+            key_padding_mask = None
+
+        if key_padding_mask is not None:
+            assert key_padding_mask.size(0) == bsz
+            assert key_padding_mask.size(1) == src_len
+
+        if self.add_zero_attn:
+            assert v is not None
+            src_len += 1
+            k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+            v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+            if attn_mask is not None:
+                attn_mask = torch.cat(
+                    [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
+                )
+            if key_padding_mask is not None:
+                key_padding_mask = torch.cat(
+                    [
+                        key_padding_mask,
+                        torch.zeros(key_padding_mask.size(0), 1).type_as(
+                            key_padding_mask
+                        ),
+                    ],
+                    dim=1,
+                )
+
+        attn_weights = torch.bmm(q, k.transpose(1, 2))
+        attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+        assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+        if attn_mask is not None:
+            attn_mask = attn_mask.unsqueeze(0)
+            attn_weights += attn_mask
+
+        if key_padding_mask is not None:
+            # don't attend to padding symbols
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            if not is_tpu:
+                attn_weights = attn_weights.masked_fill(
+                    key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+                    float("-inf"),
+                )
+            else:
+                attn_weights = attn_weights.transpose(0, 2)
+                attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
+                attn_weights = attn_weights.transpose(0, 2)
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if before_softmax:
+            return attn_weights, v, position_bias
+
+        if position_bias is not None:
+            if self.gru_rel_pos == 1:
+                query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
+                _B, _H, _L, __ = query_layer.size()
+                gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
+                    _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
+                gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
+                position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
+
+            position_bias = position_bias.view(attn_weights.size())
+
+            attn_weights = attn_weights + position_bias
+
+        attn_weights_float = F.softmax(
+            attn_weights, dim=-1
+        )
+        attn_weights = attn_weights_float.type_as(attn_weights)
+        attn_probs = self.dropout_module(attn_weights)
+
+        assert v is not None
+        attn = torch.bmm(attn_probs, v)
+        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+        attn = self.out_proj(attn)
+        attn_weights: Optional[Tensor] = None
+        if need_weights:
+            attn_weights = attn_weights_float.view(
+                bsz, self.num_heads, tgt_len, src_len
+            ).transpose(1, 0)
+            if not need_head_weights:
+                # average attention weights over heads
+                attn_weights = attn_weights.mean(dim=0)
+
+        return attn, attn_weights, position_bias
+
+    @staticmethod
+    def _append_prev_key_padding_mask(
+            key_padding_mask: Optional[Tensor],
+            prev_key_padding_mask: Optional[Tensor],
+            batch_size: int,
+            src_len: int,
+            static_kv: bool,
+    ) -> Optional[Tensor]:
+        # saved key padding masks have shape (bsz, seq_len)
+        if prev_key_padding_mask is not None and static_kv:
+            new_key_padding_mask = prev_key_padding_mask
+        elif prev_key_padding_mask is not None and key_padding_mask is not None:
+            new_key_padding_mask = torch.cat(
+                [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
+            )
+        # During incremental decoding, as the padding token enters and
+        # leaves the frame, there will be a time when prev or current
+        # is None
+        elif prev_key_padding_mask is not None:
+            if src_len > prev_key_padding_mask.size(1):
+                filler = torch.zeros(
+                    (batch_size, src_len - prev_key_padding_mask.size(1)),
+                    device=prev_key_padding_mask.device,
+                )
+                new_key_padding_mask = torch.cat(
+                    [prev_key_padding_mask.float(), filler.float()], dim=1
+                )
+            else:
+                new_key_padding_mask = prev_key_padding_mask.float()
+        elif key_padding_mask is not None:
+            if src_len > key_padding_mask.size(1):
+                filler = torch.zeros(
+                    (batch_size, src_len - key_padding_mask.size(1)),
+                    device=key_padding_mask.device,
+                )
+                new_key_padding_mask = torch.cat(
+                    [filler.float(), key_padding_mask.float()], dim=1
+                )
+            else:
+                new_key_padding_mask = key_padding_mask.float()
+        else:
+            new_key_padding_mask = prev_key_padding_mask
+        return new_key_padding_mask
+
+    def _get_input_buffer(
+            self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
+    ) -> Dict[str, Optional[Tensor]]:
+        result = self.get_incremental_state(incremental_state, "attn_state")
+        if result is not None:
+            return result
+        else:
+            empty_result: Dict[str, Optional[Tensor]] = {}
+            return empty_result
+
+    def _set_input_buffer(
+            self,
+            incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
+            buffer: Dict[str, Optional[Tensor]],
+    ):
+        return self.set_incremental_state(incremental_state, "attn_state", buffer)
+
+    def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
+        return attn_weights
diff --git a/slam_llm/policies/__init__.py b/slam_llm/policies/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..309b87277c583123820458295626de24cd1d10fb
--- /dev/null
+++ b/slam_llm/policies/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from slam_llm.policies.mixed_precision import *
+from slam_llm.policies.wrapping import *
+from slam_llm.policies.activation_checkpointing_functions import apply_fsdp_checkpointing
+from slam_llm.policies.anyprecision_optimizer import AnyPrecisionAdamW
diff --git a/slam_llm/policies/activation_checkpointing_functions.py b/slam_llm/policies/activation_checkpointing_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fa98fe68c3a0581d134f4038d2cdde9d0985003
--- /dev/null
+++ b/slam_llm/policies/activation_checkpointing_functions.py
@@ -0,0 +1,29 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from functools import partial
+
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+    checkpoint_wrapper,
+    CheckpointImpl,
+    apply_activation_checkpointing,
+)
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+
+non_reentrant_wrapper = partial(
+    checkpoint_wrapper,
+    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
+)
+
+check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
+
+
+def apply_fsdp_checkpointing(model):
+    """apply activation checkpointing to model
+    returns None as model is updated directly
+    """
+    print(f"--> applying fsdp activation checkpointing...")
+
+    apply_activation_checkpointing(
+        model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
+    )
diff --git a/slam_llm/policies/anyprecision_optimizer.py b/slam_llm/policies/anyprecision_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0172229656e611616ff8c05d07ecfabdac05ce3c
--- /dev/null
+++ b/slam_llm/policies/anyprecision_optimizer.py
@@ -0,0 +1,179 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# AnyPrecisionAdamW: a flexible precision AdamW optimizer
+# with optional Kahan summation for high precision weight updates.
+# Allows direct control over momentum, variance and auxiliary compensation
+# buffer dtypes.
+# Optional Kahan summation is used to offset precision reduction for
+# the weight updates. This allows full training in BFloat16 (equal or
+# better than FP32 results in many cases) due to high precision weight upates.
+
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class AnyPrecisionAdamW(Optimizer):
+    def __init__(
+        self,
+        params,
+        lr=1e-3,
+        betas=(0.9, 0.999),
+        eps=1e-8,
+        weight_decay=0.0,
+        use_kahan_summation=False,
+        momentum_dtype=torch.bfloat16,
+        variance_dtype=torch.bfloat16,
+        compensation_buffer_dtype=torch.bfloat16,
+    ):
+        """
+        Args:
+                params (iterable): iterable of parameters to optimize or dicts defining
+                    parameter groups
+                lr (float, optional): learning rate (default: 1e-3)
+                betas (Tuple[float, float], optional): coefficients used for computing
+                    running averages of gradient and its square (default: (0.9, 0.999))
+                eps (float, optional): term added to the denominator to improve
+                    numerical stability (default: 1e-8)
+                weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+
+                # Any Precision specific
+                use_kahan_summation = creates auxiliary buffer to ensure high precision
+                model param updates (default: False)
+                momentum_dtype = dtype for momentum  (default: BFloat32)
+                variance_dtype = dtype for uncentered variance (default: BFloat16)
+                compensation_buffer_dtype  = dtype for Kahan summation
+                                             buffer (default: BFloat16)
+
+                # Usage
+                This optimizer implements optimizer states, and Kahan summation
+                for high precision updates, all in user controlled dtypes.
+                Defaults are variance in BF16, Momentum in FP32.
+                This can be run in FSDP mixed precision, amp, or full precision,
+                depending on what training pipeline you wish to work with.
+
+                Setting to use_kahan_summation = False, and changing momentum and
+                variance dtypes to FP32, reverts this to a standard AdamW optimizer.
+
+        """
+        defaults = dict(
+            lr=lr,
+            betas=betas,
+            eps=eps,
+            weight_decay=weight_decay,
+            use_kahan_summation=use_kahan_summation,
+            momentum_dtype=momentum_dtype,
+            variance_dtype=variance_dtype,
+            compensation_buffer_dtype=compensation_buffer_dtype,
+        )
+
+        super().__init__(params, defaults)
+
+    @torch.no_grad()
+    def step(self, closure=None):
+        """Performs a single optimization step.
+        Args:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+
+        if closure is not None:
+            with torch.enable_grad():
+                # to fix linter, we do not keep the returned loss for use atm.
+                closure()
+
+        for group in self.param_groups:
+
+            beta1, beta2 = group["betas"]
+            lr = group["lr"]
+            weight_decay = group["weight_decay"]
+            eps = group["eps"]
+            use_kahan_summation = group["use_kahan_summation"]
+
+            momentum_dtype = group["momentum_dtype"]
+            variance_dtype = group["variance_dtype"]
+            compensation_buffer_dtype = group["compensation_buffer_dtype"]
+
+            for p in group["params"]:
+                if p.grad is None:
+                    continue
+
+                if p.grad.is_sparse:
+                    raise RuntimeError(
+                        "AnyPrecisionAdamW does not support sparse gradients"
+                    )
+
+                state = self.state[p]
+
+                # State initialization
+                if len(state) == 0:
+
+                    state["step"] = torch.tensor(0.0)
+
+                    # momentum - EMA of gradient values
+                    state["exp_avg"] = torch.zeros_like(
+                        p,
+                        dtype=momentum_dtype,
+                    )
+
+                    # variance uncentered - EMA of squared gradient values
+                    state["exp_avg_sq"] = torch.zeros_like(
+                        p,
+                        dtype=variance_dtype,
+                    )
+
+                    # optional Kahan summation - accumulated error tracker
+                    if use_kahan_summation:
+                        state["compensation"] = torch.zeros_like(
+                            p,
+                            dtype=compensation_buffer_dtype,
+                        )
+
+                # main processing -------------------------
+
+                # update the steps for each param group update
+                state["step"] += 1
+                step = state["step"]
+
+                exp_avg = state["exp_avg"]
+                exp_avg_sq = state["exp_avg_sq"]
+
+                grad = p.grad
+
+                # weight decay, AdamW style
+                if weight_decay:
+                    p.data.mul_(1 - lr * weight_decay)
+
+                # update momentum
+                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+
+                # update uncentered variance
+                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+
+                # adjust using bias1
+                bias_correction1 = 1 - beta1**step
+
+                step_size = lr / bias_correction1
+
+                # adjust using bias2
+                denom_correction = (1 - beta2**step) ** 0.5  # avoids math import
+
+                centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(
+                    eps, alpha=1
+                )
+
+                # lr update to compensation
+                if use_kahan_summation:
+                    compensation = state["compensation"]
+
+                    compensation.addcdiv_(exp_avg, centered_variance, value=-step_size)
+
+                    # update weights with compensation (Kahan summation)
+                    # save error back to compensation for next iteration
+                    temp_buffer = p.detach().clone()
+                    p.data.add_(compensation)
+                    compensation.add_(temp_buffer.sub_(p.data))
+
+                else:
+                    # usual AdamW updates
+                    p.data.addcdiv_(exp_avg, centered_variance, value=-step_size)
\ No newline at end of file
diff --git a/slam_llm/policies/mixed_precision.py b/slam_llm/policies/mixed_precision.py
new file mode 100644
index 0000000000000000000000000000000000000000..5175b2ad9690f67571c1ac7ac27d3c2495d3f914
--- /dev/null
+++ b/slam_llm/policies/mixed_precision.py
@@ -0,0 +1,38 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import torch
+
+from torch.distributed.fsdp import (
+    MixedPrecision,
+)
+
+# requires grad scaler in main loop
+fpSixteen = MixedPrecision(
+    param_dtype=torch.float16,
+    # Gradient communication precision.
+    reduce_dtype=torch.float16,
+    # Buffer precision.
+    buffer_dtype=torch.float16,
+)
+
+bfSixteen = MixedPrecision(
+    param_dtype=torch.bfloat16,
+    # Gradient communication precision.
+    reduce_dtype=torch.bfloat16,
+    # Buffer precision.
+    buffer_dtype=torch.bfloat16,
+    cast_forward_inputs=True,
+)
+
+bfSixteen_mixed = MixedPrecision(
+    param_dtype=torch.float32,
+    reduce_dtype=torch.bfloat16,
+    buffer_dtype=torch.bfloat16,
+)
+
+fp32_policy = MixedPrecision(
+    param_dtype=torch.float32,
+    reduce_dtype=torch.float32,
+    buffer_dtype=torch.float32,
+)
diff --git a/slam_llm/policies/wrapping.py b/slam_llm/policies/wrapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..2703ebc684936ac79f87b01fee44af2cec7e4c1b
--- /dev/null
+++ b/slam_llm/policies/wrapping.py
@@ -0,0 +1,33 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import functools
+
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from torch.distributed.fsdp.wrap import (
+    transformer_auto_wrap_policy,
+    size_based_auto_wrap_policy,
+)
+
+
+def get_size_policy(min_params=1e8):
+    num_wrap_policy = functools.partial(
+        size_based_auto_wrap_policy, min_num_params=min_params
+    )
+    return num_wrap_policy
+
+
+def get_llama_wrapper():
+    """we register our main layer class and use the fsdp transformer wrapping policy
+    ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers
+    """
+    # ====   use new transformer wrapper
+
+    llama_auto_wrap_policy = functools.partial(
+        transformer_auto_wrap_policy,
+        transformer_layer_cls={
+            LlamaDecoderLayer,
+        },
+    )
+
+    return llama_auto_wrap_policy
diff --git a/slam_llm/utils/__init__.py b/slam_llm/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a46f9a8ca7dc578bbe8ac822a8f98e5880351a70
--- /dev/null
+++ b/slam_llm/utils/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from slam_llm.utils.memory_utils import MemoryTrace
+from slam_llm.utils.dataset_utils import *
+from slam_llm.utils.fsdp_utils import fsdp_auto_wrap_policy
+from slam_llm.utils.train_utils import *
\ No newline at end of file
diff --git a/slam_llm/utils/checkpoint_handler.py b/slam_llm/utils/checkpoint_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..50449e232a7481519125feb9354b27b0b6933538
--- /dev/null
+++ b/slam_llm/utils/checkpoint_handler.py
@@ -0,0 +1,333 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+import os
+from pathlib import Path
+from datetime import datetime
+import torch
+import time
+from collections import OrderedDict
+
+from torch.distributed.fsdp import (
+    FullyShardedDataParallel as FSDP,
+    StateDictType,
+    FullStateDictConfig,  # general model non-sharded, non-flattened params
+    LocalStateDictConfig,  # flattened params, usable only by FSDP
+    # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
+)
+
+from torch.distributed.checkpoint import (
+    FileSystemReader,
+    FileSystemWriter,
+    save_state_dict,
+    load_state_dict,
+)
+from torch.distributed.checkpoint.default_planner import (
+    DefaultSavePlanner,
+    DefaultLoadPlanner,
+)
+
+
+from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
+import torch.distributed.checkpoint as dist_cp
+import torch.distributed as dist
+
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+def get_date_of_run():
+    """create date and time for file save uniqueness
+    example: 2022-05-07-08:31:12_PM'
+    """
+    date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p")
+    logger.info(f"--> current date and time of run = {date_of_run}")
+    return date_of_run
+
+
+# create singleton saving policies to avoid making over and over
+fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
+
+
+def load_model_sharded(model, rank, cfg):
+    # torch.manual_seed(103)
+    folder_name = (
+        cfg.dist_checkpoint_root_folder
+        + "/"
+        + cfg.dist_checkpoint_folder
+        + "-"
+        + cfg.model_name
+    )
+
+    load_dir = Path.cwd() / folder_name
+
+    if not load_dir.exists():
+        if rank == 0:
+            logger.info(f"No sharded_state_dict checkpoint directory found...skipping")
+        return
+    if rank == 0:
+         logger.info(f"loading model from model path: {load_dir} ")
+    reader = FileSystemReader(load_dir)
+
+    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
+        checkpoint = {"model": model.state_dict()}
+        if rank == 0:
+            ck = checkpoint.keys()
+            logger.info(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
+      
+        dist_cp.load_state_dict(
+            state_dict=checkpoint,
+            storage_reader=reader,
+        )
+        if rank == 0:
+            logger.info(f"checkpoint after load_state_dict()")
+            ck = checkpoint.keys()
+            logger.info(f" checkpoint key len = {len(ck)} and \n keys =  {ck}")
+        model.load_state_dict(checkpoint["model"])
+    if rank == 0:
+        logger.info(f"Sharded state checkpoint loaded from {load_dir}")
+
+
+def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
+    """save model and optimizer via sharded_state_dict to save_dir"""
+    
+    folder_name = (
+        cfg.dist_checkpoint_root_folder
+        + "/"
+        + cfg.dist_checkpoint_folder
+        + "-"
+        + cfg.model_name
+    )
+
+    save_dir = Path.cwd() / folder_name
+    if rank == 0:
+        logger.info(f"Saving model to {save_dir}")
+
+    distributed_writer = dist_cp.FileSystemWriter(
+        save_dir,
+    )
+    t0 = time.perf_counter()
+
+    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
+        
+        state_dict = {"model": model.state_dict()}
+        if optim is not None:
+            state_dict["optim"] = FSDP.optim_state_dict(model, optim)
+
+        dist_cp.save_state_dict(
+            state_dict=state_dict,
+            storage_writer=distributed_writer,
+            planner=DefaultSavePlanner(),
+            
+        )
+    dist.barrier()
+    t1 = time.perf_counter()
+    if rank == 0:
+        logger.info(f"Sharded state checkpoint saved to {save_dir}")
+        logger.info(
+            f"Checkpoint Time = {t1-t0:.4f}\n"
+        )
+def save_model_checkpoint(
+    model,
+    optimizer,
+    rank,
+    cfg,
+    epoch=1,
+):
+    """saving model via rank0 cpu streaming and full_state_dict"""
+
+    with FSDP.state_dict_type(
+        model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
+    ):
+        cpu_state = model.state_dict()
+
+        logger.info(f"saving process: rank {rank}  done w model state_dict\n")
+   
+
+    if rank == 0:
+        logger.info(f"--> saving model ...")
+        # create save path
+        folder_name = (
+        cfg.dist_checkpoint_root_folder
+        + "/"
+        + cfg.dist_checkpoint_folder
+        + "-"
+        + cfg.model_name
+        )
+        save_dir = Path.cwd() / folder_name
+        save_dir.mkdir(parents=True, exist_ok=True)
+        save_name = cfg.model_name + "-" + str(epoch) + ".pt"
+        save_full_path = str(save_dir) + "/" + save_name
+
+        # save model
+        torch.save(cpu_state, save_full_path)
+
+        
+        logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")
+
+def save_model_checkpoint_deepspeed(model, cfg, checkpoint_name="checkpoint"):
+    logger.info(f"--> saving model ...")
+    save_dir = os.path.join(cfg.output_dir, checkpoint_name)
+    os.makedirs(save_dir, exist_ok=True)
+    # save_full_path = os.path.join(save_dir, "model.pt")
+    save_full_path = save_dir
+    model.save_checkpoint(save_dir=save_full_path, exclude_frozen_parameters=True)
+    logger.info(f"encoder saved at {save_full_path}")
+      
+def save_model_checkpoint_peft(model, optimizer, rank, cfg, checkpoint_name="checkpoint", save_trainable_only=True):
+    logger.info(f"--> saving model ...")
+    save_dir = os.path.join(cfg.output_dir, checkpoint_name)
+    os.makedirs(save_dir, exist_ok=True)
+    save_full_path = os.path.join(save_dir, "model.pt")
+    if cfg.enable_ddp:
+        model = model.module
+    cpu_state = model.state_dict()
+    if save_trainable_only:
+        state_dict = OrderedDict()
+        for name, para in model.named_parameters():
+            if para.requires_grad:
+                state_dict[name] = cpu_state[name]
+    else:
+        state_dict = cpu_state
+    torch.save(state_dict, save_full_path)
+    logger.info(f"encoder saved at {save_full_path}")
+    
+def save_model_checkpoint_peft_full_shard(model, optimizer, rank, cfg, epoch=0):
+    with FSDP.state_dict_type(
+        model, StateDictType.FULL_STATE_DICT, fullstate_save_policy
+    ):
+        cpu_state = model.state_dict()
+        logger.info(f"saving process: rank {rank}  done w model state_dict\n")
+
+    if rank == 0:
+        logger.info(f"--> saving model ...")
+        save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1))
+        os.makedirs(save_dir, exist_ok=True)
+
+        if not cfg.freeze_llm:
+            llm_dict = {}
+            for key in cpu_state.keys():
+                if key.startswith("llm."):
+                    llm_dict[key] = cpu_state[key]
+            model.llm.save_pretrained(save_directory=save_dir, state_dict=llm_dict)
+            logger.info(f"llm saved at {save_dir}")
+
+        save_full_path = os.path.join(save_dir, "model.pt")
+        encoder_dict = {}
+        if not cfg.freeze_encoder:
+            for key in cpu_state.keys():
+                if key.startswith("encoder."):
+                    encoder_dict[key] = cpu_state[key]
+        for key in cpu_state.keys():
+            if key.startswith("encoder_projector."):
+                encoder_dict[key] = cpu_state[key]
+        torch.save(encoder_dict, save_full_path)
+        logger.info(f"encoder saved at {save_full_path}")
+
+        logger.info(f"model checkpoint saved for epoch {epoch+1}\n")
+        
+    dist.barrier()
+
+def load_model_checkpoint(model, rank, cfg):
+    """load local checkpoint to rank0 cpu
+    must be called * before * passing to FSDP"""
+
+    if rank != 0:
+        return
+
+    # where is the checkpoint at...
+    full_state_dict_model_path = (
+        Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename
+    )
+    # is it present...
+    if not full_state_dict_model_path.is_file():
+        logger.info(
+            f"model checkpoint {full_state_dict_model_path} not present. Returning..."
+        )
+        return
+
+
+    model_checkpoint = torch.load(full_state_dict_model_path)
+    # integrate into loaded model
+    model.load_state_dict(model_checkpoint)
+
+    
+    logger.info(f"model checkpoint loaded to rank0 cpu")
+
+
+def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
+    """save optimizer state via full state dict"""
+
+   
+    logger.info(f"--> optim state call on rank {rank}\n")
+
+    # pull all sharded optimizer states to rank0 cpu...
+
+    optim_state = FSDP.full_optim_state_dict(model, optimizer)
+
+    
+    logger.info(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")
+
+    if rank == 0:
+        folder_name = (
+        cfg.dist_checkpoint_root_folder
+        + "/"
+        + cfg.dist_checkpoint_folder
+        + "-"
+        + cfg.model_name
+        )
+        save_dir = Path.cwd() / folder_name
+        save_dir.mkdir(parents=True, exist_ok=True)
+
+        opt_save_name = (
+            "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
+        )
+        opt_save_full_path = save_dir / opt_save_name
+
+        logger.info(f"--> saving optimizer state...")
+
+        torch.save(optim_state, opt_save_full_path)
+
+        logger.info(f"--> saved {opt_save_full_path} to disk")
+
+
+def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
+    """load an fsdp optimizer full_state checkpoint using scatter method
+    this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
+    """
+
+
+    if not optimizer_checkpoint_path.is_file():
+        logger.info(
+            f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
+        )
+        return
+
+    full_osd = None
+
+    if rank == 0:
+        full_osd = torch.load(optimizer_checkpoint_path)
+
+    # called from all ranks, though only rank0 has a valid param for full_osd
+    sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)
+
+    logger.info(f"optimizer shard loaded on rank {rank}")
+
+def load_sharded_model_single_gpu(model,model_path):
+    
+    reader = FileSystemReader(model_path)
+    
+    state_dict = {
+        "model": model.state_dict()
+    }
+    
+    dist_cp.load_state_dict(
+                state_dict=state_dict,
+                storage_reader= FileSystemReader(model_path),
+                no_dist=True,
+            )
+    
+    model.load_state_dict(state_dict["model"])
+    
+    logger.info(f"Sharded state checkpoint loaded from {model_path}")
+    return model
diff --git a/slam_llm/utils/compute_aac_metrics.py b/slam_llm/utils/compute_aac_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..6390fd038cefe10c3d0af698d120d29037deb109
--- /dev/null
+++ b/slam_llm/utils/compute_aac_metrics.py
@@ -0,0 +1,38 @@
+from aac_metrics import evaluate
+import sys
+
+def compute_wer(ref_file,
+                hyp_file):
+    pred_captions = []
+    gt_captions = []
+
+    with open(hyp_file, 'r') as hyp_reader:
+        for line in hyp_reader:
+            key = line.strip().split()[0]
+            value = line.strip().split()[1:]
+            pred_captions.append(value)
+    with open(ref_file, 'r') as ref_reader:
+        for line in ref_reader:
+            key = line.strip().split()[0]
+            value = line.strip().split()[1:]
+            gt_captions.append(value)
+
+    print('Used lines:', len(pred_captions))
+    candidates: list[str] = pred_captions
+    mult_references: list[list[str]] = [[gt] for gt in gt_captions]
+
+    corpus_scores, _ = evaluate(candidates, mult_references)
+    print(corpus_scores)
+    # dict containing the score of each metric: "bleu_1", "bleu_2", "bleu_3", "bleu_4", "rouge_l", "meteor", "cider_d", "spice", "spider"
+    # {"bleu_1": tensor(0.4278), "bleu_2": ..., ...}
+
+
+if __name__ == '__main__':
+    if len(sys.argv) != 3:
+        print("usage : python compute_aac_metrics.py test.ref test.hyp")
+        sys.exit(0)
+
+    ref_file = sys.argv[1]
+    hyp_file = sys.argv[2]
+    cer_detail_file = sys.argv[3]
+    compute_wer(ref_file, hyp_file, cer_detail_file)
\ No newline at end of file
diff --git a/slam_llm/utils/compute_ppl.py b/slam_llm/utils/compute_ppl.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dd0f159548609f291842def106d64c59c50daff
--- /dev/null
+++ b/slam_llm/utils/compute_ppl.py
@@ -0,0 +1,45 @@
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import torch
+from tqdm import tqdm
+import json
+
+# MODEL_PATH = "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
+MODEL_PATH = "/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf"
+# MODEL_PATH = "/nfs/maziyang.mzy/models/Llama-2-7b-chat-hf"
+tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
+model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
+
+device = 'cuda:7'
+model.to(device)
+model.eval()
+
+corpus_path = "/nfs/maziyang.mzy/data/librispeech/librispeech_test_clean_filtered.jsonl"
+corpus = []
+with open(corpus_path, encoding='utf-8') as fin:
+    for line in fin:
+        data_dict = json.loads(line.strip())
+        corpus.append(data_dict.get("target", None))
+
+cumulative_log_likelihood = 0
+total_tokens = 0
+
+for sentence in tqdm(corpus):
+    inputs = tokenizer(sentence.strip().lower(), return_tensors="pt").to(device)
+
+    input_ids = inputs["input_ids"]
+    # input_len = input_ids.size(1)
+    input_len = len(sentence.split(" "))
+    total_tokens += input_len
+
+    with torch.no_grad():
+        outputs = model(**inputs, labels=input_ids)
+        log_likelihood = outputs.loss * input_len
+        cumulative_log_likelihood += log_likelihood.item()
+
+
+average_log_likelihood = cumulative_log_likelihood / total_tokens
+corpus_ppl = torch.exp(torch.tensor(average_log_likelihood)).item()
+
+print(f"Model: {MODEL_PATH}")
+print(f"Corpus: {corpus_path}")
+print(f"Corpus Perplexity: {corpus_ppl}")
diff --git a/slam_llm/utils/compute_utils.py b/slam_llm/utils/compute_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c85d63ae10708aa3f12a397522ec2aa8de0f3cf
--- /dev/null
+++ b/slam_llm/utils/compute_utils.py
@@ -0,0 +1,3 @@
+
+def calculate_output_length_1d(L_in, kernel_size, stride, padding=0):
+    return (L_in + 2 * padding - kernel_size) // stride + 1
\ No newline at end of file
diff --git a/slam_llm/utils/compute_wer.py b/slam_llm/utils/compute_wer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4759b3f37e528444bce10123f6816672d22dd2d1
--- /dev/null
+++ b/slam_llm/utils/compute_wer.py
@@ -0,0 +1,197 @@
+import os
+import numpy as np
+import sys
+
+def build_diff(ref, hyp, path):
+    result = []
+    ref = list(map(lambda x: x.lower(), ref))
+    hyp = list(map(lambda x: x.lower(), hyp))
+    r_record = -1
+    h_record = -1
+    # path = path+[(len(ref), len(hyp))]
+
+    for rpointer, hpointer in path:
+        if rpointer!=r_record+1 or hpointer!=h_record+1:
+            r_buffer = ' '.join(ref[r_record+1:rpointer])
+            r_buffer = r_buffer if len(r_buffer)>0 else "*"
+            h_buffer = ' '.join(hyp[h_record+1:hpointer])
+            h_buffer = h_buffer if len(h_buffer)>0 else "*"
+            result.append(f"({r_buffer}->{h_buffer})")
+
+        result.append(ref[rpointer])
+        r_record = rpointer
+        h_record = hpointer
+
+    if r_record<len(ref)-1 or h_record<len(hyp)-1:
+        r_buffer = ' '.join(ref[r_record+1:])
+        r_buffer = r_buffer if len(r_buffer)>0 else "*"
+        h_buffer = ' '.join(hyp[h_record+1:])
+        h_buffer = h_buffer if len(h_buffer)>0 else "*"
+        result.append(f"({r_buffer}->{h_buffer})")
+    return ' '.join(result)
+
+
+
+
+
+
+def compute_wer(ref_file,
+                hyp_file,
+                cer_detail_file):
+    rst = {
+        'Wrd': 0,
+        'Corr': 0,
+        'Ins': 0,
+        'Del': 0,
+        'Sub': 0,
+        'Snt': 0,
+        'Err': 0.0,
+        'S.Err': 0.0,
+        'wrong_words': 0,
+        'wrong_sentences': 0
+    }
+
+    hyp_dict = {}
+    ref_dict = {}
+    with open(hyp_file, 'r') as hyp_reader:
+        for line in hyp_reader:
+            key = line.strip().split()[0]
+            value = line.strip().split()[1:]
+            hyp_dict[key] = value
+    with open(ref_file, 'r') as ref_reader:
+        for line in ref_reader:
+            key = line.strip().split()[0]
+            value = line.strip().split()[1:]
+            ref_dict[key] = value
+
+    cer_detail_writer = open(cer_detail_file, 'w')
+    for hyp_key in hyp_dict:
+        if hyp_key in ref_dict:
+            out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
+            # if out_item['ins'] > 10 or out_item['del'] > 10:
+            #     print(hyp_key + print_cer_detail(out_item))
+            #     print("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))))
+            #     print("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))))
+            rst['Wrd'] += out_item['nwords']
+            rst['Corr'] += out_item['cor']
+            rst['wrong_words'] += out_item['wrong']
+            rst['Ins'] += out_item['ins']
+            rst['Del'] += out_item['del']
+            rst['Sub'] += out_item['sub']
+            rst['Snt'] += 1
+            if out_item['wrong'] > 0:
+                rst['wrong_sentences'] += 1
+            cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
+            cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
+            cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
+            cer_detail_writer.write("diff:" + '\t' + build_diff(ref_dict[hyp_key], hyp_dict[hyp_key], out_item['path']) + '\n')
+
+    if rst['Wrd'] > 0:
+        rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
+    if rst['Snt'] > 0:
+        rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
+
+    cer_detail_writer.write('\n')
+    cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
+                            ", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
+    cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
+    cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
+
+     
+def compute_wer_by_line(hyp,
+                        ref):
+    hyp = list(map(lambda x: x.lower(), hyp))
+    ref = list(map(lambda x: x.lower(), ref))
+
+    len_hyp = len(hyp)
+    len_ref = len(ref)
+
+    cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
+
+    ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
+
+    for i in range(len_hyp + 1):
+        cost_matrix[i][0] = i
+    for j in range(len_ref + 1):
+        cost_matrix[0][j] = j
+
+    for i in range(1, len_hyp + 1):
+        for j in range(1, len_ref + 1):
+            if hyp[i - 1] == ref[j - 1]:
+                cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
+            else:
+                substitution = cost_matrix[i - 1][j - 1] + 1
+                insertion = cost_matrix[i - 1][j] + 1
+                deletion = cost_matrix[i][j - 1] + 1
+
+                compare_val = [substitution, insertion, deletion]
+
+                min_val = min(compare_val)
+                operation_idx = compare_val.index(min_val) + 1
+                cost_matrix[i][j] = min_val
+                ops_matrix[i][j] = operation_idx
+
+    match_idx = []
+    i = len_hyp
+    j = len_ref
+    rst = {
+        'nwords': len_ref,
+        'cor': 0,
+        'wrong': 0,
+        'ins': 0,
+        'del': 0,
+        'sub': 0,
+        'path': []
+    }
+    while i >= 0 or j >= 0:
+        i_idx = max(0, i)
+        j_idx = max(0, j)
+
+        if ops_matrix[i_idx][j_idx] == 0:  # correct
+            if i - 1 >= 0 and j - 1 >= 0:
+                match_idx.append((j - 1, i - 1))
+                rst['cor'] += 1
+
+            i -= 1
+            j -= 1
+
+        elif ops_matrix[i_idx][j_idx] == 2:  # insert
+            i -= 1
+            rst['ins'] += 1
+
+        elif ops_matrix[i_idx][j_idx] == 3:  # delete
+            j -= 1
+            rst['del'] += 1
+
+        elif ops_matrix[i_idx][j_idx] == 1:  # substitute
+            i -= 1
+            j -= 1
+            rst['sub'] += 1
+
+        if i < 0 and j >= 0:
+            rst['del'] += 1
+        elif j < 0 and i >= 0:
+            rst['ins'] += 1
+
+    match_idx.reverse()
+    wrong_cnt = cost_matrix[len_hyp][len_ref]
+    rst['wrong'] = wrong_cnt
+    rst['path'] = match_idx
+
+    return rst
+
+def print_cer_detail(rst):
+    return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
+            + ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
+            + str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
+            + ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
+
+if __name__ == '__main__':
+    if len(sys.argv) != 4:
+        print("usage : python compute-wer.py test.ref test.hyp test.wer")
+        sys.exit(0)
+
+    ref_file = sys.argv[1]
+    hyp_file = sys.argv[2]
+    cer_detail_file = sys.argv[3]
+    compute_wer(ref_file, hyp_file, cer_detail_file)
diff --git a/slam_llm/utils/config_utils.py b/slam_llm/utils/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0057ce4d2d247ffdc8e123b21b60c4cbdd9d521
--- /dev/null
+++ b/slam_llm/utils/config_utils.py
@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import inspect
+# from dataclasses import asdict
+
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
+from peft import (
+    LoraConfig,
+    AdaptionPromptConfig,
+    PrefixTuningConfig,
+)
+from transformers import default_data_collator
+from transformers.data import DataCollatorForSeq2Seq
+
+# from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
+from slam_llm.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
+
+from omegaconf import OmegaConf
+
+import logging
+logger = logging.getLogger(__name__)
+
+# def update_config(config, **kwargs):
+#     if isinstance(config, (tuple, list)):
+#         for c in config:
+#             update_config(c, **kwargs)
+#     else:
+#         for k, v in kwargs.items():
+#             if hasattr(config, k):
+#                 setattr(config, k, v)
+#             elif "." in k:
+#                 # allow --some_config.some_param=True
+#                 config_name, param_name = k.split(".")
+#                 if type(config).__name__ == config_name:
+#                     if hasattr(config, param_name):
+#                         setattr(config, param_name, v)
+#                     else:
+#                         # In case of specialized config we can warm user
+#                         logger.warning(f"Warning: {config_name} does not accept parameter: {k}")
+#             elif isinstance(config, train_config):
+#                 logger.warning(f"Warning: unknown parameter {k}")
+
+
+def generate_peft_config(train_config):
+    # configs = (lora_config, llama_adapter_config, prefix_config)
+    # peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
+    peft_configs = {"lora": LoraConfig,
+                    "llama_adapter": AdaptionPromptConfig,
+                    "prefix": PrefixTuningConfig
+                    }
+    # names = tuple(c.__name__.rstrip("_config") for c in configs)
+    #
+    # assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
+    #
+    # config = configs[names.index(train_config.peft_method)]()
+    config = train_config.peft_config
+
+    params = OmegaConf.to_container(config, resolve=True)
+    # peft_config = peft_configs[names.index(train_config.peft_method)](**params)
+    params.pop("peft_method", None) #(FIX:MZY): remove peft_method from params to avoid error
+    peft_config = peft_configs[config.get("peft_method", "lora")](**params)
+
+    return peft_config
+
+
+def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
+        kwargs = {}
+        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+        if train_config.batching_strategy == "padding":
+            if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed:
+                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
+                    dataset,
+                    batch_size=batch_size,
+                    rank=dist.get_rank(),
+                    num_replicas=dist.get_world_size(),
+                    shuffle=mode=="train",
+                )
+            else:
+                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
+        elif train_config.batching_strategy == "packing":
+            if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed:
+                kwargs["sampler"] = DistributedSampler(
+                dataset,
+                rank=dist.get_rank(),
+                num_replicas=dist.get_world_size(),
+                shuffle=mode=="train",
+            )
+            kwargs["batch_size"] = batch_size
+            kwargs["drop_last"] = True
+            kwargs["collate_fn"] = default_data_collator
+        else:
+            # raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+            if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed:
+                kwargs["sampler"] = DistributedSampler(
+                dataset,
+                rank=dist.get_rank(),
+                num_replicas=dist.get_world_size(),
+                shuffle=mode=="train",
+            )
+            kwargs["batch_size"] = batch_size
+            kwargs["drop_last"] = True
+            kwargs["collate_fn"] = dataset.collator
+            logger.info(f"Using batching strategy: {train_config.batching_strategy}")
+
+        return kwargs
diff --git a/slam_llm/utils/custom_utils.py b/slam_llm/utils/custom_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c9ffbc397118ba786af646238dac1550f7ec5ca
--- /dev/null
+++ b/slam_llm/utils/custom_utils.py
@@ -0,0 +1,298 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import cv2
+import torch
+import random
+import numpy as np
+from typing import Dict, List, Optional, Tuple
+
+def load_video(path):
+    for i in range(3):
+        try:
+            cap = cv2.VideoCapture(path)
+            frames = []
+            while True:
+                ret, frame = cap.read()
+                if ret:
+                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+                    frames.append(frame)
+                else:
+                    break
+            frames = np.stack(frames)
+            return frames
+        except Exception:
+            print(f"failed loading {path} ({i} / 3)")
+            if i == 2:
+                raise ValueError(f"Unable to load {path}")
+
+
+class Compose(object):
+    """Compose several preprocess together.
+    Args:
+        preprocess (list of ``Preprocess`` objects): list of preprocess to compose.
+    """
+
+    def __init__(self, preprocess):
+        self.preprocess = preprocess
+
+    def __call__(self, sample):
+        for t in self.preprocess:
+            sample = t(sample)
+        return sample
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '('
+        for t in self.preprocess:
+            format_string += '\n'
+            format_string += '    {0}'.format(t)
+        format_string += '\n)'
+        return format_string
+
+
+class Normalize(object):
+    """Normalize a ndarray image with mean and standard deviation.
+    """
+
+    def __init__(self, mean, std):
+        self.mean = mean
+        self.std = std
+
+    def __call__(self, frames):
+        """
+        Args:
+            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+        Returns:
+            Tensor: Normalized Tensor image.
+        """
+        frames = (frames - self.mean) / self.std
+        return frames
+
+    def __repr__(self):
+        return self.__class__.__name__+'(mean={0}, std={1})'.format(self.mean, self.std)
+
+class CenterCrop(object):
+    """Crop the given image at the center
+    """
+    def __init__(self, size):
+        self.size = size
+
+    def __call__(self, frames):
+        """
+        Args:
+            img (numpy.ndarray): Images to be cropped.
+        Returns:
+            numpy.ndarray: Cropped image.
+        """
+        t, h, w = frames.shape
+        th, tw = self.size
+        delta_w = int(round((w - tw))/2.)
+        delta_h = int(round((h - th))/2.)
+        frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw]
+        return frames
+
+
+class RandomCrop(object):
+    """Crop the given image at the center
+    """
+
+    def __init__(self, size):
+        self.size = size
+
+    def __call__(self, frames):
+        """
+        Args:
+            img (numpy.ndarray): Images to be cropped.
+        Returns:
+            numpy.ndarray: Cropped image.
+        """
+        t, h, w = frames.shape
+        th, tw = self.size
+        delta_w = random.randint(0, w-tw)
+        delta_h = random.randint(0, h-th)
+        frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw]
+        return frames
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(size={0})'.format(self.size)
+
+class HorizontalFlip(object):
+    """Flip image horizontally.
+    """
+
+    def __init__(self, flip_ratio):
+        self.flip_ratio = flip_ratio
+
+    def __call__(self, frames):
+        """
+        Args:
+            img (numpy.ndarray): Images to be flipped with a probability flip_ratio
+        Returns:
+            numpy.ndarray: Cropped image.
+        """
+        t, h, w = frames.shape
+        if random.random() < self.flip_ratio:
+            for index in range(t):
+                frames[index] = cv2.flip(frames[index], 1)
+        return frames
+
+def compute_mask_indices(
+    shape: Tuple[int, int],
+    padding_mask: Optional[torch.Tensor],
+    mask_prob: float,
+    mask_length: int,
+    mask_type: str = "static",
+    mask_other: float = 0.0,
+    min_masks: int = 0,
+    no_overlap: bool = False,
+    min_space: int = 0,
+) -> np.ndarray:
+    """
+    Computes random mask spans for a given shape
+    Args:
+        shape: the the shape for which to compute masks.
+            should be of size 2 where first element is batch size and 2nd is timesteps
+        padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+        mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+            number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+            however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+        mask_type: how to compute mask lengths
+            static = fixed size
+            uniform = sample from uniform distribution [mask_other, mask_length*2]
+            normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+            poisson = sample from possion distribution with lambda = mask length
+        min_masks: minimum number of masked spans
+        no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+        min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+    """
+
+    bsz, all_sz = shape
+    mask = np.full((bsz, all_sz), False)
+
+    all_num_mask = int(
+        # add a random number for probabilistic rounding
+        mask_prob * all_sz / float(mask_length)
+        + np.random.rand()
+    )
+
+    all_num_mask = max(min_masks, all_num_mask)
+
+    mask_idcs = []
+    for i in range(bsz):
+        if padding_mask is not None:
+            sz = all_sz - padding_mask[i].long().sum().item()
+            num_mask = int(
+                # add a random number for probabilistic rounding
+                mask_prob * sz / float(mask_length)
+                + np.random.rand()
+            )
+            num_mask = max(min_masks, num_mask)
+        else:
+            sz = all_sz
+            num_mask = all_num_mask
+
+        if mask_type == "static":
+            lengths = np.full(num_mask, mask_length)
+        elif mask_type == "uniform":
+            lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+        elif mask_type == "normal":
+            lengths = np.random.normal(mask_length, mask_other, size=num_mask)
+            lengths = [max(1, int(round(x))) for x in lengths]
+        elif mask_type == "poisson":
+            lengths = np.random.poisson(mask_length, size=num_mask)
+            lengths = [int(round(x)) for x in lengths]
+        else:
+            raise Exception("unknown mask selection " + mask_type)
+
+        if sum(lengths) == 0:
+            lengths[0] = min(mask_length, sz - 1)
+
+        if no_overlap:
+            mask_idc = []
+
+            def arrange(s, e, length, keep_length):
+                span_start = np.random.randint(s, e - length)
+                mask_idc.extend(span_start + i for i in range(length))
+
+                new_parts = []
+                if span_start - s - min_space >= keep_length:
+                    new_parts.append((s, span_start - min_space + 1))
+                if e - span_start - keep_length - min_space > keep_length:
+                    new_parts.append((span_start + length + min_space, e))
+                return new_parts
+
+            parts = [(0, sz)]
+            min_length = min(lengths)
+            for length in sorted(lengths, reverse=True):
+                lens = np.fromiter(
+                    (e - s if e - s >= length + min_space else 0 for s, e in parts),
+                    np.int,
+                )
+                l_sum = np.sum(lens)
+                if l_sum == 0:
+                    break
+                probs = lens / np.sum(lens)
+                c = np.random.choice(len(parts), p=probs)
+                s, e = parts.pop(c)
+                parts.extend(arrange(s, e, length, min_length))
+            mask_idc = np.asarray(mask_idc)
+        else:
+            min_len = min(lengths)
+            if sz - min_len <= num_mask:
+                min_len = sz - num_mask - 1
+
+            mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+
+            mask_idc = np.asarray(
+                [
+                    mask_idc[j] + offset
+                    for j in range(len(mask_idc))
+                    for offset in range(lengths[j])
+                ]
+            )
+
+        mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+
+    min_len = min([len(m) for m in mask_idcs])
+    batch_indexes, starts, ends = [], [], []
+    for i, mask_idc in enumerate(mask_idcs):
+        if len(mask_idc) > min_len:
+            mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+        mask[i, mask_idc] = True
+        vals, run_starts, run_lengths = find_runs(mask[i])
+        start_indices, lengths = run_starts[vals == True], run_lengths[vals == True]
+        starts.append(start_indices)
+        ends.append(start_indices+lengths)
+        batch_indexes.append(np.zeros([len(start_indices)])+i)
+    return mask, np.concatenate(starts).astype(np.int64), np.concatenate(ends).astype(np.int64), np.concatenate(batch_indexes).astype(np.int64)
+
+def find_runs(x):
+    """Find runs of consecutive items in an array."""
+
+    # ensure array
+    x = np.asanyarray(x)
+    if x.ndim != 1:
+        raise ValueError('only 1D array supported')
+    n = x.shape[0]
+
+    # handle empty array
+    if n == 0:
+        return np.array([]), np.array([]), np.array([])
+
+    else:
+        # find run starts
+        loc_run_start = np.empty(n, dtype=bool)
+        loc_run_start[0] = True
+        np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
+        run_starts = np.nonzero(loc_run_start)[0]
+
+        # find run values
+        run_values = x[loc_run_start]
+
+        # find run lengths
+        run_lengths = np.diff(np.append(run_starts, n))
+
+        return run_values, run_starts, run_lengths
diff --git a/slam_llm/utils/dataset_utils.py b/slam_llm/utils/dataset_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9726a7aad5e7b388381719b666b8993c6b62401d
--- /dev/null
+++ b/slam_llm/utils/dataset_utils.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import importlib
+from pathlib import Path
+
+import torch
+
+import logging
+logger = logging.getLogger(__name__)
+
+
+def load_module_from_py_file(py_file: str) -> object:
+    """
+    This method loads a module from a py file which is not in the Python path
+    """
+    module_name = Path(py_file).name
+    loader = importlib.machinery.SourceFileLoader(module_name, py_file)
+    spec = importlib.util.spec_from_loader(module_name, loader)
+    module = importlib.util.module_from_spec(spec)
+
+    loader.exec_module(module)
+
+    return module
+
+
+def get_custom_dataset(dataset_config, tokenizer, split: str):
+    if ":" in dataset_config.file:
+        module_path, func_name = dataset_config.file.split(":")
+    else:
+        module_path, func_name = dataset_config.file, "get_custom_dataset"
+
+    if not module_path.endswith(".py"):
+        raise ValueError(f"Dataset file {module_path} is not a .py file.")
+
+    module_path = Path(module_path)
+    if not module_path.is_file():
+        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+
+    module = load_module_from_py_file(module_path.as_posix())
+    try:
+        return getattr(module, func_name)(dataset_config, tokenizer, split)
+    except AttributeError as e:
+        logger.info(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
+        raise e
+
+
+def get_preprocessed_dataset(
+    tokenizer, dataset_config, split: str = "train"
+) -> torch.utils.data.Dataset:
+
+    def get_split():
+        return (
+            dataset_config.train_split
+            if split == "train"
+            else dataset_config.test_split
+        )
+
+    return get_custom_dataset(
+        dataset_config,
+        tokenizer,
+        get_split(),
+    )
diff --git a/slam_llm/utils/deepspeed_utils.py b/slam_llm/utils/deepspeed_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..66c49e593a65253ba8aca52b702ef4b4e54676d7
--- /dev/null
+++ b/slam_llm/utils/deepspeed_utils.py
@@ -0,0 +1,601 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import os
+import time
+import yaml
+from contextlib import nullcontext
+from pathlib import Path
+from pkg_resources import packaging
+
+
+import functools
+import hydra
+import torch
+import torch.cuda.nccl as nccl
+import torch.distributed as dist
+from omegaconf import DictConfig
+from tqdm import tqdm
+from transformers import LlamaTokenizer
+from typing import Any, Callable, List, Optional
+from textwrap import dedent
+from hydra import version
+from hydra.main import _UNSPECIFIED_, _get_rerun_conf
+from hydra._internal.deprecation_warning import deprecation_warning
+from hydra._internal.utils import _run_hydra, get_args_parser
+from hydra.types import TaskFunction
+from hydra.core.utils import _flush_loggers, configure_log
+
+
+from slam_llm.utils.checkpoint_handler import (
+    save_model_checkpoint,
+    save_model_checkpoint_deepspeed,
+    save_model_and_optimizer_sharded,
+    save_optimizer_checkpoint,
+    save_model_checkpoint_peft,
+    save_model_checkpoint_peft_full_shard,
+)
+from slam_llm.policies import fpSixteen, bfSixteen_mixed, get_llama_wrapper
+from slam_llm.utils.memory_utils import MemoryTrace
+from slam_llm.utils.metric import compute_accuracy
+
+import wandb
+import logging
+
+logger = logging.getLogger(__name__)
+
+# For deepspeed --local_rank argument
+def deepspeed_main_wrapper(
+    config_path: Optional[str] = _UNSPECIFIED_,
+    config_name: Optional[str] = None,
+    version_base: Optional[str] = _UNSPECIFIED_,
+) -> Callable[[TaskFunction], Any]:
+    """
+    :param config_path: The config path, a directory where Hydra will search for
+                        config files. This path is added to Hydra's searchpath.
+                        Relative paths are interpreted relative to the declaring python
+                        file. Alternatively, you can use the prefix `pkg://` to specify
+                        a python package to add to the searchpath.
+                        If config_path is None no directory is added to the Config search path.
+    :param config_name: The name of the config (usually the file name without the .yaml extension)
+    """
+
+    version.setbase(version_base)
+
+    if config_path is _UNSPECIFIED_:
+        if version.base_at_least("1.2"):
+            config_path = None
+        elif version_base is _UNSPECIFIED_:
+            url = "https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path"
+            deprecation_warning(
+                message=dedent(
+                    f"""
+                config_path is not specified in @hydra.main().
+                See {url} for more information."""
+                ),
+                stacklevel=2,
+            )
+            config_path = "."
+        else:
+            config_path = "."
+
+    def main_decorator(task_function: TaskFunction) -> Callable[[], None]:
+        @functools.wraps(task_function)
+        def decorated_main(cfg_passthrough: Optional[DictConfig] = None) -> Any:
+            if cfg_passthrough is not None:
+                return task_function(cfg_passthrough)
+            else:
+                args_parser = get_args_parser()
+                args_parser.add_argument("--local_rank", type=int, default=-1)
+                args = args_parser.parse_args()
+                if args.experimental_rerun is not None:
+                    cfg = _get_rerun_conf(args.experimental_rerun, args.overrides)
+                    task_function(cfg)
+                    _flush_loggers()
+                else:
+                    # no return value from run_hydra() as it may sometime actually run the task_function
+                    # multiple times (--multirun)
+                    _run_hydra(
+                        args=args,
+                        args_parser=args_parser,
+                        task_function=task_function,
+                        config_path=config_path,
+                        config_name=config_name,
+                    )
+
+        return decorated_main
+
+    return main_decorator
+
+
+
+def set_tokenizer_params(tokenizer: LlamaTokenizer):
+    tokenizer.pad_token_id = 0
+    tokenizer.padding_side = "left"
+
+
+# Converting Bytes to Megabytes
+def byte2mb(x):
+    return int(x / 2**20)
+
+
+def train(
+    model,
+    train_dataloader,
+    eval_dataloader,
+    tokenizer,
+    gradient_accumulation_steps,
+    train_config,
+    log_config,
+    local_rank=None,
+    rank=None,
+):
+    """
+    Trains the model on the given dataloader
+
+    Args:
+        model: The model to be trained
+        train_dataloader: The dataloader containing the training data
+        optimizer: The optimizer used for training
+        lr_scheduler: The learning rate scheduler
+        gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
+        num_epochs: The number of epochs to train for
+        local_rank: The rank of the current node in a distributed setting
+        train_config: The training configuration
+        log_config: The logging configuration
+        eval_dataloader: The dataloader containing the eval data
+        tokenizer: tokenizer used in the eval for decoding the predicitons
+
+    Returns: results dictionary containing average training and validation perplexity and loss
+    """
+    # Create a gradient scaler for fp16
+    # if train_config.use_fp16 and train_config.enable_fsdp:
+    #     scaler = ShardedGradScaler()
+    # elif train_config.use_fp16 and not train_config.enable_fsdp:
+    #     scaler = torch.cuda.amp.GradScaler()
+    if train_config.enable_ddp:
+        world_size = int(os.environ["WORLD_SIZE"])
+    autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
+
+    train_prep = []
+    train_loss = []
+    train_acc = []
+    val_prep = []
+    val_loss = []
+    val_acc = []
+    epoch_times = []
+    checkpoint_times = []
+    results = {}
+    best_val_loss = float("inf")
+    best_val_acc = 0.0
+    for epoch in range(train_config.num_epochs):
+        epoch_start_time = time.perf_counter()
+        with MemoryTrace() as memtrace:  # track the memory usage
+            model.train()
+            total_loss = 0.0
+            total_acc = 0.0
+            total_length = len(train_dataloader) // gradient_accumulation_steps
+            pbar = tqdm(
+                colour="blue",
+                desc=f"Training Epoch: {epoch+1}",
+                total=total_length,
+                dynamic_ncols=True,
+            )
+            for step, batch in enumerate(train_dataloader):
+                for key in batch.keys():
+                    batch[key] = (
+                        batch[key].to(local_rank).half()
+                        if isinstance(batch[key], torch.Tensor)
+                        and batch[key].dtype == torch.float32
+                        else (
+                            batch[key].to(local_rank)
+                            if isinstance(batch[key], torch.Tensor)
+                            else batch[key]
+                        )
+                    )
+                # with autocast():
+                outputs, *rest = model(**batch)
+                acc = rest[0] if rest else -1
+                loss = outputs.loss
+
+                loss = loss / gradient_accumulation_steps
+                acc = acc / gradient_accumulation_steps
+
+                if log_config.use_wandb and step % log_config.log_interval == 0:
+                    if train_config.enable_fsdp or train_config.enable_ddp:
+                        if rank == 0:
+                            wandb.log(
+                                {
+                                    "train_inner/train_inner_loss": loss,
+                                    "train_inner/train_inner_accuracy": acc,
+                                },
+                                step=(epoch * total_length + step),
+                            )
+                    else:
+                        wandb.log(
+                            {
+                                "train_inner/train_inner_loss": loss,
+                                "train_inner/train_inner_accuracy": acc,
+                            },
+                            step=(epoch * total_length + step),
+                        )
+
+                total_loss += loss.detach().float()
+                total_acc += acc
+
+                # deepspeed should handle gradient accumulate
+                model.backward(loss)
+                model.step()
+
+                if (step + 1) % gradient_accumulation_steps == 0 or step == len(
+                    train_dataloader
+                ) - 1:
+                    pbar.update(1)
+
+                pbar.set_description(
+                    f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()}, acc: {acc})"
+                )
+
+                if (
+                    (epoch * total_length + step + 1) % train_config.validation_interval
+                    == 0
+                    and train_config.run_validation
+                ):
+                    eval_ppl, eval_epoch_loss, *rest = evaluation(
+                        model, train_config, eval_dataloader, local_rank, tokenizer
+                    )
+                    eval_epoch_acc = rest[0] if rest else -1
+                    checkpoint_start_time = time.perf_counter()
+
+                    if train_config.save_model and (eval_epoch_loss < best_val_loss):
+                        checkpoint_name = f"{train_config.model_name}_epoch_{str(epoch+1)}_step_{step+1}"
+                        save_model_checkpoint_deepspeed(
+                            model, train_config, checkpoint_name
+                        )
+
+                    checkpoint_end_time = time.perf_counter() - checkpoint_start_time
+                    checkpoint_times.append(checkpoint_end_time)
+                    if eval_epoch_loss < best_val_loss:
+                        best_val_loss = eval_epoch_loss
+                        if rank == 0:
+                            logger.info(
+                                f"best eval loss on epoch {epoch+1} is {best_val_loss}"
+                            )
+                    val_loss.append(eval_epoch_loss)
+                    val_prep.append(eval_ppl)
+                    if rest:
+                        if eval_epoch_acc > best_val_acc:
+                            best_val_acc = eval_epoch_acc
+                            if rank == 0:
+                                logger.info(
+                                    f"best eval acc on epoch {epoch+1} is {best_val_acc}"
+                                )
+                        val_acc.append(rest[0])
+                    else:
+                        val_acc.append(-1)
+
+                    if log_config.use_wandb:
+                        if rank == 0:
+                            wandb.log(
+                                {
+                                    "valid/val_epoch_loss": eval_epoch_loss,
+                                    "valid/val_perplexity": eval_ppl,
+                                    "valid/best_val_loss": best_val_loss,
+                                    "valid/val_accuracy": val_acc[-1],
+                                    "valid/val_best_accuracy": best_val_acc,
+                                }
+                            )
+
+                if train_config.run_test_during_validation:
+                    if rank == 0:
+                        logger.info("=====================================")
+                        logger.info(
+                            f"Test the file {train_config.run_test_during_validation_file} during validation:"
+                        )
+                        with autocast():
+                            logger.info(
+                                model.inference(
+                                    train_config.run_test_during_validation_file,
+                                    train_config.run_test_during_validation_prompt,
+                                )
+                            )
+                        logger.info("=====================================")
+                    dist.barrier()
+            pbar.close()
+
+        epoch_end_time = time.perf_counter() - epoch_start_time
+        epoch_times.append(epoch_end_time)
+        # Reducing total_loss across all devices if there's more than one CUDA device
+        if torch.cuda.device_count() > 1 and (
+            train_config.enable_fsdp or train_config.enable_ddp
+        ):
+            dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
+            dist.all_reduce(total_acc, op=dist.ReduceOp.SUM)
+        train_epoch_loss = total_loss / len(train_dataloader)
+        train_epoch_acc = total_acc / len(train_dataloader)
+        if train_config.enable_fsdp or train_config.enable_ddp:
+            train_epoch_loss = train_epoch_loss / world_size
+            train_epoch_acc = train_epoch_acc / world_size
+        train_perplexity = torch.exp(train_epoch_loss)
+
+        train_prep.append(train_perplexity)
+        train_loss.append(train_epoch_loss)
+        train_acc.append(train_epoch_acc)
+
+        if log_config.use_wandb:
+            if train_config.enable_fsdp or train_config.enable_ddp:
+                if rank == 0:
+                    wandb.log(
+                        {
+                            "train/train_perplexity": train_perplexity,
+                            "train/train_epoch_loss": train_epoch_loss,
+                            "train/train_epoch_acc": train_epoch_acc,
+                        }
+                    )
+            else:
+                wandb.log(
+                    {
+                        "train/train_perplexity": train_perplexity,
+                        "train/train_epoch_loss": train_epoch_loss,
+                        "train/train_epoch_acc": train_epoch_acc,
+                    }
+                )
+
+        if rank == 0:
+            logger.info(
+                f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
+            )
+
+        if rank == 0:
+            logger.info(f"Max CUDA memory allocated was {memtrace.peak} GB")
+            logger.info(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+            logger.info(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+            logger.info(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+            logger.info(
+                f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB"
+            )
+
+        # Update the learning rate as needed
+        # lr_scheduler.step()
+
+    avg_epoch_time = sum(epoch_times) / len(epoch_times)
+    avg_checkpoint_time = (
+        sum(checkpoint_times) / len(checkpoint_times)
+        if len(checkpoint_times) > 0
+        else 0
+    )
+    avg_train_prep = sum(train_prep) / len(train_prep)
+    avg_train_loss = sum(train_loss) / len(train_loss)
+    avg_train_acc = sum(train_acc) / len(train_acc)
+    if train_config.run_validation:
+        avg_eval_prep = sum(val_prep) / len(val_prep)
+        avg_eval_loss = sum(val_loss) / len(val_loss)
+        avg_eval_acc = sum(val_acc) / len(val_acc)
+
+    results["avg_train_prep"] = avg_train_prep
+    results["avg_train_loss"] = avg_train_loss
+    results["avg_train_acc"] = avg_train_acc
+    if train_config.run_validation:
+        results["avg_eval_prep"] = avg_eval_prep
+        results["avg_eval_loss"] = avg_eval_loss
+        results["avg_eval_acc"] = avg_eval_acc
+    results["avg_epoch_time"] = avg_epoch_time
+    results["avg_checkpoint_time"] = avg_checkpoint_time
+
+    # saving the training params including fsdp setting for reference.
+    # if (train_config.enable_fsdp or train_config.enable_ddp)and not train_config.use_peft:
+    #     save_train_params(train_config, fsdp_config, rank)
+
+    return results
+
+
+def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer):
+    """
+    Evaluates the model on the given dataloader
+
+    Args:
+        model: The model to evaluate
+        eval_dataloader: The dataloader containing the evaluation data
+        local_rank: The rank of the current node in a distributed setting
+        tokenizer: The tokenizer used to decode predictions
+
+    Returns: eval_ppl, eval_epoch_loss
+    """
+    world_size = int(os.environ["WORLD_SIZE"])
+    model.eval()
+    eval_preds = []
+    eval_loss = 0.0  # Initialize evaluation loss
+    eval_acc = 0.0
+    autocast = (
+        torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
+    )  # (Fix:MZY): fix expected scalar type mismatch in norm
+
+    with MemoryTrace() as memtrace:
+        total_length = len(eval_dataloader)
+        pbar = tqdm(
+            colour="green",
+            desc=f"Evaluating Epoch",
+            total=total_length,
+            dynamic_ncols=True,
+        )
+        for step, batch in enumerate(eval_dataloader):
+            for key in batch.keys():
+                batch[key] = (
+                    batch[key].to(local_rank).half()
+                    if isinstance(batch[key], torch.Tensor) and batch[key].dtype==torch.float32
+                    else (
+                        batch[key].to(local_rank) if isinstance(batch[key], torch.Tensor) else batch[key]
+                    )
+                )
+            # Ensure no gradients are computed for this scope to save memory
+            with torch.no_grad():
+                # Forward pass and compute loss
+                with autocast():  # (Fix:MZY): fix expected scalar type mismatch in norm
+                    outputs, *rest = model(**batch)
+                acc = rest[0] if rest else -1
+                loss = outputs.loss
+
+                eval_loss += loss.detach().float()
+                eval_acc += acc
+            # Decode predictions and add to evaluation predictions list
+            preds = torch.argmax(outputs.logits, -1)
+            eval_preds.extend(
+                tokenizer.batch_decode(
+                    preds.detach().cpu().numpy(), skip_special_tokens=True
+                )
+            )
+            pbar.update(1)
+            pbar.set_description(
+                f"step: {step+1}/{total_length}, eval_loss: {eval_loss/(step+1):.4f}, eval_acc: {eval_acc/(step+1):.4f}"
+            )
+
+    # If there's more than one CUDA device, reduce evaluation loss across all devices
+    if (
+        torch.cuda.device_count() > 1
+    ):
+        dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
+        dist.all_reduce(eval_acc, op=dist.ReduceOp.SUM)
+
+    # Compute average loss and perplexity
+    eval_epoch_loss = eval_loss / len(eval_dataloader)
+    eval_epoch_acc = eval_acc / len(eval_dataloader)
+    eval_epoch_loss = eval_epoch_loss / world_size
+    eval_epoch_acc = eval_epoch_acc / world_size
+    eval_ppl = torch.exp(eval_epoch_loss)
+
+    # Print evaluation metrics
+    if local_rank == 0:
+        logger.info(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}")
+
+    model.train()
+    return eval_ppl, eval_epoch_loss, eval_epoch_acc
+
+
+def freeze_transformer_layers(model, num_layer):
+    for i, layer in enumerate(model.model.layers):
+        if i < num_layer:
+            for param in layer.parameters():
+                param.requires_grad = False
+
+
+def check_frozen_layers_peft_model(model):
+    for i, layer in enumerate(model.base_model.model.model.layers):
+        for name, param in layer.named_parameters():
+            logger.info(
+                f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}"
+            )
+
+
+def setup():
+    """Initialize the process group for distributed training"""
+    dist.init_process_group("nccl")
+
+
+def setup_environ_flags(rank):
+    """Set environment flags for debugging purposes"""
+    os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
+    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
+    # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
+    # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
+    # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
+    if rank == 0:
+        logger.info(f"--> Running with torch dist debug set to detail")
+
+
+def cleanup():
+    """Clean up the process group after training"""
+    dist.destroy_process_group()
+
+
+def clear_gpu_cache(rank=None):
+    """Clear the GPU cache for all ranks"""
+    if rank == 0:
+        logger.info(f"Clearing GPU cache for all ranks")
+    torch.cuda.empty_cache()
+
+
+def get_parameter_dtypes(model):
+    """Get the data types of model parameters"""
+    parameter_dtypes = {}
+    for name, parameter in model.named_parameters():
+        parameter_dtypes[name] = parameter.dtype
+    return parameter_dtypes
+
+
+def print_model_size(model, config, rank: int = 0) -> None:
+    """
+    log model name, the number of trainable parameters and initialization time.
+
+    Args:
+        model: The PyTorch model.
+        model_name (str): Name of the model.
+        init_time_start (float): Initialization start time.
+        init_time_end (float): Initialization end time.
+        rank (int, optional): Current process's rank. Defaults to 0.
+    """
+    if rank == 0:
+        logger.info(f"--> Model {config.model_name}")
+        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+        logger.info(
+            f"--> {config.model_name} has {total_params / 1e6} Million params\n"
+        )
+
+
+def print_module_size(module, module_name, rank: int = 0) -> None:
+    """
+    Print module name, the number of trainable parameters and initialization time.
+
+    Args:
+        module: The PyTorch module.
+        module_name (str): Name of the model.
+        rank (int, optional): Current process's rank. Defaults to 0.
+    """
+    if rank == 0:
+        logger.info(f"--> Module {module_name}")
+        total_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
+        logger.info(f"--> {module_name} has {total_params / 1e6} Million params\n")
+
+
+def save_train_params(train_config, fsdp_config, rank):
+    """
+    This function saves the train_config and FSDP config into a train_params.yaml.
+    This will be used by converter script in the inference folder to fetch the HF model name or path.
+    It also would be hepful as a log for future references.
+    """
+    # Convert the train_config and fsdp_config objects to dictionaries,
+    # converting all values to strings to ensure they can be serialized into a YAML file
+    train_config_dict = {
+        k: str(v) for k, v in vars(train_config).items() if not k.startswith("__")
+    }
+    fsdp_config_dict = {
+        k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith("__")
+    }
+    # Merge the two dictionaries into one
+    train_params_dict = {**train_config_dict, **fsdp_config_dict}
+    # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
+    folder_name = (
+        train_config.dist_checkpoint_root_folder
+        + "/"
+        + train_config.dist_checkpoint_folder
+        + "-"
+        + train_config.model_name
+    )
+
+    save_dir = Path.cwd() / folder_name
+    # If the directory does not exist, create it
+    if not os.path.exists(save_dir):
+        os.makedirs(save_dir)
+    # Convert the dictionary to a YAML string
+    config_yaml = yaml.dump(train_params_dict, indent=4)
+    file_name = os.path.join(save_dir, "train_params.yaml")
+
+    # Check if there's a directory with the same name as the file
+    if os.path.isdir(file_name):
+        logger.info(f"Error: {file_name} is a directory, not a file.")
+    else:
+        # Write the YAML string to the file
+        with open(file_name, "w") as f:
+            f.write(config_yaml)
+        if rank == 0:
+            logger.info(f"training params are saved in {file_name}")
diff --git a/slam_llm/utils/fsdp_utils.py b/slam_llm/utils/fsdp_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e1f7d2ec544915f0288e674dd39752bc3bd0b55
--- /dev/null
+++ b/slam_llm/utils/fsdp_utils.py
@@ -0,0 +1,35 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+def fsdp_auto_wrap_policy(model, transformer_layer_name):
+    import functools
+
+    from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
+
+    from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
+
+    def lambda_policy_fn(module):
+        if (
+            len(list(module.named_children())) == 0
+            and getattr(module, "weight", None) is not None
+            and module.weight.requires_grad
+        ):
+            return True
+        return False
+
+    lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
+    transformer_wrap_policy = functools.partial(
+        transformer_auto_wrap_policy,
+        transformer_layer_cls=(
+            PrefixEncoder,
+            PromptEncoder,
+            PromptEmbedding,
+            transformer_layer_name,
+            # FullyShardedDataParallelPlugin.get_module_class_from_name(
+            #     model, transformer_layer_name
+            # ),
+        ),
+    )
+
+    auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
+    return auto_wrap_policy
\ No newline at end of file
diff --git a/slam_llm/utils/llm_tn.py b/slam_llm/utils/llm_tn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7967ab4be997036565934d4214df8f65166d7dd4
--- /dev/null
+++ b/slam_llm/utils/llm_tn.py
@@ -0,0 +1,34 @@
+import sys
+import os
+import re
+import string
+from whisper_normalizer.english import EnglishTextNormalizer
+
+english_normalizer = EnglishTextNormalizer()
+
+def reduce_repeated_words(text):
+    pattern ="."
+    for i in range(1, 50):
+        p = pattern * i
+        text = re.sub(f'({p})' + r'\1{4,200}', r'\1', text)
+    for i in range (50, 100):
+        p = pattern * i
+        text = re.sub(f'({p})' + r'\1{3,200}', r'\1', text)
+    return text
+
+def normalize_text(srcfn, dstfn):
+    with open(srcfn, "r") as f_read, open(dstfn, "w") as f_write:
+        all_lines = f_read.readlines()
+        for line in all_lines:
+            line = line.strip()
+            line_arr = line.split()
+            key = line_arr[0]
+            conts = " ".join(line_arr[1:])
+            normalized_conts = english_normalizer(conts)
+            reduced_conts = reduce_repeated_words(normalized_conts)
+            f_write.write("{0}\t{1}\n".format(key, reduced_conts))
+
+if __name__ == "__main__":
+    srcfn = sys.argv[1]
+    dstfn = sys.argv[2]
+    normalize_text(srcfn, dstfn)
\ No newline at end of file
diff --git a/slam_llm/utils/memory_utils.py b/slam_llm/utils/memory_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..19e3e7b16c59d5bc77ebe710e39bba8c708d0956
--- /dev/null
+++ b/slam_llm/utils/memory_utils.py
@@ -0,0 +1,62 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import gc
+import psutil
+import threading
+
+import torch
+
+def byte2gb(x):
+    return int(x / 2**30)
+# This context manager is used to track the peak memory usage of the process
+class MemoryTrace:
+    def __enter__(self):
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
+        self.begin = byte2gb(torch.cuda.memory_allocated())
+        self.process = psutil.Process()
+        self.cpu_begin = byte2gb(self.cpu_mem_used())
+        self.peak_monitoring = True
+        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
+        peak_monitor_thread.daemon = True
+        peak_monitor_thread.start()
+        return self
+
+    def cpu_mem_used(self):
+        """get resident set size memory for the current process"""
+        return self.process.memory_info().rss
+
+    def peak_monitor_func(self):
+        self.cpu_peak = -1
+
+        while True:
+            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
+
+            # can't sleep or will not catch the peak right (this comment is here on purpose)
+            # time.sleep(0.001) # 1msec
+
+            if not self.peak_monitoring:
+                break
+
+    def __exit__(self, *exc):
+        self.peak_monitoring = False
+
+        gc.collect()
+        torch.cuda.empty_cache()
+        self.end = byte2gb(torch.cuda.memory_allocated())
+        self.peak = byte2gb(torch.cuda.max_memory_allocated())
+        cuda_info = torch.cuda.memory_stats()
+        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+        self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+        self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
+        self.used = byte2gb(self.end - self.begin)
+        self.peaked = byte2gb(self.peak - self.begin)
+        self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
+
+        self.cpu_end = self.cpu_mem_used()
+        self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
+        self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)
+        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
\ No newline at end of file
diff --git a/slam_llm/utils/metric.py b/slam_llm/utils/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..44fd563e3aaabdf7236997fbaae2594ec9d20187
--- /dev/null
+++ b/slam_llm/utils/metric.py
@@ -0,0 +1,20 @@
+import torch
+
+def compute_accuracy(pad_outputs, pad_targets, ignore_label):
+    """Calculate accuracy.
+
+    Args:
+        pad_outputs (LongTensor): Prediction tensors (B, Lmax).
+        pad_targets (LongTensor): Target label tensors (B, Lmax).
+        ignore_label (int): Ignore label id.
+
+    Returns:
+        float: Accuracy value (0.0 - 1.0).
+
+    """
+    mask = pad_targets != ignore_label
+    numerator = torch.sum(
+        pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
+    )
+    denominator = torch.sum(mask)
+    return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type
\ No newline at end of file
diff --git a/slam_llm/utils/model_utils.py b/slam_llm/utils/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1f61a6ce081d30e61d46c2b8e2a1c26f1d1edfc
--- /dev/null
+++ b/slam_llm/utils/model_utils.py
@@ -0,0 +1,29 @@
+from slam_llm.utils.dataset_utils import load_module_from_py_file
+from pathlib import Path
+
+def get_custom_model_factory(model_config):
+    costom_model_path = model_config.get(
+        "file", None
+    )
+    if costom_model_path is None:
+        from slam_llm.models.slam_model import model_factory
+        return model_factory
+
+    if ":" in costom_model_path:
+        module_path, func_name = costom_model_path.split(":")
+    else:
+        module_path, func_name = costom_model_path, "model_factory"
+
+    if not module_path.endswith(".py"):
+        raise ValueError(f"Dataset file {module_path} is not a .py file.")
+    
+    module_path = Path(module_path)
+    if not module_path.is_file():
+        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+
+    module = load_module_from_py_file(module_path.as_posix())
+    try:
+        return getattr(module, func_name)
+    except AttributeError as e:
+        raise e
+    
diff --git a/slam_llm/utils/num2word.py b/slam_llm/utils/num2word.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab6b87812a18d2760cc28a64afd4c5ab1f483a29
--- /dev/null
+++ b/slam_llm/utils/num2word.py
@@ -0,0 +1,19 @@
+
+import sys
+from num2words import num2words
+
+file = sys.argv[1]
+out_file = sys.argv[2]
+
+with open(file) as f:
+    lines = f.readlines()
+
+with open(out_file, "w") as fw:
+    for line in lines:
+        key, content = line.strip().split(maxsplit=1)
+        new_content = ""
+        for ct in content.split():
+            if ct.isdigit():
+                ct = num2words(ct)
+            new_content += ct + " "
+        fw.write(key + " " + new_content + "\n")
\ No newline at end of file
diff --git a/slam_llm/utils/preprocess_text.py b/slam_llm/utils/preprocess_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..87a22f0b66ce11852faff02d280a7ff48ea3aff8
--- /dev/null
+++ b/slam_llm/utils/preprocess_text.py
@@ -0,0 +1,37 @@
+
+import sys
+import re
+import string
+
+in_f = sys.argv[1]
+out_f = sys.argv[2]
+
+
+with open(in_f, "r", encoding="utf-8") as f:
+  lines = f.readlines()
+
+with open(out_f, "w", encoding="utf-8") as f:
+  for line in lines:
+    outs = line.strip().split("\t", 1)
+    if len(outs) == 2:
+      idx, text = outs
+      text = re.sub("<|", "", text)
+      text = re.sub("|>", "", text)
+      text = re.sub("—", "", text)
+      # text = re.sub("<s>", "", text)
+      # text = re.sub("@@", "", text)
+      # text = re.sub("@", "", text)
+      # text = re.sub("<unk>", "", text)
+      # text = re.sub(" ", "", text)
+      # text = text.lower()
+      translator = str.maketrans('', '', string.punctuation.replace("'", ""))
+      result = text.translate(translator)
+      text = result.upper()
+    else:
+      idx = outs[0]
+      text = " "
+
+    # text = [x for x in text]
+    # text = " ".join(text)
+    out = "{} {}\n".format(idx, text)
+    f.write(out)
diff --git a/slam_llm/utils/train_utils.py b/slam_llm/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bf160cacec9bfa2fe78f9f705a41fd0dc9c563e
--- /dev/null
+++ b/slam_llm/utils/train_utils.py
@@ -0,0 +1,628 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import os
+import time
+import yaml
+from contextlib import nullcontext
+from pathlib import Path
+from pkg_resources import packaging
+
+
+import torch
+import torch.cuda.nccl as nccl
+import torch.distributed as dist
+from torch.distributed.fsdp import ShardingStrategy
+from torch.distributed.fsdp import StateDictType
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from tqdm import tqdm
+from transformers import LlamaTokenizer
+
+
+from slam_llm.utils.checkpoint_handler import (
+    save_model_checkpoint, 
+    save_model_and_optimizer_sharded, 
+    save_optimizer_checkpoint, 
+    save_model_checkpoint_peft,
+    save_model_checkpoint_peft_full_shard
+)
+from slam_llm.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
+from slam_llm.utils.memory_utils import MemoryTrace
+from slam_llm.utils.metric import compute_accuracy
+
+import wandb
+import logging
+logger = logging.getLogger(__name__)
+
+
+def set_tokenizer_params(tokenizer: LlamaTokenizer):
+    tokenizer.pad_token_id = 0
+    tokenizer.padding_side = "left"
+
+# Converting Bytes to Megabytes
+def byte2mb(x):
+    return int(x / 2**20)
+
+def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, log_config, fsdp_config=None, local_rank=None, rank=None):
+    """
+    Trains the model on the given dataloader
+
+    Args:
+        model: The model to be trained
+        train_dataloader: The dataloader containing the training data
+        optimizer: The optimizer used for training
+        lr_scheduler: The learning rate scheduler
+        gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
+        num_epochs: The number of epochs to train for
+        local_rank: The rank of the current node in a distributed setting
+        train_config: The training configuration
+        log_config: The logging configuration
+        eval_dataloader: The dataloader containing the eval data
+        tokenizer: tokenizer used in the eval for decoding the predicitons
+
+    Returns: results dictionary containing average training and validation perplexity and loss
+    """
+    # Create a gradient scaler for fp16
+    # if train_config.use_fp16 and train_config.enable_fsdp:
+    #     scaler = ShardedGradScaler()
+    # elif train_config.use_fp16 and not train_config.enable_fsdp:
+    #     scaler = torch.cuda.amp.GradScaler()
+    if train_config.use_fp16:
+        scaler = torch.cuda.amp.GradScaler()
+        if train_config.enable_fsdp:
+            scaler = ShardedGradScaler()
+    if train_config.enable_fsdp or train_config.enable_ddp:
+        world_size = int(os.environ["WORLD_SIZE"])
+    autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
+    
+    train_prep = []
+    train_loss = []
+    train_acc = []
+    val_prep = []
+    val_loss =[]
+    val_acc = []
+    epoch_times = []
+    checkpoint_times = []
+    results = {}
+    best_val_loss = float("inf")
+    best_val_acc = 0.0
+    for epoch in range(train_config.num_epochs):
+        epoch_start_time = time.perf_counter()
+        with MemoryTrace() as memtrace:  # track the memory usage
+            model.train()
+            total_loss = 0.0
+            total_acc = 0.0
+            total_length = len(train_dataloader)//gradient_accumulation_steps
+            pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
+            for step, batch in enumerate(train_dataloader):
+                for key in batch.keys():
+                    if train_config.enable_fsdp or train_config.enable_ddp:
+                        batch[key] = batch[key].to(local_rank) if isinstance(batch[key], torch.Tensor) else batch[key]
+                        if isinstance(batch[key], dict):
+                            for k2 in batch[key].keys():
+                                batch[key][k2] = batch[key][k2].to(local_rank) if isinstance(batch[key][k2], torch.Tensor) else batch[key][k2]
+                    else:
+                        batch[key] = batch[key].to('cuda:0') if isinstance(batch[key], torch.Tensor) else batch[key]
+                        if isinstance(batch[key], dict):
+                            for k2 in batch[key].keys():
+                                batch[key][k2] = batch[key][k2].to('cuda:0') if isinstance(batch[key][k2], torch.Tensor) else batch[key][k2]
+                with autocast():
+                    outputs, *rest = model(**batch)
+                acc = rest[0] if rest else -1
+                audio_acc = rest[1] if rest else -1   # seven layers of audio acc
+                layer_loss = rest[2] if rest else -1  # eight layers of loss (seven audio and one text)
+                loss = outputs.loss
+
+                loss = loss / gradient_accumulation_steps
+                layer_loss = [l / gradient_accumulation_steps for l in layer_loss]
+                acc = acc / gradient_accumulation_steps
+                audio_acc = [a / gradient_accumulation_steps for a in audio_acc]
+
+                if log_config.use_wandb and step % log_config.log_interval == 0:
+                    if train_config.enable_fsdp or train_config.enable_ddp:
+                        if rank==0:
+                            wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_text_accuracy":acc}, step=(epoch * total_length + step))
+                            for layer, acc in enumerate(audio_acc):
+                                wandb.log({f"train_inner/train_inner_audio_accuracy_layer{layer}":acc}, step=(epoch * total_length + step))
+                            for layer, l in enumerate(layer_loss[:-1]):
+                                wandb.log({f"train_inner/train_inner_audio_loss_layer{layer}":l}, step=(epoch * total_length + step))
+                            wandb.log({f"train_inner/train_inner_text_loss":layer_loss[-1]}, step=(epoch * total_length + step))
+                    else:
+                        wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step))
+                        for layer, acc in enumerate(audio_acc):
+                            wandb.log({f"train_inner/train_inner_audio_accuracy_layer{layer}":acc}, step=(epoch * total_length + step))
+                        for layer, l in enumerate(layer_loss[:-1]):
+                            wandb.log({f"train_inner/train_inner_audio_loss_layer{layer}":l}, step=(epoch * total_length + step))
+                        wandb.log({f"train_inner/train_inner_text_loss":layer_loss[-1]}, step=(epoch * total_length + step))
+                    
+                total_loss += loss.detach().float()
+                total_acc += acc
+                if train_config.use_fp16:
+                    # if fp16 is enabled, use gradient scaler to handle gradient update
+                    scaler.scale(loss).backward()
+                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                        scaler.step(optimizer)
+                        scaler.update()
+                        if lr_scheduler is not None:
+                            lr_scheduler.step()
+                            current_lr = lr_scheduler.get_last_lr()[0]
+                        else:
+                            current_lr = optimizer.param_groups[0]["lr"]
+                        if current_lr == 0:
+                            break
+                        if log_config.use_wandb and step % log_config.log_interval == 0:
+                            if train_config.enable_fsdp or train_config.enable_ddp:
+                                if rank==0:
+                                    wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step))
+                            else:
+                                wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step))
+                        optimizer.zero_grad()
+                        pbar.update(1)
+                else:
+                    # regular backpropagation when fp16 is not used
+                    loss.backward()
+                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                        optimizer.step()
+                        if lr_scheduler is not None:
+                            lr_scheduler.step()
+                            current_lr = lr_scheduler.get_last_lr()[0]
+                        else:
+                            current_lr = optimizer.param_groups[0]["lr"]
+                        if current_lr == 0:
+                            break
+                        if log_config.use_wandb and step % log_config.log_interval == 0:
+                            if train_config.enable_fsdp or train_config.enable_ddp:
+                                if rank==0:
+                                    wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step))
+                            else:
+                                wandb.log({"train_inner/lr":current_lr}, step=(epoch * total_length + step))
+                        optimizer.zero_grad()
+                        pbar.update(1)
+
+                pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()}, acc: {acc})")
+                
+                if (epoch * total_length + step + 1) % train_config.validation_interval == 0 and train_config.run_validation:
+                    eval_ppl, eval_epoch_loss, *rest = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
+                    eval_epoch_acc = rest[0] if rest else -1
+                    checkpoint_start_time = time.perf_counter()
+                    if train_config.save_model and (eval_epoch_loss < best_val_loss):
+                        checkpoint_name = f"{train_config.model_name}_epoch_{str(epoch+1)}_step_{step+1}"
+                        if train_config.enable_fsdp or train_config.enable_ddp:
+                            dist.barrier()
+                        if train_config.use_peft:
+                            if train_config.enable_fsdp or train_config.enable_ddp:
+                                if rank==0:
+                                    logger.info(f"we are about to save the PEFT modules")
+                            else:
+                                logger.info(f"we are about to save the PEFT modules")
+                            if train_config.enable_fsdp:
+                                if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD:
+                                    save_model_checkpoint_peft_full_shard(
+                                            model, optimizer, rank, train_config, epoch=epoch
+                                        )
+                                elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD:
+                                    if rank==0:
+                                        save_model_checkpoint_peft(
+                                            model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
+                                        )
+                                    dist.barrier()
+                            elif train_config.enable_ddp:
+                                if rank==0:
+                                    save_model_checkpoint_peft(
+                                            model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
+                                        )
+                                dist.barrier()
+                            else:
+                                # model.save_pretrained(train_config.output_dir)
+                                save_model_checkpoint_peft(
+                                        model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
+                                    )
+                            if train_config.enable_fsdp or train_config.enable_ddp:
+                                if rank==0:
+                                    logger.info(f"PEFT modules are saved in {train_config.output_dir} directory")
+                            else:
+                                logger.info(f"PEFT modules are saved in {train_config.output_dir} directory")
+                        
+                        elif not train_config.use_peft and train_config.freeze_llm:
+                            logger.info(f"llm is frozen, we are about to save other parts.")
+                            if train_config.enable_fsdp:
+                                if fsdp_config.sharding_strategy == ShardingStrategy.FULL_SHARD:
+                                    save_model_checkpoint_peft_full_shard(
+                                            model, optimizer, rank, train_config, epoch=epoch
+                                        )
+                                elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD:
+                                    if rank==0:
+                                        save_model_checkpoint_peft(
+                                            model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
+                                        )
+                                    dist.barrier()
+                            elif train_config.enable_ddp:
+                                if rank==0:
+                                    save_model_checkpoint_peft(
+                                            model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
+                                        )
+                                dist.barrier()
+                            else:
+                                save_model_checkpoint_peft(
+                                        model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
+                                    )
+
+                        else:
+                            if train_config.enable_fsdp:
+                                if getattr(StateDictType, fsdp_config.checkpoint_type) == StateDictType.FULL_STATE_DICT:
+                                    save_model_checkpoint(
+                                        model, optimizer, rank, train_config, epoch=epoch
+                                    )
+                                elif getattr(StateDictType, fsdp_config.checkpoint_type) == StateDictType.SHARDED_STATE_DICT:
+                                    logger.info(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
+                                    logger.info("=====================================================")
+
+                                    save_model_and_optimizer_sharded(model, rank, train_config)
+                                    if train_config.save_optimizer:
+                                        save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
+                                        logger.info(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
+                                        logger.info("=====================================================")
+
+                                if train_config.save_optimizer:
+                                    save_optimizer_checkpoint(
+                                        model, optimizer, rank, train_config, epoch=epoch
+                                    )
+                                    logger.info(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
+                                    logger.info("=====================================================")
+
+                            elif train_config.enable_ddp:
+                                if rank==0:
+                                    save_model_checkpoint_peft(
+                                            model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
+                                        )
+                                dist.barrier()
+                                    
+                            else:
+                                save_model_checkpoint_peft(
+                                        model, optimizer, rank, train_config, checkpoint_name=checkpoint_name
+                                    )
+                                
+                        if train_config.enable_fsdp or train_config.enable_ddp:
+                            dist.barrier()
+                    checkpoint_end_time = time.perf_counter() - checkpoint_start_time
+                    checkpoint_times.append(checkpoint_end_time)
+                    if eval_epoch_loss < best_val_loss:
+                        best_val_loss = eval_epoch_loss
+                        if train_config.enable_fsdp or train_config.enable_ddp:
+                            if rank==0:
+                                logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
+                        else:
+                            logger.info(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
+                    val_loss.append(eval_epoch_loss)
+                    val_prep.append(eval_ppl)
+                    if rest:
+                        if eval_epoch_acc > best_val_acc:
+                            best_val_acc = eval_epoch_acc
+                            if train_config.enable_fsdp or train_config.enable_ddp:
+                                if rank==0:
+                                    logger.info(f"best eval acc on epoch {epoch+1} is {best_val_acc}")
+                            else:
+                                logger.info(f"best eval acc on epoch {epoch+1} is {best_val_acc}")
+                        val_acc.append(rest[0]) 
+                    else: 
+                        val_acc.append(-1)
+                    
+                    if log_config.use_wandb:
+                        if train_config.enable_fsdp or train_config.enable_ddp:
+                            if rank==0:
+                                wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1], "valid/val_best_accuracy":best_val_acc})
+                        else:
+                            wandb.log({"valid/val_epoch_loss":eval_epoch_loss, "valid/val_perplexity":eval_ppl, "valid/best_val_loss":best_val_loss, "valid/val_accuracy":val_acc[-1], "valid/val_best_accuracy":best_val_acc})
+
+                if train_config.run_test_during_validation:
+                    if train_config.enable_fsdp or train_config.enable_ddp:
+                        if rank==0:
+                            logger.info("=====================================")
+                            logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:")
+                            with autocast():
+                                logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt))
+                            logger.info("=====================================")
+                        dist.barrier()
+                    else:
+                        logger.info("=====================================")
+                        logger.info(f"Test the file {train_config.run_test_during_validation_file} during validation:")
+                        with autocast():
+                            logger.info(model.inference(train_config.run_test_during_validation_file, train_config.run_test_during_validation_prompt))
+                        logger.info("=====================================")
+            pbar.close()
+
+        epoch_end_time = time.perf_counter()-epoch_start_time
+        epoch_times.append(epoch_end_time)
+        # Reducing total_loss across all devices if there's more than one CUDA device
+        if torch.cuda.device_count() > 1 and (train_config.enable_fsdp or train_config.enable_ddp):
+            dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
+            dist.all_reduce(total_acc, op=dist.ReduceOp.SUM)
+        train_epoch_loss = total_loss / len(train_dataloader)
+        train_epoch_acc = total_acc / len(train_dataloader)
+        if train_config.enable_fsdp or train_config.enable_ddp:
+            train_epoch_loss = train_epoch_loss/world_size
+            train_epoch_acc = train_epoch_acc/world_size
+        train_perplexity = torch.exp(train_epoch_loss)
+
+        train_prep.append(train_perplexity)
+        train_loss.append(train_epoch_loss)
+        train_acc.append(train_epoch_acc)
+
+        if log_config.use_wandb:
+            if train_config.enable_fsdp or train_config.enable_ddp:
+                if rank==0:
+                    wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc})
+            else:
+                wandb.log({"train/train_perplexity":train_perplexity, "train/train_epoch_loss":train_epoch_loss, "train/train_epoch_acc":train_epoch_acc})
+
+        if train_config.enable_fsdp or train_config.enable_ddp:
+            if rank==0:
+                logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
+        else:
+            logger.info(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
+
+        if train_config.enable_fsdp:
+            if rank==0:
+                logger.info(f"Max CUDA memory allocated was {memtrace.peak} GB")
+                logger.info(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+                logger.info(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+                logger.info(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+                logger.info(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        else:
+            logger.info(f"Max CUDA memory allocated was {memtrace.peak} GB")
+            logger.info(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+            logger.info(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+            logger.info(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+            logger.info(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+
+        # Update the learning rate as needed
+        # lr_scheduler.step()
+
+    avg_epoch_time = sum(epoch_times)/ len(epoch_times)
+    avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
+    avg_train_prep = sum(train_prep)/len(train_prep)
+    avg_train_loss = sum(train_loss)/len(train_loss)
+    avg_train_acc = sum(train_acc)/len(train_acc)
+    if train_config.run_validation:
+        avg_eval_prep = sum(val_prep)/len(val_prep)
+        avg_eval_loss = sum(val_loss)/len(val_loss)
+        avg_eval_acc = sum(val_acc)/len(val_acc)
+
+    results['avg_train_prep'] = avg_train_prep
+    results['avg_train_loss'] = avg_train_loss
+    results['avg_train_acc'] = avg_train_acc
+    if train_config.run_validation:
+        results['avg_eval_prep'] = avg_eval_prep
+        results['avg_eval_loss'] = avg_eval_loss
+        results['avg_eval_acc'] = avg_eval_acc
+    results["avg_epoch_time"] = avg_epoch_time
+    results["avg_checkpoint_time"] = avg_checkpoint_time
+
+    #saving the training params including fsdp setting for reference.
+    # if (train_config.enable_fsdp or train_config.enable_ddp)and not train_config.use_peft:
+    #     save_train_params(train_config, fsdp_config, rank)
+
+    return results
+
+def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
+    """
+    Evaluates the model on the given dataloader
+
+    Args:
+        model: The model to evaluate
+        eval_dataloader: The dataloader containing the evaluation data
+        local_rank: The rank of the current node in a distributed setting
+        tokenizer: The tokenizer used to decode predictions
+
+    Returns: eval_ppl, eval_epoch_loss
+    """
+    if train_config.enable_fsdp or train_config.enable_ddp:
+        world_size = int(os.environ["WORLD_SIZE"])
+    model.eval()
+    eval_preds = []
+    eval_loss = 0.0  # Initialize evaluation loss
+    eval_acc = 0.0
+    autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext # (Fix:MZY): fix expected scalar type mismatch in norm 
+
+    with MemoryTrace() as memtrace:
+        total_length = len(eval_dataloader)
+        pbar = tqdm(colour="green", desc=f"Evaluating Epoch", total=total_length, dynamic_ncols=True)
+        for step, batch in enumerate(eval_dataloader):
+            for key in batch.keys():
+                if train_config.enable_fsdp or train_config.enable_ddp:
+                    batch[key] = batch[key].to(local_rank) if isinstance(batch[key], torch.Tensor) else batch[key]
+                else:
+                    batch[key] = batch[key].to('cuda:0') if isinstance(batch[key], torch.Tensor) else batch[key]
+            # Ensure no gradients are computed for this scope to save memory
+            with torch.no_grad():
+                # Forward pass and compute loss
+                with autocast(): # (Fix:MZY): fix expected scalar type mismatch in norm 
+                    outputs, *rest = model(**batch)
+                acc = rest[0] if rest else -1
+                loss = outputs.loss
+
+                eval_loss += loss.detach().float()
+                eval_acc += acc
+            # Decode predictions and add to evaluation predictions list
+            try:
+                preds = torch.argmax(outputs.logits, -1)
+                eval_preds.extend(
+                    tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
+                )
+            except Exception:
+                pass  # vallex does not need to show it's result (we can't view any thing from abstract acoustic token)
+            pbar.update(1)
+            pbar.set_description(f"step: {step+1}/{total_length}, eval_loss: {eval_loss/(step+1):.4f}, eval_acc: {eval_acc/(step+1):.4f}")
+
+    # If there's more than one CUDA device, reduce evaluation loss across all devices
+    if torch.cuda.device_count() > 1 and train_config.enable_fsdp or train_config.enable_ddp:
+        dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
+        dist.all_reduce(eval_acc, op=dist.ReduceOp.SUM)
+
+    # Compute average loss and perplexity
+    eval_epoch_loss = eval_loss / len(eval_dataloader)
+    eval_epoch_acc = eval_acc / len(eval_dataloader)
+    if train_config.enable_fsdp or train_config.enable_ddp:
+        eval_epoch_loss = eval_epoch_loss/world_size
+        eval_epoch_acc = eval_epoch_acc/world_size
+    eval_ppl = torch.exp(eval_epoch_loss)
+
+    # Print evaluation metrics
+    if train_config.enable_fsdp or train_config.enable_ddp:
+        if local_rank==0:
+            logger.info(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}")
+    else:
+        logger.info(f" {eval_ppl=} {eval_epoch_loss=} {eval_epoch_acc=}")
+
+    return eval_ppl, eval_epoch_loss, eval_epoch_acc
+
+def freeze_transformer_layers(model, num_layer):
+   for i, layer in enumerate(model.model.layers):
+            if i < num_layer:
+                for param in layer.parameters():
+                    param.requires_grad = False
+
+
+def check_frozen_layers_peft_model(model):
+     for i, layer in enumerate(model.base_model.model.model.layers):
+            for name, param in layer.named_parameters():
+                logger.info(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
+
+
+def setup():
+    """Initialize the process group for distributed training"""
+    dist.init_process_group("nccl")
+
+
+def setup_environ_flags(rank):
+    """Set environment flags for debugging purposes"""
+    os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
+    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
+    # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
+    # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
+    # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
+    if rank == 0:
+        logger.info(f"--> Running with torch dist debug set to detail")
+
+
+def cleanup():
+    """Clean up the process group after training"""
+    dist.destroy_process_group()
+
+
+def clear_gpu_cache(rank=None):
+    """Clear the GPU cache for all ranks"""
+    if rank == 0:
+        logger.info(f"Clearing GPU cache for all ranks")
+    torch.cuda.empty_cache()
+
+
+def get_parameter_dtypes(model):
+    """Get the data types of model parameters"""
+    parameter_dtypes = {}
+    for name, parameter in model.named_parameters():
+        parameter_dtypes[name] = parameter.dtype
+    return parameter_dtypes
+
+def print_model_size(model, config, rank: int = 0) -> None:
+    """
+    log model name, the number of trainable parameters and initialization time.
+
+    Args:
+        model: The PyTorch model.
+        model_name (str): Name of the model.
+        init_time_start (float): Initialization start time.
+        init_time_end (float): Initialization end time.
+        rank (int, optional): Current process's rank. Defaults to 0.
+    """
+    if rank == 0:
+        logger.info(f"--> Model {config.model_name}")
+        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+        logger.info(f"--> {config.model_name} has {total_params / 1e6} Million params\n")
+
+def print_module_size(module, module_name, rank: int = 0) -> None:
+    """
+    Print module name, the number of trainable parameters and initialization time.
+
+    Args:
+        module: The PyTorch module.
+        module_name (str): Name of the model.
+        rank (int, optional): Current process's rank. Defaults to 0.
+    """
+    if rank == 0:
+        logger.info(f"--> Module {module_name}")
+        total_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
+        logger.info(f"--> {module_name} has {total_params / 1e6} Million params\n")
+
+
+def get_policies(cfg, rank):
+    """Get the policies for mixed precision and fsdp wrapping"""
+
+    verify_bfloat_support = (
+    torch.version.cuda
+    and torch.cuda.is_bf16_supported()
+    and packaging.version.parse(torch.version.cuda).release >= (11, 0)
+    and dist.is_nccl_available()
+    and nccl.version() >= (2, 10)
+    )
+
+
+    mixed_precision_policy = None
+    wrapping_policy = None
+
+    # Mixed precision
+    if cfg.mixed_precision:
+        bf16_ready = verify_bfloat_support
+
+        if bf16_ready and not cfg.use_fp16:
+            mixed_precision_policy = bfSixteen_mixed
+            if rank == 0:
+                logger.info(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
+        elif cfg.use_fp16:
+            mixed_precision_policy = fpSixteen
+            if rank == 0:
+                logger.info(f"FP16 enabled")
+        else:
+            logger.info(f"bFloat16 support not present. Using FP32, and not mixed precision")
+    wrapping_policy = get_llama_wrapper()
+    return mixed_precision_policy, wrapping_policy
+
+def save_train_params(train_config, fsdp_config, rank):
+    """
+    This function saves the train_config and FSDP config into a train_params.yaml.
+    This will be used by converter script in the inference folder to fetch the HF model name or path.
+    It also would be hepful as a log for future references.
+    """
+    # Convert the train_config and fsdp_config objects to dictionaries,
+    # converting all values to strings to ensure they can be serialized into a YAML file
+    train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
+    fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
+    # Merge the two dictionaries into one
+    train_params_dict = {**train_config_dict, **fsdp_config_dict}
+    # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
+    folder_name = (
+    train_config.dist_checkpoint_root_folder
+    + "/"
+    + train_config.dist_checkpoint_folder
+    + "-"
+    + train_config.model_name
+    )
+
+    save_dir = Path.cwd() / folder_name
+    # If the directory does not exist, create it
+    if not os.path.exists(save_dir):
+        os.makedirs(save_dir)
+    # Convert the dictionary to a YAML string
+    config_yaml = yaml.dump(train_params_dict, indent=4)
+    file_name = os.path.join(save_dir,'train_params.yaml')
+
+    # Check if there's a directory with the same name as the file
+    if os.path.isdir(file_name):
+        logger.info(f"Error: {file_name} is a directory, not a file.")
+    else:
+        # Write the YAML string to the file
+        with open(file_name, 'w') as f:
+            f.write(config_yaml)
+        if rank==0:
+            logger.info(f"training params are saved in {file_name}")
diff --git a/slam_llm/utils/whisper_tn.py b/slam_llm/utils/whisper_tn.py
new file mode 100644
index 0000000000000000000000000000000000000000..6748d6920f0fa2bbb42aee87848718507e488eb3
--- /dev/null
+++ b/slam_llm/utils/whisper_tn.py
@@ -0,0 +1,23 @@
+import sys
+import os
+import re
+import string
+from whisper_normalizer.english import EnglishTextNormalizer
+
+english_normalizer = EnglishTextNormalizer()
+
+def normalize_text(srcfn, dstfn):
+    with open(srcfn, "r") as f_read, open(dstfn, "w") as f_write:
+        all_lines = f_read.readlines()
+        for line in all_lines:
+            line = line.strip()
+            line_arr = line.split()
+            key = line_arr[0]
+            conts = " ".join(line_arr[1:])
+            normalized_conts = english_normalizer(conts)
+            f_write.write("{0}\t{1}\n".format(key, normalized_conts))
+
+if __name__ == "__main__":
+    srcfn = sys.argv[1]
+    dstfn = sys.argv[2]
+    normalize_text(srcfn, dstfn)
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/codec_utils.py b/utils/codec_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b576414a59db176a84d548e0eae83b2da601a503
--- /dev/null
+++ b/utils/codec_utils.py
@@ -0,0 +1,12 @@
+from snac import SNAC
+from slam_llm.utils.train_utils import print_module_size
+import os
+
+def setup_codec(train_config, model_config, **kwargs):
+    if model_config.codec_decoder_type == "SNAC":
+        codec_decoder = SNAC.from_pretrained(model_config.codec_decoder_path).eval()
+    else:
+        raise NotImplementedError
+    print_module_size(codec_decoder, model_config.codec_decoder_type, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0)
+    
+    return codec_decoder
\ No newline at end of file
diff --git a/utils/snac_utils.py b/utils/snac_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..05d74a0176deac04c77ecc522ece0f4c20bbd4be
--- /dev/null
+++ b/utils/snac_utils.py
@@ -0,0 +1,168 @@
+import torch
+import time
+import numpy as np
+
+
+class SnacConfig:
+    audio_vocab_size = 4096
+    padded_vocab_size = 4160
+    end_of_audio = 4096
+    padding_token = 4097
+
+
+snac_config = SnacConfig()    
+
+
+def get_time_str():
+    time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
+    return time_str
+
+
+def layershift(input_id, layer, stride=4160, shift=152000):
+    return input_id + shift + layer * stride
+
+    
+def generate_audio_data(snac_tokens, snacmodel, device=None):
+    audio = reconstruct_tensors(snac_tokens, device)
+    with torch.inference_mode():
+        audio_hat = snacmodel.decode(audio)
+    audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
+    audio_data = audio_data.astype(np.int16)
+    audio_data = audio_data.tobytes()
+    return audio_data
+
+    
+def get_snac(list_output, index, nums_generate):
+
+    snac = []
+    start = index
+    for i in range(nums_generate):
+        snac.append("#")
+        for j in range(7):
+            snac.append(list_output[j][start - nums_generate - 5 + j + i])
+    return snac
+
+
+def reconscruct_snac(output_list):
+    if len(output_list) == 8:
+        output_list = output_list[:-1]
+    output = []
+    for i in range(7):
+        output_list[i] = output_list[i][i + 1 :]
+    for i in range(len(output_list[-1])):
+        output.append("#")
+        for j in range(7):
+            output.append(output_list[j][i])
+    return output
+
+
+def get_snac_answer_token(snac_tokens_str):
+    snac_tokens = snac_tokens_str.split()
+    audio_length = len(snac_tokens) // 8 + 8    # here the additional 8 is due to parallel generation, 7 padding tokens and 1 end of audio token
+    snac_config = SnacConfig()    
+    eoa = snac_config.end_of_audio
+    padding_token = snac_config.padding_token
+    result = []
+
+    for layer in range(1, 8):  # 从第1层到第7层
+        layer_tokens = []
+        layer_tokens.extend([padding_token] * layer)
+        layer_tokens.extend([snac_tokens[i] for i in range(len(snac_tokens)) if i % 8 == layer])
+        layer_tokens.append(eoa)
+        if layer < 7:
+            layer_tokens.extend([padding_token] * (7 - layer))
+        result.append(torch.tensor([int(token) for token in layer_tokens]))
+        
+    result_tensor = torch.stack(result)
+    return result_tensor, audio_length
+
+
+def reconstruct_tensors(flattened_output, device=None):
+    """Reconstructs the list of tensors from the flattened output."""
+
+    if device is None:
+        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    def count_elements_between_hashes(lst):
+        try:
+            # Find the index of the first '#'
+            first_index = lst.index("#")
+            # Find the index of the second '#' after the first
+            second_index = lst.index("#", first_index + 1)
+            # Count the elements between the two indices
+            return second_index - first_index - 1
+        except ValueError:
+            # Handle the case where there aren't enough '#' symbols
+            return "List does not contain two '#' symbols"
+
+    def remove_elements_before_hash(flattened_list):
+        try:
+            # Find the index of the first '#'
+            first_hash_index = flattened_list.index("#")
+            # Return the list starting from the first '#'
+            return flattened_list[first_hash_index:]
+        except ValueError:
+            # Handle the case where there is no '#'
+            return "List does not contain the symbol '#'"
+
+    def list_to_torch_tensor(tensor1):
+        # Convert the list to a torch tensor
+        tensor = torch.tensor(tensor1)
+        # Reshape the tensor to have size (1, n)
+        tensor = tensor.unsqueeze(0)
+        return tensor
+
+    flattened_output = remove_elements_before_hash(flattened_output)
+    codes = []
+    tensor1 = []
+    tensor2 = []
+    tensor3 = []
+    tensor4 = []
+
+    n_tensors = count_elements_between_hashes(flattened_output)
+    if n_tensors == 7:
+        for i in range(0, len(flattened_output), 8):
+
+            tensor1.append(flattened_output[i + 1])
+            tensor2.append(flattened_output[i + 2])
+            tensor3.append(flattened_output[i + 3])
+            tensor3.append(flattened_output[i + 4])
+
+            tensor2.append(flattened_output[i + 5])
+            tensor3.append(flattened_output[i + 6])
+            tensor3.append(flattened_output[i + 7])
+            codes = [
+                list_to_torch_tensor(tensor1).to(device),
+                list_to_torch_tensor(tensor2).to(device),
+                list_to_torch_tensor(tensor3).to(device),
+            ]
+
+    if n_tensors == 15:
+        for i in range(0, len(flattened_output), 16):
+
+            tensor1.append(flattened_output[i + 1])
+            tensor2.append(flattened_output[i + 2])
+            tensor3.append(flattened_output[i + 3])
+            tensor4.append(flattened_output[i + 4])
+            tensor4.append(flattened_output[i + 5])
+            tensor3.append(flattened_output[i + 6])
+            tensor4.append(flattened_output[i + 7])
+            tensor4.append(flattened_output[i + 8])
+
+            tensor2.append(flattened_output[i + 9])
+            tensor3.append(flattened_output[i + 10])
+            tensor4.append(flattened_output[i + 11])
+            tensor4.append(flattened_output[i + 12])
+            tensor3.append(flattened_output[i + 13])
+            tensor4.append(flattened_output[i + 14])
+            tensor4.append(flattened_output[i + 15])
+
+            codes = [
+                list_to_torch_tensor(tensor1).to(device),
+                list_to_torch_tensor(tensor2).to(device),
+                list_to_torch_tensor(tensor3).to(device),
+                list_to_torch_tensor(tensor4).to(device),
+            ]
+
+    return codes
+
diff --git a/utils/trick_utils.py b/utils/trick_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0ad178272b059301174a832e7d82f5e8ac9e537
--- /dev/null
+++ b/utils/trick_utils.py
@@ -0,0 +1,45 @@
+import torch
+import os
+import logging
+
+logger = logging.getLogger(__name__)
+
+def partial_freeze_weights(model, original_vocabsize, total_vocabsize):
+    if int(os.environ.get("RANK", "0")) == 0:
+        logger.info("Only training partial embedding layer")
+
+    trainable_range = (original_vocabsize, total_vocabsize)
+
+    # Define a hook to zero out the gradient for weights outside the trainable range during the backward pass
+    def zero_out_gradient(grad):
+        grad[:trainable_range[0], :] = 0
+        grad[trainable_range[1] + 1:, :] = 0
+        return grad
+
+    # Freeze all layers first
+    for param in model.parameters():
+        param.requires_grad = False
+
+    # Assuming the output layer is `lm_head`
+    for param in model.llm.lm_head.parameters():
+        # Compute the standard deviation for He initialization
+        std_dev = (2.0 / param.size(1)) ** 0.5
+
+        # Initialize the specific rows with He initialization
+        param[original_vocabsize:total_vocabsize] = (
+            torch.randn((trainable_range[1] - trainable_range[0], param.size(1))) * std_dev
+        )
+        param.requires_grad = True
+
+        # Register the hook on the weight tensor
+        param.register_hook(zero_out_gradient)
+
+def train_embedding_layer_only(model):
+    if int(os.environ.get("RANK", "0")) == 0:
+        logger.info("Only training embedding layer")
+
+    for param in model.parameters():
+        param.requires_grad = False
+        
+    for param in model.llm.lm_head.parameters():
+        param.requires_grad = True
\ No newline at end of file
diff --git a/utils/tts_adapter_utils.py b/utils/tts_adapter_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..866257130ecdea134cf723b1ce38246f77a2033b
--- /dev/null
+++ b/utils/tts_adapter_utils.py
@@ -0,0 +1,323 @@
+# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
+
+"""Full definition of a decoder-only transformer-based language model, all of it in this single file.
+
+Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
+https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
+"""
+
+import math
+from typing import Any, Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+def setup_tts_adapter(adapter_config, model_config, **kwargs):
+    return nn.ModuleDict(
+        dict(
+            post_adapter=nn.ModuleList(
+                Block(adapter_config) for _ in range(adapter_config.n_layer)
+            ),
+            post_adapter_audio_ln=adapter_config.norm_class(
+                model_config.llm_dim, eps=adapter_config.norm_eps
+            ),
+            post_adapter_audio_lm_head=nn.Linear(
+                model_config.llm_dim, model_config.vocab_config.total_audio_vocabsize, bias=adapter_config.lm_head_bias
+            ),
+        )
+    )
+
+class Block(nn.Module):
+
+    def __init__(self, config) -> None:
+        super().__init__()
+        if not config.parallel_residual and config.shared_attention_norm:
+            raise NotImplementedError(
+                "No checkpoint amongst the ones we support uses this configuration"
+                " (non-parallel residual and shared attention norm)."
+            )
+
+        if config.norm_class_name == "RMSNorm":
+            self.norm_class = RMSNorm
+
+        self.norm_1 = self.norm_class(config.n_embd, eps=config.norm_eps)
+        self.attn = CausalSelfAttention(config)
+        self.norm_2 = (
+            None
+            if config.shared_attention_norm
+            else self.norm_class(config.n_embd, eps=config.norm_eps)
+        )
+
+        if config.mlp_class_name == "GptNeoxMLP":
+            self.mlp_class = GptNeoxMLP
+
+        self.mlp = self.mlp_class(config)
+
+        self.config = config
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        cos: torch.Tensor,
+        sin: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        input_pos: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """
+        Non-parallel residual       Parallel residual
+           ┌─ x                     ┌─ x ────────────┐             Note: if `shared_attention_norm` is True,
+           │  ↓                     │  ↓             ↓                   the output from `norm_1` is reused
+           │  norm_1                │  norm_1  ───►  norm_2
+           │  ↓                     │  ↓             ↓
+           │  attn                  │  attn          mlp
+           │  ↓                     │  ↓             │
+        ┌─ └► +                     └► + ◄───────────┘
+        │     norm_2
+        │     ↓
+        │     mlp
+        │     ↓
+        └───► +
+        """
+
+        x_normed = self.norm_1(x)
+        attention_output = self.attn(x_normed, cos, sin, mask, input_pos)
+
+        if self.config.parallel_residual:
+            x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x)
+            x = self.mlp(x_normed) + attention_output + x
+        else:
+            x = attention_output + x
+            x = self.mlp(self.norm_2(x)) + x
+        return x
+
+
+class CausalSelfAttention(nn.Module):
+    def __init__(self, config) -> None:
+        super().__init__()
+        shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
+        # key, query, value projections for all heads, but in a batch
+        self.attn = nn.Linear(config.n_embd, shape, bias=config.add_qkv_bias)
+        # output projection
+        # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
+        self.proj = nn.Linear(
+            config.head_size * config.n_head, config.n_embd, bias=config.bias
+        )
+        # disabled by default
+        self.kv_cache: Optional[KVCache] = None
+
+        self.config = config
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        cos: torch.Tensor,
+        sin: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        input_pos: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        B, T, C = (
+            x.size()
+        )  # batch size, sequence length, embedding dimensionality (n_embd)
+
+        qkv = self.attn(x)
+
+        # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
+        q_per_kv = self.config.n_head // self.config.n_query_groups
+        total_qkv = q_per_kv + 2  # each group has 1+ queries, 1 key, and 1 value
+        qkv = qkv.view(
+            B, T, self.config.n_query_groups, total_qkv, self.config.head_size
+        )
+        qkv = qkv.permute(0, 2, 3, 1, 4)  # (B, n_query_groups, total_qkv, T, hs)
+
+        # split batched computation into three
+        q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
+
+        # maybe repeat k and v if for the non multi-head attention cases
+        # training: flash attention requires it
+        # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
+        if self.config.n_query_groups != self.config.n_head and (
+            input_pos is None or self.config.n_query_groups != 1
+        ):
+            k = k.expand(
+                B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
+            )
+            v = v.expand(
+                B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
+            )
+
+        q = q.reshape(B, -1, T, self.config.head_size)  # (B, nh_q, T, hs)
+        k = k.reshape(B, -1, T, self.config.head_size)  # (B, nh_k, T, hs)
+        v = v.reshape(B, -1, T, self.config.head_size)  # (B, nh_v, T, hs)
+
+        q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
+        k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
+        q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
+        k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
+
+        if input_pos is not None:
+            if not isinstance(self.kv_cache, KVCache):
+                raise TypeError("You need to call `gpt.set_kv_cache()`")
+            k, v = self.kv_cache(input_pos, k, v)
+
+        y = self.scaled_dot_product_attention(q, k, v, mask)
+
+        y = y.reshape(
+            B, T, self.config.head_size * self.config.n_head
+        )  # re-assemble all head outputs side by side
+
+        # output projection
+        return self.proj(y)
+
+    def scaled_dot_product_attention(
+        self,
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        scale = 1.0 / math.sqrt(self.config.head_size)
+        y = torch.nn.functional.scaled_dot_product_attention(
+            q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
+        )
+        return y.transpose(1, 2)
+
+    def build_kv_cache(
+        self,
+        batch_size: int,
+        max_seq_length: int,
+        rope_cache_length: Optional[int] = None,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ) -> "KVCache":
+        heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
+        v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
+        if rope_cache_length is None:
+            if self.config.rotary_percentage != 1.0:
+                raise TypeError(
+                    "Please pass the `rope_cache_length=gpt.cos.size(-1)` value"
+                )
+            k_shape = v_shape
+        else:
+            k_shape = (
+                batch_size,
+                heads,
+                max_seq_length,
+                rope_cache_length + self.config.head_size - self.config.rope_n_elem,
+            )
+        return KVCache(k_shape, v_shape, device=device, dtype=dtype)
+
+
+
+
+def build_rope_cache(
+    seq_len: int,
+    n_elem: int,
+    device: Optional[torch.device] = None,
+    base: int = 10000,
+    condense_ratio: int = 1,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Enhanced Transformer with Rotary Position Embedding.
+
+    Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
+    transformers/rope/__init__.py. MIT License:
+    https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
+    """
+    # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
+    theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
+
+    # Create position indexes `[0, 1, ..., seq_len - 1]`
+    seq_idx = torch.arange(seq_len, device=device) / condense_ratio
+
+    # Calculate the product of position index and $\theta_i$
+    idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
+
+    return torch.cos(idx_theta), torch.sin(idx_theta)
+
+
+def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
+    head_size = x.size(-1)
+    x1 = x[..., : head_size // 2]  # (B, nh, T, hs/2)
+    x2 = x[..., head_size // 2 :]  # (B, nh, T, hs/2)
+    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
+    roped = (x * cos) + (rotated * sin)
+    return roped.to(dtype=x.dtype)
+
+
+class KVCache(nn.Module):
+    def __init__(
+        self,
+        k_shape: Tuple[int, int, int, int],
+        v_shape: Tuple[int, int, int, int],
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ) -> None:
+        super().__init__()
+        self.register_buffer(
+            "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
+        )
+        self.register_buffer(
+            "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
+        )
+
+    def forward(
+        self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # move the buffer to the activation dtype for when AMP is used
+        self.k = self.k.to(k.dtype)
+        self.v = self.v.to(v.dtype)
+        # update the cache
+        k = self.k.index_copy_(2, input_pos, k)
+        v = self.v.index_copy_(2, input_pos, v)
+        return k, v
+
+    def reset_parameters(self) -> None:
+        torch.nn.init.zeros_(self.k)
+        torch.nn.init.zeros_(self.v)
+
+
+
+class RMSNorm(torch.nn.Module):
+    """Root Mean Square Layer Normalization.
+
+    Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
+    https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
+    """
+
+    def __init__(
+        self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False
+    ) -> None:
+        super().__init__()
+        self.weight = torch.nn.Parameter(torch.ones(size))
+        self.eps = eps
+        self.dim = dim
+        self.add_unit_offset = add_unit_offset
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        dtype = x.dtype
+        x = x.float()
+        # NOTE: the original RMSNorm paper implementation is not equivalent
+        norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
+        x_normed = x * torch.rsqrt(norm_x + self.eps)
+        x_normed = x_normed.to(dtype=dtype)
+        if self.add_unit_offset:
+            # Gemma model requires a unit offset
+            # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176
+            return x_normed * (1 + self.weight)
+        return x_normed * self.weight
+
+    def reset_parameters(self) -> None:
+        torch.nn.init.ones_(self.weight)
+
+
+class GptNeoxMLP(nn.Module):
+    def __init__(self, config) -> None:
+        super().__init__()
+        self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
+        self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
+
+        self.config = config
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.fc(x)
+        x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
+        return self.proj(x)
\ No newline at end of file