import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class ResStack(nn.Module): def __init__(self, channel): super(ResStack, self).__init__() self.blocks = nn.ModuleList([ nn.Sequential( nn.LeakyReLU(0.2), nn.ReflectionPad1d(3**i), nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=3**i)), nn.LeakyReLU(0.2), nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)), ) for i in range(3) ]) self.shortcuts = nn.ModuleList([ nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)) for i in range(3) ]) def forward(self, x): for block, shortcut in zip(self.blocks, self.shortcuts): x = shortcut(x) + block(x) return x def remove_weight_norm(self): for block, shortcut in zip(self.blocks, self.shortcuts): nn.utils.remove_weight_norm(block[2]) nn.utils.remove_weight_norm(block[4]) nn.utils.remove_weight_norm(shortcut)