Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .github/workflows/update_space.yml +28 -0
- README.md +2 -8
- __init__.py +0 -0
- __pycache__/base.cpython-39.pyc +0 -0
- __pycache__/config.cpython-39.pyc +0 -0
- __pycache__/dataset.cpython-39.pyc +0 -0
- __pycache__/gradio_test.cpython-39.pyc +0 -0
- __pycache__/utils.cpython-39.pyc +0 -0
- base.py +27 -0
- config.py +28 -0
- dataset.py +30 -0
- debug.ipynb +0 -0
- flagged/input_img/a37df70aebf29843d866a7ad396a79ce94debfbd/tmpvvjl4871.jpg +0 -0
- flagged/log.csv +2 -0
- flagged/output/c1b5617618b88e89a6daa275fa1df7a05106d635/tmpdb5o51or.png +0 -0
- gradio_test.py +115 -0
- model/model_0/config.json +1 -0
- model/model_0/model.pth +3 -0
- model/model_1/config.json +1 -0
- model/model_1/model.pth +3 -0
- model/model_2/config.json +1 -0
- model/model_2/model.pth +3 -0
- requirements.txt +230 -0
- train.py +183 -0
- utils.py +111 -0
.github/workflows/update_space.yml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Run Python script
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
build:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
|
12 |
+
steps:
|
13 |
+
- name: Checkout
|
14 |
+
uses: actions/checkout@v2
|
15 |
+
|
16 |
+
- name: Set up Python
|
17 |
+
uses: actions/setup-python@v2
|
18 |
+
with:
|
19 |
+
python-version: '3.9'
|
20 |
+
|
21 |
+
- name: Install Gradio
|
22 |
+
run: python -m pip install gradio
|
23 |
+
|
24 |
+
- name: Log in to Hugging Face
|
25 |
+
run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
|
26 |
+
|
27 |
+
- name: Deploy to Spaces
|
28 |
+
run: gradio deploy
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: pink
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.41.2
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: acne_grading
|
3 |
+
app_file: gradio_test.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 3.41.2
|
|
|
|
|
6 |
---
|
|
|
|
__init__.py
ADDED
File without changes
|
__pycache__/base.cpython-39.pyc
ADDED
Binary file (1.02 kB). View file
|
|
__pycache__/config.cpython-39.pyc
ADDED
Binary file (998 Bytes). View file
|
|
__pycache__/dataset.cpython-39.pyc
ADDED
Binary file (1.57 kB). View file
|
|
__pycache__/gradio_test.cpython-39.pyc
ADDED
Binary file (2.06 kB). View file
|
|
__pycache__/utils.cpython-39.pyc
ADDED
Binary file (3.15 kB). View file
|
|
base.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from enum import Enum
|
4 |
+
|
5 |
+
class TrainingObjective(Enum):
|
6 |
+
CLASSIFICATION = 0
|
7 |
+
REGRESSION = 1
|
8 |
+
|
9 |
+
class ModelBackbone(Enum):
|
10 |
+
INCEPTION_V3 = 0
|
11 |
+
EFFICIENT_NET_B0 = 1
|
12 |
+
MOBILE_NET_V3_LARGE = 2
|
13 |
+
|
14 |
+
|
15 |
+
class TransformationType(Enum):
|
16 |
+
RESIZE = 0
|
17 |
+
TO_TENSOR = 1
|
18 |
+
RANDOM_HORIZONTAL_FLIP = 2
|
19 |
+
NORMALIZE = 3
|
20 |
+
RANDOM_ROTATION = 4
|
21 |
+
RANDOM_CLIP = 5
|
22 |
+
|
23 |
+
model_input_dict = {
|
24 |
+
ModelBackbone.INCEPTION_V3: (299, 299),
|
25 |
+
ModelBackbone.EFFICIENT_NET_B0: (224, 224),
|
26 |
+
ModelBackbone.MOBILE_NET_V3_LARGE: (224, 224)
|
27 |
+
}
|
config.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from base import TrainingObjective, ModelBackbone, model_input_dict, TransformationType
|
3 |
+
|
4 |
+
DATASET_PATH = '/Users/suyashharlalka/Documents/workspace/gabit/acne_classification/dataset/Classification/JPEGImages'
|
5 |
+
MODEL_DIR = '/Users/suyashharlalka/Documents/workspace/gabit/acne_classification/model'
|
6 |
+
|
7 |
+
MODEL_OBJECTIVE = TrainingObjective.REGRESSION
|
8 |
+
MODEL_BACKBONE = ModelBackbone.EFFICIENT_NET_B0
|
9 |
+
MODEL_INPUT_SIZE = model_input_dict[MODEL_BACKBONE]
|
10 |
+
TRANSFORMS_TO_APPLY = [
|
11 |
+
TransformationType.RESIZE.name,
|
12 |
+
TransformationType.TO_TENSOR.name,
|
13 |
+
TransformationType.RANDOM_HORIZONTAL_FLIP.name,
|
14 |
+
TransformationType.NORMALIZE.name
|
15 |
+
]
|
16 |
+
NUM_CLASSES = 4
|
17 |
+
LAST_N_LAYERS_TO_TRAIN = 5
|
18 |
+
EPOCHS = 20
|
19 |
+
MODEL_TRAINING = True
|
20 |
+
IS_LIMITED = False
|
21 |
+
BATCH_SIZE = 64
|
22 |
+
SHUFFLE = True
|
23 |
+
NUM_WORKERS = 0
|
24 |
+
|
25 |
+
BASE_LR = 0.001
|
26 |
+
LR_DECAY_STEP_SIZE = 5
|
27 |
+
LR_DECAY_GAMMA = 0.1
|
28 |
+
|
dataset.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
class AcneDataset(Dataset):
|
9 |
+
def __init__(self, dataDir, limit=True, transform=None):
|
10 |
+
self.dataDir = dataDir
|
11 |
+
self.image_names = os.listdir(self.dataDir)
|
12 |
+
self.image_names = [os.path.join(self.dataDir, x) for x in self.image_names]
|
13 |
+
self.image_names = [x for x in self.image_names if x.endswith('.jpg')]
|
14 |
+
self.image_names = sorted(self.image_names)
|
15 |
+
self.transform = transform
|
16 |
+
if limit:
|
17 |
+
self.image_names = self.image_names[1000:1200]
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
return len(self.image_names)
|
21 |
+
|
22 |
+
def __getitem__(self, idx):
|
23 |
+
imgName = self.image_names[idx]
|
24 |
+
label = imgName.split('/')[-1].split('.')[0].split('_')[0][-1]
|
25 |
+
label = int(label)
|
26 |
+
label = np.array(label).astype(np.float32)
|
27 |
+
img = Image.open(imgName)
|
28 |
+
if self.transform:
|
29 |
+
img = self.transform(img)
|
30 |
+
return img, label
|
debug.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
flagged/input_img/a37df70aebf29843d866a7ad396a79ce94debfbd/tmpvvjl4871.jpg
ADDED
![]() |
flagged/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
input_img,output,flag,username,timestamp
|
2 |
+
/Users/suyash.harlalka/Desktop/personal/acne_classification/code/flagged/input_img/a37df70aebf29843d866a7ad396a79ce94debfbd/tmpvvjl4871.jpg,/Users/suyash.harlalka/Desktop/personal/acne_classification/code/flagged/output/c1b5617618b88e89a6daa275fa1df7a05106d635/tmpdb5o51or.png,,,2023-08-31 17:29:17.512316
|
flagged/output/c1b5617618b88e89a6daa275fa1df7a05106d635/tmpdb5o51or.png
ADDED
![]() |
gradio_test.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gradio as gr
|
3 |
+
from base import TrainingObjective, ModelBackbone, model_input_dict, TransformationType
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
from utils import get_model_architecture, get_transforms_to_apply_, get_transforms_to_apply
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision import transforms
|
10 |
+
from matplotlib import pyplot as plt
|
11 |
+
from dataset import AcneDataset
|
12 |
+
import config
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
|
17 |
+
curr_dir = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
#go up one level
|
19 |
+
# curr_dir = os.path.dirname(curr_dir)
|
20 |
+
model_dir = os.path.join(curr_dir, 'model')
|
21 |
+
model_name = 'model_0'
|
22 |
+
model_dir = os.path.join(model_dir, model_name)
|
23 |
+
# model_dir = '/Users/suyashharlalka/Documents/workspace/gabit/acne_classification/model/model_0'
|
24 |
+
model_config = os.path.join(model_dir, 'config.json')
|
25 |
+
|
26 |
+
with open(model_config, 'r') as f:
|
27 |
+
config_json = json.load(f)
|
28 |
+
|
29 |
+
model = get_model_architecture(config_json)
|
30 |
+
model_path = os.path.join(model_dir, 'model.pth')
|
31 |
+
model.load_state_dict(torch.load(model_path), map_location=torch.device('cpu'))
|
32 |
+
|
33 |
+
testing_transform = config_json['TRANSFORMS_TO_APPLY']
|
34 |
+
testing_transform = [TransformationType[transform] for transform in testing_transform if transform != 'RANDOM_HORIZONTAL_FLIP']
|
35 |
+
testing_transform = [get_transforms_to_apply_(transform) for transform in testing_transform]
|
36 |
+
testing_transform = transforms.Compose(testing_transform)
|
37 |
+
|
38 |
+
# data_dir = config.DATASET_PATH
|
39 |
+
# isLimited = config.IS_LIMITED
|
40 |
+
# dataset = AcneDataset(data_dir, limit=isLimited)
|
41 |
+
# dataset.transform = testing_transform
|
42 |
+
# dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0)
|
43 |
+
device = torch.device('cpu')
|
44 |
+
|
45 |
+
def predict_acne(img):
|
46 |
+
img = testing_transform(img)
|
47 |
+
img = img.unsqueeze(0)
|
48 |
+
img = img.to(device)
|
49 |
+
output = model(img)
|
50 |
+
predicted = torch.round(output.data)
|
51 |
+
predicted = predicted.squeeze(1)
|
52 |
+
predicted = predicted.cpu().numpy()[0]
|
53 |
+
return predicted
|
54 |
+
|
55 |
+
|
56 |
+
# img_dir = '/Users/suyash.harlalka/Desktop/personal/acne_classification/dataset/Classification/JPEGImages'
|
57 |
+
# img_names = os.listdir(img_dir)
|
58 |
+
# for name in img_names:
|
59 |
+
# if name.endswith('.jpg'):
|
60 |
+
# img_path = os.path.join(img_dir, name)
|
61 |
+
# img = Image.open(img_path)
|
62 |
+
# img = img.convert('RGB')
|
63 |
+
# print(predict_acne(img))
|
64 |
+
|
65 |
+
correct = 0
|
66 |
+
total = 0
|
67 |
+
mainLabel = []
|
68 |
+
predictedLabel = []
|
69 |
+
model.to(device)
|
70 |
+
model.eval()
|
71 |
+
|
72 |
+
# for name in tqdm(img_names):
|
73 |
+
# if name.endswith('.jpg'):
|
74 |
+
# img_path = os.path.join(img_dir, name)
|
75 |
+
# img = Image.open(img_path)
|
76 |
+
# predicted = predict_acne(img)
|
77 |
+
# label = int(name.split('_')[0][-1])
|
78 |
+
# label = int(label)
|
79 |
+
# correct += (predicted == label).sum().item()
|
80 |
+
# mainLabel.append(label)
|
81 |
+
# predictedLabel.append(predicted)
|
82 |
+
# total += 1
|
83 |
+
|
84 |
+
|
85 |
+
import gradio as gr
|
86 |
+
|
87 |
+
|
88 |
+
gr.Interface(fn=predict_acne,
|
89 |
+
inputs=gr.Image(type="pil"),
|
90 |
+
outputs="number",
|
91 |
+
).launch(
|
92 |
+
share=True,
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
# for images, labels in tqdm(dataloader, desc="Processing", unit="batch"):
|
97 |
+
# images = images.to(device)
|
98 |
+
# labels = labels.to(device)
|
99 |
+
# outputs = model(images)
|
100 |
+
# # print('Outputs:{outputs}, label: {label}'.format(outputs=outputs, label=labels))
|
101 |
+
|
102 |
+
# predicted = torch.round(outputs.data)
|
103 |
+
# predicted = predicted.squeeze(1)
|
104 |
+
# total += labels.size(0)
|
105 |
+
# correct += (predicted == labels).sum().item()
|
106 |
+
|
107 |
+
# mainLabel.extend(labels.cpu().numpy())
|
108 |
+
# predictedLabel.extend(predicted.cpu().numpy())
|
109 |
+
|
110 |
+
# print(f'Accuracy of the network on the {total} test images: {100 * correct / total}%')
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
|
model/model_0/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"DATASET_PATH": "/Users/suyash.harlalka/Desktop/personal/acne_classification/dataset/Classification/JPEGImages", "MODEL_DIR": "/Users/suyash.harlalka/Desktop/personal/acne_classification/model", "MODEL_OBJECTIVE": "REGRESSION", "MODEL_BACKBONE": "EFFICIENT_NET_B0", "MODEL_INPUT_SIZE": [224, 224], "TRANSFORMS_TO_APPLY": ["RESIZE", "TO_TENSOR", "RANDOM_HORIZONTAL_FLIP", "NORMALIZE"], "NUM_CLASSES": 4, "LAST_N_LAYERS_TO_TRAIN": 5, "EPOCHS": 10, "MODEL_TRAINING": true, "IS_LIMITED": false, "BATCH_SIZE": 64, "SHUFFLE": true, "NUM_WORKERS": 0, "BASE_LR": 0.001, "LR_DECAY_STEP_SIZE": 5, "LR_DECAY_GAMMA": 0.1, "LOSS": [0.6981074272290521, 0.2594577281371407, 0.16056970701269482, 0.12656819302102792, 0.10815351424009903, 0.10907206441397252, 0.08555250760653745, 0.07441781347860461, 0.07726245827001074, 0.0703853020525497], "EPOCHS_DONE": 10, "TRAINED_MODEL_DIR": "/Users/suyash.harlalka/Desktop/personal/acne_classification/model/model_0"}
|
model/model_0/model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d84aa1c3284ac5dffc5fcf0e820c26ab684e87f3f00817bf1038aea4bd6ccb8
|
3 |
+
size 26806741
|
model/model_1/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"DATASET_PATH": "/Users/suyash.harlalka/Desktop/personal/acne_classification/dataset/Classification/JPEGImages", "MODEL_DIR": "/Users/suyash.harlalka/Desktop/personal/acne_classification/model", "MODEL_OBJECTIVE": "REGRESSION", "MODEL_BACKBONE": "EFFICIENT_NET_B0", "MODEL_INPUT_SIZE": [224, 224], "TRANSFORMS_TO_APPLY": ["RESIZE", "TO_TENSOR", "RANDOM_HORIZONTAL_FLIP", "NORMALIZE"], "NUM_CLASSES": 4, "LAST_N_LAYERS_TO_TRAIN": 5, "EPOCHS": 10, "MODEL_TRAINING": true, "IS_LIMITED": false, "BATCH_SIZE": 64, "SHUFFLE": true, "NUM_WORKERS": 0, "BASE_LR": 0.001, "LR_DECAY_STEP_SIZE": 5, "LR_DECAY_GAMMA": 0.1, "LOSS": [0.8423960444174314, 0.2454124618517725, 0.16250047322950864, 0.15113061862556557, 0.12270178175286243, 0.0968482023791263, 0.10082359141425083, 0.07315939174670923, 0.06981939018556946, 0.06618086170209081], "EPOCHS_DONE": 10, "TRAINED_MODEL_DIR": "/Users/suyash.harlalka/Desktop/personal/acne_classification/model/model_1"}
|
model/model_1/model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:89e305f1729464767d019ffcf35f04ad715a06e5f193d4005ec8be1e678ebd29
|
3 |
+
size 26806741
|
model/model_2/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"DATASET_PATH": "/Users/suyash.harlalka/Desktop/personal/acne_classification/dataset/Classification/JPEGImages", "MODEL_DIR": "/Users/suyash.harlalka/Desktop/personal/acne_classification/model", "MODEL_OBJECTIVE": "REGRESSION", "MODEL_BACKBONE": "EFFICIENT_NET_B0", "MODEL_INPUT_SIZE": [224, 224], "TRANSFORMS_TO_APPLY": ["RESIZE", "TO_TENSOR", "RANDOM_HORIZONTAL_FLIP", "NORMALIZE"], "NUM_CLASSES": 4, "LAST_N_LAYERS_TO_TRAIN": 5, "EPOCHS": 20, "MODEL_TRAINING": true, "IS_LIMITED": false, "BATCH_SIZE": 64, "SHUFFLE": true, "NUM_WORKERS": 0, "BASE_LR": 0.001, "LR_DECAY_STEP_SIZE": 5, "LR_DECAY_GAMMA": 0.1, "LOSS": [0.7182596655268418, 0.28382997606929977, 0.21820551078570516, 0.16277591256718887, 0.11228116758559879, 0.11437285848354038, 0.11287243585837514, 0.10567777721505416], "EPOCHS_DONE": 8, "TRAINED_MODEL_DIR": "/Users/suyash.harlalka/Desktop/personal/acne_classification/model/model_2"}
|
model/model_2/model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c96803649430ea5fd25c87e3876bf0e368b438c0f47e2c92f5926c9ffd6685c
|
3 |
+
size 26806741
|
requirements.txt
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.4.0
|
2 |
+
aiofiles==22.1.0
|
3 |
+
aiohttp==3.8.4
|
4 |
+
aiosignal==1.3.1
|
5 |
+
aiosqlite==0.18.0
|
6 |
+
altair==4.2.2
|
7 |
+
antlr4-python3-runtime==4.9.3
|
8 |
+
anyio==3.6.2
|
9 |
+
APScheduler==3.6.3
|
10 |
+
argon2-cffi==21.3.0
|
11 |
+
argon2-cffi-bindings==21.2.0
|
12 |
+
arrow==1.2.3
|
13 |
+
astunparse==1.6.3
|
14 |
+
async-timeout==4.0.2
|
15 |
+
attrs==22.2.0
|
16 |
+
Babel==2.12.1
|
17 |
+
beautifulsoup4==4.12.1
|
18 |
+
bleach==6.0.0
|
19 |
+
blinker==1.6
|
20 |
+
cachetools==4.2.2
|
21 |
+
certifi==2022.12.7
|
22 |
+
cffi==1.15.1
|
23 |
+
charset-normalizer==3.1.0
|
24 |
+
click==8.1.3
|
25 |
+
colorama==0.4.6
|
26 |
+
contourpy==1.0.7
|
27 |
+
cycler==0.11.0
|
28 |
+
dataclasses-json==0.5.7
|
29 |
+
datasets==2.11.0
|
30 |
+
defusedxml==0.7.1
|
31 |
+
dill==0.3.6
|
32 |
+
e==1.4.5
|
33 |
+
evaluate==0.4.0
|
34 |
+
fastapi==0.103.0
|
35 |
+
fastjsonschema==2.16.3
|
36 |
+
ffmpy==0.3.1
|
37 |
+
filelock==3.10.7
|
38 |
+
flatbuffers==23.3.3
|
39 |
+
fonttools==4.39.3
|
40 |
+
fqdn==1.5.1
|
41 |
+
frozenlist==1.3.3
|
42 |
+
fsspec==2023.3.0
|
43 |
+
ftfy==6.1.1
|
44 |
+
gast==0.4.0
|
45 |
+
gdown==4.7.1
|
46 |
+
gitdb==4.0.10
|
47 |
+
GitPython==3.1.31
|
48 |
+
google-auth==2.17.0
|
49 |
+
google-auth-oauthlib==1.0.0
|
50 |
+
google-pasta==0.2.0
|
51 |
+
gradio==3.41.2
|
52 |
+
gradio_client==0.5.0
|
53 |
+
grpcio==1.53.0
|
54 |
+
h11==0.14.0
|
55 |
+
h5py
|
56 |
+
httpcore==0.17.3
|
57 |
+
httpx==0.24.1
|
58 |
+
huggingface-hub==0.16.4
|
59 |
+
idna==3.4
|
60 |
+
imageio==2.31.0
|
61 |
+
imgaug==0.4.0
|
62 |
+
importlib-metadata==6.1.0
|
63 |
+
importlib-resources==5.12.0
|
64 |
+
ipykernel
|
65 |
+
ipython
|
66 |
+
ipython-genutils==0.2.0
|
67 |
+
isoduration==20.11.0
|
68 |
+
jax==0.4.8
|
69 |
+
jedi
|
70 |
+
Jinja2==3.1.2
|
71 |
+
joblib==1.2.0
|
72 |
+
json-fix==0.5.2
|
73 |
+
json5==0.9.11
|
74 |
+
jsonpointer==2.3
|
75 |
+
jsonschema==4.17.3
|
76 |
+
jupyter-events==0.6.3
|
77 |
+
jupyter-http-over-ws==0.0.8
|
78 |
+
jupyter-server==1.23.6
|
79 |
+
jupyter-ydoc==0.2.3
|
80 |
+
jupyter_client==7.4.1
|
81 |
+
jupyter_core
|
82 |
+
jupyter_server_fileid==0.8.0
|
83 |
+
jupyter_server_terminals==0.4.4
|
84 |
+
jupyter_server_ydoc==0.8.0
|
85 |
+
jupyterlab==3.6.3
|
86 |
+
jupyterlab-pygments==0.2.2
|
87 |
+
jupyterlab_server==2.22.0
|
88 |
+
keras==2.13.1
|
89 |
+
kiwisolver==1.4.4
|
90 |
+
langchain==0.0.158
|
91 |
+
lazy_loader==0.2
|
92 |
+
libclang==16.0.0
|
93 |
+
lightning-utilities==0.8.0
|
94 |
+
Markdown==3.4.3
|
95 |
+
markdown-it-py==2.2.0
|
96 |
+
MarkupSafe==2.1.2
|
97 |
+
marshmallow==3.19.0
|
98 |
+
marshmallow-enum==1.5.1
|
99 |
+
matplotlib==3.7.1
|
100 |
+
matplotlib-inline
|
101 |
+
mdurl==0.1.2
|
102 |
+
mistune==2.0.5
|
103 |
+
ml-dtypes==0.0.4
|
104 |
+
mpmath==1.3.0
|
105 |
+
multidict==6.0.4
|
106 |
+
multiprocess==0.70.14
|
107 |
+
mypy-extensions==1.0.0
|
108 |
+
nbclassic==0.5.5
|
109 |
+
nbclient==0.7.3
|
110 |
+
nbconvert==7.3.0
|
111 |
+
nbformat==5.8.0
|
112 |
+
nest-asyncio
|
113 |
+
networkx==3.0
|
114 |
+
notebook==6.5.3
|
115 |
+
notebook_shim==0.2.2
|
116 |
+
numexpr==2.8.4
|
117 |
+
numpy==1.23.5
|
118 |
+
oauthlib==3.2.2
|
119 |
+
omegaconf==2.3.0
|
120 |
+
openai==0.27.4
|
121 |
+
openapi-schema-pydantic==1.2.4
|
122 |
+
opencv-python==4.7.0.72
|
123 |
+
opt-einsum==3.3.0
|
124 |
+
orjson==3.9.5
|
125 |
+
packaging
|
126 |
+
pandas==1.5.3
|
127 |
+
pandocfilters==1.5.0
|
128 |
+
parso
|
129 |
+
pexpect
|
130 |
+
pickleshare
|
131 |
+
Pillow==9.4.0
|
132 |
+
platformdirs
|
133 |
+
plotly==5.14.0
|
134 |
+
powerlaw==1.5
|
135 |
+
prometheus-client==0.16.0
|
136 |
+
prompt-toolkit
|
137 |
+
protobuf==3.20.3
|
138 |
+
psutil
|
139 |
+
ptyprocess
|
140 |
+
pure-eval
|
141 |
+
pyarrow==11.0.0
|
142 |
+
pyasn1==0.4.8
|
143 |
+
pyasn1-modules==0.2.8
|
144 |
+
pycparser==2.21
|
145 |
+
pydantic==1.10.7
|
146 |
+
pydeck==0.8.0
|
147 |
+
pydub==0.25.1
|
148 |
+
Pygments
|
149 |
+
Pympler==1.0.1
|
150 |
+
pyparsing==3.0.9
|
151 |
+
pyrsistent==0.19.3
|
152 |
+
PySocks==1.7.1
|
153 |
+
python-dateutil
|
154 |
+
python-json-logger==2.0.7
|
155 |
+
python-multipart==0.0.6
|
156 |
+
python-telegram-bot==13.7
|
157 |
+
pytorch-lightning==2.0.1.post0
|
158 |
+
pytz==2023.3
|
159 |
+
pytz-deprecation-shim==0.1.0.post0
|
160 |
+
PyWavelets==1.4.1
|
161 |
+
PyYAML==6.0
|
162 |
+
pyzmq==25.0.2
|
163 |
+
regex==2023.3.23
|
164 |
+
requests==2.28.2
|
165 |
+
requests-oauthlib==1.3.1
|
166 |
+
responses==0.18.0
|
167 |
+
retina-face==0.0.13
|
168 |
+
rfc3339-validator==0.1.4
|
169 |
+
rfc3986-validator==0.1.1
|
170 |
+
rich==13.3.3
|
171 |
+
river==0.14.0
|
172 |
+
rsa==4.9
|
173 |
+
scikit-image==0.21.0
|
174 |
+
scikit-learn==1.2.2
|
175 |
+
scipy==1.10.1
|
176 |
+
semantic-version==2.10.0
|
177 |
+
semver==3.0.0
|
178 |
+
Send2Trash==1.8.0
|
179 |
+
shapely==2.0.1
|
180 |
+
smmap==5.0.0
|
181 |
+
sniffio==1.3.0
|
182 |
+
soupsieve==2.4
|
183 |
+
SQLAlchemy==2.0.12
|
184 |
+
starlette==0.27.0
|
185 |
+
streamlit==1.20.0
|
186 |
+
sympy==1.11.1
|
187 |
+
tabulate==0.9.0
|
188 |
+
tenacity==8.2.2
|
189 |
+
tensorboard==2.13.0
|
190 |
+
tensorboard-data-server==0.7.0
|
191 |
+
tensorboard-plugin-wit==1.8.1
|
192 |
+
tensorflow==2.13.0
|
193 |
+
tensorflow-estimator==2.13.0
|
194 |
+
termcolor==2.2.0
|
195 |
+
terminado==0.17.1
|
196 |
+
threadpoolctl==3.1.0
|
197 |
+
tifffile==2023.4.12
|
198 |
+
tinycss2==1.2.1
|
199 |
+
tokenizers==0.13.2
|
200 |
+
toml==0.10.2
|
201 |
+
tomli==2.0.1
|
202 |
+
toolz==0.12.0
|
203 |
+
torch==2.0.0
|
204 |
+
torchaudio==2.0.1
|
205 |
+
torchmetrics==0.11.4
|
206 |
+
torchvision==0.15.1
|
207 |
+
tornado==6.2
|
208 |
+
tqdm==4.65.0
|
209 |
+
transformers==4.27.4
|
210 |
+
typing-inspect==0.8.0
|
211 |
+
tzdata==2023.3
|
212 |
+
tzlocal==4.3
|
213 |
+
uri-template==1.2.0
|
214 |
+
urllib3==1.26.15
|
215 |
+
uvicorn==0.23.2
|
216 |
+
validators==0.20.0
|
217 |
+
voila==0.4.0
|
218 |
+
webcolors==1.13
|
219 |
+
webencodings==0.5.1
|
220 |
+
websocket-client==1.5.1
|
221 |
+
websockets==11.0.1
|
222 |
+
weightwatcher==0.7.0.8
|
223 |
+
Werkzeug==2.2.3
|
224 |
+
wikipedia==1.4.0
|
225 |
+
wrapt==1.14.1
|
226 |
+
xxhash==3.2.0
|
227 |
+
y-py==0.5.9
|
228 |
+
yarl==1.8.2
|
229 |
+
ypy-websocket==0.8.2
|
230 |
+
zipp==3.15.0
|
train.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from tqdm import tqdm
|
8 |
+
import torchvision
|
9 |
+
from torchvision import transforms
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.optim as optim
|
12 |
+
from torch.utils.data.sampler import SubsetRandomSampler
|
13 |
+
|
14 |
+
|
15 |
+
sys.path.append('/Users/suyashharlalka/Documents/workspace/gabit/acne_classification/code')
|
16 |
+
|
17 |
+
from dataset import AcneDataset
|
18 |
+
from utils import save_model, get_transforms_to_apply,get_model_architecture, get_training_params, get_criterion
|
19 |
+
import config
|
20 |
+
from base import TrainingObjective, ModelBackbone
|
21 |
+
import json
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
data_dir = config.DATASET_PATH
|
26 |
+
image_names = os.listdir(data_dir)
|
27 |
+
model_training = config.MODEL_TRAINING
|
28 |
+
isLimited = config.IS_LIMITED
|
29 |
+
batch_size = config.BATCH_SIZE
|
30 |
+
shuffle = config.SHUFFLE
|
31 |
+
num_workers = config.NUM_WORKERS
|
32 |
+
|
33 |
+
dataset = AcneDataset(data_dir, limit=isLimited)
|
34 |
+
|
35 |
+
validation_split = 0.2
|
36 |
+
dataset_size = len(dataset)
|
37 |
+
indices = list(range(dataset_size))
|
38 |
+
split = int(np.floor(validation_split * dataset_size))
|
39 |
+
if shuffle:
|
40 |
+
np.random.seed(42)
|
41 |
+
np.random.shuffle(indices)
|
42 |
+
|
43 |
+
train_indices, test_indices = indices[split:], indices[:split]
|
44 |
+
train_sampler = SubsetRandomSampler(train_indices)
|
45 |
+
test_sampler = SubsetRandomSampler(test_indices)
|
46 |
+
|
47 |
+
transform = get_transforms_to_apply()
|
48 |
+
dataset.transform = transform
|
49 |
+
|
50 |
+
train_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler)
|
51 |
+
test_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, sampler=test_sampler)
|
52 |
+
|
53 |
+
num_classes = config.NUM_CLASSES
|
54 |
+
model = get_model_architecture()
|
55 |
+
training_params = get_training_params(model)
|
56 |
+
criterion = get_criterion()
|
57 |
+
|
58 |
+
optimizer = optim.Adam(training_params, lr=config.BASE_LR)
|
59 |
+
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=config.LR_DECAY_STEP_SIZE, gamma=config.LR_DECAY_GAMMA)
|
60 |
+
|
61 |
+
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
|
62 |
+
model.to(device)
|
63 |
+
|
64 |
+
config_json = {}
|
65 |
+
config_json['DATASET_PATH'] = config.DATASET_PATH
|
66 |
+
config_json['MODEL_DIR'] = config.MODEL_DIR
|
67 |
+
config_json['MODEL_OBJECTIVE'] = config.MODEL_OBJECTIVE.name
|
68 |
+
config_json['MODEL_BACKBONE'] = config.MODEL_BACKBONE.name
|
69 |
+
config_json['MODEL_INPUT_SIZE'] = config.MODEL_INPUT_SIZE
|
70 |
+
config_json['TRANSFORMS_TO_APPLY'] = config.TRANSFORMS_TO_APPLY
|
71 |
+
config_json['NUM_CLASSES'] = config.NUM_CLASSES
|
72 |
+
config_json['LAST_N_LAYERS_TO_TRAIN'] = config.LAST_N_LAYERS_TO_TRAIN
|
73 |
+
config_json['EPOCHS'] = config.EPOCHS
|
74 |
+
config_json['MODEL_TRAINING'] = config.MODEL_TRAINING
|
75 |
+
config_json['IS_LIMITED'] = config.IS_LIMITED
|
76 |
+
config_json['BATCH_SIZE'] = config.BATCH_SIZE
|
77 |
+
config_json['SHUFFLE'] = config.SHUFFLE
|
78 |
+
config_json['NUM_WORKERS'] = config.NUM_WORKERS
|
79 |
+
config_json['BASE_LR'] = config.BASE_LR
|
80 |
+
config_json['LR_DECAY_STEP_SIZE'] = config.LR_DECAY_STEP_SIZE
|
81 |
+
config_json['LR_DECAY_GAMMA'] = config.LR_DECAY_GAMMA
|
82 |
+
|
83 |
+
|
84 |
+
if model_training:
|
85 |
+
num_epochs = config.EPOCHS
|
86 |
+
for epoch in range(num_epochs):
|
87 |
+
model.train()
|
88 |
+
runningLoss = 0.0
|
89 |
+
for i, (images, labels) in enumerate(tqdm(train_dataloader, desc="Processing", unit="batch")):
|
90 |
+
images = images.to(device)
|
91 |
+
labels = labels.to(device)
|
92 |
+
outputs = model(images)
|
93 |
+
loss = criterion(outputs.squeeze(), labels)
|
94 |
+
optimizer.zero_grad()
|
95 |
+
loss.backward()
|
96 |
+
optimizer.step()
|
97 |
+
runningLoss += loss.item()
|
98 |
+
|
99 |
+
# scheduler.step()
|
100 |
+
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {runningLoss/len(train_dataloader):.4f}')
|
101 |
+
|
102 |
+
if 'LOSS' not in config_json:
|
103 |
+
config_json['LOSS'] = []
|
104 |
+
config_json['LOSS'].append(runningLoss/len(train_dataloader))
|
105 |
+
|
106 |
+
if 'EPOCHS_DONE' not in config_json:
|
107 |
+
config_json['EPOCHS_DONE'] = 0
|
108 |
+
config_json['EPOCHS_DONE'] = epoch + 1
|
109 |
+
|
110 |
+
if epoch == 0 :
|
111 |
+
model_dir = save_model(model, config_json)
|
112 |
+
else:
|
113 |
+
model_dir = save_model(model, config_json, model_dir)
|
114 |
+
|
115 |
+
if 'TRAINED_MODEL_DIR' not in config_json:
|
116 |
+
config_json['TRAINED_MODEL_DIR'] = model_dir
|
117 |
+
|
118 |
+
config_save_path = os.path.join(model_dir, 'config.json')
|
119 |
+
with open(config_save_path, 'w') as f:
|
120 |
+
json.dump(config_json, f)
|
121 |
+
|
122 |
+
|
123 |
+
# config_path = os.path.join(model_dir, 'config.json')
|
124 |
+
config_path = '/Users/suyash.harlalka/Desktop/personal/acne_classification/model/model_1/config.json'
|
125 |
+
with open(config_path, 'r') as f:
|
126 |
+
config_loaded = json.load(f)
|
127 |
+
|
128 |
+
from sklearn.metrics import confusion_matrix
|
129 |
+
|
130 |
+
model_trained_path = os.path.join(config_loaded['TRAINED_MODEL_DIR'], 'model.pth')
|
131 |
+
model.load_state_dict(torch.load(model_trained_path))
|
132 |
+
model.eval()
|
133 |
+
with torch.no_grad():
|
134 |
+
correct = 0
|
135 |
+
total = 0
|
136 |
+
mainLabel = []
|
137 |
+
predictedLabel = []
|
138 |
+
for images, labels in tqdm(test_dataloader, desc="Processing", unit="batch"):
|
139 |
+
images = images.to(device)
|
140 |
+
labels = labels.to(device)
|
141 |
+
outputs = model(images)
|
142 |
+
predicted = torch.round(outputs.data)
|
143 |
+
predicted = predicted.squeeze(1)
|
144 |
+
total += labels.size(0)
|
145 |
+
correct += (predicted == labels).sum().item()
|
146 |
+
|
147 |
+
mainLabel.extend(labels.cpu().numpy())
|
148 |
+
predictedLabel.extend(predicted.cpu().numpy())
|
149 |
+
|
150 |
+
print(f'Accuracy of the network on the {total} test images: {100 * correct / total}%')
|
151 |
+
|
152 |
+
cft = confusion_matrix(mainLabel, predictedLabel, labels=[0, 1, 2, 3], normalize='true')
|
153 |
+
print(cft)
|
154 |
+
|
155 |
+
correct = 0
|
156 |
+
total = 0
|
157 |
+
mainLabel = []
|
158 |
+
predictedLabel = []
|
159 |
+
for images, labels in tqdm(train_dataloader, desc="Processing", unit="batch"):
|
160 |
+
images = images.to(device)
|
161 |
+
labels = labels.to(device)
|
162 |
+
outputs = model(images)
|
163 |
+
predicted = torch.round(outputs.data)
|
164 |
+
predicted = predicted.squeeze(1)
|
165 |
+
total += labels.size(0)
|
166 |
+
correct += (predicted == labels).sum().item()
|
167 |
+
|
168 |
+
mainLabel.extend(labels.cpu().numpy())
|
169 |
+
predictedLabel.extend(predicted.cpu().numpy())
|
170 |
+
|
171 |
+
print(f'Accuracy of the network on the {total} train images: {100 * correct / total}%')
|
172 |
+
cft = confusion_matrix(mainLabel, predictedLabel, labels=[0, 1, 2, 3], normalize='true')
|
173 |
+
print(cft)
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
|
utils.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from config import MODEL_DIR, MODEL_INPUT_SIZE, TRANSFORMS_TO_APPLY, MODEL_BACKBONE, MODEL_OBJECTIVE, LAST_N_LAYERS_TO_TRAIN
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import json
|
6 |
+
from base import TransformationType, ModelBackbone, TrainingObjective
|
7 |
+
from torchvision import transforms
|
8 |
+
import torchvision
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
def save_model(model, config_json,model_dir = None):
|
14 |
+
if model_dir is None:
|
15 |
+
model_basedir = MODEL_DIR
|
16 |
+
models_present_in_dir = os.listdir(model_basedir)
|
17 |
+
|
18 |
+
model_dir_name = 'model_{}'.format(len(models_present_in_dir))
|
19 |
+
model_dir = os.path.join(model_basedir, model_dir_name)
|
20 |
+
os.mkdir(model_dir)
|
21 |
+
|
22 |
+
model_path = os.path.join(model_dir, 'model.pth')
|
23 |
+
torch.save(model.state_dict(), model_path)
|
24 |
+
config_path = os.path.join(model_dir, 'config.json')
|
25 |
+
# import pdb; pdb.set_trace()
|
26 |
+
with open(config_path, 'w') as f:
|
27 |
+
json.dump(config_json, f)
|
28 |
+
|
29 |
+
return model_dir
|
30 |
+
|
31 |
+
def get_transforms_to_apply_(transformation_type, config_json = None):
|
32 |
+
if config_json:
|
33 |
+
model_input_size = config_json['MODEL_INPUT_SIZE']
|
34 |
+
else:
|
35 |
+
model_input_size = MODEL_INPUT_SIZE
|
36 |
+
|
37 |
+
if transformation_type == TransformationType.RESIZE:
|
38 |
+
return transforms.Resize(model_input_size)
|
39 |
+
elif transformation_type == TransformationType.TO_TENSOR:
|
40 |
+
return transforms.ToTensor()
|
41 |
+
elif transformation_type == TransformationType.RANDOM_HORIZONTAL_FLIP:
|
42 |
+
return transforms.RandomHorizontalFlip(p=0.5)
|
43 |
+
elif transformation_type == TransformationType.NORMALIZE:
|
44 |
+
return transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
45 |
+
std=[0.229, 0.224, 0.225])
|
46 |
+
elif transformation_type == TransformationType.RANDOM_ROTATION:
|
47 |
+
return transforms.RandomRotation(degrees=10)
|
48 |
+
elif transformation_type == TransformationType.RANDOM_CLIP:
|
49 |
+
return transforms.RandomCrop(model_input_size)
|
50 |
+
else:
|
51 |
+
raise Exception("Invalid transformation type")
|
52 |
+
|
53 |
+
def get_transforms_to_apply():
|
54 |
+
transforms_to_apply = []
|
55 |
+
for transform in TRANSFORMS_TO_APPLY:
|
56 |
+
transforms_to_apply.append(get_transforms_to_apply_(TransformationType[transform]))
|
57 |
+
return transforms.Compose(transforms_to_apply)
|
58 |
+
|
59 |
+
def get_model_architecture(config_json = None):
|
60 |
+
if config_json:
|
61 |
+
model_backbone = ModelBackbone[config_json['MODEL_BACKBONE']]
|
62 |
+
model_objective = TrainingObjective[config_json['MODEL_OBJECTIVE']]
|
63 |
+
else:
|
64 |
+
model_backbone = MODEL_BACKBONE
|
65 |
+
model_objective = MODEL_OBJECTIVE
|
66 |
+
if model_backbone == ModelBackbone.EFFICIENT_NET_B0:
|
67 |
+
if model_objective == TrainingObjective.REGRESSION:
|
68 |
+
model = torchvision.models.efficientnet_b0(pretrained=True)
|
69 |
+
model.classifier[1] = nn.Sequential(
|
70 |
+
nn.Linear(model.classifier[1].in_features, 2048),
|
71 |
+
nn.ReLU(),
|
72 |
+
nn.Dropout(0.5),
|
73 |
+
nn.Linear(2048, 1),
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
raise Exception("Invalid model objective")
|
77 |
+
else:
|
78 |
+
raise Exception("Invalid model backbone")
|
79 |
+
|
80 |
+
return model
|
81 |
+
|
82 |
+
def get_training_params(model):
|
83 |
+
training_params = []
|
84 |
+
if MODEL_BACKBONE == ModelBackbone.EFFICIENT_NET_B0:
|
85 |
+
if LAST_N_LAYERS_TO_TRAIN > 0:
|
86 |
+
for param in model.features[:-LAST_N_LAYERS_TO_TRAIN].parameters():
|
87 |
+
param.requires_grad = False
|
88 |
+
|
89 |
+
for param in model.features[-LAST_N_LAYERS_TO_TRAIN:].parameters():
|
90 |
+
training_params.append(param)
|
91 |
+
|
92 |
+
|
93 |
+
for param in model.classifier[1].parameters():
|
94 |
+
training_params.append(param)
|
95 |
+
else:
|
96 |
+
raise Exception("Invalid model backbone")
|
97 |
+
|
98 |
+
return training_params
|
99 |
+
|
100 |
+
def get_criterion():
|
101 |
+
if MODEL_OBJECTIVE == TrainingObjective.REGRESSION:
|
102 |
+
criterion = nn.MSELoss()
|
103 |
+
else:
|
104 |
+
raise Exception("Invalid model objective")
|
105 |
+
|
106 |
+
return criterion
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
|