Spaces:
Building
Building
import torch | |
import pytorch_lightning as pl | |
from pathlib import Path | |
from typing import Any | |
import torchvision | |
import wandb | |
class EvalSaveCallback(pl.Callback): | |
def __init__(self, save_dir: Path) -> None: | |
super().__init__() | |
self.save_dir = save_dir | |
def save(self, outputs, batch, batch_idx): | |
name = batch['name'] | |
filename = self.save_dir / f"{batch_idx:06d}_{name[0]}.pt" | |
torch.save({ | |
"fpv": batch['image'], | |
"seg_masks": batch['seg_masks'], | |
'name': name, | |
"output": outputs["output"], | |
"valid_bev": outputs["valid_bev"], | |
}, filename) | |
def on_test_batch_end(self, trainer: pl.Trainer, | |
pl_module: pl.LightningModule, | |
outputs: torch.Tensor | Any | None, | |
batch: Any, | |
batch_idx: int, | |
dataloader_idx: int = 0) -> None: | |
if not outputs: | |
return | |
self.save(outputs, batch, batch_idx) | |
def on_validation_batch_end(self, trainer: pl.Trainer, | |
pl_module: pl.LightningModule, | |
outputs: torch.Tensor | Any | None, | |
batch: Any, | |
batch_idx: int, | |
dataloader_idx: int = 0) -> None: | |
if not outputs: | |
return | |
self.save(outputs, batch, batch_idx) | |
class ImageLoggerCallback(pl.Callback): | |
def __init__(self, num_classes): | |
super().__init__() | |
self.num_classes = num_classes | |
def log_image(self, trainer, pl_module, outputs, batch, batch_idx, mode="train"): | |
fpv_rgb = batch["image"] | |
fpv_grid = torchvision.utils.make_grid( | |
fpv_rgb, nrow=8, normalize=False) | |
images = [ | |
wandb.Image(fpv_grid, caption="fpv") | |
] | |
pred = outputs['output'].permute(0, 2, 3, 1) | |
pred[outputs["valid_bev"][..., :-1] == 0] = 0 | |
pred = (pred > 0.5).float() | |
pred = pred.permute(0, 3, 1, 2) | |
for i in range(self.num_classes): | |
gt_class_i = batch['seg_masks'][..., i] | |
gt_class_i_grid = torchvision.utils.make_grid( | |
gt_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0) | |
pred_class_i = pred[:, i] | |
pred_class_i_grid = torchvision.utils.make_grid( | |
pred_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0) | |
images += [ | |
wandb.Image(gt_class_i_grid, caption=f"gt_class_{i}"), | |
wandb.Image(pred_class_i_grid, caption=f"pred_class_{i}") | |
] | |
trainer.logger.experiment.log( | |
{ | |
"{}/images".format(mode): images | |
} | |
) | |
def on_validation_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx): | |
if batch_idx == 0: | |
with torch.no_grad(): | |
outputs = pl_module(batch) | |
self.log_image(trainer, pl_module, outputs, | |
batch, batch_idx, mode="val") | |
def on_train_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx): | |
if batch_idx == 0: | |
pl_module.eval() | |
with torch.no_grad(): | |
outputs = pl_module(batch) | |
self.log_image(trainer, pl_module, outputs, | |
batch, batch_idx, mode="train") | |
pl_module.train() | |