from torchvision.models import resnet50, ResNet50_Weights from transformers import PreTrainedModel from .config import ResnetConfig import torch.nn as nn class ResNet50(nn.Module): def __init__(self, ): super().__init__() self.cnn = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) self.backbone = nn.Sequential(*list(self.cnn.children())[:-2]) self.flaten = nn.Sequential(nn.AvgPool2d(kernel_size=7), nn.Flatten()) self.fc_1 = nn.Linear(2048, 768) def forward(self, x): if len(x.shape) == 3: x = x.unsqueeze(0) x = self.backbone(x) x = self.flaten(x) x = self.fc_1(x) x = x.squeeze(0) return x class ResNet50AffectiveFeatureExtractor(PreTrainedModel): config_class = ResnetConfig def __init__(self, config): super().__init__(config) self.model = ResNet50() del self.model.cnn def forward(self, tensor): return self.model(tensor)