Spaces:
Runtime error
Runtime error
from config import MODEL_DIR, MODEL_INPUT_SIZE, TRANSFORMS_TO_APPLY, MODEL_BACKBONE, MODEL_OBJECTIVE, LAST_N_LAYERS_TO_TRAIN | |
import os | |
import torch | |
import json | |
from base import TransformationType, ModelBackbone, TrainingObjective | |
from torchvision import transforms | |
import torchvision | |
import torch.nn as nn | |
def save_model(model, config_json,model_dir = None): | |
if model_dir is None: | |
model_basedir = MODEL_DIR | |
models_present_in_dir = os.listdir(model_basedir) | |
model_dir_name = 'model_{}'.format(len(models_present_in_dir)) | |
model_dir = os.path.join(model_basedir, model_dir_name) | |
os.mkdir(model_dir) | |
model_path = os.path.join(model_dir, 'model.pth') | |
torch.save(model.state_dict(), model_path) | |
config_path = os.path.join(model_dir, 'config.json') | |
# import pdb; pdb.set_trace() | |
with open(config_path, 'w') as f: | |
json.dump(config_json, f) | |
return model_dir | |
def get_transforms_to_apply_(transformation_type, config_json = None): | |
if config_json: | |
model_input_size = config_json['MODEL_INPUT_SIZE'] | |
else: | |
model_input_size = MODEL_INPUT_SIZE | |
if transformation_type == TransformationType.RESIZE: | |
return transforms.Resize(model_input_size) | |
elif transformation_type == TransformationType.TO_TENSOR: | |
return transforms.ToTensor() | |
elif transformation_type == TransformationType.RANDOM_HORIZONTAL_FLIP: | |
return transforms.RandomHorizontalFlip(p=0.5) | |
elif transformation_type == TransformationType.NORMALIZE: | |
return transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
elif transformation_type == TransformationType.RANDOM_ROTATION: | |
return transforms.RandomRotation(degrees=10) | |
elif transformation_type == TransformationType.RANDOM_CLIP: | |
return transforms.RandomCrop(model_input_size) | |
else: | |
raise Exception("Invalid transformation type") | |
def get_transforms_to_apply(): | |
transforms_to_apply = [] | |
for transform in TRANSFORMS_TO_APPLY: | |
transforms_to_apply.append(get_transforms_to_apply_(TransformationType[transform])) | |
return transforms.Compose(transforms_to_apply) | |
def get_model_architecture(config_json = None): | |
if config_json: | |
model_backbone = ModelBackbone[config_json['MODEL_BACKBONE']] | |
model_objective = TrainingObjective[config_json['MODEL_OBJECTIVE']] | |
else: | |
model_backbone = MODEL_BACKBONE | |
model_objective = MODEL_OBJECTIVE | |
if model_backbone == ModelBackbone.EFFICIENT_NET_B0: | |
if model_objective == TrainingObjective.REGRESSION: | |
model = torchvision.models.efficientnet_b0(pretrained=True) | |
model.classifier[1] = nn.Sequential( | |
nn.Linear(model.classifier[1].in_features, 2048), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Linear(2048, 1), | |
) | |
else: | |
raise Exception("Invalid model objective") | |
else: | |
raise Exception("Invalid model backbone") | |
return model | |
def get_training_params(model): | |
training_params = [] | |
if MODEL_BACKBONE == ModelBackbone.EFFICIENT_NET_B0: | |
if LAST_N_LAYERS_TO_TRAIN > 0: | |
for param in model.features[:-LAST_N_LAYERS_TO_TRAIN].parameters(): | |
param.requires_grad = False | |
for param in model.features[-LAST_N_LAYERS_TO_TRAIN:].parameters(): | |
training_params.append(param) | |
for param in model.classifier[1].parameters(): | |
training_params.append(param) | |
else: | |
raise Exception("Invalid model backbone") | |
return training_params | |
def get_criterion(): | |
if MODEL_OBJECTIVE == TrainingObjective.REGRESSION: | |
criterion = nn.MSELoss() | |
else: | |
raise Exception("Invalid model objective") | |
return criterion | |