import torch import torch.nn as nn import torch.nn.functional as F class ScoringModel(nn.Module): def __init__(self, frames_per_clip:int, input_dim: int, hidden_dim: int, num_hidden_layers: int): super().__init__() self.frames_per_clip = frames_per_clip if num_hidden_layers == 0: self.model = nn.Linear(input_dim * frames_per_clip, 1) else: modules = [ nn.Linear(input_dim * frames_per_clip, hidden_dim), nn.ReLU(True), nn.BatchNorm1d(hidden_dim) ] for _ in range(num_hidden_layers-1): modules.extend([ nn.Linear(hidden_dim, hidden_dim), nn.ReLU(True), nn.BatchNorm1d(hidden_dim) ]) modules.append(nn.Linear(hidden_dim, 1)) self.model = nn.Sequential(*modules) def forward(self, x): return self.model(x) if __name__ == '__main__': batch_size, input_dim, frames_per_clip = 8, 512, 3 hidden_dim, num_hidden_layers = 0, 0 x = torch.rand(batch_size, input_dim * frames_per_clip) scoring_model = ScoringModel(frames_per_clip, input_dim, hidden_dim, num_hidden_layers) print(scoring_model) y = scoring_model(x) print(y.size()) # should be (batch_size, 1)