Spaces:
Building
Building
File size: 3,546 Bytes
4187c6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import torch
import pytorch_lightning as pl
from pathlib import Path
from typing import Any
import torchvision
import wandb
class EvalSaveCallback(pl.Callback):
def __init__(self, save_dir: Path) -> None:
super().__init__()
self.save_dir = save_dir
def save(self, outputs, batch, batch_idx):
name = batch['name']
filename = self.save_dir / f"{batch_idx:06d}_{name[0]}.pt"
torch.save({
"fpv": batch['image'],
"seg_masks": batch['seg_masks'],
'name': name,
"output": outputs["output"],
"valid_bev": outputs["valid_bev"],
}, filename)
def on_test_batch_end(self, trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: torch.Tensor | Any | None,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0) -> None:
if not outputs:
return
self.save(outputs, batch, batch_idx)
def on_validation_batch_end(self, trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: torch.Tensor | Any | None,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0) -> None:
if not outputs:
return
self.save(outputs, batch, batch_idx)
class ImageLoggerCallback(pl.Callback):
def __init__(self, num_classes):
super().__init__()
self.num_classes = num_classes
def log_image(self, trainer, pl_module, outputs, batch, batch_idx, mode="train"):
fpv_rgb = batch["image"]
fpv_grid = torchvision.utils.make_grid(
fpv_rgb, nrow=8, normalize=False)
images = [
wandb.Image(fpv_grid, caption="fpv")
]
pred = outputs['output'].permute(0, 2, 3, 1)
pred[outputs["valid_bev"][..., :-1] == 0] = 0
pred = (pred > 0.5).float()
pred = pred.permute(0, 3, 1, 2)
for i in range(self.num_classes):
gt_class_i = batch['seg_masks'][..., i]
gt_class_i_grid = torchvision.utils.make_grid(
gt_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0)
pred_class_i = pred[:, i]
pred_class_i_grid = torchvision.utils.make_grid(
pred_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0)
images += [
wandb.Image(gt_class_i_grid, caption=f"gt_class_{i}"),
wandb.Image(pred_class_i_grid, caption=f"pred_class_{i}")
]
trainer.logger.experiment.log(
{
"{}/images".format(mode): images
}
)
def on_validation_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx):
if batch_idx == 0:
with torch.no_grad():
outputs = pl_module(batch)
self.log_image(trainer, pl_module, outputs,
batch, batch_idx, mode="val")
def on_train_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx):
if batch_idx == 0:
pl_module.eval()
with torch.no_grad():
outputs = pl_module(batch)
self.log_image(trainer, pl_module, outputs,
batch, batch_idx, mode="train")
pl_module.train()
|