"""Pointnet++ Implementation.""" from __future__ import annotations import torch from torch import Tensor, nn from vis4d.common.ckpt import load_model_checkpoint from vis4d.common.typing import LossesType, ModelOutput from vis4d.data.const import CommonKeys as K from vis4d.op.base.pointnetpp import ( PointNet2Segmentation, PointNet2SegmentationOut, ) class PointNet2SegmentationModel(nn.Module): """PointNet++ Segmentation Model implementaiton.""" def __init__( self, num_classes: int, in_dimensions: int = 3, weights: str | None = None, ): """Creates a Pointnet+++ Model. Args: num_classes (int): Number of classes in_dimensions (int, optional): Input dimensions. Defaults to 3. weights (str, optional): Path to weights. Defaults to None. """ super().__init__() self.segmentation_model = PointNet2Segmentation( num_classes, in_dimensions ) if weights is not None: load_model_checkpoint(self, weights) def forward( self, points3d: Tensor, semantics3d: Tensor | None = None ) -> PointNet2SegmentationOut | ModelOutput: """Forward pass of the model. Extract semantic predictions. Args: points3d (Tensor): Input point shape [b, N, C]. semantics3d (torch.Tenosr): Groundtruth semantic labels of shape [b, N]. Defaults to None Returns: ModelOutput: Semantic predictions of the model. """ x = self.segmentation_model(points3d) if semantics3d is not None: return x class_pred = torch.argmax(x.class_logits, dim=1) return {K.semantics3d: class_pred} class Pointnet2SegmentationLoss(nn.Module): """Pointnet2SegmentationLoss Loss.""" def __init__( self, ignore_index: int = 255, semantic_weights: Tensor | None = None, ) -> None: """Creates an instance of the class. Args: ignore_index (int, optional): Class Index that should be ignored. Defaults to 255. semantic_weights (Tensor, optional): Weights for each class. """ super().__init__() self.segmentation_loss = nn.CrossEntropyLoss( weight=semantic_weights, ignore_index=ignore_index ) def forward( self, outputs: PointNet2SegmentationOut, semantics3d: Tensor ) -> LossesType: """Calculates the loss. Args: outputs (PointNet2SegmentationOut): Model outputs. semantics3d (Tensor): Groundtruth semantic labels. """ return dict( segmentation_loss=self.segmentation_loss( outputs.class_logits, semantics3d ), )