|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import json |
|
import contextlib |
|
import random |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import LlamaTokenizer, StoppingCriteriaList |
|
from peft import LoraConfig, TaskType, get_peft_model |
|
|
|
from .Qformer import BertConfig, BertLMHeadModel |
|
from .modeling_llama import LlamaForCausalLM |
|
from .modeling_whisper import WhisperModel |
|
from .beats.BEATs import BEATsConfig, BEATs |
|
from .utils import StoppingCriteriaSub |
|
|
|
|
|
class TINYOCTOPUS(nn.Module): |
|
@classmethod |
|
def init_speech_Qformer(cls, num_query_token, speech_width, num_hidden_layers=2): |
|
encoder_config = BertConfig.from_pretrained("bert-base-uncased") |
|
encoder_config.num_hidden_layers = num_hidden_layers |
|
encoder_config.encoder_width = speech_width |
|
|
|
encoder_config.add_cross_attention = True |
|
encoder_config.cross_attention_freq = 1 |
|
encoder_config.query_length = num_query_token |
|
Qformer = BertLMHeadModel(config=encoder_config) |
|
query_tokens = nn.Parameter( |
|
torch.zeros(1, num_query_token, encoder_config.hidden_size) |
|
) |
|
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) |
|
return Qformer, query_tokens |
|
|
|
@property |
|
def device(self): |
|
return list(self.parameters())[0].device |
|
|
|
def maybe_autocast(self, dtype=torch.float16): |
|
|
|
|
|
enable_autocast = self.device != torch.device("cpu") |
|
|
|
if enable_autocast: |
|
return torch.cuda.amp.autocast(dtype=dtype) |
|
else: |
|
return contextlib.nullcontext() |
|
|
|
def __init__( |
|
self, |
|
llama_path="", |
|
whisper_path="", |
|
freeze_whisper=True, |
|
beats_path="", |
|
freeze_beats=True, |
|
|
|
use_speech_Qformer=True, |
|
num_speech_query_token=1, |
|
freeze_speech_QFormer=False, |
|
window_level_Qformer=True, |
|
second_per_window=0.333333, |
|
second_stride=0.333333, |
|
|
|
speech_llama_proj_model="", |
|
freeze_speech_llama_proj=False, |
|
|
|
lora=True, |
|
lora_rank=8, |
|
lora_alpha=32, |
|
lora_dropout=0.1, |
|
|
|
multi_prompt=False, |
|
prompt_path="", |
|
prompt_template="", |
|
max_txt_len=128, |
|
end_sym="</s>", |
|
low_resource=False, |
|
device_8bit=0, |
|
): |
|
super().__init__() |
|
|
|
self.beats_path = beats_path |
|
self.use_speech_Qformer = use_speech_Qformer |
|
self.window_level_Qformer = window_level_Qformer |
|
self.second_per_window = second_per_window |
|
self.second_stride = second_stride |
|
self.lora = lora |
|
self.multi_prompt = multi_prompt |
|
self.max_txt_len = max_txt_len |
|
self.end_sym = end_sym |
|
self.low_resource = low_resource |
|
|
|
logging.info('Loading LLaMA Tokenizer') |
|
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_path, use_fast=False) |
|
self.llama_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
self.llama_tokenizer.padding_side = "right" |
|
|
|
logging.info('Loading LLaMA Model') |
|
if self.low_resource: |
|
self.llama_model = LlamaForCausalLM.from_pretrained( |
|
llama_path, |
|
torch_dtype=torch.float16, |
|
load_in_8bit=True, |
|
device_map={"": device_8bit}, |
|
) |
|
else: |
|
self.llama_model = LlamaForCausalLM.from_pretrained( |
|
llama_path, |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
self.llama_model.resize_token_embeddings(len(self.llama_tokenizer)) |
|
for name, param in self.llama_model.named_parameters(): |
|
param.requires_grad = False |
|
logging.info('Loading LLaMA Done') |
|
|
|
if self.lora: |
|
self.peft_config = LoraConfig( |
|
task_type=TaskType.CAUSAL_LM, |
|
inference_mode=False, |
|
r=lora_rank, |
|
lora_alpha=lora_alpha, |
|
lora_dropout=lora_dropout, |
|
) |
|
self.llama_model = get_peft_model(self.llama_model, self.peft_config) |
|
self.llama_model.print_trainable_parameters() |
|
logging.info('LoRA Training') |
|
|
|
assert whisper_path |
|
logging.info('Loading Whisper Model') |
|
self.speech_encoder = WhisperModel.from_pretrained(whisper_path).encoder |
|
self.ln_speech = nn.LayerNorm(self.speech_encoder.config.d_model) |
|
if freeze_whisper: |
|
for name, param in self.speech_encoder.named_parameters(): |
|
param.requires_grad = False |
|
self.speech_encoder.eval() |
|
logging.info("freeze Whisper") |
|
|
|
if self.beats_path: |
|
logging.info("Loading BEATs Model") |
|
beats_ckpt = torch.load(self.beats_path, map_location='cpu') |
|
beats_cfg = BEATsConfig(beats_ckpt['cfg']) |
|
self.beats = BEATs(beats_cfg) |
|
self.beats.load_state_dict(beats_ckpt['model']) |
|
self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim) |
|
if freeze_beats: |
|
for name, param in self.beats.named_parameters(): |
|
param.requires_grad = False |
|
self.beats.eval() |
|
logging.info("freeze BEATs") |
|
|
|
if self.use_speech_Qformer: |
|
if self.beats_path: |
|
self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer( |
|
num_query_token=num_speech_query_token, speech_width=self.speech_encoder.config.d_model + self.beats.cfg.encoder_embed_dim |
|
) |
|
else: |
|
self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer( |
|
num_query_token=num_speech_query_token, speech_width=self.speech_encoder.config.d_model |
|
) |
|
self.speech_Qformer.bert.embeddings.word_embeddings = None |
|
self.speech_Qformer.bert.embeddings.position_embeddings = None |
|
for layer in self.speech_Qformer.bert.encoder.layer: |
|
layer.output = None |
|
layer.intermediate = None |
|
self.speech_Qformer.cls = None |
|
if freeze_speech_QFormer: |
|
for name, param in self.speech_Qformer.named_parameters(): |
|
param.requires_grad = False |
|
self.speech_Qformer.eval() |
|
self.speech_query_tokens.requires_grad = False |
|
logging.info("freeze Speech QFormer") |
|
|
|
logging.info('Loading speech LLAMA proj') |
|
self.speech_llama_proj = nn.Linear( |
|
self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size |
|
) |
|
if speech_llama_proj_model: |
|
logging.info("Loading speech LLAMA proj from {}".format(speech_llama_proj_model)) |
|
speech_llama_proj_weight = torch.load(speech_llama_proj_model, map_location="cpu") |
|
self.load_state_dict(speech_llama_proj_weight['model'], strict=False) |
|
if freeze_speech_llama_proj: |
|
for name, param in self.speech_llama_proj.named_parameters(): |
|
param.requires_grad = False |
|
self.speech_llama_proj.eval() |
|
logging.info("freeze speech LLAMA proj") |
|
else: |
|
|
|
raise NotImplementedError |
|
|
|
|
|
self.prompt_dict = {} |
|
if prompt_path: |
|
try: |
|
raw_prompts = json.load(open(prompt_path, "r")) |
|
except: |
|
print("Failed to load prompt! Try to use utf-8 encoding.") |
|
raw_prompts = json.load(open(prompt_path, "r", encoding='utf-8')) |
|
for task in raw_prompts.keys(): |
|
filted_prompts = [raw_prompt for raw_prompt in raw_prompts[task] if "<SpeechHere>" in raw_prompt] |
|
self.prompt_dict[task] = [prompt_template.format(p) for p in filted_prompts] |
|
print("Loading training prompts done!") |
|
|
|
def _encode_auditory_feature(self, speech_embeds, audio_embeds=None): |
|
with self.maybe_autocast(): |
|
if self.use_speech_Qformer: |
|
speech_embeds = self.ln_speech(speech_embeds) |
|
if audio_embeds is not None: |
|
audio_embeds = self.ln_audio(audio_embeds) |
|
if audio_embeds.size(1) < speech_embeds.size(1): |
|
audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) |
|
elif audio_embeds.size(1) > speech_embeds.size(1): |
|
speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) |
|
speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) |
|
speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long).to(speech_embeds.device) |
|
|
|
if self.window_level_Qformer: |
|
B, T, C = speech_embeds.shape |
|
kernel = round(1500 * self.second_per_window / 30.0) |
|
stride = round(1500 * self.second_stride / 30.0) |
|
kernel = (1, kernel) |
|
stride = (1, stride) |
|
speech_embeds_tr = speech_embeds.transpose(1, 2).unsqueeze(2) |
|
speech_embeds_overlap = F.unfold(speech_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride) |
|
_, _, L = speech_embeds_overlap.shape |
|
speech_embeds_overlap = speech_embeds_overlap.view(B, -1, kernel[1], L) |
|
speech_embeds_overlap = torch.permute(speech_embeds_overlap, [0, 3, 2, 1]) |
|
speech_embeds = speech_embeds_overlap.reshape(-1, kernel[1], C) |
|
speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long, device=speech_embeds.device) |
|
|
|
query_tokens = self.speech_query_tokens.expand(speech_embeds.shape[0], -1, -1) |
|
query_output = self.speech_Qformer.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=speech_embeds, |
|
encoder_attention_mask=speech_atts, |
|
return_dict=True, |
|
) |
|
speech_embeds = self.speech_llama_proj(query_output.last_hidden_state) |
|
|
|
if self.window_level_Qformer: |
|
speech_embeds = speech_embeds.view(B, -1, speech_embeds.size(2)).contiguous() |
|
|
|
speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long).to(speech_embeds.device) |
|
else: |
|
raise NotImplementedError |
|
|
|
return speech_embeds, speech_atts |
|
|
|
def encode_speech(self, spectrogram, raw_wav=None, audio_padding_mask=None): |
|
with self.maybe_autocast(): |
|
speech_embeds = self.speech_encoder(spectrogram, return_dict=True).last_hidden_state |
|
|
|
if self.beats_path and raw_wav is not None: |
|
audio_embeds, _ = self.beats.extract_features(raw_wav, padding_mask=audio_padding_mask, feature_only=True) |
|
else: |
|
audio_embeds = None |
|
|
|
return self._encode_auditory_feature(speech_embeds, audio_embeds=audio_embeds) |
|
|
|
def prompt_wrap(self, embeds, atts, prompt, multi_prompt=False): |
|
if prompt: |
|
if multi_prompt: |
|
p_before = [] |
|
p_after = [] |
|
for i, p in enumerate(prompt): |
|
b, a = p.split("<SpeechHere>") |
|
p_before.append(b) |
|
p_after.append(a) |
|
|
|
p_before_tokens = self.llama_tokenizer( |
|
p_before, return_tensors="pt", add_special_tokens=False |
|
).to(embeds.device) |
|
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids) if not self.lora else self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids) |
|
|
|
|
|
p_after_tokens = self.llama_tokenizer( |
|
p_after, return_tensors="pt", padding="longest", add_special_tokens=False |
|
).to(embeds.device) |
|
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids) if not self.lora else self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids) |
|
|
|
wrapped_embeds = torch.cat([p_before_embeds, embeds, p_after_embeds], dim=1) |
|
wrapped_atts = torch.cat([p_before_tokens.attention_mask, atts, p_after_tokens.attention_mask], dim=1) |
|
else: |
|
batch_size = embeds.shape[0] |
|
p_before, p_after = prompt.split("<SpeechHere>") |
|
|
|
p_before_tokens = self.llama_tokenizer( |
|
p_before, return_tensors="pt", add_special_tokens=False |
|
).to(embeds.device) |
|
p_after_tokens = self.llama_tokenizer( |
|
p_after, return_tensors="pt", add_special_tokens=False |
|
).to(embeds.device) |
|
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) if not self.lora else self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) |
|
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) if not self.lora else self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) |
|
|
|
wrapped_embeds = torch.cat([p_before_embeds, embeds, p_after_embeds], dim=1) |
|
wrapped_atts = torch.cat([p_before_tokens.attention_mask, atts, p_after_tokens.attention_mask], dim=1) |
|
return wrapped_embeds, wrapped_atts |
|
else: |
|
return embeds, atts |
|
|
|
def forward(self, samples, verbose=False): |
|
|
|
task = list(set(samples["task"])) |
|
if len(task) > 1 or "QA" in task: |
|
self.multi_prompt = True |
|
|
|
|
|
if self.prompt_dict: |
|
if self.multi_prompt: |
|
prompt = [random.choice(self.prompt_dict[task]) for task in samples["task"]] |
|
if "Q" in samples: |
|
prompt = [p.format(q) if '{}' in p else p for p, q in zip(prompt, samples["Q"]) ] |
|
else: |
|
prompt = random.choice(self.prompt_dict[samples["task"][0]]) |
|
|
|
|
|
spectrogram = samples["spectrogram"] |
|
raw_wav = samples.get("raw_wav", None) |
|
|
|
audio_padding_mask = samples.get("padding_mask", None) |
|
|
|
speech_embeds, speech_atts = self.encode_speech(spectrogram, raw_wav=raw_wav, audio_padding_mask=audio_padding_mask) |
|
|
|
|
|
if self.prompt_dict: |
|
speech_embeds, speech_atts = self.prompt_wrap(speech_embeds, speech_atts, prompt, multi_prompt=self.multi_prompt) |
|
|
|
|
|
text = [t + self.end_sym for t in samples["text"]] |
|
to_regress_tokens = self.llama_tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding="longest", |
|
truncation=True, |
|
max_length=self.max_txt_len, |
|
add_special_tokens=False |
|
).to(spectrogram.device) |
|
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) if not self.lora else self.llama_model.model.model.embed_tokens(to_regress_tokens.input_ids) |
|
targets = to_regress_tokens.input_ids.masked_fill( |
|
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 |
|
) |
|
empty_targets = ( |
|
torch.ones( |
|
[speech_atts.shape[0], speech_atts.shape[1] + 1], |
|
dtype=torch.long |
|
).to(spectrogram.device).fill_(-100) |
|
) |
|
targets = torch.cat([empty_targets, targets], dim=1) |
|
|
|
batch_size = speech_embeds.shape[0] |
|
bos = torch.ones( |
|
[batch_size, 1], |
|
dtype=to_regress_tokens.input_ids.dtype, |
|
device=to_regress_tokens.input_ids.device, |
|
) * self.llama_tokenizer.bos_token_id |
|
bos_embeds = self.llama_model.model.embed_tokens(bos) if not self.lora else self.llama_model.model.model.embed_tokens(bos) |
|
atts_bos = speech_atts[:, :1] |
|
|
|
inputs_embeds = torch.cat([bos_embeds, speech_embeds, to_regress_embeds], dim=1) |
|
attention_mask = torch.cat([atts_bos, speech_atts, to_regress_tokens.attention_mask], dim=1) |
|
|
|
|
|
with self.maybe_autocast(): |
|
outputs = self.llama_model( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
return_dict=True, |
|
labels=targets, |
|
) |
|
loss = outputs.loss |
|
|
|
if verbose: |
|
nvocab = self.llama_model.config.vocab_size |
|
results = outputs.logits[:, empty_targets.size(1) - 1: -1, :].contiguous().view(-1, nvocab).argmax(dim=-1) |
|
labels = targets[:, empty_targets.size(1):].contiguous().view(-1) |
|
mask = (labels != -100) |
|
correct = (results[mask] == labels[mask]).float().sum() |
|
total = len(labels[mask]) |
|
|
|
if verbose: |
|
return {"loss": loss, "correct": correct, "total": total} |
|
|
|
return {"loss": loss} |
|
|
|
def generate(self, samples, generate_cfg, prompts=None): |
|
batch_size = samples["spectrogram"].shape[0] |
|
|
|
spectrogram = samples["spectrogram"] |
|
raw_wav = samples.get("raw_wav", None) |
|
audio_padding_mask = samples.get("padding_mask", None) |
|
|
|
speech_embeds, speech_atts = self.encode_speech(spectrogram, raw_wav=raw_wav, audio_padding_mask=audio_padding_mask) |
|
|
|
if prompts is not None: |
|
speech_embeds, speech_atts = self.prompt_wrap(speech_embeds, speech_atts, prompts, multi_prompt=True) |
|
|
|
bos = torch.ones( |
|
[batch_size, 1], |
|
dtype=torch.int32, |
|
device=speech_embeds.device, |
|
) * self.llama_tokenizer.bos_token_id |
|
bos_embeds = self.llama_model.model.embed_tokens(bos) if not self.lora else self.llama_model.model.model.embed_tokens(bos) |
|
atts_bos = speech_atts[:, :1] |
|
|
|
embeds = torch.cat([bos_embeds, speech_embeds], dim=1) |
|
attns = torch.cat([atts_bos, speech_atts], dim=1) |
|
|
|
stop_words_ids = [torch.tensor([2]).cuda()] |
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
|
outputs = self.llama_model.generate( |
|
inputs_embeds=embeds, |
|
max_new_tokens=generate_cfg.get("max_new_tokens", 200), |
|
stopping_criteria=stopping_criteria, |
|
num_beams=generate_cfg.get("num_beams", 4), |
|
do_sample=generate_cfg.get("do_sample", False), |
|
min_length=generate_cfg.get("min_length", 1), |
|
temperature=generate_cfg.get("temperature", 1.0), |
|
top_p=generate_cfg.get("top_p", 0.9), |
|
repetition_penalty=generate_cfg.get("repetition_penalty", 1.0), |
|
length_penalty=generate_cfg.get("length_penalty", 1.0), |
|
attention_mask=attns, |
|
) |
|
text = self.llama_tokenizer.batch_decode(outputs, add_special_tokens=False) |
|
|
|
return text |
|
|
|
@classmethod |
|
def from_config(cls, config): |
|
llama_path = config.get("llama_path") |
|
whisper_path = config.get("whisper_path") |
|
freeze_whisper = config.get("freeze_whisper", True) |
|
beats_path = config.get("beats_path", "") |
|
freeze_beats = config.get("freeze_beats", True) |
|
|
|
use_speech_Qformer = config.get("use_speech_Qformer", True) |
|
num_speech_query_token = config.get("num_speech_query_token", 1) |
|
freeze_speech_QFormer = config.get("freeze_speech_QFormer", False) |
|
window_level_Qformer = config.get("window_level_Qformer", True) |
|
second_per_window = config.get("second_per_window", 0.333333) |
|
second_stride = config.get("second_stride", 0.333333) |
|
|
|
speech_llama_proj_model = config.get("speech_llama_proj_model", "") |
|
freeze_speech_llama_proj = config.get("freeze_speech_llama_proj", False) |
|
|
|
lora = config.get("lora", True) |
|
lora_rank = config.get("lora_rank", 8) |
|
lora_alpha = config.get("lora_alpha", 32) |
|
lora_dropout = config.get("lora_dropout", 0.1) |
|
|
|
multi_prompt = config.get("multi_prompt", False) |
|
prompt_path = config.get("prompt_path", "") |
|
prompt_template = config.get("prompt_template", "") |
|
max_txt_len = config.get("max_txt_len", 128) |
|
end_sym = config.get("end_sym", "</s>") |
|
low_resource = config.get("low_resource", False) |
|
device_8bit = config.get("device_8bit", 0) |
|
|
|
model = cls( |
|
llama_path=llama_path, |
|
whisper_path=whisper_path, |
|
freeze_whisper=freeze_whisper, |
|
beats_path=beats_path, |
|
freeze_beats=freeze_beats, |
|
use_speech_Qformer=use_speech_Qformer, |
|
num_speech_query_token=num_speech_query_token, |
|
freeze_speech_QFormer=freeze_speech_QFormer, |
|
window_level_Qformer=window_level_Qformer, |
|
second_per_window=second_per_window, |
|
second_stride=second_stride, |
|
speech_llama_proj_model=speech_llama_proj_model, |
|
freeze_speech_llama_proj=freeze_speech_llama_proj, |
|
lora=lora, |
|
lora_rank=lora_rank, |
|
lora_alpha=lora_alpha, |
|
lora_dropout=lora_dropout, |
|
multi_prompt=multi_prompt, |
|
prompt_path=prompt_path, |
|
prompt_template=prompt_template, |
|
max_txt_len=max_txt_len, |
|
end_sym=end_sym, |
|
low_resource=low_resource, |
|
device_8bit=device_8bit, |
|
) |
|
|
|
ckpt_path = config.get("ckpt", "") |
|
if ckpt_path: |
|
logging.info("Load TinyOctopus ckpt from: {}".format(ckpt_path)) |
|
ckpt = torch.load(ckpt_path, map_location="cpu") |
|
model.load_state_dict(ckpt['model'], strict=False) |
|
|
|
return model |
|
|