Spaces:
Runtime error
Runtime error
import numpy as np | |
import os | |
import sys | |
import cv2 | |
import torch | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
import torchvision | |
from torchvision import transforms | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data.sampler import SubsetRandomSampler | |
sys.path.append('/Users/suyashharlalka/Documents/workspace/gabit/acne_classification/code') | |
from dataset import AcneDataset | |
from utils import save_model, get_transforms_to_apply,get_model_architecture, get_training_params, get_criterion | |
import config | |
from base import TrainingObjective, ModelBackbone | |
import json | |
data_dir = config.DATASET_PATH | |
image_names = os.listdir(data_dir) | |
model_training = config.MODEL_TRAINING | |
isLimited = config.IS_LIMITED | |
batch_size = config.BATCH_SIZE | |
shuffle = config.SHUFFLE | |
num_workers = config.NUM_WORKERS | |
dataset = AcneDataset(data_dir, limit=isLimited) | |
validation_split = 0.2 | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
split = int(np.floor(validation_split * dataset_size)) | |
if shuffle: | |
np.random.seed(42) | |
np.random.shuffle(indices) | |
train_indices, test_indices = indices[split:], indices[:split] | |
train_sampler = SubsetRandomSampler(train_indices) | |
test_sampler = SubsetRandomSampler(test_indices) | |
transform = get_transforms_to_apply() | |
dataset.transform = transform | |
train_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, sampler=train_sampler) | |
test_dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, sampler=test_sampler) | |
num_classes = config.NUM_CLASSES | |
model = get_model_architecture() | |
training_params = get_training_params(model) | |
criterion = get_criterion() | |
optimizer = optim.Adam(training_params, lr=config.BASE_LR) | |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=config.LR_DECAY_STEP_SIZE, gamma=config.LR_DECAY_GAMMA) | |
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu') | |
model.to(device) | |
config_json = {} | |
config_json['DATASET_PATH'] = config.DATASET_PATH | |
config_json['MODEL_DIR'] = config.MODEL_DIR | |
config_json['MODEL_OBJECTIVE'] = config.MODEL_OBJECTIVE.name | |
config_json['MODEL_BACKBONE'] = config.MODEL_BACKBONE.name | |
config_json['MODEL_INPUT_SIZE'] = config.MODEL_INPUT_SIZE | |
config_json['TRANSFORMS_TO_APPLY'] = config.TRANSFORMS_TO_APPLY | |
config_json['NUM_CLASSES'] = config.NUM_CLASSES | |
config_json['LAST_N_LAYERS_TO_TRAIN'] = config.LAST_N_LAYERS_TO_TRAIN | |
config_json['EPOCHS'] = config.EPOCHS | |
config_json['MODEL_TRAINING'] = config.MODEL_TRAINING | |
config_json['IS_LIMITED'] = config.IS_LIMITED | |
config_json['BATCH_SIZE'] = config.BATCH_SIZE | |
config_json['SHUFFLE'] = config.SHUFFLE | |
config_json['NUM_WORKERS'] = config.NUM_WORKERS | |
config_json['BASE_LR'] = config.BASE_LR | |
config_json['LR_DECAY_STEP_SIZE'] = config.LR_DECAY_STEP_SIZE | |
config_json['LR_DECAY_GAMMA'] = config.LR_DECAY_GAMMA | |
if model_training: | |
num_epochs = config.EPOCHS | |
for epoch in range(num_epochs): | |
model.train() | |
runningLoss = 0.0 | |
for i, (images, labels) in enumerate(tqdm(train_dataloader, desc="Processing", unit="batch")): | |
images = images.to(device) | |
labels = labels.to(device) | |
outputs = model(images) | |
loss = criterion(outputs.squeeze(), labels) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
runningLoss += loss.item() | |
# scheduler.step() | |
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {runningLoss/len(train_dataloader):.4f}') | |
if 'LOSS' not in config_json: | |
config_json['LOSS'] = [] | |
config_json['LOSS'].append(runningLoss/len(train_dataloader)) | |
if 'EPOCHS_DONE' not in config_json: | |
config_json['EPOCHS_DONE'] = 0 | |
config_json['EPOCHS_DONE'] = epoch + 1 | |
if epoch == 0 : | |
model_dir = save_model(model, config_json) | |
else: | |
model_dir = save_model(model, config_json, model_dir) | |
if 'TRAINED_MODEL_DIR' not in config_json: | |
config_json['TRAINED_MODEL_DIR'] = model_dir | |
config_save_path = os.path.join(model_dir, 'config.json') | |
with open(config_save_path, 'w') as f: | |
json.dump(config_json, f) | |
# config_path = os.path.join(model_dir, 'config.json') | |
config_path = '/Users/suyash.harlalka/Desktop/personal/acne_classification/model/model_1/config.json' | |
with open(config_path, 'r') as f: | |
config_loaded = json.load(f) | |
from sklearn.metrics import confusion_matrix | |
model_trained_path = os.path.join(config_loaded['TRAINED_MODEL_DIR'], 'model.pth') | |
model.load_state_dict(torch.load(model_trained_path)) | |
model.eval() | |
with torch.no_grad(): | |
correct = 0 | |
total = 0 | |
mainLabel = [] | |
predictedLabel = [] | |
for images, labels in tqdm(test_dataloader, desc="Processing", unit="batch"): | |
images = images.to(device) | |
labels = labels.to(device) | |
outputs = model(images) | |
predicted = torch.round(outputs.data) | |
predicted = predicted.squeeze(1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
mainLabel.extend(labels.cpu().numpy()) | |
predictedLabel.extend(predicted.cpu().numpy()) | |
print(f'Accuracy of the network on the {total} test images: {100 * correct / total}%') | |
cft = confusion_matrix(mainLabel, predictedLabel, labels=[0, 1, 2, 3], normalize='true') | |
print(cft) | |
correct = 0 | |
total = 0 | |
mainLabel = [] | |
predictedLabel = [] | |
for images, labels in tqdm(train_dataloader, desc="Processing", unit="batch"): | |
images = images.to(device) | |
labels = labels.to(device) | |
outputs = model(images) | |
predicted = torch.round(outputs.data) | |
predicted = predicted.squeeze(1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
mainLabel.extend(labels.cpu().numpy()) | |
predictedLabel.extend(predicted.cpu().numpy()) | |
print(f'Accuracy of the network on the {total} train images: {100 * correct / total}%') | |
cft = confusion_matrix(mainLabel, predictedLabel, labels=[0, 1, 2, 3], normalize='true') | |
print(cft) | |