|
|
|
|
|
|
|
"""Trainers for semantic segmentation.""" |
|
|
|
import os |
|
import warnings |
|
from abc import ABC, abstractmethod |
|
from collections import OrderedDict |
|
from collections.abc import Sequence |
|
from typing import Any, Optional, Union |
|
|
|
import lightning |
|
import segmentation_models_pytorch as smp |
|
import torch |
|
import torch.nn as nn |
|
from lightning.pytorch import LightningModule |
|
from lightning.pytorch.callbacks import Callback |
|
from torch import Tensor |
|
from torch.optim import AdamW |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
from torchmetrics import MetricCollection |
|
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex |
|
from torchvision.models._api import WeightsEnum |
|
|
|
|
|
def get_weight(name: str) -> WeightsEnum: |
|
"""Get the weights enum value by its full name. |
|
|
|
.. versionadded:: 0.4 |
|
|
|
Args: |
|
name: Name of the weight enum entry. |
|
|
|
Returns: |
|
The requested weight enum. |
|
""" |
|
return eval(name) |
|
|
|
|
|
def extract_backbone(path: str) -> tuple[str, "OrderedDict[str, Tensor]"]: |
|
"""Extracts a backbone from a lightning checkpoint file. |
|
|
|
Args: |
|
path: path to checkpoint file (.ckpt) |
|
|
|
Returns: |
|
tuple containing model name and state dict |
|
|
|
Raises: |
|
ValueError: if 'model' or 'backbone' not in |
|
checkpoint['hyper_parameters'] |
|
|
|
.. versionchanged:: 0.4 |
|
Renamed from *extract_encoder* to *extract_backbone* |
|
""" |
|
checkpoint = torch.load(path, map_location=torch.device("cpu")) |
|
if "model" in checkpoint["hyper_parameters"]: |
|
name = checkpoint["hyper_parameters"]["model"] |
|
state_dict = checkpoint["state_dict"] |
|
state_dict = OrderedDict({k: v for k, v in state_dict.items() if "model." in k}) |
|
state_dict = OrderedDict( |
|
{k.replace("model.", ""): v for k, v in state_dict.items()} |
|
) |
|
elif "backbone" in checkpoint["hyper_parameters"]: |
|
name = checkpoint["hyper_parameters"]["backbone"] |
|
state_dict = checkpoint["state_dict"] |
|
state_dict = OrderedDict( |
|
{k: v for k, v in state_dict.items() if "model.backbone.model" in k} |
|
) |
|
state_dict = OrderedDict( |
|
{k.replace("model.backbone.model.", ""): v for k, v in state_dict.items()} |
|
) |
|
else: |
|
raise ValueError( |
|
"Unknown checkpoint task. Only backbone or model extraction is supported" |
|
) |
|
|
|
return name, state_dict |
|
|
|
|
|
class BaseTask(LightningModule, ABC): |
|
"""Abstract base class for all TorchGeo trainers. |
|
|
|
.. versionadded:: 0.5 |
|
""" |
|
|
|
|
|
model: Any |
|
|
|
|
|
monitor = "val_loss" |
|
|
|
|
|
mode = "min" |
|
|
|
def __init__(self, ignore: Optional[Union[Sequence[str], str]] = None) -> None: |
|
"""Initialize a new BaseTask instance. |
|
|
|
Args: |
|
ignore: Arguments to skip when saving hyperparameters. |
|
""" |
|
super().__init__() |
|
self.save_hyperparameters(ignore=ignore) |
|
self.configure_losses() |
|
self.configure_metrics() |
|
self.configure_models() |
|
|
|
def configure_losses(self) -> None: |
|
"""Initialize the loss criterion.""" |
|
|
|
def configure_metrics(self) -> None: |
|
"""Initialize the performance metrics.""" |
|
|
|
@abstractmethod |
|
def configure_models(self) -> None: |
|
"""Initialize the model.""" |
|
|
|
def configure_optimizers( |
|
self, |
|
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig": |
|
"""Initialize the optimizer and learning rate scheduler. |
|
|
|
Returns: |
|
Optimizer and learning rate scheduler. |
|
""" |
|
optimizer = AdamW(self.parameters(), lr=self.hparams["lr"]) |
|
scheduler = ReduceLROnPlateau(optimizer, patience=self.hparams["patience"]) |
|
return { |
|
"optimizer": optimizer, |
|
"lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor}, |
|
} |
|
|
|
def forward(self, *args: Any, **kwargs: Any) -> Any: |
|
"""Forward pass of the model. |
|
|
|
Args: |
|
args: Arguments to pass to model. |
|
kwargs: Keyword arguments to pass to model. |
|
|
|
Returns: |
|
Output of the model. |
|
""" |
|
return self.model(*args, **kwargs) |
|
|
|
|
|
class SemanticSegmentationTask(BaseTask): |
|
"""Semantic Segmentation.""" |
|
|
|
def __init__( |
|
self, |
|
model: str = "unet", |
|
backbone: str = "resnet50", |
|
weights: Optional[Union[WeightsEnum, str, bool]] = None, |
|
in_channels: int = 3, |
|
num_classes: int = 1000, |
|
num_filters: int = 3, |
|
loss: str = "ce", |
|
class_weights: Optional[Tensor] = None, |
|
ignore_index: Optional[int] = None, |
|
lr: float = 1e-3, |
|
patience: int = 10, |
|
freeze_backbone: bool = False, |
|
freeze_decoder: bool = False, |
|
) -> None: |
|
"""Initialize a new SemanticSegmentationTask instance. |
|
|
|
Args: |
|
model: Name of the |
|
`smp <https://smp.readthedocs.io/en/latest/models.html>`__ model to use. |
|
backbone: Name of the `timm |
|
<https://smp.readthedocs.io/en/latest/encoders_timm.html>`__ or `smp |
|
<https://smp.readthedocs.io/en/latest/encoders.html>`__ backbone to use. |
|
weights: Initial model weights. Either a weight enum, the string |
|
representation of a weight enum, True for ImageNet weights, False or |
|
None for random weights, or the path to a saved model state dict. FCN |
|
model does not support pretrained weights. Pretrained ViT weight enums |
|
are not supported yet. |
|
in_channels: Number of input channels to model. |
|
num_classes: Number of prediction classes. |
|
num_filters: Number of filters. Only applicable when model='fcn'. |
|
loss: Name of the loss function, currently supports |
|
'ce', 'jaccard' or 'focal' loss. |
|
class_weights: Optional rescaling weight given to each |
|
class and used with 'ce' loss. |
|
ignore_index: Optional integer class index to ignore in the loss and |
|
metrics. |
|
lr: Learning rate for optimizer. |
|
patience: Patience for learning rate scheduler. |
|
freeze_backbone: Freeze the backbone network to fine-tune the |
|
decoder and segmentation head. |
|
freeze_decoder: Freeze the decoder network to linear probe |
|
the segmentation head. |
|
|
|
Warns: |
|
UserWarning: When loss='jaccard' and ignore_index is specified. |
|
|
|
.. versionchanged:: 0.3 |
|
*ignore_zeros* was renamed to *ignore_index*. |
|
|
|
.. versionchanged:: 0.4 |
|
*segmentation_model*, *encoder_name*, and *encoder_weights* |
|
were renamed to *model*, *backbone*, and *weights*. |
|
|
|
.. versionadded: 0.5 |
|
The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters. |
|
|
|
.. versionchanged:: 0.5 |
|
The *weights* parameter now supports WeightEnums and checkpoint paths. |
|
*learning_rate* and *learning_rate_schedule_patience* were renamed to |
|
*lr* and *patience*. |
|
""" |
|
if ignore_index is not None and loss == "jaccard": |
|
warnings.warn( |
|
"ignore_index has no effect on training when loss='jaccard'", |
|
UserWarning, |
|
) |
|
|
|
self.weights = weights |
|
super().__init__(ignore="weights") |
|
|
|
def configure_losses(self) -> None: |
|
"""Initialize the loss criterion. |
|
|
|
Raises: |
|
ValueError: If *loss* is invalid. |
|
""" |
|
loss: str = self.hparams["loss"] |
|
ignore_index = self.hparams["ignore_index"] |
|
if loss == "ce": |
|
ignore_value = -1000 if ignore_index is None else ignore_index |
|
self.criterion = nn.CrossEntropyLoss( |
|
ignore_index=ignore_value, weight=self.hparams["class_weights"] |
|
) |
|
elif loss == "jaccard": |
|
self.criterion = smp.losses.JaccardLoss( |
|
mode="multiclass", classes=self.hparams["num_classes"] |
|
) |
|
elif loss == "focal": |
|
self.criterion = smp.losses.FocalLoss( |
|
"multiclass", ignore_index=ignore_index, normalized=True |
|
) |
|
else: |
|
raise ValueError( |
|
f"Loss type '{loss}' is not valid. " |
|
"Currently, supports 'ce', 'jaccard' or 'focal' loss." |
|
) |
|
|
|
def configure_metrics(self) -> None: |
|
"""Initialize the performance metrics. |
|
|
|
* :class:`~torchmetrics.classification.MulticlassAccuracy`: Overall accuracy |
|
(OA) using 'micro' averaging. The number of true positives divided by the |
|
dataset size. Higher values are better. |
|
* :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection |
|
over union (IoU). Uses 'micro' averaging. Higher valuers are better. |
|
|
|
.. note:: |
|
* 'Micro' averaging suits overall performance evaluation but may not reflect |
|
minority class accuracy. |
|
* 'Macro' averaging, not used here, gives equal weight to each class, useful |
|
for balanced performance assessment across imbalanced classes. |
|
""" |
|
num_classes: int = self.hparams["num_classes"] |
|
ignore_index: Optional[int] = self.hparams["ignore_index"] |
|
metrics = MetricCollection( |
|
[ |
|
MulticlassAccuracy( |
|
num_classes=num_classes, |
|
ignore_index=ignore_index, |
|
multidim_average="global", |
|
average="micro", |
|
), |
|
MulticlassJaccardIndex( |
|
num_classes=num_classes, ignore_index=ignore_index, average="micro" |
|
), |
|
] |
|
) |
|
self.train_metrics = metrics.clone(prefix="train_") |
|
self.val_metrics = metrics.clone(prefix="val_") |
|
self.test_metrics = metrics.clone(prefix="test_") |
|
|
|
def configure_models(self) -> None: |
|
"""Initialize the model. |
|
|
|
Raises: |
|
ValueError: If *model* is invalid. |
|
""" |
|
model: str = self.hparams["model"] |
|
backbone: str = self.hparams["backbone"] |
|
weights = self.weights |
|
in_channels: int = self.hparams["in_channels"] |
|
num_classes: int = self.hparams["num_classes"] |
|
num_filters: int = self.hparams["num_filters"] |
|
|
|
if model == "unet": |
|
self.model = smp.Unet( |
|
encoder_name=backbone, |
|
encoder_weights="imagenet" if weights is True else None, |
|
in_channels=in_channels, |
|
classes=num_classes, |
|
) |
|
elif model == "deeplabv3+": |
|
self.model = smp.DeepLabV3Plus( |
|
encoder_name=backbone, |
|
encoder_weights="imagenet" if weights is True else None, |
|
in_channels=in_channels, |
|
classes=num_classes, |
|
) |
|
else: |
|
raise ValueError( |
|
f"Model type '{model}' is not valid. " |
|
"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'." |
|
) |
|
|
|
if weights and weights is not True: |
|
if isinstance(weights, WeightsEnum): |
|
state_dict = weights.get_state_dict(progress=True) |
|
elif os.path.exists(weights): |
|
_, state_dict = extract_backbone(weights) |
|
else: |
|
state_dict = get_weight(weights).get_state_dict(progress=True) |
|
self.model.encoder.load_state_dict(state_dict) |
|
|
|
|
|
if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]: |
|
for param in self.model.encoder.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]: |
|
for param in self.model.decoder.parameters(): |
|
param.requires_grad = False |
|
|
|
def training_step( |
|
self, batch: Any, batch_idx: int, dataloader_idx: int = 0 |
|
) -> Tensor: |
|
"""Compute the training loss and additional metrics. |
|
|
|
Args: |
|
batch: The output of your DataLoader. |
|
batch_idx: Integer displaying index of this batch. |
|
dataloader_idx: Index of the current dataloader. |
|
|
|
Returns: |
|
The loss tensor. |
|
""" |
|
x = batch["image"] |
|
y = batch["mask"] |
|
y_hat = self(x) |
|
loss: Tensor = self.criterion(y_hat, y) |
|
self.log("train_loss", loss) |
|
self.train_metrics(y_hat, y) |
|
self.log_dict(self.train_metrics) |
|
return loss |
|
|
|
def validation_step( |
|
self, batch: Any, batch_idx: int, dataloader_idx: int = 0 |
|
) -> None: |
|
"""Compute the validation loss and additional metrics. |
|
|
|
Args: |
|
batch: The output of your DataLoader. |
|
batch_idx: Integer displaying index of this batch. |
|
dataloader_idx: Index of the current dataloader. |
|
""" |
|
x = batch["image"] |
|
y = batch["mask"] |
|
y_hat = self(x) |
|
loss = self.criterion(y_hat, y) |
|
self.log("val_loss", loss) |
|
self.val_metrics(y_hat, y) |
|
self.log_dict(self.val_metrics) |
|
|
|
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: |
|
"""Compute the test loss and additional metrics. |
|
|
|
Args: |
|
batch: The output of your DataLoader. |
|
batch_idx: Integer displaying index of this batch. |
|
dataloader_idx: Index of the current dataloader. |
|
""" |
|
x = batch["image"] |
|
y = batch["mask"] |
|
y_hat = self(x) |
|
loss = self.criterion(y_hat, y) |
|
self.log("test_loss", loss) |
|
self.test_metrics(y_hat, y) |
|
self.log_dict(self.test_metrics) |
|
|
|
def predict_step( |
|
self, batch: Any, batch_idx: int, dataloader_idx: int = 0 |
|
) -> Tensor: |
|
"""Compute the predicted class probabilities. |
|
|
|
Args: |
|
batch: The output of your DataLoader. |
|
batch_idx: Integer displaying index of this batch. |
|
dataloader_idx: Index of the current dataloader. |
|
|
|
Returns: |
|
Output predicted probabilities. |
|
""" |
|
x = batch["image"] |
|
y_hat: Tensor = self(x).softmax(dim=1) |
|
return y_hat |
|
|
|
|
|
class CustomSemanticSegmentationTask(SemanticSegmentationTask): |
|
"""A custom trainer for semantic segmentation tasks.""" |
|
|
|
def configure_callbacks(self) -> list[Callback]: |
|
"""Configures the callbacks for the trainer. |
|
|
|
Returns: |
|
an empty list to override the default callbacks, we set these in the Trainer |
|
""" |
|
return [] |
|
|