File size: 522 Bytes
a1dda5b
87c4954
 
0d39fea
87c4954
 
 
 
ba30de8
87c4954
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

import torch
from torch import nn
from torchvision import models

class FineTunedResNet(nn.Module):
    def __init__(self, num_classes=3):
        super(FineTunedResNet, self).__init__()
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.resnet.fc = nn.Sequential(
            nn.Linear(self.resnet.fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.resnet(x)