import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.container import ModuleList from torch.nn.modules.activation import MultiheadAttention from torch.nn.modules.dropout import Dropout from torch.nn.modules.linear import Linear from torch.nn.modules.rnn import LSTM from torch.nn.modules.normalization import LayerNorm from torch.autograd import Variable import copy import math # adapted from https://github.com/ujscjj/DPTNet class DPTNet_base(nn.Module): def __init__( self, enc_dim, feature_dim, hidden_dim, layer, segment_size=250, nspk=2, win_len=2, ): super().__init__() # parameters self.window = win_len self.stride = self.window // 2 self.enc_dim = enc_dim self.feature_dim = feature_dim self.hidden_dim = hidden_dim self.segment_size = segment_size self.layer = layer self.num_spk = nspk self.eps = 1e-8 self.dpt_encoder = DPTEncoder( n_filters=enc_dim, window_size=win_len, ) self.enc_LN = nn.GroupNorm(1, self.enc_dim, eps=1e-8) self.dpt_separation = DPTSeparation( self.enc_dim, self.feature_dim, self.hidden_dim, self.num_spk, self.layer, self.segment_size, ) self.mask_conv1x1 = nn.Conv1d(self.feature_dim, self.enc_dim, 1, bias=False) self.decoder = DPTDecoder(n_filters=enc_dim, window_size=win_len) def forward(self, mix): """ mix: shape (batch, T) """ batch_size = mix.shape[0] mix = self.dpt_encoder(mix) # (B, E, L) score_ = self.enc_LN(mix) # B, E, L score_ = self.dpt_separation(score_) # B, nspk, T, N score_ = ( score_.view(batch_size * self.num_spk, -1, self.feature_dim) .transpose(1, 2) .contiguous() ) # B*nspk, N, T score = self.mask_conv1x1(score_) # [B*nspk, N, L] -> [B*nspk, E, L] score = score.view( batch_size, self.num_spk, self.enc_dim, -1 ) # [B*nspk, E, L] -> [B, nspk, E, L] est_mask = F.relu(score) est_source = self.decoder( mix, est_mask ) # [B, E, L] + [B, nspk, E, L]--> [B, nspk, T] return est_source class DPTEncoder(nn.Module): def __init__(self, n_filters: int = 64, window_size: int = 2): super().__init__() self.conv = nn.Conv1d( 1, n_filters, kernel_size=window_size, stride=window_size // 2, bias=False ) def forward(self, x): x = x.unsqueeze(1) x = F.relu(self.conv(x)) return x class TransformerEncoderLayer(torch.nn.Module): def __init__( self, d_model, nhead, hidden_size, dim_feedforward, dropout, activation="relu" ): super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) # Implementation of improved part self.lstm = LSTM(d_model, hidden_size, 1, bidirectional=True) self.dropout = Dropout(dropout) self.linear = Linear(hidden_size * 2, d_model) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) self.activation = _get_activation_fn(activation) def __setstate__(self, state): if "activation" not in state: state["activation"] = F.relu super(TransformerEncoderLayer, self).__setstate__(state) def forward(self, src, src_mask=None, src_key_padding_mask=None): r"""Pass the input through the encoder layer. Args: src: the sequnce to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ src2 = self.self_attn( src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask )[0] src = src + self.dropout1(src2) src = self.norm1(src) src2 = self.linear(self.dropout(self.activation(self.lstm(src)[0]))) src = src + self.dropout2(src2) src = self.norm2(src) return src def _get_clones(module, N): return ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): if activation == "relu": return F.relu elif activation == "gelu": return F.gelu raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) class SingleTransformer(nn.Module): """ Container module for a single Transformer layer. args: input_size: int, dimension of the input feature. The input should have shape (batch, seq_len, input_size). """ def __init__(self, input_size, hidden_size, dropout): super(SingleTransformer, self).__init__() self.transformer = TransformerEncoderLayer( d_model=input_size, nhead=4, hidden_size=hidden_size, dim_feedforward=hidden_size * 2, dropout=dropout, ) def forward(self, input): # input shape: batch, seq, dim output = input transformer_output = ( self.transformer(output.permute(1, 0, 2).contiguous()) .permute(1, 0, 2) .contiguous() ) return transformer_output # dual-path transformer class DPT(nn.Module): """ Deep dual-path transformer. args: input_size: int, dimension of the input feature. The input should have shape (batch, seq_len, input_size). hidden_size: int, dimension of the hidden state. output_size: int, dimension of the output size. num_layers: int, number of stacked Transformer layers. Default is 1. dropout: float, dropout ratio. Default is 0. """ def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0): super(DPT, self).__init__() self.input_size = input_size self.output_size = output_size self.hidden_size = hidden_size # dual-path transformer self.row_transformer = nn.ModuleList([]) self.col_transformer = nn.ModuleList([]) for i in range(num_layers): self.row_transformer.append( SingleTransformer(input_size, hidden_size, dropout) ) self.col_transformer.append( SingleTransformer(input_size, hidden_size, dropout) ) # output layer self.output = nn.Sequential(nn.PReLU(), nn.Conv2d(input_size, output_size, 1)) def forward(self, input): # input shape: batch, N, dim1, dim2 # apply transformer on dim1 first and then dim2 # output shape: B, output_size, dim1, dim2 # input = input.to(device) batch_size, _, dim1, dim2 = input.shape output = input for i in range(len(self.row_transformer)): row_input = ( output.permute(0, 3, 2, 1) .contiguous() .view(batch_size * dim2, dim1, -1) ) # B*dim2, dim1, N row_output = self.row_transformer[i](row_input) # B*dim2, dim1, H row_output = ( row_output.view(batch_size, dim2, dim1, -1) .permute(0, 3, 2, 1) .contiguous() ) # B, N, dim1, dim2 output = row_output col_input = ( output.permute(0, 2, 3, 1) .contiguous() .view(batch_size * dim1, dim2, -1) ) # B*dim1, dim2, N col_output = self.col_transformer[i](col_input) # B*dim1, dim2, H col_output = ( col_output.view(batch_size, dim1, dim2, -1) .permute(0, 3, 1, 2) .contiguous() ) # B, N, dim1, dim2 output = col_output output = self.output(output) # B, output_size, dim1, dim2 return output # base module for deep DPT class DPT_base(nn.Module): def __init__( self, input_dim, feature_dim, hidden_dim, num_spk=2, layer=6, segment_size=250 ): super(DPT_base, self).__init__() self.input_dim = input_dim self.feature_dim = feature_dim self.hidden_dim = hidden_dim self.layer = layer self.segment_size = segment_size self.num_spk = num_spk self.eps = 1e-8 # bottleneck self.BN = nn.Conv1d(self.input_dim, self.feature_dim, 1, bias=False) # DPT model self.DPT = DPT( self.feature_dim, self.hidden_dim, self.feature_dim * self.num_spk, num_layers=layer, ) def pad_segment(self, input, segment_size): # input is the features: (B, N, T) batch_size, dim, seq_len = input.shape segment_stride = segment_size // 2 rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size if rest > 0: pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type()) input = torch.cat([input, pad], 2) pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type( input.type() ) input = torch.cat([pad_aux, input, pad_aux], 2) return input, rest def split_feature(self, input, segment_size): # split the feature into chunks of segment size # input is the features: (B, N, T) input, rest = self.pad_segment(input, segment_size) batch_size, dim, seq_len = input.shape segment_stride = segment_size // 2 segments1 = ( input[:, :, :-segment_stride] .contiguous() .view(batch_size, dim, -1, segment_size) ) segments2 = ( input[:, :, segment_stride:] .contiguous() .view(batch_size, dim, -1, segment_size) ) segments = ( torch.cat([segments1, segments2], 3) .view(batch_size, dim, -1, segment_size) .transpose(2, 3) ) return segments.contiguous(), rest def merge_feature(self, input, rest): # merge the splitted features into full utterance # input is the features: (B, N, L, K) batch_size, dim, segment_size, _ = input.shape segment_stride = segment_size // 2 input = ( input.transpose(2, 3) .contiguous() .view(batch_size, dim, -1, segment_size * 2) ) # B, N, K, L input1 = ( input[:, :, :, :segment_size] .contiguous() .view(batch_size, dim, -1)[:, :, segment_stride:] ) input2 = ( input[:, :, :, segment_size:] .contiguous() .view(batch_size, dim, -1)[:, :, :-segment_stride] ) output = input1 + input2 if rest > 0: output = output[:, :, :-rest] return output.contiguous() # B, N, T def forward(self, input): pass class DPTSeparation(DPT_base): def __init__(self, *args, **kwargs): super(DPTSeparation, self).__init__(*args, **kwargs) # gated output layer self.output = nn.Sequential( nn.Conv1d(self.feature_dim, self.feature_dim, 1), nn.Tanh() ) self.output_gate = nn.Sequential( nn.Conv1d(self.feature_dim, self.feature_dim, 1), nn.Sigmoid() ) def forward(self, input): # input = input.to(device) # input: (B, E, T) batch_size, E, seq_length = input.shape enc_feature = self.BN(input) # (B, E, L)-->(B, N, L) # split the encoder output into overlapped, longer segments enc_segments, enc_rest = self.split_feature( enc_feature, self.segment_size ) # B, N, L, K: L is the segment_size # print('enc_segments.shape {}'.format(enc_segments.shape)) # pass to DPT output = self.DPT(enc_segments).view( batch_size * self.num_spk, self.feature_dim, self.segment_size, -1 ) # B*nspk, N, L, K # overlap-and-add of the outputs output = self.merge_feature(output, enc_rest) # B*nspk, N, T # gated output layer for filter generation bf_filter = self.output(output) * self.output_gate(output) # B*nspk, K, T bf_filter = ( bf_filter.transpose(1, 2) .contiguous() .view(batch_size, self.num_spk, -1, self.feature_dim) ) # B, nspk, T, N return bf_filter class DPTDecoder(nn.Module): def __init__(self, n_filters: int = 64, window_size: int = 2): super().__init__() self.W = window_size self.basis_signals = nn.Linear(n_filters, window_size, bias=False) def forward(self, mixture, mask): """ mixture: (batch, n_filters, L) mask: (batch, sources, n_filters, L) """ source_w = torch.unsqueeze(mixture, 1) * mask # [B, C, E, L] source_w = torch.transpose(source_w, 2, 3) # [B, C, L, E] # S = DV est_source = self.basis_signals(source_w) # [B, C, L, W] est_source = overlap_and_add(est_source, self.W // 2) # B x C x T return est_source def overlap_and_add(signal, frame_step): """Reconstructs a signal from a framed representation. Adds potentially overlapping frames of a signal with shape `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. The resulting tensor has shape `[..., output_size]` where output_size = (frames - 1) * frame_step + frame_length Args: signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. Returns: A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. output_size = (frames - 1) * frame_step + frame_length Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py """ outer_dimensions = signal.size()[:-2] frames, frame_length = signal.size()[-2:] subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor subframe_step = frame_step // subframe_length subframes_per_frame = frame_length // subframe_length output_size = frame_step * (frames - 1) + frame_length output_subframes = output_size // subframe_length subframe_signal = signal.reshape(*outer_dimensions, -1, subframe_length) frame = torch.arange(0, output_subframes).unfold( 0, subframes_per_frame, subframe_step ) frame = signal.new_tensor(frame).long() # signal may in GPU or CPU frame = frame.contiguous().view(-1) result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) result.index_add_(-2, frame, subframe_signal) result = result.view(*outer_dimensions, -1) return result