File size: 3,911 Bytes
418196b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

from config import MODEL_DIR, MODEL_INPUT_SIZE, TRANSFORMS_TO_APPLY, MODEL_BACKBONE, MODEL_OBJECTIVE, LAST_N_LAYERS_TO_TRAIN
import os
import torch
import json
from base import TransformationType, ModelBackbone, TrainingObjective
from torchvision import transforms
import torchvision
import torch.nn as nn



def save_model(model, config_json,model_dir = None):
    if model_dir is None:
        model_basedir = MODEL_DIR
        models_present_in_dir = os.listdir(model_basedir)

        model_dir_name = 'model_{}'.format(len(models_present_in_dir))
        model_dir = os.path.join(model_basedir, model_dir_name)
        os.mkdir(model_dir)

    model_path = os.path.join(model_dir, 'model.pth')
    torch.save(model.state_dict(), model_path)
    config_path = os.path.join(model_dir, 'config.json')
    # import pdb; pdb.set_trace()
    with open(config_path, 'w') as f:
        json.dump(config_json, f)
    
    return model_dir

def get_transforms_to_apply_(transformation_type, config_json = None):
    if config_json:
        model_input_size = config_json['MODEL_INPUT_SIZE']
    else:
        model_input_size = MODEL_INPUT_SIZE

    if transformation_type == TransformationType.RESIZE:
        return transforms.Resize(model_input_size)
    elif transformation_type == TransformationType.TO_TENSOR:
        return transforms.ToTensor()
    elif transformation_type == TransformationType.RANDOM_HORIZONTAL_FLIP:
        return transforms.RandomHorizontalFlip(p=0.5)
    elif transformation_type == TransformationType.NORMALIZE:
        return transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                    std=[0.229, 0.224, 0.225])
    elif transformation_type == TransformationType.RANDOM_ROTATION:
        return transforms.RandomRotation(degrees=10)
    elif transformation_type == TransformationType.RANDOM_CLIP:
        return transforms.RandomCrop(model_input_size)
    else:
        raise Exception("Invalid transformation type")

def get_transforms_to_apply():
    transforms_to_apply = []
    for transform in TRANSFORMS_TO_APPLY:
        transforms_to_apply.append(get_transforms_to_apply_(TransformationType[transform]))
    return transforms.Compose(transforms_to_apply)

def get_model_architecture(config_json = None):
    if config_json:
        model_backbone = ModelBackbone[config_json['MODEL_BACKBONE']]
        model_objective = TrainingObjective[config_json['MODEL_OBJECTIVE']]
    else:
        model_backbone = MODEL_BACKBONE
        model_objective = MODEL_OBJECTIVE
    if model_backbone == ModelBackbone.EFFICIENT_NET_B0:
        if model_objective == TrainingObjective.REGRESSION:
            model = torchvision.models.efficientnet_b0(pretrained=True)
            model.classifier[1] = nn.Sequential(
                nn.Linear(model.classifier[1].in_features, 2048),
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(2048, 1),
            )
        else:
            raise Exception("Invalid model objective")
    else:
        raise Exception("Invalid model backbone")
    
    return model

def get_training_params(model):
    training_params = []
    if MODEL_BACKBONE == ModelBackbone.EFFICIENT_NET_B0:
        if LAST_N_LAYERS_TO_TRAIN > 0:
            for param in model.features[:-LAST_N_LAYERS_TO_TRAIN].parameters():
                param.requires_grad = False
            
            for param in model.features[-LAST_N_LAYERS_TO_TRAIN:].parameters():
                training_params.append(param)


        for param in model.classifier[1].parameters():
            training_params.append(param)
    else:
        raise Exception("Invalid model backbone")

    return training_params

def get_criterion():
    if MODEL_OBJECTIVE == TrainingObjective.REGRESSION:
        criterion = nn.MSELoss()
    else:
        raise Exception("Invalid model objective")
    
    return criterion