acne_grading / utils.py
suyash94's picture
Upload folder using huggingface_hub
418196b
from config import MODEL_DIR, MODEL_INPUT_SIZE, TRANSFORMS_TO_APPLY, MODEL_BACKBONE, MODEL_OBJECTIVE, LAST_N_LAYERS_TO_TRAIN
import os
import torch
import json
from base import TransformationType, ModelBackbone, TrainingObjective
from torchvision import transforms
import torchvision
import torch.nn as nn
def save_model(model, config_json,model_dir = None):
if model_dir is None:
model_basedir = MODEL_DIR
models_present_in_dir = os.listdir(model_basedir)
model_dir_name = 'model_{}'.format(len(models_present_in_dir))
model_dir = os.path.join(model_basedir, model_dir_name)
os.mkdir(model_dir)
model_path = os.path.join(model_dir, 'model.pth')
torch.save(model.state_dict(), model_path)
config_path = os.path.join(model_dir, 'config.json')
# import pdb; pdb.set_trace()
with open(config_path, 'w') as f:
json.dump(config_json, f)
return model_dir
def get_transforms_to_apply_(transformation_type, config_json = None):
if config_json:
model_input_size = config_json['MODEL_INPUT_SIZE']
else:
model_input_size = MODEL_INPUT_SIZE
if transformation_type == TransformationType.RESIZE:
return transforms.Resize(model_input_size)
elif transformation_type == TransformationType.TO_TENSOR:
return transforms.ToTensor()
elif transformation_type == TransformationType.RANDOM_HORIZONTAL_FLIP:
return transforms.RandomHorizontalFlip(p=0.5)
elif transformation_type == TransformationType.NORMALIZE:
return transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
elif transformation_type == TransformationType.RANDOM_ROTATION:
return transforms.RandomRotation(degrees=10)
elif transformation_type == TransformationType.RANDOM_CLIP:
return transforms.RandomCrop(model_input_size)
else:
raise Exception("Invalid transformation type")
def get_transforms_to_apply():
transforms_to_apply = []
for transform in TRANSFORMS_TO_APPLY:
transforms_to_apply.append(get_transforms_to_apply_(TransformationType[transform]))
return transforms.Compose(transforms_to_apply)
def get_model_architecture(config_json = None):
if config_json:
model_backbone = ModelBackbone[config_json['MODEL_BACKBONE']]
model_objective = TrainingObjective[config_json['MODEL_OBJECTIVE']]
else:
model_backbone = MODEL_BACKBONE
model_objective = MODEL_OBJECTIVE
if model_backbone == ModelBackbone.EFFICIENT_NET_B0:
if model_objective == TrainingObjective.REGRESSION:
model = torchvision.models.efficientnet_b0(pretrained=True)
model.classifier[1] = nn.Sequential(
nn.Linear(model.classifier[1].in_features, 2048),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(2048, 1),
)
else:
raise Exception("Invalid model objective")
else:
raise Exception("Invalid model backbone")
return model
def get_training_params(model):
training_params = []
if MODEL_BACKBONE == ModelBackbone.EFFICIENT_NET_B0:
if LAST_N_LAYERS_TO_TRAIN > 0:
for param in model.features[:-LAST_N_LAYERS_TO_TRAIN].parameters():
param.requires_grad = False
for param in model.features[-LAST_N_LAYERS_TO_TRAIN:].parameters():
training_params.append(param)
for param in model.classifier[1].parameters():
training_params.append(param)
else:
raise Exception("Invalid model backbone")
return training_params
def get_criterion():
if MODEL_OBJECTIVE == TrainingObjective.REGRESSION:
criterion = nn.MSELoss()
else:
raise Exception("Invalid model objective")
return criterion