Woleek commited on
Commit
68a574e
·
verified ·
1 Parent(s): fc8b3f4

Upload feature extractor

Browse files
Files changed (2) hide show
  1. model.py +3 -0
  2. model.safetensors +1 -1
model.py CHANGED
@@ -12,9 +12,12 @@ class ResNet50(nn.Module):
12
  self.fc_1 = nn.Linear(2048, 768)
13
 
14
  def forward(self, x):
 
 
15
  x = self.backbone(x)
16
  x = self.flaten(x)
17
  x = self.fc_1(x)
 
18
  return x
19
 
20
  class ResNet50AffectiveFeatureExtractor(PreTrainedModel):
 
12
  self.fc_1 = nn.Linear(2048, 768)
13
 
14
  def forward(self, x):
15
+ if len(x.shape) == 3:
16
+ x = x.unsqueeze(0)
17
  x = self.backbone(x)
18
  x = self.flaten(x)
19
  x = self.fc_1(x)
20
+ x = x.squeeze(0)
21
  return x
22
 
23
  class ResNet50AffectiveFeatureExtractor(PreTrainedModel):
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:042169f2ec560953412d12f8aad4b164a10ad317c25f834244a83a1c3ec70da4
3
  size 100571344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b7bd2885cff231ee9908c75e70797ed5557b797cbe6ea479e306ae676863e00
3
  size 100571344