| import pytorch_lightning as pl | |
| import torch | |
| from datasets import load_metric | |
| from torch import nn | |
| from transformers import SegformerForSemanticSegmentation | |
| from typing import Dict | |
| class SidewalkSegmentationModel(pl.LightningModule): | |
| def __init__( | |
| self, | |
| num_labels: int, | |
| id2label: Dict[int, str], | |
| model_flavor: int = 0, | |
| learning_rate: float = 6e-5, | |
| ): | |
| super().__init__() | |
| self.id2label = id2label | |
| self.label2id = {v: k for k, v in id2label.items()} | |
| self.learning_rate = learning_rate | |
| self.metrics = { | |
| "train": load_metric("mean_iou"), | |
| "val": load_metric("mean_iou"), | |
| } | |
| self.model = SegformerForSemanticSegmentation.from_pretrained( | |
| f"nvidia/mit-b{model_flavor}", num_labels=num_labels, id2label=self.id2label, label2id=self.label2id, | |
| ) | |
| self.save_hyperparameters() | |
| def forward(self, *args, **kwargs): | |
| return self.model(*args, **kwargs) | |
| def training_step(self, batch, batch_idx): | |
| pixel_values = batch["pixel_values"] | |
| labels = batch["labels"] | |
| outputs = self(pixel_values=pixel_values, labels=labels) | |
| loss, logits = outputs.loss, outputs.logits | |
| self.add_batch_to_metric("train", logits, labels) | |
| self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True) | |
| return {"loss": loss} | |
| def validation_step(self, batch, batch_idx): | |
| pixel_values = batch["pixel_values"] | |
| labels = batch["labels"] | |
| outputs = self(pixel_values=pixel_values, labels=labels) | |
| loss, logits = outputs.loss, outputs.logits | |
| self.add_batch_to_metric("val", logits, labels) | |
| self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) | |
| return {"val_loss": loss} | |
| def training_epoch_end(self, training_step_outputs): | |
| """ | |
| Log the training metrics. | |
| """ | |
| metrics = self.metrics["train"].compute(num_labels=len(self.id2label), ignore_index=255, reduce_labels=False) | |
| self.log("train_mean_iou", metrics["mean_iou"], prog_bar=True, on_step=False, on_epoch=True) | |
| self.log("train_mean_acc", metrics["mean_accuracy"], prog_bar=True, on_step=False, on_epoch=True) | |
| def validation_epoch_end(self, validation_step_outputs): | |
| """ | |
| Log the validation metrics. | |
| """ | |
| metrics = self.metrics["val"].compute(num_labels=len(self.id2label), ignore_index=255, reduce_labels=False) | |
| self.log("val_mean_iou", metrics["mean_iou"], prog_bar=True, on_step=False, on_epoch=True) | |
| self.log("val_mean_acc", metrics["mean_accuracy"], prog_bar=True, on_step=False, on_epoch=True) | |
| def add_batch_to_metric(self, stage: str, logits: torch.Tensor, labels: torch.Tensor): | |
| """ | |
| Add the current batch to the metric. | |
| Parameters | |
| ---------- | |
| stage : str | |
| Stage of the training. Either "train" or "val". | |
| logits : torch.Tensor | |
| Predicted logits. | |
| labels : torch.Tensor | |
| Ground truth labels. | |
| """ | |
| with torch.no_grad(): | |
| upsampled_logits = nn.functional.interpolate( | |
| logits, size=labels.shape[-2:], mode="bilinear", align_corners=False | |
| ) | |
| predicted = upsampled_logits.argmax(dim=1) | |
| self.metrics[stage].add_batch( | |
| predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy() | |
| ) | |
| def configure_optimizers(self) -> torch.optim.AdamW: | |
| """ | |
| Configure the optimizer. | |
| Returns | |
| ------- | |
| torch.optim.AdamW | |
| Optimizer for the model | |
| """ | |
| return torch.optim.AdamW(self.parameters(), lr=self.learning_rate) | |