resnet-train / model.py
Sreekanth Tangirala
first commit
de2aabe
raw
history blame contribute delete
694 Bytes
import torch
import torch.nn as nn
from torchvision.models import resnet50
def get_model(num_classes):
"""
Initialize a ResNet50 model from scratch
Args:
num_classes (int): Number of output classes
Returns:
model: ResNet50 model with custom final layer
"""
model = resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
def save_model(model, path):
"""
Save model state dict
"""
torch.save(model.state_dict(), path)
def load_model(num_classes, path):
"""
Load a saved model
"""
model = get_model(num_classes)
model.load_state_dict(torch.load(path))
return model