File size: 2,838 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""Pointnet++ Implementation."""

from __future__ import annotations

import torch
from torch import Tensor, nn

from vis4d.common.ckpt import load_model_checkpoint
from vis4d.common.typing import LossesType, ModelOutput
from vis4d.data.const import CommonKeys as K
from vis4d.op.base.pointnetpp import (
    PointNet2Segmentation,
    PointNet2SegmentationOut,
)


class PointNet2SegmentationModel(nn.Module):
    """PointNet++ Segmentation Model implementaiton."""

    def __init__(
        self,
        num_classes: int,
        in_dimensions: int = 3,
        weights: str | None = None,
    ):
        """Creates a Pointnet+++ Model.

        Args:
            num_classes (int): Number of classes
            in_dimensions (int, optional): Input dimensions. Defaults to 3.
            weights (str, optional): Path to weights. Defaults to None.
        """
        super().__init__()

        self.segmentation_model = PointNet2Segmentation(
            num_classes, in_dimensions
        )

        if weights is not None:
            load_model_checkpoint(self, weights)

    def forward(
        self, points3d: Tensor, semantics3d: Tensor | None = None
    ) -> PointNet2SegmentationOut | ModelOutput:
        """Forward pass of the model. Extract semantic predictions.

        Args:
            points3d (Tensor): Input point shape [b, N, C].
            semantics3d (torch.Tenosr): Groundtruth semantic labels of
                shape [b, N]. Defaults to None

        Returns:
            ModelOutput: Semantic predictions of the model.
        """
        x = self.segmentation_model(points3d)
        if semantics3d is not None:
            return x
        class_pred = torch.argmax(x.class_logits, dim=1)
        return {K.semantics3d: class_pred}


class Pointnet2SegmentationLoss(nn.Module):
    """Pointnet2SegmentationLoss Loss."""

    def __init__(
        self,
        ignore_index: int = 255,
        semantic_weights: Tensor | None = None,
    ) -> None:
        """Creates an instance of the class.

        Args:
            ignore_index (int, optional): Class Index that should be ignored.
                Defaults to 255.
            semantic_weights (Tensor, optional): Weights for each class.
        """
        super().__init__()
        self.segmentation_loss = nn.CrossEntropyLoss(
            weight=semantic_weights, ignore_index=ignore_index
        )

    def forward(
        self, outputs: PointNet2SegmentationOut, semantics3d: Tensor
    ) -> LossesType:
        """Calculates the loss.

        Args:
            outputs (PointNet2SegmentationOut): Model outputs.
            semantics3d (Tensor): Groundtruth semantic labels.
        """
        return dict(
            segmentation_loss=self.segmentation_loss(
                outputs.class_logits, semantics3d
            ),
        )