|
import torch
|
|
import torch.nn as nn
|
|
|
|
class ReConsLoss(nn.Module):
|
|
def __init__(self, recons_loss, nb_joints):
|
|
super(ReConsLoss, self).__init__()
|
|
|
|
if recons_loss == 'l1':
|
|
self.Loss = torch.nn.L1Loss()
|
|
elif recons_loss == 'l2' :
|
|
self.Loss = torch.nn.MSELoss()
|
|
elif recons_loss == 'l1_smooth' :
|
|
self.Loss = torch.nn.SmoothL1Loss()
|
|
|
|
|
|
|
|
|
|
|
|
self.nb_joints = nb_joints
|
|
self.motion_dim = (nb_joints - 1) * 12 + 4 + 3 + 4
|
|
|
|
def forward(self, motion_pred, motion_gt) :
|
|
loss = self.Loss(motion_pred[..., : self.motion_dim], motion_gt[..., :self.motion_dim])
|
|
return loss
|
|
|
|
def forward_vel(self, motion_pred, motion_gt) :
|
|
loss = self.Loss(motion_pred[..., 4 : (self.nb_joints - 1) * 3 + 4], motion_gt[..., 4 : (self.nb_joints - 1) * 3 + 4])
|
|
return loss
|
|
|
|
|