embodiedmae-base / modeling_embodiedmae.py
ZibinDong's picture
Upload model
4b587c2 verified
from typing import Optional
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.models.dinov2.modeling_dinov2 import Dinov2Encoder
from .configuration_embodiedmae import EmbodiedMAEConfig
from .modular_embodiedmae import (
EmbodiedMAEDecoder,
EmbodiedMAEDepthEmbeddings,
EmbodiedMAEPointCloudEmbeddings,
EmbodiedMAERGBEmbeddings,
EncoderModelOutput,
concat_sequence_with_dummy,
prepare_shuffle_idx,
)
class EmbodiedMAEModel(PreTrainedModel):
config_class = EmbodiedMAEConfig
def __init__(self, config: EmbodiedMAEConfig):
super().__init__(config)
self.config = config
self.dirichlet = torch.distributions.Dirichlet(torch.full((3,), config.dirichlet_alpha))
self.rgb_embeddings = EmbodiedMAERGBEmbeddings(config)
self.depth_embeddings = EmbodiedMAEDepthEmbeddings(config)
self.pc_embeddings = EmbodiedMAEPointCloudEmbeddings(config)
self.encoder = Dinov2Encoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
num_patches = (config.image_size // config.patch_size) ** 2
self.embedding_sz = (
num_patches,
num_patches,
config.num_pc_centers,
) # token size for each modality
self.unmask_sz = config.unmask_sz # number of unmasked tokens
def get_input_embeddings(
self,
rgb: Optional[torch.Tensor],
depth: Optional[torch.Tensor],
pc: Optional[torch.Tensor],
add_mask: bool = True,
unmask_sz: Optional[int] = None,
forward_pc: bool = True,
shuffle_idx: Optional[torch.Tensor] = None,
):
# provide at least one modality
assert any([rgb is not None, depth is not None, pc is not None])
# embeddings
rgb_emb = self.rgb_embeddings(rgb)
depth_emb = self.depth_embeddings(depth)
pc_emb, pc_centers, pc_knn = self.pc_embeddings(pc)
if not forward_pc:
pc = None
pc_emb = None
# concat embeddings
all_emb = concat_sequence_with_dummy([rgb_emb, depth_emb, pc_emb], self.embedding_sz)
# prepare shuffle indices
shuffle_idx, restore_idx, unmask_sz = prepare_shuffle_idx(
has_rgb=rgb is not None,
has_depth=depth is not None,
has_pc=pc is not None,
batch_size=all_emb.shape[0],
unmask_sz=self.unmask_sz if unmask_sz is None else unmask_sz,
dirichlet=self.dirichlet,
embedding_sz=self.embedding_sz,
add_mask=add_mask,
shuffle_idx=shuffle_idx,
device=all_emb.device,
)
# get unmasked embeddings
unmasked_emb = torch.gather(
all_emb, 1, shuffle_idx[:, :unmask_sz, None].repeat(1, 1, all_emb.shape[-1])
)
return EncoderModelOutput(
embedding=unmasked_emb,
pc_centers=pc_centers,
pc_knn=pc_knn,
shuffle_idx=shuffle_idx,
restore_idx=restore_idx,
add_mask=add_mask,
unmask_sz=unmask_sz,
)
def get_last_hidden_states(
self,
embedding_output: EncoderModelOutput,
output_attentions: bool = False,
output_hidden_states: bool = False,
):
embedding = embedding_output.embedding
encoder_outputs = self.encoder(
embedding,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
embedding_output.last_hidden_states = sequence_output
embedding_output.hidden_states = encoder_outputs.hidden_states
embedding_output.attentions = encoder_outputs.attentions
return embedding_output
def forward(
self,
rgb: Optional[torch.Tensor],
depth: Optional[torch.Tensor],
pc: Optional[torch.Tensor],
add_mask: bool = True,
unmask_sz: Optional[int] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
forward_pc: bool = True,
):
embedding_output = self.get_input_embeddings(
rgb, depth, pc, add_mask, unmask_sz, forward_pc
)
return self.get_last_hidden_states(
embedding_output, output_attentions, output_hidden_states
)
class EmbodiedMAEForMaskedImageModeling(EmbodiedMAEModel):
def __init__(self, config: EmbodiedMAEConfig):
super().__init__(config)
self.decoder = EmbodiedMAEDecoder(config)
def forward(
self,
rgb: Optional[torch.Tensor],
depth: Optional[torch.Tensor],
pc: Optional[torch.Tensor],
add_mask: bool = True,
unmask_sz: Optional[int] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
forward_pc: bool = True,
):
encoder_output = super().forward(
rgb, depth, pc, add_mask, unmask_sz, output_attentions, output_hidden_states, forward_pc
)
decoder_input = self.decoder.get_decoder_input(encoder_output)
return self.decoder(decoder_input)
@torch.no_grad()
def visualize(
self,
rgb: Optional[torch.Tensor],
depth: Optional[torch.Tensor],
pc: Optional[torch.Tensor],
mask_rgb: bool = False,
mask_depth: bool = False,
mask_pc: bool = False,
add_mask: bool = True,
unmask_sz: Optional[int] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
forward_pc: bool = True,
):
_rgb = None if mask_rgb else rgb
_depth = None if mask_depth else depth
_pc = None if mask_pc else pc
encoder_output = super().forward(
_rgb,
_depth,
_pc,
add_mask,
unmask_sz,
output_attentions,
output_hidden_states,
forward_pc,
)
decoder_input = self.decoder.get_decoder_input(encoder_output)
return self.decoder.visualize(decoder_input, rgb, depth, pc)
__all__ = [EmbodiedMAEModel, EmbodiedMAEForMaskedImageModeling]