Spaces:
Running
Running
File size: 1,986 Bytes
9791162 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# Adapted from https://github.com/ubisoft/ubisoft-laforge-daft-exprt Apache License Version 2.0
# Unsupervised Domain Adaptation by Backpropagation
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.nn.utils import weight_norm
class GradientReversalFunction(Function):
@staticmethod
def forward(ctx, x, lambda_):
ctx.lambda_ = lambda_
return x.clone()
@staticmethod
def backward(ctx, grads):
lambda_ = ctx.lambda_
lambda_ = grads.new_tensor(lambda_)
dx = -lambda_ * grads
return dx, None
class GradientReversal(torch.nn.Module):
''' Gradient Reversal Layer
Y. Ganin, V. Lempitsky,
"Unsupervised Domain Adaptation by Backpropagation",
in ICML, 2015.
Forward pass is the identity function
In the backward pass, upstream gradients are multiplied by -lambda (i.e. gradient are reversed)
'''
def __init__(self, lambda_reversal=1):
super(GradientReversal, self).__init__()
self.lambda_ = lambda_reversal
def forward(self, x):
return GradientReversalFunction.apply(x, self.lambda_)
class SpeakerClassifier(nn.Module):
def __init__(self, embed_dim, spk_dim):
super(SpeakerClassifier, self).__init__()
self.classifier = nn.Sequential(
GradientReversal(lambda_reversal=1),
weight_norm(nn.Conv1d(embed_dim, embed_dim, kernel_size=5, padding=2)),
nn.ReLU(),
weight_norm(nn.Conv1d(embed_dim, embed_dim, kernel_size=5, padding=2)),
nn.ReLU(),
weight_norm(nn.Conv1d(embed_dim, spk_dim, kernel_size=5, padding=2))
)
def forward(self, x):
''' Forward function of Speaker Classifier:
x = (B, embed_dim, len)
'''
# pass through classifier
outputs = self.classifier(x) # (B, nb_speakers)
outputs = torch.mean(outputs, dim=-1)
return outputs
|