Spaces:
Runtime error
Runtime error
| from torch.utils.data import Dataset, DataLoader | |
| from loss import YoloLoss | |
| import config | |
| import torch | |
| from dataset import YOLODataset | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| import random | |
| from model import YOLOv3 | |
| import lightning.pytorch as pl | |
| def criterion(out, y, anchors): | |
| loss_fn = YoloLoss() | |
| loss = ( | |
| loss_fn(out[0], y[0], anchors[0]) | |
| + loss_fn(out[1], y[1], anchors[1]) | |
| + loss_fn(out[2], y[2], anchors[2])) | |
| return loss | |
| def get_loader(train_dataset, test_dataset): | |
| train_loader = DataLoader( | |
| dataset=train_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| num_workers=config.NUM_WORKERS, | |
| pin_memory=config.PIN_MEMORY, | |
| shuffle=True, | |
| drop_last=False, | |
| ) | |
| test_loader = DataLoader( | |
| dataset=test_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| num_workers=config.NUM_WORKERS, | |
| pin_memory=config.PIN_MEMORY, | |
| shuffle=False, | |
| drop_last=False, | |
| ) | |
| return(train_loader, test_loader) | |
| def accuracy_fn(y, out, threshold, | |
| correct_class, correct_obj, | |
| correct_noobj, tot_class_preds, | |
| tot_obj, tot_noobj): | |
| for i in range(3): | |
| obj = y[i][..., 0] == 1 # in paper this is Iobj_i | |
| noobj = y[i][..., 0] == 0 # in paper this is Iobj_i | |
| correct_class += torch.sum( | |
| torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj] | |
| ) | |
| tot_class_preds += torch.sum(obj) | |
| obj_preds = torch.sigmoid(out[i][..., 0]) > threshold | |
| correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj]) | |
| tot_obj += torch.sum(obj) | |
| correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj]) | |
| tot_noobj += torch.sum(noobj) | |
| return((correct_class/(tot_class_preds+1e-16))*100, | |
| (correct_noobj/(tot_noobj+1e-16))*100, | |
| (correct_obj/(tot_obj+1e-16))*100) | |
| def get_datasets(train_loc="/train.csv", test_loc="/test.csv"): | |
| train_dataset = YOLODataset( | |
| config.DATASET + train_loc, | |
| transform=config.train_transform, | |
| img_dir=config.IMG_DIR, | |
| label_dir=config.LABEL_DIR, | |
| anchors=config.ANCHORS, | |
| ) | |
| test_dataset = YOLODataset( | |
| config.DATASET + test_loc, | |
| transform=config.test_transform, | |
| img_dir=config.IMG_DIR, | |
| label_dir=config.LABEL_DIR, | |
| anchors=config.ANCHORS, | |
| train=False | |
| ) | |
| return(train_dataset, test_dataset) | |
| class YOLOv3Lightning(pl.LightningModule): | |
| def __init__(self, dataset=None, lr=config.LEARNING_RATE): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.model = YOLOv3(num_classes=config.NUM_CLASSES) | |
| self.lr = lr | |
| self.criterion = criterion | |
| self.losses = [] | |
| self.threshold = config.CONF_THRESHOLD | |
| self.iou_threshold = config.NMS_IOU_THRESH | |
| self.train_idx = 0 | |
| self.box_format="midpoint" | |
| self.dataset = dataset | |
| self.criterion = criterion | |
| self.accuracy_fn = accuracy_fn | |
| self.tot_class_preds, self.correct_class = 0, 0 | |
| self.tot_noobj, self.correct_noobj = 0, 0 | |
| self.tot_obj, self.correct_obj = 0, 0 | |
| self.scaled_anchors = 0 | |
| def forward(self, x): | |
| return self.model(x) | |
| def set_scaled_anchor(self, scaled_anchors): | |
| self.scaled_anchors = scaled_anchors | |
| def on_train_epoch_start(self): | |
| # Set a new image size for the dataset at the beginning of each epoch | |
| size_idx = random.choice(range(len(config.IMAGE_SIZES))) | |
| self.dataset.set_image_size(size_idx) | |
| self.set_scaled_anchor(( | |
| torch.tensor(config.ANCHORS) | |
| * torch.tensor(config.S[size_idx]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
| )) | |
| def on_validation_epoch_start(self): | |
| self.set_scaled_anchor(( | |
| torch.tensor(config.ANCHORS) | |
| * torch.tensor(config.S[1]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
| )) | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| out = self(x) | |
| loss = self.criterion(out, y, self.scaled_anchors) | |
| self.log('train_loss', loss, prog_bar=True, on_epoch=True, on_step=True, logger=True) | |
| return loss | |
| def validation_step(self, val_batch, batch_idx): | |
| x, labels = val_batch | |
| out = self(x) | |
| loss = self.criterion(out, labels, self.scaled_anchors) | |
| self.log('val_loss', loss, prog_bar=True, on_epoch=True) | |
| self.evaluate(x, labels, out, 'val') | |
| def evaluate(self, x, y, out, stage=None): | |
| # Class Accuracy | |
| class_accuracy, no_obj_accuracy, obj_accuracy = self.accuracy_fn(y, | |
| out, | |
| self.threshold, | |
| self.correct_class, | |
| self.correct_obj, | |
| self.correct_noobj, | |
| self.tot_class_preds, | |
| self.tot_obj, | |
| self.tot_noobj, ) | |
| if stage: | |
| self.log(f'{stage}_class_accuracy', class_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True) | |
| self.log(f'{stage}_no_obj_accuracy', no_obj_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True) | |
| self.log(f'{stage}_obj_accuracy', obj_accuracy, prog_bar=True, on_epoch=True, on_step=True, logger=True) | |