whisper-base-spkreg / modeling_whisper_spkreg.py
yangwang825's picture
Upload model
214b742 verified
raw
history blame
25.8 kB
import math
import warnings
from typing import Union, Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import (
SequenceClassifierOutput,
Wav2Vec2BaseModelOutput,
Seq2SeqModelOutput,
BaseModelOutput
)
from transformers.cache_utils import (
Cache,
DynamicCache,
EncoderDecoderCache,
StaticCache
)
from transformers.models.whisper.modeling_whisper import (
WhisperEncoder,
WhisperEncoderLayer,
WhisperDecoderLayer,
WhisperDecoder,
_HIDDEN_STATES_START_POSITION
)
from .configuration_whisper_spkreg import WhisperSpkRegConfig
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
"""Returns sinusoids for positional embedding"""
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
class WhisperSpkRegPreTrainedModel(PreTrainedModel):
config_class = WhisperSpkRegConfig
base_model_prefix = "model"
main_input_name = "input_features"
supports_gradient_checkpointing = True
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, WhisperEncoder):
with torch.no_grad():
embed_positions = module.embed_positions.weight
embed_positions.copy_(sinusoids(*embed_positions.shape))
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
"""
input_lengths = (input_lengths - 1) // 2 + 1
return input_lengths
class WhisperSpkRegModel(WhisperSpkRegPreTrainedModel):
def __init__(self, config: WhisperSpkRegConfig):
super().__init__(config)
self.encoder = WhisperEncoder(config)
self.decoder = WhisperDecoder(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.decoder.embed_tokens
def set_input_embeddings(self, value):
self.decoder.embed_tokens = value
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def freeze_encoder(self):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training.
"""
self.encoder._freeze_parameters()
def _mask_input_features(
self,
input_features: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return input_features
# generate indices & apply SpecAugment along time axis
batch_size, hidden_size, sequence_length = input_features.size()
if self.config.mask_time_prob > 0 and self.training:
# generate indices & apply SpecAugment along time axis
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks,
)
mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
input_features[mask_time_indices] = 0
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks,
)
mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
input_features[mask_feature_indices] = 0
return input_features
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
r"""
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, WhisperModel
>>> from datasets import load_dataset
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 512]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class AngularLinear(nn.Module):
def __init__(self, in_features: int, out_features: int):
super(AngularLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(
torch.FloatTensor(out_features, in_features), requires_grad=True
)
nn.init.xavier_normal_(self.weight, gain=1)
def forward(
self,
inputs: torch.Tensor,
):
# Calculation of cos(theta)
cosine = F.linear(F.normalize(inputs), F.normalize(self.weight))
return cosine
def extra_repr(self) -> str:
return 'in_features={}, out_features={}'.format(
self.in_features, self.out_features
)
class AMSoftmaxLoss(nn.Module):
"""Additive Margin Softmax (CosFace).
Paper: Wang, Feng, et al. "Additive margin softmax for face verification."
IEEE Signal Processing Letters 25.7 (2018): 926-930.
"""
def __init__(
self,
scale: float = 30.0,
margin: float = 0.35,
label_smoothing: float = 0.0,
reduction: str = "mean"
):
"""
Args:
num_classes: Number of classes (output dimension)
scale: Scaling factor for logits (default: 30.0)
margin: Angular margin (default: 0.35)
"""
super(AMSoftmaxLoss, self).__init__()
self.scale = scale
self.margin = margin
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(
self,
inputs: torch.Tensor,
targets: torch.Tensor,
):
"""
Args:
inputs: Input features of shape (batch_size, num_labels)
targets: Ground truth labels of shape (batch_size)
label_smoothing: Label smoothing factor (default: 0.0)
reduction: Reduction method (default: "mean")
Returns:
Loss value
"""
_, num_labels = inputs.shape
# `inputs` are the outputs from AngularLinear()
cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
psi = cos_theta - self.margin
one_hot = nn.functional.one_hot(targets, num_labels)
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
loss = F.cross_entropy(
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
)
return loss
class AAMSoftmaxLoss(nn.Module):
"""Additive Angular Margin Softmax (ArcFace).
Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition."
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
"""
def __init__(
self,
scale: float = 30.0,
margin: float = 0.35,
easy_margin: bool = False,
label_smoothing: float = 0.0,
reduction: str = "mean"
):
"""
Args:
num_classes: Number of classes (output dimension)
scale: Scaling factor for logits (default: 30.0)
margin: Angular margin (default: 0.35)
easy_margin: Use the easy margin loss (default: False)
"""
super(AAMSoftmaxLoss, self).__init__()
self.scale = scale
self.margin = margin
self.easy_margin = easy_margin
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(
self,
inputs: torch.Tensor,
targets: torch.Tensor,
):
"""
Args:
inputs: Input features of shape (batch_size, num_labels)
targets: Ground truth labels of shape (batch_size)
Returns:
Loss value
"""
_, num_labels = inputs.shape
# `inputs` are the outputs from AngularLinear()
cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
theta = torch.acos(cos_theta)
psi = torch.cos(theta + self.margin)
one_hot = nn.functional.one_hot(targets, num_labels)
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
loss = F.cross_entropy(
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
)
return loss
class WhisperSpkRegForSequenceClassification(WhisperSpkRegPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.encoder = WhisperEncoder(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def freeze_encoder(self):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training. Only the projection layers and classification head will be updated.
"""
self.encoder._freeze_parameters()
def get_input_embeddings(self) -> nn.Module:
return self.encoder.get_input_embeddings()
def set_input_embeddings(self, value: nn.Module):
self.encoder.set_input_embeddings(value)
def forward(
self,
input_features: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification
>>> from datasets import load_dataset
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
>>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
>>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
>>> sample = next(iter(ds))
>>> inputs = feature_extractor(
... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt"
... )
>>> input_features = inputs.input_features
>>> with torch.no_grad():
... logits = model(input_features).logits
>>> predicted_class_ids = torch.argmax(logits).item()
>>> predicted_label = model.config.id2label[predicted_class_ids]
>>> predicted_label
'Afrikaans'
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if self.config.use_weighted_layer_sum:
output_hidden_states = True
elif output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = encoder_outputs[0]
hidden_states = self.projector(hidden_states)
pooled_output = hidden_states.mean(dim=1)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.loss_fct == 'cross_entropy':
loss_fct = nn.CrossEntropyLoss(
label_smoothing=self.config.label_smoothing,
reduction=self.config.reduction
)
elif self.config.loss_fct == 'additive_margin':
loss_fct = AMSoftmaxLoss(
scale=self.config.scale,
margin=self.config.margin,
label_smoothing=self.config.label_smoothing,
reduction=self.config.reduction
)
elif self.config.loss_fct == 'additive_angular_margin':
loss_fct = AAMSoftmaxLoss(
scale=self.config.scale,
margin=self.config.margin,
easy_margin=self.config.easy_margin,
label_smoothing=self.config.label_smoothing,
reduction=self.config.reduction
)
loss = loss_fct(
logits.view(-1, self.config.num_labels),
labels.view(-1).to(logits.device),
)
if not return_dict:
output = (logits,) + encoder_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)