File size: 5,943 Bytes
eb339cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
from torchmetrics import Metric
from torchmetrics.utilities import dim_zero_cat
from .utils import (euclidean_distance_matrix, calculate_top_k, calculate_diversity_np,
calculate_activation_statistics_np, calculate_frechet_distance_np)
class TM2TMetrics(Metric):
def __init__(self,
top_k: int = 3,
R_size: int = 32,
diversity_times: int = 300,
dist_sync_on_step: bool = True) -> None:
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.name = "Matching, FID, and Diversity scores"
self.top_k = top_k
self.R_size = R_size
self.diversity_times = diversity_times
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("count_seq",
default=torch.tensor(0),
dist_reduce_fx="sum")
self.metrics = []
# Matching scores
self.add_state("Matching_score",
default=torch.tensor(0.0),
dist_reduce_fx="sum")
self.add_state("gt_Matching_score",
default=torch.tensor(0.0),
dist_reduce_fx="sum")
self.Matching_metrics = ["Matching_score", "gt_Matching_score"]
for k in range(1, top_k + 1):
self.add_state(
f"R_precision_top_{str(k)}",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
self.Matching_metrics.append(f"R_precision_top_{str(k)}")
for k in range(1, top_k + 1):
self.add_state(
f"gt_R_precision_top_{str(k)}",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
self.Matching_metrics.append(f"gt_R_precision_top_{str(k)}")
self.metrics.extend(self.Matching_metrics)
# FID
self.add_state("FID", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.metrics.append("FID")
# Diversity
self.add_state("Diversity",
default=torch.tensor(0.0),
dist_reduce_fx="sum")
self.add_state("gt_Diversity",
default=torch.tensor(0.0),
dist_reduce_fx="sum")
self.metrics.extend(["Diversity", "gt_Diversity"])
# cached batches
self.add_state("text_embeddings", default=[], dist_reduce_fx='cat')
self.add_state("recmotion_embeddings", default=[], dist_reduce_fx='cat')
self.add_state("gtmotion_embeddings", default=[], dist_reduce_fx='cat')
def compute(self) -> dict:
count_seq = self.count_seq.item()
# init metrics
metrics = {metric: getattr(self, metric) for metric in self.metrics}
shuffle_idx = torch.randperm(count_seq)
all_texts = dim_zero_cat(self.text_embeddings).cpu()[shuffle_idx, :]
all_genmotions = dim_zero_cat(self.recmotion_embeddings).cpu()[shuffle_idx, :]
all_gtmotions = dim_zero_cat(self.gtmotion_embeddings).cpu()[shuffle_idx, :]
# Compute r-precision
assert count_seq >= self.R_size
top_k_mat = torch.zeros((self.top_k,))
for i in range(count_seq // self.R_size):
group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size]
group_motions = all_genmotions[i * self.R_size:(i + 1) * self.R_size]
dist_mat = euclidean_distance_matrix(group_texts, group_motions).nan_to_num()
self.Matching_score += dist_mat.trace()
argmax = torch.argsort(dist_mat, dim=1)
top_k_mat += calculate_top_k(argmax, top_k=self.top_k).sum(axis=0)
R_count = count_seq // self.R_size * self.R_size
metrics["Matching_score"] = self.Matching_score / R_count
for k in range(self.top_k):
metrics[f"R_precision_top_{str(k + 1)}"] = top_k_mat[k] / R_count
# Compute r-precision with gt
assert count_seq >= self.R_size
top_k_mat = torch.zeros((self.top_k,))
for i in range(count_seq // self.R_size):
group_texts = all_texts[i * self.R_size:(i + 1) * self.R_size]
group_motions = all_gtmotions[i * self.R_size:(i + 1) * self.R_size]
dist_mat = euclidean_distance_matrix(group_texts, group_motions).nan_to_num()
self.gt_Matching_score += dist_mat.trace()
argmax = torch.argsort(dist_mat, dim=1)
top_k_mat += calculate_top_k(argmax, top_k=self.top_k).sum(axis=0)
metrics["gt_Matching_score"] = self.gt_Matching_score / R_count
for k in range(self.top_k):
metrics[f"gt_R_precision_top_{str(k + 1)}"] = top_k_mat[k] / R_count
all_genmotions = all_genmotions.numpy()
all_gtmotions = all_gtmotions.numpy()
# Compute fid
mu, cov = calculate_activation_statistics_np(all_genmotions)
gt_mu, gt_cov = calculate_activation_statistics_np(all_gtmotions)
metrics["FID"] = calculate_frechet_distance_np(gt_mu, gt_cov, mu, cov)
# Compute diversity
assert count_seq >= self.diversity_times
metrics["Diversity"] = calculate_diversity_np(all_genmotions, self.diversity_times)
metrics["gt_Diversity"] = calculate_diversity_np(all_gtmotions, self.diversity_times)
return {**metrics}
def update(
self,
text_embeddings: torch.Tensor,
recmotion_embeddings: torch.Tensor,
gtmotion_embeddings: torch.Tensor,
lengths: list[int]) -> None:
self.count += sum(lengths)
self.count_seq += len(lengths)
# store all texts and motions
self.text_embeddings.append(text_embeddings.detach())
self.recmotion_embeddings.append(recmotion_embeddings.detach())
self.gtmotion_embeddings.append(gtmotion_embeddings.detach())
|