Spaces:
Running
on
Zero
Running
on
Zero
"""Pointnet++ Implementation.""" | |
from __future__ import annotations | |
import torch | |
from torch import Tensor, nn | |
from vis4d.common.ckpt import load_model_checkpoint | |
from vis4d.common.typing import LossesType, ModelOutput | |
from vis4d.data.const import CommonKeys as K | |
from vis4d.op.base.pointnetpp import ( | |
PointNet2Segmentation, | |
PointNet2SegmentationOut, | |
) | |
class PointNet2SegmentationModel(nn.Module): | |
"""PointNet++ Segmentation Model implementaiton.""" | |
def __init__( | |
self, | |
num_classes: int, | |
in_dimensions: int = 3, | |
weights: str | None = None, | |
): | |
"""Creates a Pointnet+++ Model. | |
Args: | |
num_classes (int): Number of classes | |
in_dimensions (int, optional): Input dimensions. Defaults to 3. | |
weights (str, optional): Path to weights. Defaults to None. | |
""" | |
super().__init__() | |
self.segmentation_model = PointNet2Segmentation( | |
num_classes, in_dimensions | |
) | |
if weights is not None: | |
load_model_checkpoint(self, weights) | |
def forward( | |
self, points3d: Tensor, semantics3d: Tensor | None = None | |
) -> PointNet2SegmentationOut | ModelOutput: | |
"""Forward pass of the model. Extract semantic predictions. | |
Args: | |
points3d (Tensor): Input point shape [b, N, C]. | |
semantics3d (torch.Tenosr): Groundtruth semantic labels of | |
shape [b, N]. Defaults to None | |
Returns: | |
ModelOutput: Semantic predictions of the model. | |
""" | |
x = self.segmentation_model(points3d) | |
if semantics3d is not None: | |
return x | |
class_pred = torch.argmax(x.class_logits, dim=1) | |
return {K.semantics3d: class_pred} | |
class Pointnet2SegmentationLoss(nn.Module): | |
"""Pointnet2SegmentationLoss Loss.""" | |
def __init__( | |
self, | |
ignore_index: int = 255, | |
semantic_weights: Tensor | None = None, | |
) -> None: | |
"""Creates an instance of the class. | |
Args: | |
ignore_index (int, optional): Class Index that should be ignored. | |
Defaults to 255. | |
semantic_weights (Tensor, optional): Weights for each class. | |
""" | |
super().__init__() | |
self.segmentation_loss = nn.CrossEntropyLoss( | |
weight=semantic_weights, ignore_index=ignore_index | |
) | |
def forward( | |
self, outputs: PointNet2SegmentationOut, semantics3d: Tensor | |
) -> LossesType: | |
"""Calculates the loss. | |
Args: | |
outputs (PointNet2SegmentationOut): Model outputs. | |
semantics3d (Tensor): Groundtruth semantic labels. | |
""" | |
return dict( | |
segmentation_loss=self.segmentation_loss( | |
outputs.class_logits, semantics3d | |
), | |
) | |