Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,838 Bytes
9b33fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
"""Pointnet++ Implementation."""
from __future__ import annotations
import torch
from torch import Tensor, nn
from vis4d.common.ckpt import load_model_checkpoint
from vis4d.common.typing import LossesType, ModelOutput
from vis4d.data.const import CommonKeys as K
from vis4d.op.base.pointnetpp import (
PointNet2Segmentation,
PointNet2SegmentationOut,
)
class PointNet2SegmentationModel(nn.Module):
"""PointNet++ Segmentation Model implementaiton."""
def __init__(
self,
num_classes: int,
in_dimensions: int = 3,
weights: str | None = None,
):
"""Creates a Pointnet+++ Model.
Args:
num_classes (int): Number of classes
in_dimensions (int, optional): Input dimensions. Defaults to 3.
weights (str, optional): Path to weights. Defaults to None.
"""
super().__init__()
self.segmentation_model = PointNet2Segmentation(
num_classes, in_dimensions
)
if weights is not None:
load_model_checkpoint(self, weights)
def forward(
self, points3d: Tensor, semantics3d: Tensor | None = None
) -> PointNet2SegmentationOut | ModelOutput:
"""Forward pass of the model. Extract semantic predictions.
Args:
points3d (Tensor): Input point shape [b, N, C].
semantics3d (torch.Tenosr): Groundtruth semantic labels of
shape [b, N]. Defaults to None
Returns:
ModelOutput: Semantic predictions of the model.
"""
x = self.segmentation_model(points3d)
if semantics3d is not None:
return x
class_pred = torch.argmax(x.class_logits, dim=1)
return {K.semantics3d: class_pred}
class Pointnet2SegmentationLoss(nn.Module):
"""Pointnet2SegmentationLoss Loss."""
def __init__(
self,
ignore_index: int = 255,
semantic_weights: Tensor | None = None,
) -> None:
"""Creates an instance of the class.
Args:
ignore_index (int, optional): Class Index that should be ignored.
Defaults to 255.
semantic_weights (Tensor, optional): Weights for each class.
"""
super().__init__()
self.segmentation_loss = nn.CrossEntropyLoss(
weight=semantic_weights, ignore_index=ignore_index
)
def forward(
self, outputs: PointNet2SegmentationOut, semantics3d: Tensor
) -> LossesType:
"""Calculates the loss.
Args:
outputs (PointNet2SegmentationOut): Model outputs.
semantics3d (Tensor): Groundtruth semantic labels.
"""
return dict(
segmentation_loss=self.segmentation_loss(
outputs.class_logits, semantics3d
),
)
|