Spaces:
Sleeping
Sleeping
File size: 846 Bytes
ce2eaae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch.nn as nn
import torchvision
class Resnet50Flower102(nn.Module):
def __init__(self, device, pretrained=True, freeze_backbone=True):
super().__init__()
self.device = device
if pretrained:
weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
else:
weights = None
self.model = torchvision.models.resnet50(weights=weights)
self.model.fc = nn.Sequential(
nn.Linear(2048, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 102),
)
self.model.to(device)
def forward(self, x):
return self.model(x) |