File size: 694 Bytes
de2aabe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
from torchvision.models import resnet50

def get_model(num_classes):
    """
    Initialize a ResNet50 model from scratch
    Args:
        num_classes (int): Number of output classes
    Returns:
        model: ResNet50 model with custom final layer
    """
    model = resnet50(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def save_model(model, path):
    """
    Save model state dict
    """
    torch.save(model.state_dict(), path)

def load_model(num_classes, path):
    """
    Load a saved model
    """
    model = get_model(num_classes)
    model.load_state_dict(torch.load(path))
    return model