from config import MODEL_DIR, MODEL_INPUT_SIZE, TRANSFORMS_TO_APPLY, MODEL_BACKBONE, MODEL_OBJECTIVE, LAST_N_LAYERS_TO_TRAIN import os import torch import json from base import TransformationType, ModelBackbone, TrainingObjective from torchvision import transforms import torchvision import torch.nn as nn def save_model(model, config_json,model_dir = None): if model_dir is None: model_basedir = MODEL_DIR models_present_in_dir = os.listdir(model_basedir) model_dir_name = 'model_{}'.format(len(models_present_in_dir)) model_dir = os.path.join(model_basedir, model_dir_name) os.mkdir(model_dir) model_path = os.path.join(model_dir, 'model.pth') torch.save(model.state_dict(), model_path) config_path = os.path.join(model_dir, 'config.json') # import pdb; pdb.set_trace() with open(config_path, 'w') as f: json.dump(config_json, f) return model_dir def get_transforms_to_apply_(transformation_type, config_json = None): if config_json: model_input_size = config_json['MODEL_INPUT_SIZE'] else: model_input_size = MODEL_INPUT_SIZE if transformation_type == TransformationType.RESIZE: return transforms.Resize(model_input_size) elif transformation_type == TransformationType.TO_TENSOR: return transforms.ToTensor() elif transformation_type == TransformationType.RANDOM_HORIZONTAL_FLIP: return transforms.RandomHorizontalFlip(p=0.5) elif transformation_type == TransformationType.NORMALIZE: return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) elif transformation_type == TransformationType.RANDOM_ROTATION: return transforms.RandomRotation(degrees=10) elif transformation_type == TransformationType.RANDOM_CLIP: return transforms.RandomCrop(model_input_size) else: raise Exception("Invalid transformation type") def get_transforms_to_apply(): transforms_to_apply = [] for transform in TRANSFORMS_TO_APPLY: transforms_to_apply.append(get_transforms_to_apply_(TransformationType[transform])) return transforms.Compose(transforms_to_apply) def get_model_architecture(config_json = None): if config_json: model_backbone = ModelBackbone[config_json['MODEL_BACKBONE']] model_objective = TrainingObjective[config_json['MODEL_OBJECTIVE']] else: model_backbone = MODEL_BACKBONE model_objective = MODEL_OBJECTIVE if model_backbone == ModelBackbone.EFFICIENT_NET_B0: if model_objective == TrainingObjective.REGRESSION: model = torchvision.models.efficientnet_b0(pretrained=True) model.classifier[1] = nn.Sequential( nn.Linear(model.classifier[1].in_features, 2048), nn.ReLU(), nn.Dropout(0.5), nn.Linear(2048, 1), ) else: raise Exception("Invalid model objective") else: raise Exception("Invalid model backbone") return model def get_training_params(model): training_params = [] if MODEL_BACKBONE == ModelBackbone.EFFICIENT_NET_B0: if LAST_N_LAYERS_TO_TRAIN > 0: for param in model.features[:-LAST_N_LAYERS_TO_TRAIN].parameters(): param.requires_grad = False for param in model.features[-LAST_N_LAYERS_TO_TRAIN:].parameters(): training_params.append(param) for param in model.classifier[1].parameters(): training_params.append(param) else: raise Exception("Invalid model backbone") return training_params def get_criterion(): if MODEL_OBJECTIVE == TrainingObjective.REGRESSION: criterion = nn.MSELoss() else: raise Exception("Invalid model objective") return criterion