Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import argparse | |
import numpy as np | |
from scipy import linalg | |
def calculate_diversity(activation: np.ndarray, diversity_times: int = 10_000) -> float: | |
assert len(activation.shape) == 2 | |
assert activation.shape[0] > diversity_times | |
num_samples = activation.shape[0] | |
first_indices = np.random.choice(num_samples, diversity_times, replace=False) | |
second_indices = np.random.choice(num_samples, diversity_times, replace=False) | |
dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) | |
return dist | |
def calculate_activation_statistics( | |
activations: np.ndarray, | |
) -> (np.ndarray, np.ndarray): | |
mu = np.mean(activations, axis=0) | |
cov = np.cov(activations, rowvar=False) | |
return mu, cov | |
def calculate_frechet_distance( | |
mu1: np.ndarray, | |
sigma1: np.ndarray, | |
mu2: np.ndarray, | |
sigma2: np.ndarray, | |
eps: float = 1e-6, | |
) -> float: | |
mu1 = np.atleast_1d(mu1) | |
mu2 = np.atleast_1d(mu2) | |
sigma1 = np.atleast_2d(sigma1) | |
sigma2 = np.atleast_2d(sigma2) | |
assert ( | |
mu1.shape == mu2.shape | |
), "Training and test mean vectors have different lengths" | |
assert ( | |
sigma1.shape == sigma2.shape | |
), "Training and test covariances have different dimensions" | |
diff = mu1 - mu2 | |
# Product might be almost singular | |
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) | |
if not np.isfinite(covmean).all(): | |
msg = ( | |
"fid calculation produces singular product; " | |
"adding %s to diagonal of cov estimates" | |
) % eps | |
print(msg) | |
offset = np.eye(sigma1.shape[0]) * eps | |
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | |
# Numerical error might give slight imaginary component | |
if np.iscomplexobj(covmean): | |
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |
m = np.max(np.abs(covmean.imag)) | |
raise ValueError("Imaginary component {}".format(m)) | |
covmean = covmean.real | |
tr_covmean = np.trace(covmean) | |
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean | |
def main(args): | |
num_samples = 5 | |
results = np.load(args.results, allow_pickle=True).item() | |
pred_reshaped = results["motion"].squeeze().reshape((num_samples, -1, 104, 600)) | |
gt_reshaped = results["gt"].squeeze().reshape((num_samples, -1, 104, 600)) | |
# calulate variance across the different samples generated | |
cross_sample_var = np.var(pred_reshaped.reshape((num_samples, -1)), axis=0) | |
print("cross var", cross_sample_var.mean()) | |
pred_pose_last = pred_reshaped.transpose((0, 1, 3, 2)).reshape(-1, 104) | |
gt_pose_last = gt_reshaped.transpose((0, 1, 3, 2)).reshape(-1, 104) | |
# calculate the static and kinematic diversity | |
var_g = calculate_diversity(pred_pose_last) | |
print("var_g", var_g.mean()) | |
var_k = np.var(pred_reshaped, axis=-1) | |
print("var_k", var_k.mean()) | |
# calculate the static and kinematic fid | |
pred_mu_g, pred_cov_g = calculate_activation_statistics(pred_pose_last) | |
gt_mu_g, gt_cov_g = calculate_activation_statistics(gt_pose_last) | |
fid_g = calculate_frechet_distance(gt_mu_g, gt_cov_g, pred_mu_g, pred_cov_g) | |
print("fid_g", fid_g) | |
# reshape for kinematic fid | |
pred_motion = pred_reshaped[..., 1:] - pred_reshaped[..., :-1] | |
gt_motion = gt_reshaped[..., 1:] - gt_reshaped[..., :-1] | |
pred_motion_last = pred_motion.transpose((0, 1, 3, 2)).reshape(-1, 104) | |
gt_motion_last = gt_motion.transpose((0, 1, 3, 2)).reshape(-1, 104) | |
pred_mu_k, pred_cov_k = calculate_activation_statistics(pred_motion_last) | |
gt_mu_k, gt_cov_k = calculate_activation_statistics(gt_motion_last) | |
fid_k = calculate_frechet_distance(gt_mu_k, gt_cov_k, pred_mu_k, pred_cov_k) | |
print("fid_k", fid_k) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--results", type=str, required=True) | |
args = parser.parse_args() | |
main(args) | |