Spaces:
Runtime error
Runtime error
File size: 1,374 Bytes
24615d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScoringModel(nn.Module):
def __init__(self, frames_per_clip:int, input_dim: int, hidden_dim: int, num_hidden_layers: int):
super().__init__()
self.frames_per_clip = frames_per_clip
if num_hidden_layers == 0:
self.model = nn.Linear(input_dim * frames_per_clip, 1)
else:
modules = [
nn.Linear(input_dim * frames_per_clip, hidden_dim),
nn.ReLU(True),
nn.BatchNorm1d(hidden_dim)
]
for _ in range(num_hidden_layers-1):
modules.extend([
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(True),
nn.BatchNorm1d(hidden_dim)
])
modules.append(nn.Linear(hidden_dim, 1))
self.model = nn.Sequential(*modules)
def forward(self, x):
return self.model(x)
if __name__ == '__main__':
batch_size, input_dim, frames_per_clip = 8, 512, 3
hidden_dim, num_hidden_layers = 0, 0
x = torch.rand(batch_size, input_dim * frames_per_clip)
scoring_model = ScoringModel(frames_per_clip, input_dim, hidden_dim, num_hidden_layers)
print(scoring_model)
y = scoring_model(x)
print(y.size()) # should be (batch_size, 1) |