File size: 1,782 Bytes
eb339cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
from torchmetrics import Metric
from mld.utils.temos_utils import remove_padding
from .utils import calculate_mpjpe
class PosMetrics(Metric):
def __init__(self, dist_sync_on_step: bool = True) -> None:
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.name = "MPJPE (aligned & unaligned), Feature l2 error"
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("mpjpe_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("mpjpe_unaligned_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("feature_error_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
def compute(self) -> dict:
metric = dict(MPJPE=self.mpjpe_sum / self.count,
MPJPE_unaligned=self.mpjpe_unaligned_sum / self.count,
FeaError=self.feature_error_sum / self.count)
return metric
def update(self, joints_ref: torch.Tensor, joints_rst: torch.Tensor,
feats_ref: torch.Tensor, feats_rst: torch.Tensor, lengths: list[int]) -> None:
self.count += sum(lengths)
joints_rst = remove_padding(joints_rst, lengths)
joints_ref = remove_padding(joints_ref, lengths)
feats_ref = remove_padding(feats_ref, lengths)
feats_rst = remove_padding(feats_rst, lengths)
for f1, f2 in zip(feats_ref, feats_rst):
self.feature_error_sum += torch.norm(f1 - f2, p=2)
for j1, j2 in zip(joints_ref, joints_rst):
mpjpe = torch.sum(calculate_mpjpe(j1, j2))
self.mpjpe_sum += mpjpe
mpjpe_unaligned = torch.sum(calculate_mpjpe(j1, j2, align_root=False))
self.mpjpe_unaligned_sum += mpjpe_unaligned
|