|
import math |
|
import warnings |
|
from typing import Union, Tuple, Optional |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.modeling_outputs import ( |
|
SequenceClassifierOutput, |
|
Wav2Vec2BaseModelOutput, |
|
Seq2SeqModelOutput, |
|
BaseModelOutput |
|
) |
|
from transformers.cache_utils import ( |
|
Cache, |
|
DynamicCache, |
|
EncoderDecoderCache, |
|
StaticCache |
|
) |
|
from transformers.models.whisper.modeling_whisper import ( |
|
WhisperEncoder, |
|
WhisperEncoderLayer, |
|
WhisperDecoderLayer, |
|
WhisperDecoder, |
|
_HIDDEN_STATES_START_POSITION |
|
) |
|
|
|
from .configuration_whisper_spkreg import WhisperSpkRegConfig |
|
|
|
|
|
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: |
|
"""Returns sinusoids for positional embedding""" |
|
if channels % 2 != 0: |
|
raise ValueError( |
|
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels." |
|
) |
|
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) |
|
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) |
|
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) |
|
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) |
|
|
|
|
|
def _compute_mask_indices( |
|
shape: Tuple[int, int], |
|
mask_prob: float, |
|
mask_length: int, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
min_masks: int = 0, |
|
) -> np.ndarray: |
|
""" |
|
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for |
|
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on |
|
CPU as part of the preprocessing during training. |
|
|
|
Args: |
|
shape: The shape for which to compute masks. This should be of a tuple of size 2 where |
|
the first element is the batch size and the second element is the length of the axis to span. |
|
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of |
|
independently generated mask spans of length `mask_length` is computed by |
|
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the |
|
actual percentage will be smaller. |
|
mask_length: size of the mask |
|
min_masks: minimum number of masked spans |
|
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of |
|
each batch dimension. |
|
""" |
|
batch_size, sequence_length = shape |
|
|
|
if mask_length < 1: |
|
raise ValueError("`mask_length` has to be bigger than 0.") |
|
|
|
if mask_length > sequence_length: |
|
raise ValueError( |
|
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" |
|
f" and `sequence_length`: {sequence_length}`" |
|
) |
|
|
|
|
|
epsilon = np.random.rand(1).item() |
|
|
|
def compute_num_masked_span(input_length): |
|
"""Given input length, compute how many spans should be masked""" |
|
num_masked_span = int(mask_prob * input_length / mask_length + epsilon) |
|
num_masked_span = max(num_masked_span, min_masks) |
|
|
|
|
|
if num_masked_span * mask_length > sequence_length: |
|
num_masked_span = sequence_length // mask_length |
|
|
|
|
|
if input_length - (mask_length - 1) < num_masked_span: |
|
num_masked_span = max(input_length - (mask_length - 1), 0) |
|
|
|
return num_masked_span |
|
|
|
|
|
input_lengths = ( |
|
attention_mask.sum(-1).detach().tolist() |
|
if attention_mask is not None |
|
else [sequence_length for _ in range(batch_size)] |
|
) |
|
|
|
|
|
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) |
|
spec_aug_mask_idxs = [] |
|
|
|
max_num_masked_span = compute_num_masked_span(sequence_length) |
|
|
|
if max_num_masked_span == 0: |
|
return spec_aug_mask |
|
|
|
for input_length in input_lengths: |
|
|
|
num_masked_span = compute_num_masked_span(input_length) |
|
|
|
|
|
spec_aug_mask_idx = np.random.choice( |
|
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False |
|
) |
|
|
|
|
|
|
|
|
|
if len(spec_aug_mask_idx) == 0: |
|
|
|
|
|
|
|
dummy_mask_idx = sequence_length - 1 |
|
else: |
|
dummy_mask_idx = spec_aug_mask_idx[0] |
|
|
|
spec_aug_mask_idx = np.concatenate( |
|
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] |
|
) |
|
spec_aug_mask_idxs.append(spec_aug_mask_idx) |
|
|
|
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) |
|
|
|
|
|
spec_aug_mask_idxs = np.broadcast_to( |
|
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) |
|
) |
|
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) |
|
|
|
|
|
offsets = np.arange(mask_length)[None, None, :] |
|
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( |
|
batch_size, max_num_masked_span * mask_length |
|
) |
|
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets |
|
|
|
|
|
if spec_aug_mask_idxs.max() > sequence_length - 1: |
|
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 |
|
|
|
|
|
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) |
|
|
|
return spec_aug_mask |
|
|
|
|
|
class WhisperSpkRegPreTrainedModel(PreTrainedModel): |
|
|
|
config_class = WhisperSpkRegConfig |
|
base_model_prefix = "model" |
|
main_input_name = "input_features" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_cache_class = True |
|
_supports_static_cache = True |
|
|
|
def _init_weights(self, module): |
|
std = self.config.init_std |
|
if isinstance(module, (nn.Linear, nn.Conv1d)): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, WhisperEncoder): |
|
with torch.no_grad(): |
|
embed_positions = module.embed_positions.weight |
|
embed_positions.copy_(sinusoids(*embed_positions.shape)) |
|
|
|
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): |
|
""" |
|
Computes the output length of the convolutional layers |
|
""" |
|
input_lengths = (input_lengths - 1) // 2 + 1 |
|
|
|
return input_lengths |
|
|
|
|
|
class WhisperSpkRegModel(WhisperSpkRegPreTrainedModel): |
|
|
|
def __init__(self, config: WhisperSpkRegConfig): |
|
super().__init__(config) |
|
|
|
self.encoder = WhisperEncoder(config) |
|
self.decoder = WhisperDecoder(config) |
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.decoder.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.decoder.embed_tokens = value |
|
|
|
def get_encoder(self): |
|
return self.encoder |
|
|
|
def get_decoder(self): |
|
return self.decoder |
|
|
|
def freeze_encoder(self): |
|
""" |
|
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will |
|
not be updated during training. |
|
""" |
|
self.encoder._freeze_parameters() |
|
|
|
def _mask_input_features( |
|
self, |
|
input_features: torch.FloatTensor, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
): |
|
""" |
|
Masks extracted features along time axis and/or along feature axis according to |
|
[SpecAugment](https://arxiv.org/abs/1904.08779). |
|
""" |
|
|
|
|
|
if not getattr(self.config, "apply_spec_augment", True): |
|
return input_features |
|
|
|
|
|
batch_size, hidden_size, sequence_length = input_features.size() |
|
|
|
if self.config.mask_time_prob > 0 and self.training: |
|
|
|
mask_time_indices = _compute_mask_indices( |
|
(batch_size, sequence_length), |
|
mask_prob=self.config.mask_time_prob, |
|
mask_length=self.config.mask_time_length, |
|
attention_mask=attention_mask, |
|
min_masks=self.config.mask_time_min_masks, |
|
) |
|
mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool) |
|
mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1) |
|
input_features[mask_time_indices] = 0 |
|
|
|
if self.config.mask_feature_prob > 0 and self.training: |
|
|
|
mask_feature_indices = _compute_mask_indices( |
|
(batch_size, hidden_size), |
|
mask_prob=self.config.mask_feature_prob, |
|
mask_length=self.config.mask_feature_length, |
|
min_masks=self.config.mask_feature_min_masks, |
|
) |
|
mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool) |
|
input_features[mask_feature_indices] = 0 |
|
|
|
return input_features |
|
|
|
def forward( |
|
self, |
|
input_features: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
decoder_head_mask: Optional[torch.Tensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, |
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, |
|
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: |
|
r""" |
|
Returns: |
|
|
|
Example: |
|
```python |
|
>>> import torch |
|
>>> from transformers import AutoFeatureExtractor, WhisperModel |
|
>>> from datasets import load_dataset |
|
|
|
>>> model = WhisperModel.from_pretrained("openai/whisper-base") |
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") |
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
|
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") |
|
>>> input_features = inputs.input_features |
|
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id |
|
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state |
|
>>> list(last_hidden_state.shape) |
|
[1, 2, 512] |
|
```""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if encoder_outputs is None: |
|
input_features = self._mask_input_features(input_features, attention_mask=attention_mask) |
|
|
|
encoder_outputs = self.encoder( |
|
input_features, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): |
|
encoder_outputs = BaseModelOutput( |
|
last_hidden_state=encoder_outputs[0], |
|
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, |
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, |
|
) |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
encoder_hidden_states=encoder_outputs[0], |
|
head_mask=decoder_head_mask, |
|
cross_attn_head_mask=cross_attn_head_mask, |
|
past_key_values=past_key_values, |
|
inputs_embeds=decoder_inputs_embeds, |
|
position_ids=decoder_position_ids, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
) |
|
|
|
if not return_dict: |
|
return decoder_outputs + encoder_outputs |
|
|
|
return Seq2SeqModelOutput( |
|
last_hidden_state=decoder_outputs.last_hidden_state, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
encoder_attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
class AngularLinear(nn.Module): |
|
|
|
def __init__(self, in_features: int, out_features: int): |
|
super(AngularLinear, self).__init__() |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.weight = torch.nn.Parameter( |
|
torch.FloatTensor(out_features, in_features), requires_grad=True |
|
) |
|
nn.init.xavier_normal_(self.weight, gain=1) |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
): |
|
|
|
cosine = F.linear(F.normalize(inputs), F.normalize(self.weight)) |
|
return cosine |
|
|
|
def extra_repr(self) -> str: |
|
return 'in_features={}, out_features={}'.format( |
|
self.in_features, self.out_features |
|
) |
|
|
|
|
|
class AMSoftmaxLoss(nn.Module): |
|
"""Additive Margin Softmax (CosFace). |
|
|
|
Paper: Wang, Feng, et al. "Additive margin softmax for face verification." |
|
IEEE Signal Processing Letters 25.7 (2018): 926-930. |
|
""" |
|
def __init__( |
|
self, |
|
scale: float = 30.0, |
|
margin: float = 0.35, |
|
label_smoothing: float = 0.0, |
|
reduction: str = "mean" |
|
): |
|
""" |
|
Args: |
|
num_classes: Number of classes (output dimension) |
|
scale: Scaling factor for logits (default: 30.0) |
|
margin: Angular margin (default: 0.35) |
|
""" |
|
super(AMSoftmaxLoss, self).__init__() |
|
self.scale = scale |
|
self.margin = margin |
|
self.label_smoothing = label_smoothing |
|
self.reduction = reduction |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
): |
|
""" |
|
Args: |
|
inputs: Input features of shape (batch_size, num_labels) |
|
targets: Ground truth labels of shape (batch_size) |
|
label_smoothing: Label smoothing factor (default: 0.0) |
|
reduction: Reduction method (default: "mean") |
|
Returns: |
|
Loss value |
|
""" |
|
_, num_labels = inputs.shape |
|
|
|
cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7) |
|
psi = cos_theta - self.margin |
|
one_hot = nn.functional.one_hot(targets, num_labels) |
|
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) |
|
loss = F.cross_entropy( |
|
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction |
|
) |
|
return loss |
|
|
|
|
|
class AAMSoftmaxLoss(nn.Module): |
|
"""Additive Angular Margin Softmax (ArcFace). |
|
|
|
Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition." |
|
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019. |
|
""" |
|
def __init__( |
|
self, |
|
scale: float = 30.0, |
|
margin: float = 0.2, |
|
easy_margin: bool = False, |
|
label_smoothing: float = 0.0, |
|
reduction: str = "mean" |
|
): |
|
""" |
|
Args: |
|
num_classes: Number of classes (output dimension) |
|
scale: Scaling factor for logits (default: 30.0) |
|
margin: Angular margin (default: 0.35) |
|
easy_margin: Use the easy margin loss (default: False) |
|
""" |
|
super(AAMSoftmaxLoss, self).__init__() |
|
self.scale = scale |
|
self.margin = margin |
|
self.easy_margin = easy_margin |
|
self.label_smoothing = label_smoothing |
|
self.reduction = reduction |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
): |
|
""" |
|
Args: |
|
inputs: Input features of shape (batch_size, num_labels) |
|
targets: Ground truth labels of shape (batch_size) |
|
Returns: |
|
Loss value |
|
""" |
|
_, num_labels = inputs.shape |
|
|
|
epsilon = 1e-6 |
|
|
|
|
|
cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon) |
|
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) |
|
sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon) |
|
|
|
cos_m = math.cos(self.margin) |
|
sin_m = math.sin(self.margin) |
|
psi = cos_theta * cos_m - sin_theta * sin_m |
|
|
|
if self.easy_margin: |
|
psi = torch.where(cos_theta > 0, psi, cos_theta) |
|
else: |
|
|
|
psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin) |
|
|
|
one_hot = nn.functional.one_hot(targets, num_labels) |
|
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) |
|
loss = F.cross_entropy( |
|
outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction |
|
) |
|
return loss |
|
|
|
|
|
class WhisperSpkRegForSequenceClassification(WhisperSpkRegPreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.encoder = WhisperEncoder(config) |
|
num_layers = config.num_hidden_layers + 1 |
|
if config.use_weighted_layer_sum: |
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
|
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) |
|
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def freeze_encoder(self): |
|
""" |
|
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will |
|
not be updated during training. Only the projection layers and classification head will be updated. |
|
""" |
|
self.encoder._freeze_parameters() |
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
return self.encoder.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value: nn.Module): |
|
self.encoder.set_input_embeddings(value) |
|
|
|
def forward( |
|
self, |
|
input_features: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> import torch |
|
>>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification |
|
>>> from datasets import load_dataset |
|
|
|
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") |
|
>>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id") |
|
|
|
>>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True) |
|
>>> sample = next(iter(ds)) |
|
|
|
>>> inputs = feature_extractor( |
|
... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt" |
|
... ) |
|
>>> input_features = inputs.input_features |
|
|
|
>>> with torch.no_grad(): |
|
... logits = model(input_features).logits |
|
|
|
>>> predicted_class_ids = torch.argmax(logits).item() |
|
>>> predicted_label = model.config.id2label[predicted_class_ids] |
|
>>> predicted_label |
|
'Afrikaans' |
|
```""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
if self.config.use_weighted_layer_sum: |
|
output_hidden_states = True |
|
elif output_hidden_states is None: |
|
output_hidden_states = self.config.output_hidden_states |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if encoder_outputs is None: |
|
encoder_outputs = self.encoder( |
|
input_features, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if self.config.use_weighted_layer_sum: |
|
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION] |
|
hidden_states = torch.stack(hidden_states, dim=1) |
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) |
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
|
else: |
|
hidden_states = encoder_outputs[0] |
|
|
|
hidden_states = self.projector(hidden_states) |
|
pooled_output = hidden_states.mean(dim=1) |
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
if self.config.loss_fct == 'cross_entropy': |
|
loss_fct = nn.CrossEntropyLoss( |
|
label_smoothing=self.config.label_smoothing, |
|
reduction=self.config.reduction |
|
) |
|
elif self.config.loss_fct == 'additive_margin': |
|
loss_fct = AMSoftmaxLoss( |
|
scale=self.config.scale, |
|
margin=self.config.margin, |
|
label_smoothing=self.config.label_smoothing, |
|
reduction=self.config.reduction |
|
) |
|
elif self.config.loss_fct == 'additive_angular_margin': |
|
loss_fct = AAMSoftmaxLoss( |
|
scale=self.config.scale, |
|
margin=self.config.margin, |
|
easy_margin=self.config.easy_margin, |
|
label_smoothing=self.config.label_smoothing, |
|
reduction=self.config.reduction |
|
) |
|
loss = loss_fct( |
|
logits.view(-1, self.config.num_labels), |
|
labels.view(-1).to(logits.device), |
|
) |
|
|
|
if not return_dict: |
|
output = (logits,) + encoder_outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |