SkillForge45 commited on
Commit
a7bbbd5
·
verified ·
1 Parent(s): 39aaddf

Create decoder.py

Browse files
Files changed (1) hide show
  1. de_en/decoder.py +49 -0
de_en/decoder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class VideoDecoder(nn.Module):
6
+ def __init__(self, embed_dim=512, hidden_dims=[512, 256, 128, 64], out_channels=3):
7
+ super().__init__()
8
+
9
+ self.fc = nn.Linear(embed_dim, hidden_dims[0] * 8 * 8)
10
+
11
+ modules = []
12
+ for i in range(len(hidden_dims)-1):
13
+ modules.append(
14
+ nn.Sequential(
15
+ nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1],
16
+ kernel_size=3, stride=2, padding=1, output_padding=1),
17
+ nn.BatchNorm2d(hidden_dims[i+1]),
18
+ nn.LeakyReLU()
19
+ )
20
+ )
21
+
22
+ self.decoder = nn.Sequential(*modules)
23
+
24
+ self.final_layer = nn.Sequential(
25
+ nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1],
26
+ kernel_size=3, stride=2, padding=1, output_padding=1),
27
+ nn.BatchNorm2d(hidden_dims[-1]),
28
+ nn.LeakyReLU(),
29
+ nn.Conv2d(hidden_dims[-1], out_channels, kernel_size=3, padding=1),
30
+ nn.Tanh()
31
+ )
32
+
33
+ def forward(self, z):
34
+ # z shape: (seq_len, batch, embed_dim) or (batch, embed_dim)
35
+ if z.dim() == 3:
36
+ batch_size, seq_len = z.size(1), z.size(0)
37
+ z = z.view(-1, z.size(2)) # flatten for linear layer
38
+ else:
39
+ batch_size, seq_len = z.size(0), 1
40
+
41
+ x = self.fc(z)
42
+ x = x.view(-1, self.decoder[0][0].in_channels, 8, 8)
43
+ x = self.decoder(x)
44
+ x = self.final_layer(x)
45
+
46
+ if seq_len > 1:
47
+ x = x.view(seq_len, batch_size, *x.shape[1:])
48
+
49
+ return x