File size: 5,568 Bytes
8026e91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch as t
import torch.nn as nn
from jukebox.vqvae.resnet import Resnet, Resnet1D
from jukebox.utils.torch_utils import assert_shape
class EncoderConvBlock(nn.Module):
def __init__(self, input_emb_width, output_emb_width, down_t,
stride_t, width, depth, m_conv,
dilation_growth_rate=1, dilation_cycle=None, zero_out=False,
res_scale=False):
super().__init__()
blocks = []
filter_t, pad_t = stride_t * 2, stride_t // 2
if down_t > 0:
for i in range(down_t):
block = nn.Sequential(
nn.Conv1d(input_emb_width if i == 0 else width, width, filter_t, stride_t, pad_t),
Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out, res_scale),
)
blocks.append(block)
block = nn.Conv1d(width, output_emb_width, 3, 1, 1)
blocks.append(block)
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
class DecoderConvBock(nn.Module):
def __init__(self, input_emb_width, output_emb_width, down_t,
stride_t, width, depth, m_conv, dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False, reverse_decoder_dilation=False, checkpoint_res=False):
super().__init__()
blocks = []
if down_t > 0:
filter_t, pad_t = stride_t * 2, stride_t // 2
block = nn.Conv1d(output_emb_width, width, 3, 1, 1)
blocks.append(block)
for i in range(down_t):
block = nn.Sequential(
Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out=zero_out, res_scale=res_scale, reverse_dilation=reverse_decoder_dilation, checkpoint_res=checkpoint_res),
nn.ConvTranspose1d(width, input_emb_width if i == (down_t - 1) else width, filter_t, stride_t, pad_t)
)
blocks.append(block)
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
class Encoder(nn.Module):
def __init__(self, input_emb_width, output_emb_width, levels, downs_t,
strides_t, **block_kwargs):
super().__init__()
self.input_emb_width = input_emb_width
self.output_emb_width = output_emb_width
self.levels = levels
self.downs_t = downs_t
self.strides_t = strides_t
block_kwargs_copy = dict(**block_kwargs)
if 'reverse_decoder_dilation' in block_kwargs_copy:
del block_kwargs_copy['reverse_decoder_dilation']
level_block = lambda level, down_t, stride_t: EncoderConvBlock(input_emb_width if level == 0 else output_emb_width,
output_emb_width,
down_t, stride_t,
**block_kwargs_copy)
self.level_blocks = nn.ModuleList()
iterator = zip(list(range(self.levels)), downs_t, strides_t)
for level, down_t, stride_t in iterator:
self.level_blocks.append(level_block(level, down_t, stride_t))
def forward(self, x):
N, T = x.shape[0], x.shape[-1]
emb = self.input_emb_width
assert_shape(x, (N, emb, T))
xs = []
# 64, 32, ...
iterator = zip(list(range(self.levels)), self.downs_t, self.strides_t)
for level, down_t, stride_t in iterator:
level_block = self.level_blocks[level]
x = level_block(x)
emb, T = self.output_emb_width, T // (stride_t ** down_t)
assert_shape(x, (N, emb, T))
xs.append(x)
return xs
class Decoder(nn.Module):
def __init__(self, input_emb_width, output_emb_width, levels, downs_t,
strides_t, **block_kwargs):
super().__init__()
self.input_emb_width = input_emb_width
self.output_emb_width = output_emb_width
self.levels = levels
self.downs_t = downs_t
self.strides_t = strides_t
level_block = lambda level, down_t, stride_t: DecoderConvBock(output_emb_width,
output_emb_width,
down_t, stride_t,
**block_kwargs)
self.level_blocks = nn.ModuleList()
iterator = zip(list(range(self.levels)), downs_t, strides_t)
for level, down_t, stride_t in iterator:
self.level_blocks.append(level_block(level, down_t, stride_t))
self.out = nn.Conv1d(output_emb_width, input_emb_width, 3, 1, 1)
def forward(self, xs, all_levels=True):
if all_levels:
assert len(xs) == self.levels
else:
assert len(xs) == 1
x = xs[-1]
N, T = x.shape[0], x.shape[-1]
emb = self.output_emb_width
assert_shape(x, (N, emb, T))
# 32, 64 ...
iterator = reversed(list(zip(list(range(self.levels)), self.downs_t, self.strides_t)))
for level, down_t, stride_t in iterator:
level_block = self.level_blocks[level]
x = level_block(x)
emb, T = self.output_emb_width, T * (stride_t ** down_t)
assert_shape(x, (N, emb, T))
if level != 0 and all_levels:
x = x + xs[level - 1]
x = self.out(x)
return x
|