rsax commited on
Commit
0617941
·
verified ·
1 Parent(s): 8f87ddc

Update models/vqvae.py

Browse files
Files changed (1) hide show
  1. models/vqvae.py +117 -117
models/vqvae.py CHANGED
@@ -1,118 +1,118 @@
1
- import torch.nn as nn
2
- from models.encdec import Encoder, Decoder
3
- from models.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset
4
-
5
-
6
- class VQVAE_251(nn.Module):
7
- def __init__(self,
8
- args,
9
- nb_code=1024,
10
- code_dim=512,
11
- output_emb_width=512,
12
- down_t=3,
13
- stride_t=2,
14
- width=512,
15
- depth=3,
16
- dilation_growth_rate=3,
17
- activation='relu',
18
- norm=None):
19
-
20
- super().__init__()
21
- self.code_dim = code_dim
22
- self.num_code = nb_code
23
- self.quant = args.quantizer
24
- self.encoder = Encoder(251 if args.dataname == 'kit' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
25
- self.decoder = Decoder(251 if args.dataname == 'kit' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
26
- if args.quantizer == "ema_reset":
27
- self.quantizer = QuantizeEMAReset(nb_code, code_dim, args)
28
- elif args.quantizer == "orig":
29
- self.quantizer = Quantizer(nb_code, code_dim, 1.0)
30
- elif args.quantizer == "ema":
31
- self.quantizer = QuantizeEMA(nb_code, code_dim, args)
32
- elif args.quantizer == "reset":
33
- self.quantizer = QuantizeReset(nb_code, code_dim, args)
34
-
35
-
36
- def preprocess(self, x):
37
- # (bs, T, Jx3) -> (bs, Jx3, T)
38
- x = x.permute(0,2,1).float()
39
- return x
40
-
41
-
42
- def postprocess(self, x):
43
- # (bs, Jx3, T) -> (bs, T, Jx3)
44
- x = x.permute(0,2,1)
45
- return x
46
-
47
-
48
- def encode(self, x):
49
- N, T, _ = x.shape
50
- x_in = self.preprocess(x)
51
- x_encoder = self.encoder(x_in)
52
- x_encoder = self.postprocess(x_encoder)
53
- x_encoder = x_encoder.contiguous().view(-1, x_encoder.shape[-1]) # (NT, C)
54
- code_idx = self.quantizer.quantize(x_encoder)
55
- code_idx = code_idx.view(N, -1)
56
- return code_idx
57
-
58
-
59
- def forward(self, x):
60
-
61
- x_in = self.preprocess(x)
62
- # Encode
63
- x_encoder = self.encoder(x_in)
64
-
65
- ## quantization
66
- x_quantized, loss, perplexity = self.quantizer(x_encoder)
67
-
68
- ## decoder
69
- x_decoder = self.decoder(x_quantized)
70
- x_out = self.postprocess(x_decoder)
71
- return x_out, loss, perplexity
72
-
73
-
74
- def forward_decoder(self, x):
75
- x_d = self.quantizer.dequantize(x)
76
- x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous()
77
-
78
- # decoder
79
- x_decoder = self.decoder(x_d)
80
- x_out = self.postprocess(x_decoder)
81
- return x_out
82
-
83
-
84
-
85
- class HumanVQVAE(nn.Module):
86
- def __init__(self,
87
- args,
88
- nb_code=512,
89
- code_dim=512,
90
- output_emb_width=512,
91
- down_t=3,
92
- stride_t=2,
93
- width=512,
94
- depth=3,
95
- dilation_growth_rate=3,
96
- activation='relu',
97
- norm=None):
98
-
99
- super().__init__()
100
-
101
- self.nb_joints = 21 if args.dataname == 'kit' else 22
102
- self.vqvae = VQVAE_251(args, nb_code, code_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
103
-
104
- def encode(self, x):
105
- b, t, c = x.size()
106
- quants = self.vqvae.encode(x) # (N, T)
107
- return quants
108
-
109
- def forward(self, x):
110
-
111
- x_out, loss, perplexity = self.vqvae(x)
112
-
113
- return x_out, loss, perplexity
114
-
115
- def forward_decoder(self, x):
116
- x_out = self.vqvae.forward_decoder(x)
117
- return x_out
118
 
 
1
+ import torch.nn as nn
2
+ from models.encdec import Encoder, Decoder
3
+ from models.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset
4
+
5
+
6
+ class VQVAE_251(nn.Module):
7
+ def __init__(self,
8
+ args,
9
+ nb_code=1024,
10
+ code_dim=512,
11
+ output_emb_width=512,
12
+ down_t=3,
13
+ stride_t=2,
14
+ width=512,
15
+ depth=3,
16
+ dilation_growth_rate=3,
17
+ activation='relu',
18
+ norm=None):
19
+
20
+ super().__init__()
21
+ self.code_dim = code_dim
22
+ self.num_code = nb_code
23
+ self.quant = args.quantizer
24
+ self.encoder = Encoder(251 if args.dataname == 'kit' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
25
+ self.decoder = Decoder(251 if args.dataname == 'kit' else 263, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
26
+ if args.quantizer == "ema_reset":
27
+ self.quantizer = QuantizeEMAReset(nb_code, code_dim, args)
28
+ elif args.quantizer == "orig":
29
+ self.quantizer = Quantizer(nb_code, code_dim, 1.0)
30
+ elif args.quantizer == "ema":
31
+ self.quantizer = QuantizeEMA(nb_code, code_dim, args)
32
+ elif args.quantizer == "reset":
33
+ self.quantizer = QuantizeReset(nb_code, code_dim, args)
34
+
35
+
36
+ def preprocess(self, x):
37
+ # (bs, T, Jx3) -> (bs, Jx3, T)
38
+ x = x.permute(0,2,1).float()
39
+ return x
40
+
41
+
42
+ def postprocess(self, x):
43
+ # (bs, Jx3, T) -> (bs, T, Jx3)
44
+ x = x.permute(0,2,1)
45
+ return x
46
+
47
+
48
+ def encode(self, x):
49
+ N, T, _ = x.shape
50
+ x_in = self.preprocess(x)
51
+ x_encoder = self.encoder(x_in)
52
+ x_encoder = self.postprocess(x_encoder)
53
+ x_encoder = x_encoder.contiguous().view(-1, x_encoder.shape[-1]) # (NT, C)
54
+ code_idx = self.quantizer.quantize(x_encoder)
55
+ code_idx = code_idx.view(N, -1)
56
+ return code_idx
57
+
58
+
59
+ def forward(self, x):
60
+
61
+ x_in = self.preprocess(x)
62
+ # Encode
63
+ x_encoder = self.encoder(x_in)
64
+
65
+ ## quantization
66
+ x_quantized, loss, perplexity = self.quantizer(x_encoder)
67
+
68
+ ## decoder
69
+ x_decoder = self.decoder(x_quantized)
70
+ x_out = self.postprocess(x_decoder)
71
+ return x_out, loss, perplexity
72
+
73
+
74
+ def forward_decoder(self, x):
75
+ x_d = self.quantizer.dequantize(x)
76
+ x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous()
77
+
78
+ # decoder
79
+ x_decoder = self.decoder(x_d)
80
+ x_out = self.postprocess(x_decoder)
81
+ return x_out
82
+
83
+
84
+
85
+ class HumanVQVAE(nn.Module):
86
+ def __init__(self,
87
+ args,
88
+ nb_code=512,
89
+ code_dim=512,
90
+ output_emb_width=512,
91
+ down_t=3,
92
+ stride_t=2,
93
+ width=512,
94
+ depth=3,
95
+ dilation_growth_rate=3,
96
+ activation='relu',
97
+ norm=None):
98
+
99
+ super().__init__()
100
+
101
+ self.nb_joints = 21 if args.dataname == 'kit' else 22
102
+ self.vqvae = VQVAE_251(args, nb_code, code_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
103
+
104
+ def encode(self, x):
105
+ b, t, c = x.size()
106
+ quants = self.vqvae.encode(x) # (N, T)
107
+ return quants
108
+
109
+ def forward(self, x):
110
+
111
+ x_out, loss, perplexity = self.vqvae(x)
112
+
113
+ return x_out, loss, perplexity
114
+
115
+ def forward_decoder(self, x):
116
+ x_out = self.vqvae.forward_decoder(x)
117
+ return x_out
118