tc-mb ssdsdfsdf commited on
Commit
dacabb6
·
verified ·
1 Parent(s): 52e6b29

Update resampler.py (#3)

Browse files

- Update resampler.py (cc31e963f29bb5e66f0d399572ba0650e7fc92d4)


Co-authored-by: ed <[email protected]>

Files changed (1) hide show
  1. resampler.py +9 -0
resampler.py CHANGED
@@ -160,6 +160,15 @@ class Resampler(nn.Module):
160
  nn.init.constant_(m.bias, 0)
161
  nn.init.constant_(m.weight, 1.0)
162
 
 
 
 
 
 
 
 
 
 
163
  def forward(self, x, tgt_sizes=None, temporal_ids=None):
164
  assert x.shape[0] == tgt_sizes.shape[0]
165
  bs = x.shape[0]
 
160
  nn.init.constant_(m.bias, 0)
161
  nn.init.constant_(m.weight, 1.0)
162
 
163
+ def _initialize_weights(self, m):
164
+ if isinstance(m, nn.Linear):
165
+ trunc_normal_(m.weight, std=.02)
166
+ if isinstance(m, nn.Linear) and m.bias is not None:
167
+ nn.init.constant_(m.bias, 0)
168
+ elif isinstance(m, nn.LayerNorm):
169
+ nn.init.constant_(m.bias, 0)
170
+ nn.init.constant_(m.weight, 1.0)
171
+
172
  def forward(self, x, tgt_sizes=None, temporal_ids=None):
173
  assert x.shape[0] == tgt_sizes.shape[0]
174
  bs = x.shape[0]