|
|
|
|
|
import os
|
|
|
|
import math
|
|
from math import tan,pi
|
|
from typing import Dict
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from torchvision.transforms import Resize
|
|
import numpy as np
|
|
import time
|
|
import random
|
|
|
|
from utils.misc import (NestedTensor, nested_tensor_from_tensor_list,
|
|
accuracy, get_world_size, interpolate,
|
|
is_dist_avail_and_initialized, inverse_sigmoid)
|
|
|
|
from utils.transforms import rot6d_to_axis_angle, img2patch_flat, img2patch, to_zorder
|
|
from utils.map import build_z_map
|
|
from utils import constants
|
|
from configs.paths import smpl_mean_path
|
|
|
|
from models.encoders import build_encoder
|
|
from .matcher import build_matcher
|
|
from .decoder import build_decoder
|
|
from .position_encoding import position_encoding_xy
|
|
from .criterion import SetCriterion
|
|
from .dn_components import prepare_for_cdn, dn_post_process
|
|
import copy
|
|
|
|
from configs.paths import smpl_model_path
|
|
from models.human_models import SMPL_Layer
|
|
|
|
|
|
def _get_clones(module, N):
|
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
|
|
|
class Model(nn.Module):
|
|
""" One-stage Multi-person Human Mesh Estimation via Scale-adaptive Tokens """
|
|
def __init__(self, encoder, decoder,
|
|
num_queries,
|
|
input_size,
|
|
sat_cfg = {'use_sat': False},
|
|
dn_cfg = {'use_dn': False},
|
|
train_pos_embed = True,
|
|
aux_loss=True,
|
|
iter_update=True,
|
|
query_dim=4,
|
|
bbox_embed_diff_each_layer=True,
|
|
random_refpoints_xy=False,
|
|
num_poses=24,
|
|
dim_shape=10,
|
|
FOV=pi/3
|
|
):
|
|
""" Initializes the model.
|
|
Parameters:
|
|
encoder: torch module of the encoder to be used. See ./encoders.
|
|
decoder: torch module of the decoder architecture. See decoder.py
|
|
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
|
DETR can detect in a single image.
|
|
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
|
iter_update: iterative update of boxes
|
|
query_dim: query dimension. 2 for point and 4 for box.
|
|
bbox_embed_diff_each_layer: dont share weights of prediction heads. Default for False. (shared weights.)
|
|
random_refpoints_xy: random init the x,y of anchor boxes and freeze them. (It sometimes helps to improve the performance)
|
|
"""
|
|
super().__init__()
|
|
|
|
|
|
self.input_size = input_size
|
|
hidden_dim = decoder.d_model
|
|
num_dec_layers = decoder.dec_layers
|
|
self.hidden_dim = hidden_dim
|
|
|
|
self.focal = input_size/(2*tan(FOV/2))
|
|
self.FOV = FOV
|
|
cam_intrinsics = torch.tensor([[self.focal,0.,self.input_size/2],
|
|
[0.,self.focal,self.input_size/2],
|
|
[0.,0.,1.]])
|
|
self.register_buffer('cam_intrinsics', cam_intrinsics)
|
|
|
|
self.num_poses = num_poses
|
|
self.dim_shape = dim_shape
|
|
self.human_model = SMPL_Layer(model_path = smpl_model_path, with_genders = False)
|
|
|
|
smpl_mean_params = np.load(smpl_mean_path, allow_pickle = True)
|
|
self.register_buffer('mean_pose', torch.from_numpy(smpl_mean_params['pose']))
|
|
self.register_buffer('mean_shape', torch.from_numpy(smpl_mean_params['shape']))
|
|
|
|
|
|
|
|
|
|
self.encoder = encoder
|
|
|
|
self.patch_size = encoder.patch_size
|
|
assert self.patch_size == 14
|
|
|
|
self.use_sat = sat_cfg['use_sat']
|
|
self.sat_cfg = sat_cfg
|
|
|
|
if self.use_sat:
|
|
assert sat_cfg['num_lvls'] >= 2
|
|
assert self.input_size % (self.patch_size<<2) == 0
|
|
|
|
self.feature_size = []
|
|
for lvl in range(sat_cfg['num_lvls']):
|
|
patch_size = self.patch_size<<lvl
|
|
self.feature_size.append(self.input_size / patch_size)
|
|
|
|
|
|
z_depth = math.ceil(math.log2(self.feature_size[1]))
|
|
z_map, ys, xs = build_z_map(z_depth)
|
|
self.register_buffer('z_order_map', z_map)
|
|
self.register_buffer('y_coords', ys)
|
|
self.register_buffer('x_coords', xs)
|
|
|
|
self.enc_inter_norm = copy.deepcopy(encoder.norm)
|
|
self.scale_head = MLP(encoder.embed_dim, encoder.embed_dim, 2, 4)
|
|
self.encoder_patch_proj = _get_clones(encoder.patch_embed.proj, 2)
|
|
self.encoder_patch_norm = _get_clones(encoder.patch_embed.norm, 2)
|
|
|
|
if sat_cfg['lvl_embed']:
|
|
|
|
self.level_embed = nn.Parameter(torch.Tensor(sat_cfg['num_lvls'],hidden_dim))
|
|
nn.init.normal_(self.level_embed)
|
|
else:
|
|
assert self.input_size % self.patch_size == 0
|
|
self.feature_size = [self.input_size // self.patch_size]
|
|
self.encoder_patch_proj = copy.deepcopy(encoder.patch_embed.proj)
|
|
self.encoder_patch_norm = copy.deepcopy(encoder.patch_embed.norm)
|
|
|
|
|
|
encoder_cr_token = self.encoder.cls_token.view(1,-1) + self.encoder.pos_embed.float()[:,0].view(1,-1)
|
|
if self.encoder.register_tokens is not None:
|
|
encoder_cr_token = torch.cat([encoder_cr_token, self.encoder.register_tokens.view(self.encoder.num_register_tokens,-1)], dim=0)
|
|
self.encoder_cr_token = nn.Parameter(encoder_cr_token)
|
|
|
|
self.encoder_pos_embeds = nn.Parameter(self.encoder.interpolate_pos_encoding3(self.feature_size[0]).detach())
|
|
if not train_pos_embed:
|
|
self.encoder_pos_embeds.requires_grad = False
|
|
|
|
self.preprocessed_pos_lvl1 = None
|
|
|
|
|
|
del(self.encoder.mask_token)
|
|
del(self.encoder.pos_embed)
|
|
del(self.encoder.patch_embed)
|
|
del(self.encoder.cls_token)
|
|
del(self.encoder.register_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
self.num_queries = num_queries
|
|
self.decoder = decoder
|
|
|
|
|
|
self.feature_proj = nn.Linear(encoder.embed_dim, hidden_dim)
|
|
|
|
|
|
self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer
|
|
if bbox_embed_diff_each_layer:
|
|
self.bbox_embed = nn.ModuleList([MLP(hidden_dim, hidden_dim, 4, 3) for i in range(num_dec_layers)])
|
|
else:
|
|
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
|
|
|
self.pose_head = MLP(hidden_dim, hidden_dim, num_poses*6, 6)
|
|
|
|
self.shape_head = MLP(hidden_dim, hidden_dim, dim_shape, 5)
|
|
|
|
self.cam_head = MLP(hidden_dim, hidden_dim//2, 3, 3)
|
|
|
|
self.conf_head = nn.Linear(hidden_dim, 1)
|
|
|
|
prior_prob = 0.01
|
|
bias_value = -math.log((1 - prior_prob) / prior_prob)
|
|
self.conf_head.bias.data = torch.ones(1) * bias_value
|
|
|
|
|
|
self.pose_head = _get_clones(self.pose_head, num_dec_layers)
|
|
self.shape_head = _get_clones(self.shape_head, num_dec_layers)
|
|
|
|
|
|
self.query_dim = query_dim
|
|
assert query_dim == 4
|
|
self.refpoint_embed = nn.Embedding(num_queries, query_dim)
|
|
self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
|
|
|
|
self.random_refpoints_xy = random_refpoints_xy
|
|
if random_refpoints_xy:
|
|
|
|
self.refpoint_embed.weight.data[:, :2].uniform_(0,1)
|
|
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2])
|
|
self.refpoint_embed.weight.data[:, :2].requires_grad = False
|
|
|
|
self.aux_loss = aux_loss
|
|
self.iter_update = iter_update
|
|
assert iter_update
|
|
if self.iter_update:
|
|
self.decoder.decoder.bbox_embed = self.bbox_embed
|
|
|
|
assert bbox_embed_diff_each_layer
|
|
if bbox_embed_diff_each_layer:
|
|
for bbox_embed in self.bbox_embed:
|
|
nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
|
|
nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
|
|
else:
|
|
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
|
|
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
|
|
|
|
|
|
|
|
|
|
self.use_dn = dn_cfg['use_dn']
|
|
self.dn_cfg = dn_cfg
|
|
if self.use_dn:
|
|
assert dn_cfg['dn_number'] > 0
|
|
if dn_cfg['tgt_embed_type'] == 'labels':
|
|
self.dn_enc = nn.Embedding(dn_cfg['dn_labelbook_size'], hidden_dim)
|
|
elif dn_cfg['tgt_embed_type'] == 'params':
|
|
self.dn_enc = nn.Linear(num_poses*3 + dim_shape, hidden_dim)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def lvl_pooling(self, tokens):
|
|
assert len(tokens)%4 == 0
|
|
C = tokens.shape[-1]
|
|
return torch.max(tokens.view(-1, 4, C), dim=1)[0]
|
|
|
|
def get_scale_map(self, x_list):
|
|
if self.sat_cfg['use_additional_blocks']:
|
|
x_list = self.encoder.forward_additional_layers_list(x_list, end=self.sat_cfg['get_map_layer'], get_feature=False)
|
|
else:
|
|
x_list = self.encoder.forward_specific_layers_list(x_list, end=self.sat_cfg['get_map_layer'], get_feature=False)
|
|
|
|
cr_token_list = [x[:, :1 + self.encoder.num_register_tokens, :].squeeze(0) for x in x_list]
|
|
x_tokens = torch.cat([x[:, 1 + self.encoder.num_register_tokens:, :].squeeze(0) for x in x_list], dim=0)
|
|
scale_map = self.scale_head(self.enc_inter_norm(x_tokens)).sigmoid()
|
|
return scale_map, cr_token_list, x_tokens
|
|
|
|
def pad_mask(self, mask):
|
|
mask = mask.reshape(-1,4)
|
|
mask[torch.any(mask, dim=1)] = True
|
|
return mask.flatten()
|
|
|
|
def forward_encoder(self, samples, targets, use_gt = False):
|
|
B = len(samples)
|
|
C = self.encoder.embed_dim
|
|
cr_token_list = [self.encoder_cr_token]*len(samples)
|
|
|
|
if not self.use_sat:
|
|
|
|
lvl0_feature_hw = [(img.shape[1]//self.patch_size, img.shape[2]//self.patch_size) for img in samples]
|
|
lvl0_token_lens = [h*w for (h,w) in lvl0_feature_hw]
|
|
lvl0_img_patches = torch.cat([img2patch_flat(img, patch_size = self.patch_size)\
|
|
for img in samples], dim=0)
|
|
lvl0_tokens = self.encoder_patch_norm(self.encoder_patch_proj(lvl0_img_patches).flatten(1))
|
|
|
|
|
|
full_grids = torch.meshgrid(torch.arange(self.feature_size[0]), torch.arange(self.feature_size[0]), indexing='ij')
|
|
lvl0_pos_y = torch.cat([full_grids[0][:h,:w].flatten() for (h,w) in lvl0_feature_hw]).to(device = lvl0_tokens.device)
|
|
lvl0_pos_x = torch.cat([full_grids[1][:h,:w].flatten() for (h,w) in lvl0_feature_hw]).to(device = lvl0_tokens.device)
|
|
|
|
|
|
full_pos_embed = self.encoder_pos_embeds
|
|
lvl0_pos_embed = torch.cat([full_pos_embed[:h,:w].flatten(0,1)\
|
|
for (h,w) in lvl0_feature_hw], dim=0)
|
|
lvl0_tokens = lvl0_tokens + lvl0_pos_embed
|
|
|
|
|
|
x_list = [torch.cat([cr, lvl0],dim=0).unsqueeze(0)\
|
|
for (cr, lvl0) \
|
|
in zip(cr_token_list, lvl0_tokens.split(lvl0_token_lens))]
|
|
|
|
|
|
lvl0_pos_y_norm = (lvl0_pos_y.to(dtype=lvl0_tokens.dtype) + 0.5) / self.feature_size[0]
|
|
lvl0_pos_x_norm = (lvl0_pos_x.to(dtype=lvl0_tokens.dtype) + 0.5) / self.feature_size[0]
|
|
pos_x_list = list(lvl0_pos_y_norm.split(lvl0_token_lens))
|
|
pos_y_list = list(lvl0_pos_x_norm.split(lvl0_token_lens))
|
|
scale_map_dict = None
|
|
|
|
lvl_list = [torch.zeros_like(pos,dtype=int) for pos in pos_x_list]
|
|
|
|
else:
|
|
lvl1_feature_hw = [(img.shape[1]//(2*self.patch_size), img.shape[2]//(2*self.patch_size)) for img in samples]
|
|
lvl1_token_lens = [h*w for (h,w) in lvl1_feature_hw]
|
|
|
|
lvl1_img_patches_28, lvl1_zorders = [], []
|
|
lvl1_pos_y, lvl1_pos_x = [], []
|
|
lvl1_bids = []
|
|
|
|
for i, img in enumerate(samples):
|
|
z_patches, z_order, pos_y, pos_x = to_zorder(img2patch(img, patch_size = 2*self.patch_size),
|
|
z_order_map = self.z_order_map,
|
|
y_coords = self.y_coords,
|
|
x_coords = self.x_coords)
|
|
|
|
lvl1_img_patches_28.append(z_patches)
|
|
|
|
lvl1_zorders.append(z_order)
|
|
lvl1_pos_y.append(pos_y)
|
|
lvl1_pos_x.append(pos_x)
|
|
lvl1_bids.append(torch.full_like(pos_y, i, dtype=torch.int64))
|
|
|
|
|
|
|
|
lvl1_img_patches_28 = torch.cat(lvl1_img_patches_28, dim=0)
|
|
lvl1_zorders = torch.cat(lvl1_zorders, dim=0)
|
|
lvl1_pos_y = torch.cat(lvl1_pos_y, dim=0)
|
|
lvl1_pos_x = torch.cat(lvl1_pos_x, dim=0)
|
|
lvl1_bids = torch.cat(lvl1_bids, dim=0)
|
|
|
|
|
|
|
|
|
|
assert len(lvl1_img_patches_28) == sum(lvl1_token_lens)
|
|
lvl1_img_patches = F.interpolate(lvl1_img_patches_28, size = (14,14), mode='bilinear', align_corners=False)
|
|
|
|
lvl1_tokens = self.encoder_patch_norm[1](self.encoder_patch_proj[1](lvl1_img_patches).flatten(1))
|
|
|
|
|
|
|
|
|
|
assert len(lvl1_pos_y) == len(lvl1_tokens)
|
|
full_pos_embed = self.preprocessed_pos_lvl1 if not self.training\
|
|
else F.interpolate(self.encoder_pos_embeds.unsqueeze(0).permute(0, 3, 1, 2),
|
|
mode="bicubic",
|
|
antialias=self.encoder.interpolate_antialias,
|
|
size = (int(self.feature_size[1]),int(self.feature_size[1]))).squeeze(0).permute(1,2,0)
|
|
lvl1_pos_embed = torch.cat([full_pos_embed[ys,xs]\
|
|
for (ys,xs) in zip(lvl1_pos_y.split(lvl1_token_lens), lvl1_pos_x.split(lvl1_token_lens))], dim=0)
|
|
lvl1_tokens = lvl1_tokens + lvl1_pos_embed
|
|
|
|
|
|
x_list = [torch.cat([cr, lvl1],dim=0).unsqueeze(0)\
|
|
for (cr, lvl1) \
|
|
in zip(cr_token_list, lvl1_tokens.split(lvl1_token_lens))]
|
|
scale_map, updated_cr_list, updated_lvl1 = self.get_scale_map(x_list)
|
|
|
|
scale_map_dict = {'scale_map': scale_map, 'lens': lvl1_token_lens, 'hw': lvl1_feature_hw,
|
|
'pos_y': lvl1_pos_y, 'pos_x': lvl1_pos_x}
|
|
|
|
|
|
conf_thresh = self.sat_cfg['conf_thresh']
|
|
scale_thresh = self.sat_cfg['scale_thresh']
|
|
if use_gt:
|
|
scale_map = torch.cat([tgt['scale_map'].view(-1,2) for tgt in targets], dim=0)
|
|
|
|
lvl1_valid_mask = scale_map[:,0] > conf_thresh
|
|
lvl1_sat_mask = lvl1_valid_mask & (scale_map[:,1] < scale_thresh)
|
|
|
|
|
|
lvl0_token_lens = [msk.sum().item()<<2 for msk in lvl1_sat_mask.split(lvl1_token_lens)]
|
|
lvl1_sat_patches_28 = lvl1_img_patches_28[lvl1_sat_mask]
|
|
lvl0_tokens = self.encoder_patch_norm[0](self.encoder_patch_proj[0](lvl1_sat_patches_28).permute(0, 2, 3, 1).flatten(0,2))
|
|
|
|
assert len(lvl0_tokens) == sum(lvl0_token_lens)
|
|
|
|
lvl0_pos_y, lvl0_pos_x = lvl1_pos_y[lvl1_sat_mask], lvl1_pos_x[lvl1_sat_mask]
|
|
lvl0_pos_y = (lvl0_pos_y<<1)[:,None].repeat(1,4).flatten()
|
|
lvl0_pos_x = (lvl0_pos_x<<1)[:,None].repeat(1,4).flatten()
|
|
lvl0_pos_y[2::4] += 1
|
|
lvl0_pos_y[3::4] += 1
|
|
lvl0_pos_x[1::2] += 1
|
|
assert len(lvl0_pos_x) == len(lvl0_tokens)
|
|
|
|
|
|
full_pos_embed = self.encoder_pos_embeds
|
|
lvl0_pos_embed = torch.cat([full_pos_embed[ys,xs]\
|
|
for (ys,xs) in zip(lvl0_pos_y.split(lvl0_token_lens), lvl0_pos_x.split(lvl0_token_lens))], dim=0)
|
|
lvl0_tokens = lvl0_tokens + lvl0_pos_embed
|
|
|
|
|
|
|
|
x_list = [torch.cat([cr, lvl0],dim=0).unsqueeze(0)\
|
|
for (cr, lvl0) \
|
|
in zip(cr_token_list, lvl0_tokens.split(lvl0_token_lens))]
|
|
x_list = self.encoder.forward_specific_layers_list(x_list, end=self.sat_cfg['get_map_layer'], get_feature=False)
|
|
lvl0_tokens = torch.cat([x[:, 1 + self.encoder.num_register_tokens:, :].squeeze(0) for x in x_list], dim=0)
|
|
assert len(lvl0_pos_x) == len(lvl0_tokens)
|
|
|
|
lvl1_tokens = updated_lvl1
|
|
cr_token_list = updated_cr_list
|
|
|
|
|
|
|
|
if self.sat_cfg['num_lvls'] == 2:
|
|
|
|
lvl1_keep = ~lvl1_sat_mask
|
|
lvl1_token_lens = [msk.sum().item() for msk in lvl1_keep.split(lvl1_token_lens)]
|
|
lvl1_tokens = lvl1_tokens[lvl1_keep]
|
|
lvl1_pos_y = lvl1_pos_y[lvl1_keep]
|
|
lvl1_pos_x = lvl1_pos_x[lvl1_keep]
|
|
|
|
|
|
lvl0_pos_y_norm = (lvl0_pos_y.to(dtype=lvl0_tokens.dtype) + 0.5) / self.feature_size[0]
|
|
lvl0_pos_x_norm = (lvl0_pos_x.to(dtype=lvl0_tokens.dtype) + 0.5) / self.feature_size[0]
|
|
lvl1_pos_y_norm = (lvl1_pos_y.to(dtype=lvl1_tokens.dtype) + 0.5) / self.feature_size[1]
|
|
lvl1_pos_x_norm = (lvl1_pos_x.to(dtype=lvl1_tokens.dtype) + 0.5) / self.feature_size[1]
|
|
|
|
|
|
x_list = [torch.cat([cr, lvl0, lvl1]).unsqueeze(0) \
|
|
for cr, lvl0, lvl1 \
|
|
in zip(cr_token_list, lvl0_tokens.split(lvl0_token_lens), lvl1_tokens.split(lvl1_token_lens))]
|
|
pos_y_list = [torch.cat([lvl0, lvl1]) \
|
|
for lvl0, lvl1 \
|
|
in zip(lvl0_pos_y_norm.split(lvl0_token_lens), lvl1_pos_y_norm.split(lvl1_token_lens))]
|
|
pos_x_list = [torch.cat([lvl0, lvl1]) \
|
|
for lvl0, lvl1 \
|
|
in zip(lvl0_pos_x_norm.split(lvl0_token_lens), lvl1_pos_x_norm.split(lvl1_token_lens))]
|
|
lvl_list = [torch.cat([torch.zeros_like(lvl0, dtype=int), torch.ones_like(lvl1, dtype=int)]) \
|
|
for lvl0, lvl1 \
|
|
in zip(lvl0_pos_x_norm.split(lvl0_token_lens), lvl1_pos_x_norm.split(lvl1_token_lens))]
|
|
|
|
|
|
else:
|
|
|
|
lvl1_valid_mask = self.pad_mask(lvl1_valid_mask)
|
|
lvl1_keep = lvl1_valid_mask & (~lvl1_sat_mask)
|
|
lvl1_to_lvl2 = ~lvl1_valid_mask
|
|
|
|
token_lvls = [lvl0_tokens, lvl1_tokens]
|
|
token_lens_lvls = [lvl0_token_lens, lvl1_token_lens]
|
|
pos_y_lvls = [lvl0_pos_y, lvl1_pos_y]
|
|
pos_x_lvls = [lvl0_pos_x, lvl1_pos_x]
|
|
|
|
to_next_lvl = lvl1_to_lvl2
|
|
keep = lvl1_keep
|
|
lvl_zorders = lvl1_zorders
|
|
lvl_bids = lvl1_bids
|
|
pad_vals = torch.full((3,), -1, dtype=lvl_zorders.dtype, device=lvl_zorders.device)
|
|
for lvl in range(self.sat_cfg['num_lvls']-2):
|
|
if to_next_lvl.sum() == 0:
|
|
break
|
|
next_tokens = self.lvl_pooling(token_lvls[-1][to_next_lvl])
|
|
|
|
next_pos_y = pos_y_lvls[-1][to_next_lvl][::4]>>1
|
|
next_pos_x = pos_x_lvls[-1][to_next_lvl][::4]>>1
|
|
next_lens = [msk.sum().item()//4 for msk in to_next_lvl.split(token_lens_lvls[-1])]
|
|
|
|
|
|
token_lvls[-1] = token_lvls[-1][keep]
|
|
pos_y_lvls[-1] = pos_y_lvls[-1][keep]
|
|
pos_x_lvls[-1] = pos_x_lvls[-1][keep]
|
|
token_lens_lvls[-1] = [msk.sum().item() for msk in keep.split(token_lens_lvls[-1])]
|
|
|
|
token_lvls.append(next_tokens)
|
|
token_lens_lvls.append(next_lens)
|
|
pos_y_lvls.append(next_pos_y)
|
|
pos_x_lvls.append(next_pos_x)
|
|
|
|
if lvl < self.sat_cfg['num_lvls']-3:
|
|
lvl_zorders = lvl_zorders[to_next_lvl][::4]>>2
|
|
lvl_bids = lvl_bids[to_next_lvl][::4]
|
|
|
|
z_starts_idx = torch.where((lvl_zorders&3)==0)[0]
|
|
padded_z = torch.cat([lvl_zorders, pad_vals])
|
|
padded_bids = torch.cat([lvl_bids, pad_vals])
|
|
valids = (padded_z[z_starts_idx] + 3 == padded_z[z_starts_idx + 3]) & (padded_bids[z_starts_idx] == padded_bids[z_starts_idx + 3])
|
|
valid_starts = z_starts_idx[valids]
|
|
|
|
to_next_lvl = torch.zeros_like(lvl_zorders, dtype=bool)
|
|
to_next_lvl[valid_starts] = True
|
|
to_next_lvl[valid_starts+1] = True
|
|
to_next_lvl[valid_starts+2] = True
|
|
to_next_lvl[valid_starts+3] = True
|
|
|
|
keep = ~to_next_lvl
|
|
|
|
norm_pos_y_lvls = [(pos_y.to(dtype=lvl0_tokens.dtype) + 0.5)/self.feature_size[i] for i, pos_y in enumerate(pos_y_lvls)]
|
|
norm_pos_x_lvls = [(pos_x.to(dtype=lvl0_tokens.dtype) + 0.5)/self.feature_size[i] for i, pos_x in enumerate(pos_x_lvls)]
|
|
|
|
x_list = [torch.cat([cr, *lvls]).unsqueeze(0) \
|
|
for cr, *lvls \
|
|
in zip(cr_token_list, *[tokens.split(lens) for (tokens, lens) in zip(token_lvls, token_lens_lvls)])]
|
|
pos_y_list = [torch.cat([*lvls]) \
|
|
for lvls \
|
|
in zip(*[pos_y.split(lens) for (pos_y, lens) in zip(norm_pos_y_lvls, token_lens_lvls)])]
|
|
pos_x_list = [torch.cat([*lvls]) \
|
|
for lvls \
|
|
in zip(*[pos_x.split(lens) for (pos_x, lens) in zip(norm_pos_x_lvls, token_lens_lvls)])]
|
|
lvl_list = [torch.cat([torch.full_like(lvl, i, dtype=torch.int64) for i, lvl in enumerate(lvls)]) \
|
|
for lvls \
|
|
in zip(*[pos_x.split(lens) for (pos_x, lens) in zip(norm_pos_x_lvls, token_lens_lvls)])]
|
|
|
|
|
|
|
|
start = self.sat_cfg['get_map_layer'] if self.use_sat else 0
|
|
_, final_feature_list = self.encoder.forward_specific_layers_list(x_list, start = start, norm=True)
|
|
|
|
|
|
token_lens = [feature.shape[1] for feature in final_feature_list]
|
|
final_features = self.feature_proj(torch.cat(final_feature_list,dim=1).squeeze(0))
|
|
assert tuple(final_features.shape) == (sum(token_lens), self.hidden_dim)
|
|
|
|
pos_embeds = position_encoding_xy(torch.cat(pos_x_list,dim=0), torch.cat(pos_y_list,dim=0), embedding_dim=self.hidden_dim)
|
|
if self.use_sat and self.sat_cfg['lvl_embed']:
|
|
lvl_embeds = self.level_embed[torch.cat(lvl_list,dim=0)]
|
|
pos_embeds = pos_embeds + lvl_embeds
|
|
|
|
sat_dict = {'pos_y': pos_y_list, 'pos_x': pos_x_list, 'lvl': lvl_list,
|
|
|
|
'lens': token_lens}
|
|
|
|
return final_features, pos_embeds, token_lens, scale_map_dict, sat_dict
|
|
|
|
def process_smpl(self, poses, shapes, cam_xys, cam_intrinsics, detach_j3ds = False):
|
|
bs, num_queries, _ = poses.shape
|
|
|
|
|
|
poses = poses.flatten(0,1)
|
|
shapes = shapes.flatten(0,1)
|
|
verts, joints = self.human_model(poses=poses,
|
|
betas=shapes)
|
|
num_verts = verts.shape[1]
|
|
num_joints = joints.shape[1]
|
|
verts = verts.reshape(bs,num_queries,num_verts,3)
|
|
joints = joints.reshape(bs,num_queries,num_joints,3)
|
|
|
|
|
|
scale = 2*cam_xys[:,:,2:].sigmoid() + 1e-6
|
|
t_xy = cam_xys[:,:,:2]/scale
|
|
t_z = (2*self.focal)/(scale*self.input_size)
|
|
transl = torch.cat([t_xy,t_z],dim=2)[:,:,None,:]
|
|
|
|
verts_cam = verts + transl
|
|
j3ds_cam = joints + transl
|
|
|
|
if detach_j3ds:
|
|
j2ds_homo = torch.matmul(joints.detach() + transl, cam_intrinsics.transpose(2,3))
|
|
else:
|
|
j2ds_homo = torch.matmul(j3ds_cam, cam_intrinsics.transpose(2,3))
|
|
j2ds_img = (j2ds_homo[..., :2] / (j2ds_homo[..., 2, None] + 1e-6)).reshape(bs,num_queries,num_joints,2)
|
|
|
|
depths = j3ds_cam[:,:,0,2:]
|
|
depths = torch.cat([depths, depths/self.focal], dim=-1)
|
|
|
|
return verts_cam, j3ds_cam, j2ds_img, depths, transl.flatten(2)
|
|
|
|
|
|
def forward(self, samples: NestedTensor, targets, sat_use_gt = False, detach_j3ds = False):
|
|
""" The forward expects a NestedTensor, which consists of:
|
|
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
|
|
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
|
|
|
|
It returns a dict with the following elements:
|
|
- "pred_logits": the classification logits (including no-object) for all queries.
|
|
Shape= [batch_size x num_queries x num_classes]
|
|
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
|
|
(center_x, center_y, width, height). These values are normalized in [0, 1],
|
|
relative to the size of each individual image (disregarding possible padding).
|
|
See PostProcess for information on how to retrieve the unnormalized bounding box.
|
|
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
|
|
dictionnaries containing the two above keys for each decoder layer.
|
|
"""
|
|
|
|
assert isinstance(samples, (list, torch.Tensor))
|
|
|
|
if self.training:
|
|
self.preprocessed_pos_lvl1 = None
|
|
|
|
elif self.preprocessed_pos_lvl1 is None and self.use_sat:
|
|
self.preprocessed_pos_lvl1 = F.interpolate(self.encoder_pos_embeds.unsqueeze(0).permute(0, 3, 1, 2),
|
|
mode="bicubic",
|
|
antialias=self.encoder.interpolate_antialias,
|
|
size = (int(self.feature_size[1]),int(self.feature_size[1]))).squeeze(0).permute(1,2,0)
|
|
|
|
|
|
bs = len(targets)
|
|
|
|
|
|
img_size = torch.stack([t['img_size'].flip(0) for t in targets])
|
|
valid_ratio = img_size/self.input_size
|
|
|
|
cam_intrinsics = self.cam_intrinsics.repeat(bs, 1, 1, 1)
|
|
cam_intrinsics[...,:2,2] = cam_intrinsics[...,:2,2] * valid_ratio[:, None, :]
|
|
|
|
|
|
final_features, pos_embeds, token_lens, scale_map_dict, sat_dict\
|
|
= self.forward_encoder(samples, targets, use_gt = sat_use_gt)
|
|
|
|
|
|
embedweight = (self.refpoint_embed.weight).unsqueeze(0).repeat(bs,1,1)
|
|
tgt = (self.tgt_embed.weight).unsqueeze(0).repeat(bs,1,1)
|
|
|
|
if self.training and self.use_dn:
|
|
input_query_tgt, input_query_bbox, attn_mask, dn_meta =\
|
|
prepare_for_cdn(targets = targets, dn_cfg = self.dn_cfg,
|
|
num_queries = self.num_queries, hidden_dim = self.hidden_dim, dn_enc = self.dn_enc)
|
|
tgt = torch.cat([input_query_tgt, tgt], dim=1)
|
|
embedweight = torch.cat([input_query_bbox, embedweight], dim=1)
|
|
else:
|
|
attn_mask = None
|
|
|
|
tgt_lens = [tgt.shape[1]]*bs
|
|
|
|
hs, reference = self.decoder(memory=final_features, memory_lens=token_lens,
|
|
tgt=tgt.flatten(0,1), tgt_lens=tgt_lens,
|
|
refpoint_embed=embedweight.flatten(0,1),
|
|
pos_embed=pos_embeds,
|
|
self_attn_mask = attn_mask)
|
|
|
|
reference_before_sigmoid = inverse_sigmoid(reference)
|
|
outputs_coords = []
|
|
for lvl in range(hs.shape[0]):
|
|
tmp = self.bbox_embed[lvl](hs[lvl])
|
|
tmp[..., :self.query_dim] += reference_before_sigmoid[lvl]
|
|
outputs_coord = tmp.sigmoid()
|
|
outputs_coords.append(outputs_coord)
|
|
pred_boxes = torch.stack(outputs_coords)
|
|
|
|
outputs_poses = []
|
|
outputs_shapes = []
|
|
outputs_confs = []
|
|
outputs_j3ds = []
|
|
outputs_j2ds = []
|
|
outputs_depths = []
|
|
|
|
|
|
outputs_pose_6d = self.mean_pose.view(1, 1, -1)
|
|
outputs_shape = self.mean_shape.view(1, 1, -1)
|
|
for lvl in range(hs.shape[0]):
|
|
|
|
outputs_pose_6d = outputs_pose_6d + self.pose_head[lvl](hs[lvl])
|
|
outputs_shape = outputs_shape + self.shape_head[lvl](hs[lvl])
|
|
|
|
if self.training or lvl == hs.shape[0] - 1:
|
|
outputs_pose = rot6d_to_axis_angle(outputs_pose_6d)
|
|
|
|
outputs_conf = self.conf_head(hs[lvl]).sigmoid()
|
|
|
|
|
|
cam_xys = self.cam_head(hs[lvl])
|
|
|
|
outputs_vert, outputs_j3d, outputs_j2d, depth, transl\
|
|
= self.process_smpl(poses = outputs_pose,
|
|
shapes = outputs_shape,
|
|
cam_xys = cam_xys,
|
|
cam_intrinsics = cam_intrinsics,
|
|
detach_j3ds = detach_j3ds)
|
|
|
|
outputs_poses.append(outputs_pose)
|
|
outputs_shapes.append(outputs_shape)
|
|
outputs_confs.append(outputs_conf)
|
|
|
|
outputs_j3ds.append(outputs_j3d)
|
|
outputs_j2ds.append(outputs_j2d)
|
|
outputs_depths.append(depth)
|
|
|
|
pred_poses = torch.stack(outputs_poses)
|
|
pred_betas = torch.stack(outputs_shapes)
|
|
pred_confs = torch.stack(outputs_confs)
|
|
pred_verts = outputs_vert
|
|
pred_transl = transl
|
|
pred_intrinsics = cam_intrinsics
|
|
pred_j3ds = torch.stack(outputs_j3ds)
|
|
pred_j2ds = torch.stack(outputs_j2ds)
|
|
pred_depths = torch.stack(outputs_depths)
|
|
|
|
|
|
|
|
if self.training > 0 and self.use_dn:
|
|
pred_poses, pred_betas,\
|
|
pred_boxes, pred_confs,\
|
|
pred_j3ds, pred_j2ds, pred_depths,\
|
|
pred_verts, pred_transl =\
|
|
dn_post_process(pred_poses, pred_betas,
|
|
pred_boxes, pred_confs,
|
|
pred_j3ds, pred_j2ds, pred_depths,
|
|
pred_verts, pred_transl,
|
|
dn_meta, self.aux_loss, self._set_aux_loss)
|
|
|
|
|
|
out = {'pred_poses': pred_poses[-1], 'pred_betas': pred_betas[-1],
|
|
'pred_boxes': pred_boxes[-1], 'pred_confs': pred_confs[-1],
|
|
'pred_j3ds': pred_j3ds[-1], 'pred_j2ds': pred_j2ds[-1],
|
|
'pred_verts': pred_verts, 'pred_intrinsics': pred_intrinsics,
|
|
'pred_depths': pred_depths[-1], 'pred_transl': pred_transl}
|
|
|
|
if self.aux_loss and self.training:
|
|
out['aux_outputs'] = self._set_aux_loss(pred_poses, pred_betas,
|
|
pred_boxes, pred_confs,
|
|
pred_j3ds, pred_j2ds, pred_depths)
|
|
|
|
if self.use_sat:
|
|
out['enc_outputs'] = scale_map_dict
|
|
|
|
out['sat'] = sat_dict
|
|
|
|
if self.training > 0 and self.use_dn:
|
|
out['dn_meta'] = dn_meta
|
|
|
|
return out
|
|
|
|
@torch.jit.unused
|
|
def _set_aux_loss(self, pred_poses, pred_betas, pred_boxes,
|
|
pred_confs, pred_j3ds,
|
|
pred_j2ds, pred_depths):
|
|
|
|
|
|
|
|
return [{'pred_poses': a, 'pred_betas': b,
|
|
'pred_boxes': c, 'pred_confs': d,
|
|
'pred_j3ds': e, 'pred_j2ds': f, 'pred_depths': g}
|
|
for a, b, c, d, e, f, g in zip(pred_poses[:-1], pred_betas[:-1],
|
|
pred_boxes[:-1], pred_confs[:-1], pred_j3ds[:-1], pred_j2ds[:-1], pred_depths[:-1])]
|
|
|
|
|
|
|
|
class MLP(nn.Module):
|
|
""" Very simple multi-layer perceptron (also called FFN)"""
|
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
super().__init__()
|
|
self.num_layers = num_layers
|
|
h = [hidden_dim] * (num_layers - 1)
|
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
|
|
|
def forward(self, x):
|
|
for i, layer in enumerate(self.layers):
|
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
return x
|
|
|
|
|
|
def build_sat_model(args, set_criterion=True):
|
|
encoder = build_encoder(args)
|
|
decoder = build_decoder(args)
|
|
|
|
model = Model(
|
|
encoder,
|
|
decoder,
|
|
num_queries=args.num_queries,
|
|
input_size=args.input_size,
|
|
sat_cfg=args.sat_cfg,
|
|
dn_cfg=args.dn_cfg,
|
|
train_pos_embed=getattr(args,'train_pos_embed',True)
|
|
)
|
|
|
|
|
|
if set_criterion:
|
|
matcher = build_matcher(args)
|
|
weight_dict = args.weight_dict
|
|
losses = args.losses
|
|
|
|
if args.dn_cfg['use_dn']:
|
|
dn_weight_dict = {}
|
|
dn_weight_dict.update({f'{k}_dn': v for k, v in weight_dict.items()})
|
|
weight_dict.update(dn_weight_dict)
|
|
|
|
aux_weight_dict = {}
|
|
for i in range(args.dec_layers - 1):
|
|
aux_weight_dict.update({f'{k}.{i}': v for k, v in weight_dict.items()})
|
|
weight_dict.update(aux_weight_dict)
|
|
|
|
if args.sat_cfg['use_sat']:
|
|
if 'map_confs' not in weight_dict:
|
|
weight_dict.update({'map_confs': weight_dict['confs']})
|
|
|
|
|
|
criterion = SetCriterion(matcher, weight_dict, losses = losses, j2ds_norm_scale = args.input_size)
|
|
return model, criterion
|
|
else:
|
|
return model, None
|
|
|