File size: 2,656 Bytes
f717329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) Facebook, Inc. and its affiliates.

# pyre-unsafe

from typing import Optional
from torch import nn

from detectron2.config import CfgNode

from .cse.embedder import Embedder
from .filter import DensePoseDataFilter


def build_densepose_predictor(cfg: CfgNode, input_channels: int):
    """

    Create an instance of DensePose predictor based on configuration options.



    Args:

        cfg (CfgNode): configuration options

        input_channels (int): input tensor size along the channel dimension

    Return:

        An instance of DensePose predictor

    """
    from .predictors import DENSEPOSE_PREDICTOR_REGISTRY

    predictor_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.PREDICTOR_NAME
    return DENSEPOSE_PREDICTOR_REGISTRY.get(predictor_name)(cfg, input_channels)


def build_densepose_data_filter(cfg: CfgNode):
    """

    Build DensePose data filter which selects data for training



    Args:

        cfg (CfgNode): configuration options



    Return:

        Callable: list(Tensor), list(Instances) -> list(Tensor), list(Instances)

        An instance of DensePose filter, which takes feature tensors and proposals

        as an input and returns filtered features and proposals

    """
    dp_filter = DensePoseDataFilter(cfg)
    return dp_filter


def build_densepose_head(cfg: CfgNode, input_channels: int):
    """

    Build DensePose head based on configurations options



    Args:

        cfg (CfgNode): configuration options

        input_channels (int): input tensor size along the channel dimension

    Return:

        An instance of DensePose head

    """
    from .roi_heads.registry import ROI_DENSEPOSE_HEAD_REGISTRY

    head_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.NAME
    return ROI_DENSEPOSE_HEAD_REGISTRY.get(head_name)(cfg, input_channels)


def build_densepose_losses(cfg: CfgNode):
    """

    Build DensePose loss based on configurations options



    Args:

        cfg (CfgNode): configuration options

    Return:

        An instance of DensePose loss

    """
    from .losses import DENSEPOSE_LOSS_REGISTRY

    loss_name = cfg.MODEL.ROI_DENSEPOSE_HEAD.LOSS_NAME
    return DENSEPOSE_LOSS_REGISTRY.get(loss_name)(cfg)


def build_densepose_embedder(cfg: CfgNode) -> Optional[nn.Module]:
    """

    Build embedder used to embed mesh vertices into an embedding space.

    Embedder contains sub-embedders, one for each mesh ID.



    Args:

        cfg (cfgNode): configuration options

    Return:

        Embedding module

    """
    if cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS:
        return Embedder(cfg)
    return None