|
|
|
|
|
|
|
|
|
import logging
|
|
import numpy as np
|
|
import pickle
|
|
from enum import Enum
|
|
from typing import Optional
|
|
import torch
|
|
from torch import nn
|
|
|
|
from detectron2.config import CfgNode
|
|
from detectron2.utils.file_io import PathManager
|
|
|
|
from .vertex_direct_embedder import VertexDirectEmbedder
|
|
from .vertex_feature_embedder import VertexFeatureEmbedder
|
|
|
|
|
|
class EmbedderType(Enum):
|
|
"""
|
|
Embedder type which defines how vertices are mapped into the embedding space:
|
|
- "vertex_direct": direct vertex embedding
|
|
- "vertex_feature": embedding vertex features
|
|
"""
|
|
|
|
VERTEX_DIRECT = "vertex_direct"
|
|
VERTEX_FEATURE = "vertex_feature"
|
|
|
|
|
|
def create_embedder(embedder_spec: CfgNode, embedder_dim: int) -> nn.Module:
|
|
"""
|
|
Create an embedder based on the provided configuration
|
|
|
|
Args:
|
|
embedder_spec (CfgNode): embedder configuration
|
|
embedder_dim (int): embedding space dimensionality
|
|
Return:
|
|
An embedder instance for the specified configuration
|
|
Raises ValueError, in case of unexpected embedder type
|
|
"""
|
|
embedder_type = EmbedderType(embedder_spec.TYPE)
|
|
if embedder_type == EmbedderType.VERTEX_DIRECT:
|
|
embedder = VertexDirectEmbedder(
|
|
num_vertices=embedder_spec.NUM_VERTICES,
|
|
embed_dim=embedder_dim,
|
|
)
|
|
if embedder_spec.INIT_FILE != "":
|
|
embedder.load(embedder_spec.INIT_FILE)
|
|
elif embedder_type == EmbedderType.VERTEX_FEATURE:
|
|
embedder = VertexFeatureEmbedder(
|
|
num_vertices=embedder_spec.NUM_VERTICES,
|
|
feature_dim=embedder_spec.FEATURE_DIM,
|
|
embed_dim=embedder_dim,
|
|
train_features=embedder_spec.FEATURES_TRAINABLE,
|
|
)
|
|
if embedder_spec.INIT_FILE != "":
|
|
embedder.load(embedder_spec.INIT_FILE)
|
|
else:
|
|
raise ValueError(f"Unexpected embedder type {embedder_type}")
|
|
|
|
if not embedder_spec.IS_TRAINABLE:
|
|
embedder.requires_grad_(False)
|
|
|
|
return embedder
|
|
|
|
|
|
class Embedder(nn.Module):
|
|
"""
|
|
Embedder module that serves as a container for embedders to use with different
|
|
meshes. Extends Module to automatically save / load state dict.
|
|
"""
|
|
|
|
DEFAULT_MODEL_CHECKPOINT_PREFIX = "roi_heads.embedder."
|
|
|
|
def __init__(self, cfg: CfgNode):
|
|
"""
|
|
Initialize mesh embedders. An embedder for mesh `i` is stored in a submodule
|
|
"embedder_{i}".
|
|
|
|
Args:
|
|
cfg (CfgNode): configuration options
|
|
"""
|
|
super(Embedder, self).__init__()
|
|
self.mesh_names = set()
|
|
embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE
|
|
logger = logging.getLogger(__name__)
|
|
for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items():
|
|
logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}")
|
|
self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim))
|
|
self.mesh_names.add(mesh_name)
|
|
if cfg.MODEL.WEIGHTS != "":
|
|
self.load_from_model_checkpoint(cfg.MODEL.WEIGHTS)
|
|
|
|
def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None):
|
|
if prefix is None:
|
|
prefix = Embedder.DEFAULT_MODEL_CHECKPOINT_PREFIX
|
|
state_dict = None
|
|
if fpath.endswith(".pkl"):
|
|
with PathManager.open(fpath, "rb") as hFile:
|
|
state_dict = pickle.load(hFile, encoding="latin1")
|
|
else:
|
|
with PathManager.open(fpath, "rb") as hFile:
|
|
state_dict = torch.load(hFile, map_location=torch.device("cpu"))
|
|
if state_dict is not None and "model" in state_dict:
|
|
state_dict_local = {}
|
|
for key in state_dict["model"]:
|
|
if key.startswith(prefix):
|
|
v_key = state_dict["model"][key]
|
|
if isinstance(v_key, np.ndarray):
|
|
v_key = torch.from_numpy(v_key)
|
|
state_dict_local[key[len(prefix) :]] = v_key
|
|
|
|
self.load_state_dict(state_dict_local, strict=False)
|
|
|
|
def forward(self, mesh_name: str) -> torch.Tensor:
|
|
"""
|
|
Produce vertex embeddings for the specific mesh; vertex embeddings are
|
|
a tensor of shape [N, D] where:
|
|
N = number of vertices
|
|
D = number of dimensions in the embedding space
|
|
Args:
|
|
mesh_name (str): name of a mesh for which to obtain vertex embeddings
|
|
Return:
|
|
Vertex embeddings, a tensor of shape [N, D]
|
|
"""
|
|
return getattr(self, f"embedder_{mesh_name}")()
|
|
|
|
def has_embeddings(self, mesh_name: str) -> bool:
|
|
return hasattr(self, f"embedder_{mesh_name}")
|
|
|