yangwang825 commited on
Commit
794401a
·
verified ·
1 Parent(s): be62aae

Update modeling_whisper_spkreg.py

Browse files
Files changed (1) hide show
  1. modeling_whisper_spkreg.py +18 -4
modeling_whisper_spkreg.py CHANGED
@@ -451,7 +451,7 @@ class AAMSoftmaxLoss(nn.Module):
451
  def __init__(
452
  self,
453
  scale: float = 30.0,
454
- margin: float = 0.35,
455
  easy_margin: bool = False,
456
  label_smoothing: float = 0.0,
457
  reduction: str = "mean"
@@ -484,9 +484,23 @@ class AAMSoftmaxLoss(nn.Module):
484
  """
485
  _, num_labels = inputs.shape
486
  # `inputs` are the outputs from AngularLinear()
487
- cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
488
- theta = torch.acos(cos_theta)
489
- psi = torch.cos(theta + self.margin)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  one_hot = nn.functional.one_hot(targets, num_labels)
491
  outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
492
  loss = F.cross_entropy(
 
451
  def __init__(
452
  self,
453
  scale: float = 30.0,
454
+ margin: float = 0.2,
455
  easy_margin: bool = False,
456
  label_smoothing: float = 0.0,
457
  reduction: str = "mean"
 
484
  """
485
  _, num_labels = inputs.shape
486
  # `inputs` are the outputs from AngularLinear()
487
+ epsilon = 1e-6
488
+ # theta = torch.acos(cos_theta)
489
+ # psi = torch.cos(theta + self.margin)
490
+ cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon)
491
+ sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
492
+ sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon)
493
+
494
+ cos_m = math.cos(self.margin)
495
+ sin_m = math.sin(self.margin)
496
+ psi = cos_theta * cos_m - sin_theta * sin_m # cos(theta + m)
497
+
498
+ if self.easy_margin:
499
+ psi = torch.where(cos_theta > 0, psi, cos_theta)
500
+ else:
501
+ # Make the function cos(theta+m) monotonic decreasing while theta in [0°, 180°]
502
+ psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin)
503
+
504
  one_hot = nn.functional.one_hot(targets, num_labels)
505
  outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
506
  loss = F.cross_entropy(