# Copyright (2024) Tsinghua University, Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import json import contextlib import random import torch import torch.nn as nn import torch.nn.functional as F from transformers import LlamaTokenizer, StoppingCriteriaList from peft import LoraConfig, TaskType, get_peft_model from .Qformer import BertConfig, BertLMHeadModel from .modeling_llama import LlamaForCausalLM from .modeling_whisper import WhisperModel from .beats.BEATs import BEATsConfig, BEATs from .utils import StoppingCriteriaSub class TINYOCTOPUS(nn.Module): @classmethod def init_speech_Qformer(cls, num_query_token, speech_width, num_hidden_layers=2): encoder_config = BertConfig.from_pretrained("bert-base-uncased") encoder_config.num_hidden_layers = num_hidden_layers encoder_config.encoder_width = speech_width # insert cross-attention layer every other block encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = 1 encoder_config.query_length = num_query_token Qformer = BertLMHeadModel(config=encoder_config) query_tokens = nn.Parameter( torch.zeros(1, num_query_token, encoder_config.hidden_size) ) query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) return Qformer, query_tokens @property def device(self): return list(self.parameters())[0].device def maybe_autocast(self, dtype=torch.float16): # if on cpu, don't use autocast # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 enable_autocast = self.device != torch.device("cpu") if enable_autocast: return torch.cuda.amp.autocast(dtype=dtype) else: return contextlib.nullcontext() def __init__( self, llama_path="", whisper_path="", freeze_whisper=True, beats_path="", freeze_beats=True, use_speech_Qformer=True, num_speech_query_token=1, freeze_speech_QFormer=False, window_level_Qformer=True, second_per_window=0.333333, second_stride=0.333333, speech_llama_proj_model="", freeze_speech_llama_proj=False, lora=True, lora_rank=8, lora_alpha=32, lora_dropout=0.1, multi_prompt=False, prompt_path="", prompt_template="", max_txt_len=128, end_sym="", low_resource=False, # use 8 bit device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. ): super().__init__() self.beats_path = beats_path self.use_speech_Qformer = use_speech_Qformer self.window_level_Qformer = window_level_Qformer self.second_per_window = second_per_window self.second_stride = second_stride self.lora = lora self.multi_prompt = multi_prompt self.max_txt_len = max_txt_len self.end_sym = end_sym self.low_resource = low_resource logging.info('Loading LLaMA Tokenizer') self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_path, use_fast=False) self.llama_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.llama_tokenizer.padding_side = "right" logging.info('Loading LLaMA Model') if self.low_resource: self.llama_model = LlamaForCausalLM.from_pretrained( llama_path, torch_dtype=torch.float16, load_in_8bit=True, device_map={"": device_8bit}, ) else: self.llama_model = LlamaForCausalLM.from_pretrained( llama_path, torch_dtype=torch.float16, ) self.llama_model.resize_token_embeddings(len(self.llama_tokenizer)) for name, param in self.llama_model.named_parameters(): param.requires_grad = False logging.info('Loading LLaMA Done') if self.lora: self.peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, ) self.llama_model = get_peft_model(self.llama_model, self.peft_config) self.llama_model.print_trainable_parameters() logging.info('LoRA Training') assert whisper_path logging.info('Loading Whisper Model') self.speech_encoder = WhisperModel.from_pretrained(whisper_path).encoder self.ln_speech = nn.LayerNorm(self.speech_encoder.config.d_model) if freeze_whisper: for name, param in self.speech_encoder.named_parameters(): param.requires_grad = False self.speech_encoder.eval() logging.info("freeze Whisper") if self.beats_path: logging.info("Loading BEATs Model") beats_ckpt = torch.load(self.beats_path, map_location='cpu') beats_cfg = BEATsConfig(beats_ckpt['cfg']) self.beats = BEATs(beats_cfg) self.beats.load_state_dict(beats_ckpt['model']) self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim) if freeze_beats: for name, param in self.beats.named_parameters(): param.requires_grad = False self.beats.eval() logging.info("freeze BEATs") if self.use_speech_Qformer: if self.beats_path: self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer( num_query_token=num_speech_query_token, speech_width=self.speech_encoder.config.d_model + self.beats.cfg.encoder_embed_dim ) else: self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer( num_query_token=num_speech_query_token, speech_width=self.speech_encoder.config.d_model ) self.speech_Qformer.bert.embeddings.word_embeddings = None self.speech_Qformer.bert.embeddings.position_embeddings = None for layer in self.speech_Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None self.speech_Qformer.cls = None if freeze_speech_QFormer: for name, param in self.speech_Qformer.named_parameters(): param.requires_grad = False self.speech_Qformer.eval() self.speech_query_tokens.requires_grad = False logging.info("freeze Speech QFormer") logging.info('Loading speech LLAMA proj') self.speech_llama_proj = nn.Linear( self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size ) if speech_llama_proj_model: logging.info("Loading speech LLAMA proj from {}".format(speech_llama_proj_model)) speech_llama_proj_weight = torch.load(speech_llama_proj_model, map_location="cpu") self.load_state_dict(speech_llama_proj_weight['model'], strict=False) if freeze_speech_llama_proj: for name, param in self.speech_llama_proj.named_parameters(): param.requires_grad = False self.speech_llama_proj.eval() logging.info("freeze speech LLAMA proj") else: # feel free to add other aligners here raise NotImplementedError # prepare prompts self.prompt_dict = {} if prompt_path: try: raw_prompts = json.load(open(prompt_path, "r")) except: print("Failed to load prompt! Try to use utf-8 encoding.") raw_prompts = json.load(open(prompt_path, "r", encoding='utf-8')) for task in raw_prompts.keys(): filted_prompts = [raw_prompt for raw_prompt in raw_prompts[task] if "" in raw_prompt] self.prompt_dict[task] = [prompt_template.format(p) for p in filted_prompts] print("Loading training prompts done!") def _encode_auditory_feature(self, speech_embeds, audio_embeds=None): with self.maybe_autocast(): if self.use_speech_Qformer: speech_embeds = self.ln_speech(speech_embeds) if audio_embeds is not None: audio_embeds = self.ln_audio(audio_embeds) if audio_embeds.size(1) < speech_embeds.size(1): audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) elif audio_embeds.size(1) > speech_embeds.size(1): speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long).to(speech_embeds.device) if self.window_level_Qformer: B, T, C = speech_embeds.shape kernel = round(1500 * self.second_per_window / 30.0) stride = round(1500 * self.second_stride / 30.0) kernel = (1, kernel) stride = (1, stride) speech_embeds_tr = speech_embeds.transpose(1, 2).unsqueeze(2) speech_embeds_overlap = F.unfold(speech_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride) _, _, L = speech_embeds_overlap.shape speech_embeds_overlap = speech_embeds_overlap.view(B, -1, kernel[1], L) speech_embeds_overlap = torch.permute(speech_embeds_overlap, [0, 3, 2, 1]) speech_embeds = speech_embeds_overlap.reshape(-1, kernel[1], C) speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long, device=speech_embeds.device) query_tokens = self.speech_query_tokens.expand(speech_embeds.shape[0], -1, -1) query_output = self.speech_Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=speech_embeds, encoder_attention_mask=speech_atts, return_dict=True, ) speech_embeds = self.speech_llama_proj(query_output.last_hidden_state) if self.window_level_Qformer: speech_embeds = speech_embeds.view(B, -1, speech_embeds.size(2)).contiguous() speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long).to(speech_embeds.device) else: raise NotImplementedError return speech_embeds, speech_atts def encode_speech(self, spectrogram, raw_wav=None, audio_padding_mask=None): with self.maybe_autocast(): speech_embeds = self.speech_encoder(spectrogram, return_dict=True).last_hidden_state if self.beats_path and raw_wav is not None: audio_embeds, _ = self.beats.extract_features(raw_wav, padding_mask=audio_padding_mask, feature_only=True) else: audio_embeds = None return self._encode_auditory_feature(speech_embeds, audio_embeds=audio_embeds) def prompt_wrap(self, embeds, atts, prompt, multi_prompt=False): if prompt: if multi_prompt: p_before = [] p_after = [] for i, p in enumerate(prompt): b, a = p.split("") p_before.append(b) p_after.append(a) p_before_tokens = self.llama_tokenizer( p_before, return_tensors="pt", add_special_tokens=False ).to(embeds.device) p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids) if not self.lora else self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids) # speech_embeds wrapped with prompts_embeds are padded to the same length here p_after_tokens = self.llama_tokenizer( p_after, return_tensors="pt", padding="longest", add_special_tokens=False ).to(embeds.device) p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids) if not self.lora else self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids) wrapped_embeds = torch.cat([p_before_embeds, embeds, p_after_embeds], dim=1) wrapped_atts = torch.cat([p_before_tokens.attention_mask, atts, p_after_tokens.attention_mask], dim=1) else: batch_size = embeds.shape[0] p_before, p_after = prompt.split("") p_before_tokens = self.llama_tokenizer( p_before, return_tensors="pt", add_special_tokens=False ).to(embeds.device) p_after_tokens = self.llama_tokenizer( p_after, return_tensors="pt", add_special_tokens=False ).to(embeds.device) p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) if not self.lora else self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) if not self.lora else self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) wrapped_embeds = torch.cat([p_before_embeds, embeds, p_after_embeds], dim=1) wrapped_atts = torch.cat([p_before_tokens.attention_mask, atts, p_after_tokens.attention_mask], dim=1) return wrapped_embeds, wrapped_atts else: return embeds, atts def forward(self, samples, verbose=False): # detect whether there are multi tasks in this batch task = list(set(samples["task"])) if len(task) > 1 or "QA" in task: self.multi_prompt = True # prepare prompts if self.prompt_dict: if self.multi_prompt: prompt = [random.choice(self.prompt_dict[task]) for task in samples["task"]] if "Q" in samples: prompt = [p.format(q) if '{}' in p else p for p, q in zip(prompt, samples["Q"]) ] else: prompt = random.choice(self.prompt_dict[samples["task"][0]]) # use speech/audio encoder to encode speech/audio spectrogram = samples["spectrogram"] raw_wav = samples.get("raw_wav", None) # print(raw_wav) audio_padding_mask = samples.get("padding_mask", None) speech_embeds, speech_atts = self.encode_speech(spectrogram, raw_wav=raw_wav, audio_padding_mask=audio_padding_mask) # wrap speech_embeds with prompts if self.prompt_dict: speech_embeds, speech_atts = self.prompt_wrap(speech_embeds, speech_atts, prompt, multi_prompt=self.multi_prompt) # prepare inputs for LLM text = [t + self.end_sym for t in samples["text"]] to_regress_tokens = self.llama_tokenizer( text, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_txt_len, add_special_tokens=False ).to(spectrogram.device) to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) if not self.lora else self.llama_model.model.model.embed_tokens(to_regress_tokens.input_ids) targets = to_regress_tokens.input_ids.masked_fill( to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 ) empty_targets = ( torch.ones( [speech_atts.shape[0], speech_atts.shape[1] + 1], dtype=torch.long ).to(spectrogram.device).fill_(-100) ) targets = torch.cat([empty_targets, targets], dim=1) batch_size = speech_embeds.shape[0] bos = torch.ones( [batch_size, 1], dtype=to_regress_tokens.input_ids.dtype, device=to_regress_tokens.input_ids.device, ) * self.llama_tokenizer.bos_token_id bos_embeds = self.llama_model.model.embed_tokens(bos) if not self.lora else self.llama_model.model.model.embed_tokens(bos) atts_bos = speech_atts[:, :1] inputs_embeds = torch.cat([bos_embeds, speech_embeds, to_regress_embeds], dim=1) attention_mask = torch.cat([atts_bos, speech_atts, to_regress_tokens.attention_mask], dim=1) # calulate loss with self.maybe_autocast(): outputs = self.llama_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True, labels=targets, ) loss = outputs.loss if verbose: nvocab = self.llama_model.config.vocab_size results = outputs.logits[:, empty_targets.size(1) - 1: -1, :].contiguous().view(-1, nvocab).argmax(dim=-1) labels = targets[:, empty_targets.size(1):].contiguous().view(-1) mask = (labels != -100) correct = (results[mask] == labels[mask]).float().sum() total = len(labels[mask]) if verbose: return {"loss": loss, "correct": correct, "total": total} return {"loss": loss} def generate(self, samples, generate_cfg, prompts=None): batch_size = samples["spectrogram"].shape[0] spectrogram = samples["spectrogram"] raw_wav = samples.get("raw_wav", None) audio_padding_mask = samples.get("padding_mask", None) speech_embeds, speech_atts = self.encode_speech(spectrogram, raw_wav=raw_wav, audio_padding_mask=audio_padding_mask) if prompts is not None: speech_embeds, speech_atts = self.prompt_wrap(speech_embeds, speech_atts, prompts, multi_prompt=True) bos = torch.ones( [batch_size, 1], dtype=torch.int32, device=speech_embeds.device, ) * self.llama_tokenizer.bos_token_id bos_embeds = self.llama_model.model.embed_tokens(bos) if not self.lora else self.llama_model.model.model.embed_tokens(bos) atts_bos = speech_atts[:, :1] embeds = torch.cat([bos_embeds, speech_embeds], dim=1) attns = torch.cat([atts_bos, speech_atts], dim=1) stop_words_ids = [torch.tensor([2]).cuda()] stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) outputs = self.llama_model.generate( inputs_embeds=embeds, max_new_tokens=generate_cfg.get("max_new_tokens", 200), stopping_criteria=stopping_criteria, num_beams=generate_cfg.get("num_beams", 4), do_sample=generate_cfg.get("do_sample", False), min_length=generate_cfg.get("min_length", 1), temperature=generate_cfg.get("temperature", 1.0), top_p=generate_cfg.get("top_p", 0.9), repetition_penalty=generate_cfg.get("repetition_penalty", 1.0), length_penalty=generate_cfg.get("length_penalty", 1.0), attention_mask=attns, ) text = self.llama_tokenizer.batch_decode(outputs, add_special_tokens=False) return text @classmethod def from_config(cls, config): llama_path = config.get("llama_path") whisper_path = config.get("whisper_path") freeze_whisper = config.get("freeze_whisper", True) beats_path = config.get("beats_path", "") freeze_beats = config.get("freeze_beats", True) use_speech_Qformer = config.get("use_speech_Qformer", True) num_speech_query_token = config.get("num_speech_query_token", 1) freeze_speech_QFormer = config.get("freeze_speech_QFormer", False) window_level_Qformer = config.get("window_level_Qformer", True) second_per_window = config.get("second_per_window", 0.333333) second_stride = config.get("second_stride", 0.333333) speech_llama_proj_model = config.get("speech_llama_proj_model", "") freeze_speech_llama_proj = config.get("freeze_speech_llama_proj", False) lora = config.get("lora", True) lora_rank = config.get("lora_rank", 8) lora_alpha = config.get("lora_alpha", 32) lora_dropout = config.get("lora_dropout", 0.1) multi_prompt = config.get("multi_prompt", False) prompt_path = config.get("prompt_path", "") prompt_template = config.get("prompt_template", "") max_txt_len = config.get("max_txt_len", 128) end_sym = config.get("end_sym", "") low_resource = config.get("low_resource", False) device_8bit = config.get("device_8bit", 0) model = cls( llama_path=llama_path, whisper_path=whisper_path, freeze_whisper=freeze_whisper, beats_path=beats_path, freeze_beats=freeze_beats, use_speech_Qformer=use_speech_Qformer, num_speech_query_token=num_speech_query_token, freeze_speech_QFormer=freeze_speech_QFormer, window_level_Qformer=window_level_Qformer, second_per_window=second_per_window, second_stride=second_stride, speech_llama_proj_model=speech_llama_proj_model, freeze_speech_llama_proj=freeze_speech_llama_proj, lora=lora, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, multi_prompt=multi_prompt, prompt_path=prompt_path, prompt_template=prompt_template, max_txt_len=max_txt_len, end_sym=end_sym, low_resource=low_resource, device_8bit=device_8bit, ) ckpt_path = config.get("ckpt", "") if ckpt_path: logging.info("Load TinyOctopus ckpt from: {}".format(ckpt_path)) ckpt = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(ckpt['model'], strict=False) return model