Spaces:
Runtime error
Runtime error
File size: 3,911 Bytes
418196b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from config import MODEL_DIR, MODEL_INPUT_SIZE, TRANSFORMS_TO_APPLY, MODEL_BACKBONE, MODEL_OBJECTIVE, LAST_N_LAYERS_TO_TRAIN
import os
import torch
import json
from base import TransformationType, ModelBackbone, TrainingObjective
from torchvision import transforms
import torchvision
import torch.nn as nn
def save_model(model, config_json,model_dir = None):
if model_dir is None:
model_basedir = MODEL_DIR
models_present_in_dir = os.listdir(model_basedir)
model_dir_name = 'model_{}'.format(len(models_present_in_dir))
model_dir = os.path.join(model_basedir, model_dir_name)
os.mkdir(model_dir)
model_path = os.path.join(model_dir, 'model.pth')
torch.save(model.state_dict(), model_path)
config_path = os.path.join(model_dir, 'config.json')
# import pdb; pdb.set_trace()
with open(config_path, 'w') as f:
json.dump(config_json, f)
return model_dir
def get_transforms_to_apply_(transformation_type, config_json = None):
if config_json:
model_input_size = config_json['MODEL_INPUT_SIZE']
else:
model_input_size = MODEL_INPUT_SIZE
if transformation_type == TransformationType.RESIZE:
return transforms.Resize(model_input_size)
elif transformation_type == TransformationType.TO_TENSOR:
return transforms.ToTensor()
elif transformation_type == TransformationType.RANDOM_HORIZONTAL_FLIP:
return transforms.RandomHorizontalFlip(p=0.5)
elif transformation_type == TransformationType.NORMALIZE:
return transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
elif transformation_type == TransformationType.RANDOM_ROTATION:
return transforms.RandomRotation(degrees=10)
elif transformation_type == TransformationType.RANDOM_CLIP:
return transforms.RandomCrop(model_input_size)
else:
raise Exception("Invalid transformation type")
def get_transforms_to_apply():
transforms_to_apply = []
for transform in TRANSFORMS_TO_APPLY:
transforms_to_apply.append(get_transforms_to_apply_(TransformationType[transform]))
return transforms.Compose(transforms_to_apply)
def get_model_architecture(config_json = None):
if config_json:
model_backbone = ModelBackbone[config_json['MODEL_BACKBONE']]
model_objective = TrainingObjective[config_json['MODEL_OBJECTIVE']]
else:
model_backbone = MODEL_BACKBONE
model_objective = MODEL_OBJECTIVE
if model_backbone == ModelBackbone.EFFICIENT_NET_B0:
if model_objective == TrainingObjective.REGRESSION:
model = torchvision.models.efficientnet_b0(pretrained=True)
model.classifier[1] = nn.Sequential(
nn.Linear(model.classifier[1].in_features, 2048),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(2048, 1),
)
else:
raise Exception("Invalid model objective")
else:
raise Exception("Invalid model backbone")
return model
def get_training_params(model):
training_params = []
if MODEL_BACKBONE == ModelBackbone.EFFICIENT_NET_B0:
if LAST_N_LAYERS_TO_TRAIN > 0:
for param in model.features[:-LAST_N_LAYERS_TO_TRAIN].parameters():
param.requires_grad = False
for param in model.features[-LAST_N_LAYERS_TO_TRAIN:].parameters():
training_params.append(param)
for param in model.classifier[1].parameters():
training_params.append(param)
else:
raise Exception("Invalid model backbone")
return training_params
def get_criterion():
if MODEL_OBJECTIVE == TrainingObjective.REGRESSION:
criterion = nn.MSELoss()
else:
raise Exception("Invalid model objective")
return criterion
|