Create decoder.py
Browse files- de_en/decoder.py +49 -0
de_en/decoder.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class VideoDecoder(nn.Module):
|
6 |
+
def __init__(self, embed_dim=512, hidden_dims=[512, 256, 128, 64], out_channels=3):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
self.fc = nn.Linear(embed_dim, hidden_dims[0] * 8 * 8)
|
10 |
+
|
11 |
+
modules = []
|
12 |
+
for i in range(len(hidden_dims)-1):
|
13 |
+
modules.append(
|
14 |
+
nn.Sequential(
|
15 |
+
nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1],
|
16 |
+
kernel_size=3, stride=2, padding=1, output_padding=1),
|
17 |
+
nn.BatchNorm2d(hidden_dims[i+1]),
|
18 |
+
nn.LeakyReLU()
|
19 |
+
)
|
20 |
+
)
|
21 |
+
|
22 |
+
self.decoder = nn.Sequential(*modules)
|
23 |
+
|
24 |
+
self.final_layer = nn.Sequential(
|
25 |
+
nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1],
|
26 |
+
kernel_size=3, stride=2, padding=1, output_padding=1),
|
27 |
+
nn.BatchNorm2d(hidden_dims[-1]),
|
28 |
+
nn.LeakyReLU(),
|
29 |
+
nn.Conv2d(hidden_dims[-1], out_channels, kernel_size=3, padding=1),
|
30 |
+
nn.Tanh()
|
31 |
+
)
|
32 |
+
|
33 |
+
def forward(self, z):
|
34 |
+
# z shape: (seq_len, batch, embed_dim) or (batch, embed_dim)
|
35 |
+
if z.dim() == 3:
|
36 |
+
batch_size, seq_len = z.size(1), z.size(0)
|
37 |
+
z = z.view(-1, z.size(2)) # flatten for linear layer
|
38 |
+
else:
|
39 |
+
batch_size, seq_len = z.size(0), 1
|
40 |
+
|
41 |
+
x = self.fc(z)
|
42 |
+
x = x.view(-1, self.decoder[0][0].in_channels, 8, 8)
|
43 |
+
x = self.decoder(x)
|
44 |
+
x = self.final_layer(x)
|
45 |
+
|
46 |
+
if seq_len > 1:
|
47 |
+
x = x.view(seq_len, batch_size, *x.shape[1:])
|
48 |
+
|
49 |
+
return x
|