Update resampler.py (#3)
Browse files- Update resampler.py (cc31e963f29bb5e66f0d399572ba0650e7fc92d4)
Co-authored-by: ed <[email protected]>
- resampler.py +9 -0
resampler.py
CHANGED
@@ -160,6 +160,15 @@ class Resampler(nn.Module):
|
|
160 |
nn.init.constant_(m.bias, 0)
|
161 |
nn.init.constant_(m.weight, 1.0)
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
def forward(self, x, tgt_sizes=None, temporal_ids=None):
|
164 |
assert x.shape[0] == tgt_sizes.shape[0]
|
165 |
bs = x.shape[0]
|
|
|
160 |
nn.init.constant_(m.bias, 0)
|
161 |
nn.init.constant_(m.weight, 1.0)
|
162 |
|
163 |
+
def _initialize_weights(self, m):
|
164 |
+
if isinstance(m, nn.Linear):
|
165 |
+
trunc_normal_(m.weight, std=.02)
|
166 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
167 |
+
nn.init.constant_(m.bias, 0)
|
168 |
+
elif isinstance(m, nn.LayerNorm):
|
169 |
+
nn.init.constant_(m.bias, 0)
|
170 |
+
nn.init.constant_(m.weight, 1.0)
|
171 |
+
|
172 |
def forward(self, x, tgt_sizes=None, temporal_ids=None):
|
173 |
assert x.shape[0] == tgt_sizes.shape[0]
|
174 |
bs = x.shape[0]
|