Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| # pyre-unsafe | |
| import logging | |
| import numpy as np | |
| import pickle | |
| from enum import Enum | |
| from typing import Optional | |
| import torch | |
| from torch import nn | |
| from detectron2.config import CfgNode | |
| from detectron2.utils.file_io import PathManager | |
| from .vertex_direct_embedder import VertexDirectEmbedder | |
| from .vertex_feature_embedder import VertexFeatureEmbedder | |
| class EmbedderType(Enum): | |
| """ | |
| Embedder type which defines how vertices are mapped into the embedding space: | |
| - "vertex_direct": direct vertex embedding | |
| - "vertex_feature": embedding vertex features | |
| """ | |
| VERTEX_DIRECT = "vertex_direct" | |
| VERTEX_FEATURE = "vertex_feature" | |
| def create_embedder(embedder_spec: CfgNode, embedder_dim: int) -> nn.Module: | |
| """ | |
| Create an embedder based on the provided configuration | |
| Args: | |
| embedder_spec (CfgNode): embedder configuration | |
| embedder_dim (int): embedding space dimensionality | |
| Return: | |
| An embedder instance for the specified configuration | |
| Raises ValueError, in case of unexpected embedder type | |
| """ | |
| embedder_type = EmbedderType(embedder_spec.TYPE) | |
| if embedder_type == EmbedderType.VERTEX_DIRECT: | |
| embedder = VertexDirectEmbedder( | |
| num_vertices=embedder_spec.NUM_VERTICES, | |
| embed_dim=embedder_dim, | |
| ) | |
| if embedder_spec.INIT_FILE != "": | |
| embedder.load(embedder_spec.INIT_FILE) | |
| elif embedder_type == EmbedderType.VERTEX_FEATURE: | |
| embedder = VertexFeatureEmbedder( | |
| num_vertices=embedder_spec.NUM_VERTICES, | |
| feature_dim=embedder_spec.FEATURE_DIM, | |
| embed_dim=embedder_dim, | |
| train_features=embedder_spec.FEATURES_TRAINABLE, | |
| ) | |
| if embedder_spec.INIT_FILE != "": | |
| embedder.load(embedder_spec.INIT_FILE) | |
| else: | |
| raise ValueError(f"Unexpected embedder type {embedder_type}") | |
| if not embedder_spec.IS_TRAINABLE: | |
| embedder.requires_grad_(False) | |
| return embedder | |
| class Embedder(nn.Module): | |
| """ | |
| Embedder module that serves as a container for embedders to use with different | |
| meshes. Extends Module to automatically save / load state dict. | |
| """ | |
| DEFAULT_MODEL_CHECKPOINT_PREFIX = "roi_heads.embedder." | |
| def __init__(self, cfg: CfgNode): | |
| """ | |
| Initialize mesh embedders. An embedder for mesh `i` is stored in a submodule | |
| "embedder_{i}". | |
| Args: | |
| cfg (CfgNode): configuration options | |
| """ | |
| super(Embedder, self).__init__() | |
| self.mesh_names = set() | |
| embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE | |
| logger = logging.getLogger(__name__) | |
| for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items(): | |
| logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}") | |
| self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim)) | |
| self.mesh_names.add(mesh_name) | |
| if cfg.MODEL.WEIGHTS != "": | |
| self.load_from_model_checkpoint(cfg.MODEL.WEIGHTS) | |
| def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None): | |
| if prefix is None: | |
| prefix = Embedder.DEFAULT_MODEL_CHECKPOINT_PREFIX | |
| state_dict = None | |
| if fpath.endswith(".pkl"): | |
| with PathManager.open(fpath, "rb") as hFile: | |
| state_dict = pickle.load(hFile, encoding="latin1") | |
| else: | |
| with PathManager.open(fpath, "rb") as hFile: | |
| state_dict = torch.load(hFile, map_location=torch.device("cpu")) | |
| if state_dict is not None and "model" in state_dict: | |
| state_dict_local = {} | |
| for key in state_dict["model"]: | |
| if key.startswith(prefix): | |
| v_key = state_dict["model"][key] | |
| if isinstance(v_key, np.ndarray): | |
| v_key = torch.from_numpy(v_key) | |
| state_dict_local[key[len(prefix) :]] = v_key | |
| # non-strict loading to finetune on different meshes | |
| self.load_state_dict(state_dict_local, strict=False) | |
| def forward(self, mesh_name: str) -> torch.Tensor: | |
| """ | |
| Produce vertex embeddings for the specific mesh; vertex embeddings are | |
| a tensor of shape [N, D] where: | |
| N = number of vertices | |
| D = number of dimensions in the embedding space | |
| Args: | |
| mesh_name (str): name of a mesh for which to obtain vertex embeddings | |
| Return: | |
| Vertex embeddings, a tensor of shape [N, D] | |
| """ | |
| return getattr(self, f"embedder_{mesh_name}")() | |
| def has_embeddings(self, mesh_name: str) -> bool: | |
| return hasattr(self, f"embedder_{mesh_name}") | |