3D-MOOD / vis4d /model /segment3d /pointnetpp.py
RoyYang0714's picture
feat: Try to build everything locally.
9b33fca
"""Pointnet++ Implementation."""
from __future__ import annotations
import torch
from torch import Tensor, nn
from vis4d.common.ckpt import load_model_checkpoint
from vis4d.common.typing import LossesType, ModelOutput
from vis4d.data.const import CommonKeys as K
from vis4d.op.base.pointnetpp import (
PointNet2Segmentation,
PointNet2SegmentationOut,
)
class PointNet2SegmentationModel(nn.Module):
"""PointNet++ Segmentation Model implementaiton."""
def __init__(
self,
num_classes: int,
in_dimensions: int = 3,
weights: str | None = None,
):
"""Creates a Pointnet+++ Model.
Args:
num_classes (int): Number of classes
in_dimensions (int, optional): Input dimensions. Defaults to 3.
weights (str, optional): Path to weights. Defaults to None.
"""
super().__init__()
self.segmentation_model = PointNet2Segmentation(
num_classes, in_dimensions
)
if weights is not None:
load_model_checkpoint(self, weights)
def forward(
self, points3d: Tensor, semantics3d: Tensor | None = None
) -> PointNet2SegmentationOut | ModelOutput:
"""Forward pass of the model. Extract semantic predictions.
Args:
points3d (Tensor): Input point shape [b, N, C].
semantics3d (torch.Tenosr): Groundtruth semantic labels of
shape [b, N]. Defaults to None
Returns:
ModelOutput: Semantic predictions of the model.
"""
x = self.segmentation_model(points3d)
if semantics3d is not None:
return x
class_pred = torch.argmax(x.class_logits, dim=1)
return {K.semantics3d: class_pred}
class Pointnet2SegmentationLoss(nn.Module):
"""Pointnet2SegmentationLoss Loss."""
def __init__(
self,
ignore_index: int = 255,
semantic_weights: Tensor | None = None,
) -> None:
"""Creates an instance of the class.
Args:
ignore_index (int, optional): Class Index that should be ignored.
Defaults to 255.
semantic_weights (Tensor, optional): Weights for each class.
"""
super().__init__()
self.segmentation_loss = nn.CrossEntropyLoss(
weight=semantic_weights, ignore_index=ignore_index
)
def forward(
self, outputs: PointNet2SegmentationOut, semantics3d: Tensor
) -> LossesType:
"""Calculates the loss.
Args:
outputs (PointNet2SegmentationOut): Model outputs.
semantics3d (Tensor): Groundtruth semantic labels.
"""
return dict(
segmentation_loss=self.segmentation_loss(
outputs.class_logits, semantics3d
),
)