Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
class ResStack(nn.Module): | |
def __init__(self, channel): | |
super(ResStack, self).__init__() | |
self.blocks = nn.ModuleList([ | |
nn.Sequential( | |
nn.LeakyReLU(0.2), | |
nn.ReflectionPad1d(3**i), | |
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=3**i)), | |
nn.LeakyReLU(0.2), | |
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)), | |
) | |
for i in range(3) | |
]) | |
self.shortcuts = nn.ModuleList([ | |
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)) | |
for i in range(3) | |
]) | |
def forward(self, x): | |
for block, shortcut in zip(self.blocks, self.shortcuts): | |
x = shortcut(x) + block(x) | |
return x | |
def remove_weight_norm(self): | |
for block, shortcut in zip(self.blocks, self.shortcuts): | |
nn.utils.remove_weight_norm(block[2]) | |
nn.utils.remove_weight_norm(block[4]) | |
nn.utils.remove_weight_norm(shortcut) | |