Spaces:
Sleeping
Sleeping
update repo
Browse files- .DS_Store +0 -0
- .gitattributes +2 -0
- dnnlib/__pycache__/util.cpython-38.pyc +0 -0
- id_loss.py +1 -1
- metrics/__init__.py +0 -9
- metrics/frechet_inception_distance.py +0 -41
- metrics/inception_score.py +0 -38
- metrics/kernel_inception_distance.py +0 -46
- metrics/metric_main.py +0 -152
- metrics/metric_utils.py +0 -275
- metrics/perceptual_path_length.py +0 -131
- metrics/precision_recall.py +0 -62
- pretrained/.DS_Store +0 -0
- pretrained/ffhq.pkl +3 -0
- pretrained/metfaces.pkl +3 -0
- model_ir_se50.pth → pretrained/model_ir_se50.pth +0 -0
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
.gitattributes
CHANGED
|
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*.pth* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*.pth* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.pkl* filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
filter=lfs diff=lfs merge=lfs -text
|
dnnlib/__pycache__/util.cpython-38.pyc
CHANGED
|
Binary files a/dnnlib/__pycache__/util.cpython-38.pyc and b/dnnlib/__pycache__/util.cpython-38.pyc differ
|
|
|
id_loss.py
CHANGED
|
@@ -15,7 +15,7 @@ class IDLoss(nn.Module):
|
|
| 15 |
super(IDLoss, self).__init__()
|
| 16 |
print('Loading ResNet ArcFace')
|
| 17 |
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
| 18 |
-
self.facenet.load_state_dict(torch.load("model_ir_se50.pth", map_location=device))
|
| 19 |
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
| 20 |
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
| 21 |
self.facenet.eval()
|
|
|
|
| 15 |
super(IDLoss, self).__init__()
|
| 16 |
print('Loading ResNet ArcFace')
|
| 17 |
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
| 18 |
+
self.facenet.load_state_dict(torch.load("./pretrained/model_ir_se50.pth", map_location=device))
|
| 19 |
self.pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
| 20 |
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
| 21 |
self.facenet.eval()
|
metrics/__init__.py
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
-
# and proprietary rights in and to this software, related documentation
|
| 5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
-
# distribution of this software and related documentation without an express
|
| 7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
-
|
| 9 |
-
# empty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/frechet_inception_distance.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
-
# and proprietary rights in and to this software, related documentation
|
| 5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
-
# distribution of this software and related documentation without an express
|
| 7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
-
|
| 9 |
-
"""Frechet Inception Distance (FID) from the paper
|
| 10 |
-
"GANs trained by a two time-scale update rule converge to a local Nash
|
| 11 |
-
equilibrium". Matches the original implementation by Heusel et al. at
|
| 12 |
-
https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
|
| 13 |
-
|
| 14 |
-
import numpy as np
|
| 15 |
-
import scipy.linalg
|
| 16 |
-
from . import metric_utils
|
| 17 |
-
|
| 18 |
-
#----------------------------------------------------------------------------
|
| 19 |
-
|
| 20 |
-
def compute_fid(opts, max_real, num_gen):
|
| 21 |
-
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
| 22 |
-
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
| 23 |
-
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
| 24 |
-
|
| 25 |
-
mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
|
| 26 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
| 27 |
-
rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
|
| 28 |
-
|
| 29 |
-
mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
|
| 30 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
| 31 |
-
rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
|
| 32 |
-
|
| 33 |
-
if opts.rank != 0:
|
| 34 |
-
return float('nan')
|
| 35 |
-
|
| 36 |
-
m = np.square(mu_gen - mu_real).sum()
|
| 37 |
-
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
| 38 |
-
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
| 39 |
-
return float(fid)
|
| 40 |
-
|
| 41 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/inception_score.py
DELETED
|
@@ -1,38 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
-
# and proprietary rights in and to this software, related documentation
|
| 5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
-
# distribution of this software and related documentation without an express
|
| 7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
-
|
| 9 |
-
"""Inception Score (IS) from the paper "Improved techniques for training
|
| 10 |
-
GANs". Matches the original implementation by Salimans et al. at
|
| 11 |
-
https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
|
| 12 |
-
|
| 13 |
-
import numpy as np
|
| 14 |
-
from . import metric_utils
|
| 15 |
-
|
| 16 |
-
#----------------------------------------------------------------------------
|
| 17 |
-
|
| 18 |
-
def compute_is(opts, num_gen, num_splits):
|
| 19 |
-
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
| 20 |
-
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
| 21 |
-
detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
|
| 22 |
-
|
| 23 |
-
gen_probs = metric_utils.compute_feature_stats_for_generator(
|
| 24 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
| 25 |
-
capture_all=True, max_items=num_gen).get_all()
|
| 26 |
-
|
| 27 |
-
if opts.rank != 0:
|
| 28 |
-
return float('nan'), float('nan')
|
| 29 |
-
|
| 30 |
-
scores = []
|
| 31 |
-
for i in range(num_splits):
|
| 32 |
-
part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
|
| 33 |
-
kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
|
| 34 |
-
kl = np.mean(np.sum(kl, axis=1))
|
| 35 |
-
scores.append(np.exp(kl))
|
| 36 |
-
return float(np.mean(scores)), float(np.std(scores))
|
| 37 |
-
|
| 38 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/kernel_inception_distance.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
-
# and proprietary rights in and to this software, related documentation
|
| 5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
-
# distribution of this software and related documentation without an express
|
| 7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
-
|
| 9 |
-
"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
|
| 10 |
-
GANs". Matches the original implementation by Binkowski et al. at
|
| 11 |
-
https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
|
| 12 |
-
|
| 13 |
-
import numpy as np
|
| 14 |
-
from . import metric_utils
|
| 15 |
-
|
| 16 |
-
#----------------------------------------------------------------------------
|
| 17 |
-
|
| 18 |
-
def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
|
| 19 |
-
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
| 20 |
-
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
| 21 |
-
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
| 22 |
-
|
| 23 |
-
real_features = metric_utils.compute_feature_stats_for_dataset(
|
| 24 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
| 25 |
-
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
|
| 26 |
-
|
| 27 |
-
gen_features = metric_utils.compute_feature_stats_for_generator(
|
| 28 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
| 29 |
-
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
|
| 30 |
-
|
| 31 |
-
if opts.rank != 0:
|
| 32 |
-
return float('nan')
|
| 33 |
-
|
| 34 |
-
n = real_features.shape[1]
|
| 35 |
-
m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
|
| 36 |
-
t = 0
|
| 37 |
-
for _subset_idx in range(num_subsets):
|
| 38 |
-
x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
|
| 39 |
-
y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
|
| 40 |
-
a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
|
| 41 |
-
b = (x @ y.T / n + 1) ** 3
|
| 42 |
-
t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
|
| 43 |
-
kid = t / num_subsets / m
|
| 44 |
-
return float(kid)
|
| 45 |
-
|
| 46 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/metric_main.py
DELETED
|
@@ -1,152 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
-
# and proprietary rights in and to this software, related documentation
|
| 5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
-
# distribution of this software and related documentation without an express
|
| 7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
-
|
| 9 |
-
import os
|
| 10 |
-
import time
|
| 11 |
-
import json
|
| 12 |
-
import torch
|
| 13 |
-
import dnnlib
|
| 14 |
-
|
| 15 |
-
from . import metric_utils
|
| 16 |
-
from . import frechet_inception_distance
|
| 17 |
-
from . import kernel_inception_distance
|
| 18 |
-
from . import precision_recall
|
| 19 |
-
from . import perceptual_path_length
|
| 20 |
-
from . import inception_score
|
| 21 |
-
|
| 22 |
-
#----------------------------------------------------------------------------
|
| 23 |
-
|
| 24 |
-
_metric_dict = dict() # name => fn
|
| 25 |
-
|
| 26 |
-
def register_metric(fn):
|
| 27 |
-
assert callable(fn)
|
| 28 |
-
_metric_dict[fn.__name__] = fn
|
| 29 |
-
return fn
|
| 30 |
-
|
| 31 |
-
def is_valid_metric(metric):
|
| 32 |
-
return metric in _metric_dict
|
| 33 |
-
|
| 34 |
-
def list_valid_metrics():
|
| 35 |
-
return list(_metric_dict.keys())
|
| 36 |
-
|
| 37 |
-
#----------------------------------------------------------------------------
|
| 38 |
-
|
| 39 |
-
def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
|
| 40 |
-
assert is_valid_metric(metric)
|
| 41 |
-
opts = metric_utils.MetricOptions(**kwargs)
|
| 42 |
-
|
| 43 |
-
# Calculate.
|
| 44 |
-
start_time = time.time()
|
| 45 |
-
results = _metric_dict[metric](opts)
|
| 46 |
-
total_time = time.time() - start_time
|
| 47 |
-
|
| 48 |
-
# Broadcast results.
|
| 49 |
-
for key, value in list(results.items()):
|
| 50 |
-
if opts.num_gpus > 1:
|
| 51 |
-
value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
|
| 52 |
-
torch.distributed.broadcast(tensor=value, src=0)
|
| 53 |
-
value = float(value.cpu())
|
| 54 |
-
results[key] = value
|
| 55 |
-
|
| 56 |
-
# Decorate with metadata.
|
| 57 |
-
return dnnlib.EasyDict(
|
| 58 |
-
results = dnnlib.EasyDict(results),
|
| 59 |
-
metric = metric,
|
| 60 |
-
total_time = total_time,
|
| 61 |
-
total_time_str = dnnlib.util.format_time(total_time),
|
| 62 |
-
num_gpus = opts.num_gpus,
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
#----------------------------------------------------------------------------
|
| 66 |
-
|
| 67 |
-
def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
|
| 68 |
-
metric = result_dict['metric']
|
| 69 |
-
assert is_valid_metric(metric)
|
| 70 |
-
if run_dir is not None and snapshot_pkl is not None:
|
| 71 |
-
snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
|
| 72 |
-
|
| 73 |
-
jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
|
| 74 |
-
print(jsonl_line)
|
| 75 |
-
if run_dir is not None and os.path.isdir(run_dir):
|
| 76 |
-
with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
|
| 77 |
-
f.write(jsonl_line + '\n')
|
| 78 |
-
|
| 79 |
-
#----------------------------------------------------------------------------
|
| 80 |
-
# Primary metrics.
|
| 81 |
-
|
| 82 |
-
@register_metric
|
| 83 |
-
def fid50k_full(opts):
|
| 84 |
-
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
| 85 |
-
fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
|
| 86 |
-
return dict(fid50k_full=fid)
|
| 87 |
-
|
| 88 |
-
@register_metric
|
| 89 |
-
def kid50k_full(opts):
|
| 90 |
-
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
| 91 |
-
kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
| 92 |
-
return dict(kid50k_full=kid)
|
| 93 |
-
|
| 94 |
-
@register_metric
|
| 95 |
-
def pr50k3_full(opts):
|
| 96 |
-
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
| 97 |
-
precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
| 98 |
-
return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
|
| 99 |
-
|
| 100 |
-
@register_metric
|
| 101 |
-
def ppl2_wend(opts):
|
| 102 |
-
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
|
| 103 |
-
return dict(ppl2_wend=ppl)
|
| 104 |
-
|
| 105 |
-
@register_metric
|
| 106 |
-
def is50k(opts):
|
| 107 |
-
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
| 108 |
-
mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
|
| 109 |
-
return dict(is50k_mean=mean, is50k_std=std)
|
| 110 |
-
|
| 111 |
-
#----------------------------------------------------------------------------
|
| 112 |
-
# Legacy metrics.
|
| 113 |
-
|
| 114 |
-
@register_metric
|
| 115 |
-
def fid50k(opts):
|
| 116 |
-
opts.dataset_kwargs.update(max_size=None)
|
| 117 |
-
fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
|
| 118 |
-
return dict(fid50k=fid)
|
| 119 |
-
|
| 120 |
-
@register_metric
|
| 121 |
-
def kid50k(opts):
|
| 122 |
-
opts.dataset_kwargs.update(max_size=None)
|
| 123 |
-
kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
| 124 |
-
return dict(kid50k=kid)
|
| 125 |
-
|
| 126 |
-
@register_metric
|
| 127 |
-
def pr50k3(opts):
|
| 128 |
-
opts.dataset_kwargs.update(max_size=None)
|
| 129 |
-
precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
| 130 |
-
return dict(pr50k3_precision=precision, pr50k3_recall=recall)
|
| 131 |
-
|
| 132 |
-
@register_metric
|
| 133 |
-
def ppl_zfull(opts):
|
| 134 |
-
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2)
|
| 135 |
-
return dict(ppl_zfull=ppl)
|
| 136 |
-
|
| 137 |
-
@register_metric
|
| 138 |
-
def ppl_wfull(opts):
|
| 139 |
-
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2)
|
| 140 |
-
return dict(ppl_wfull=ppl)
|
| 141 |
-
|
| 142 |
-
@register_metric
|
| 143 |
-
def ppl_zend(opts):
|
| 144 |
-
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2)
|
| 145 |
-
return dict(ppl_zend=ppl)
|
| 146 |
-
|
| 147 |
-
@register_metric
|
| 148 |
-
def ppl_wend(opts):
|
| 149 |
-
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2)
|
| 150 |
-
return dict(ppl_wend=ppl)
|
| 151 |
-
|
| 152 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/metric_utils.py
DELETED
|
@@ -1,275 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
-
# and proprietary rights in and to this software, related documentation
|
| 5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
-
# distribution of this software and related documentation without an express
|
| 7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
-
|
| 9 |
-
import os
|
| 10 |
-
import time
|
| 11 |
-
import hashlib
|
| 12 |
-
import pickle
|
| 13 |
-
import copy
|
| 14 |
-
import uuid
|
| 15 |
-
import numpy as np
|
| 16 |
-
import torch
|
| 17 |
-
import dnnlib
|
| 18 |
-
|
| 19 |
-
#----------------------------------------------------------------------------
|
| 20 |
-
|
| 21 |
-
class MetricOptions:
|
| 22 |
-
def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
|
| 23 |
-
assert 0 <= rank < num_gpus
|
| 24 |
-
self.G = G
|
| 25 |
-
self.G_kwargs = dnnlib.EasyDict(G_kwargs)
|
| 26 |
-
self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
|
| 27 |
-
self.num_gpus = num_gpus
|
| 28 |
-
self.rank = rank
|
| 29 |
-
self.device = device if device is not None else torch.device('cuda', rank)
|
| 30 |
-
self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
|
| 31 |
-
self.cache = cache
|
| 32 |
-
|
| 33 |
-
#----------------------------------------------------------------------------
|
| 34 |
-
|
| 35 |
-
_feature_detector_cache = dict()
|
| 36 |
-
|
| 37 |
-
def get_feature_detector_name(url):
|
| 38 |
-
return os.path.splitext(url.split('/')[-1])[0]
|
| 39 |
-
|
| 40 |
-
def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
|
| 41 |
-
assert 0 <= rank < num_gpus
|
| 42 |
-
key = (url, device)
|
| 43 |
-
if key not in _feature_detector_cache:
|
| 44 |
-
is_leader = (rank == 0)
|
| 45 |
-
if not is_leader and num_gpus > 1:
|
| 46 |
-
torch.distributed.barrier() # leader goes first
|
| 47 |
-
with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
|
| 48 |
-
_feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
|
| 49 |
-
if is_leader and num_gpus > 1:
|
| 50 |
-
torch.distributed.barrier() # others follow
|
| 51 |
-
return _feature_detector_cache[key]
|
| 52 |
-
|
| 53 |
-
#----------------------------------------------------------------------------
|
| 54 |
-
|
| 55 |
-
class FeatureStats:
|
| 56 |
-
def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
|
| 57 |
-
self.capture_all = capture_all
|
| 58 |
-
self.capture_mean_cov = capture_mean_cov
|
| 59 |
-
self.max_items = max_items
|
| 60 |
-
self.num_items = 0
|
| 61 |
-
self.num_features = None
|
| 62 |
-
self.all_features = None
|
| 63 |
-
self.raw_mean = None
|
| 64 |
-
self.raw_cov = None
|
| 65 |
-
|
| 66 |
-
def set_num_features(self, num_features):
|
| 67 |
-
if self.num_features is not None:
|
| 68 |
-
assert num_features == self.num_features
|
| 69 |
-
else:
|
| 70 |
-
self.num_features = num_features
|
| 71 |
-
self.all_features = []
|
| 72 |
-
self.raw_mean = np.zeros([num_features], dtype=np.float64)
|
| 73 |
-
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
|
| 74 |
-
|
| 75 |
-
def is_full(self):
|
| 76 |
-
return (self.max_items is not None) and (self.num_items >= self.max_items)
|
| 77 |
-
|
| 78 |
-
def append(self, x):
|
| 79 |
-
x = np.asarray(x, dtype=np.float32)
|
| 80 |
-
assert x.ndim == 2
|
| 81 |
-
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
|
| 82 |
-
if self.num_items >= self.max_items:
|
| 83 |
-
return
|
| 84 |
-
x = x[:self.max_items - self.num_items]
|
| 85 |
-
|
| 86 |
-
self.set_num_features(x.shape[1])
|
| 87 |
-
self.num_items += x.shape[0]
|
| 88 |
-
if self.capture_all:
|
| 89 |
-
self.all_features.append(x)
|
| 90 |
-
if self.capture_mean_cov:
|
| 91 |
-
x64 = x.astype(np.float64)
|
| 92 |
-
self.raw_mean += x64.sum(axis=0)
|
| 93 |
-
self.raw_cov += x64.T @ x64
|
| 94 |
-
|
| 95 |
-
def append_torch(self, x, num_gpus=1, rank=0):
|
| 96 |
-
assert isinstance(x, torch.Tensor) and x.ndim == 2
|
| 97 |
-
assert 0 <= rank < num_gpus
|
| 98 |
-
if num_gpus > 1:
|
| 99 |
-
ys = []
|
| 100 |
-
for src in range(num_gpus):
|
| 101 |
-
y = x.clone()
|
| 102 |
-
torch.distributed.broadcast(y, src=src)
|
| 103 |
-
ys.append(y)
|
| 104 |
-
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
|
| 105 |
-
self.append(x.cpu().numpy())
|
| 106 |
-
|
| 107 |
-
def get_all(self):
|
| 108 |
-
assert self.capture_all
|
| 109 |
-
return np.concatenate(self.all_features, axis=0)
|
| 110 |
-
|
| 111 |
-
def get_all_torch(self):
|
| 112 |
-
return torch.from_numpy(self.get_all())
|
| 113 |
-
|
| 114 |
-
def get_mean_cov(self):
|
| 115 |
-
assert self.capture_mean_cov
|
| 116 |
-
mean = self.raw_mean / self.num_items
|
| 117 |
-
cov = self.raw_cov / self.num_items
|
| 118 |
-
cov = cov - np.outer(mean, mean)
|
| 119 |
-
return mean, cov
|
| 120 |
-
|
| 121 |
-
def save(self, pkl_file):
|
| 122 |
-
with open(pkl_file, 'wb') as f:
|
| 123 |
-
pickle.dump(self.__dict__, f)
|
| 124 |
-
|
| 125 |
-
@staticmethod
|
| 126 |
-
def load(pkl_file):
|
| 127 |
-
with open(pkl_file, 'rb') as f:
|
| 128 |
-
s = dnnlib.EasyDict(pickle.load(f))
|
| 129 |
-
obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
|
| 130 |
-
obj.__dict__.update(s)
|
| 131 |
-
return obj
|
| 132 |
-
|
| 133 |
-
#----------------------------------------------------------------------------
|
| 134 |
-
|
| 135 |
-
class ProgressMonitor:
|
| 136 |
-
def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
|
| 137 |
-
self.tag = tag
|
| 138 |
-
self.num_items = num_items
|
| 139 |
-
self.verbose = verbose
|
| 140 |
-
self.flush_interval = flush_interval
|
| 141 |
-
self.progress_fn = progress_fn
|
| 142 |
-
self.pfn_lo = pfn_lo
|
| 143 |
-
self.pfn_hi = pfn_hi
|
| 144 |
-
self.pfn_total = pfn_total
|
| 145 |
-
self.start_time = time.time()
|
| 146 |
-
self.batch_time = self.start_time
|
| 147 |
-
self.batch_items = 0
|
| 148 |
-
if self.progress_fn is not None:
|
| 149 |
-
self.progress_fn(self.pfn_lo, self.pfn_total)
|
| 150 |
-
|
| 151 |
-
def update(self, cur_items):
|
| 152 |
-
assert (self.num_items is None) or (cur_items <= self.num_items)
|
| 153 |
-
if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
|
| 154 |
-
return
|
| 155 |
-
cur_time = time.time()
|
| 156 |
-
total_time = cur_time - self.start_time
|
| 157 |
-
time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
|
| 158 |
-
if (self.verbose) and (self.tag is not None):
|
| 159 |
-
print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
|
| 160 |
-
self.batch_time = cur_time
|
| 161 |
-
self.batch_items = cur_items
|
| 162 |
-
|
| 163 |
-
if (self.progress_fn is not None) and (self.num_items is not None):
|
| 164 |
-
self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
|
| 165 |
-
|
| 166 |
-
def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
|
| 167 |
-
return ProgressMonitor(
|
| 168 |
-
tag = tag,
|
| 169 |
-
num_items = num_items,
|
| 170 |
-
flush_interval = flush_interval,
|
| 171 |
-
verbose = self.verbose,
|
| 172 |
-
progress_fn = self.progress_fn,
|
| 173 |
-
pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
|
| 174 |
-
pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
|
| 175 |
-
pfn_total = self.pfn_total,
|
| 176 |
-
)
|
| 177 |
-
|
| 178 |
-
#----------------------------------------------------------------------------
|
| 179 |
-
|
| 180 |
-
def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
|
| 181 |
-
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
| 182 |
-
if data_loader_kwargs is None:
|
| 183 |
-
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
|
| 184 |
-
|
| 185 |
-
# Try to lookup from cache.
|
| 186 |
-
cache_file = None
|
| 187 |
-
if opts.cache:
|
| 188 |
-
# Choose cache file name.
|
| 189 |
-
args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
|
| 190 |
-
md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
|
| 191 |
-
cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
|
| 192 |
-
cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
|
| 193 |
-
|
| 194 |
-
# Check if the file exists (all processes must agree).
|
| 195 |
-
flag = os.path.isfile(cache_file) if opts.rank == 0 else False
|
| 196 |
-
if opts.num_gpus > 1:
|
| 197 |
-
flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
|
| 198 |
-
torch.distributed.broadcast(tensor=flag, src=0)
|
| 199 |
-
flag = (float(flag.cpu()) != 0)
|
| 200 |
-
|
| 201 |
-
# Load.
|
| 202 |
-
if flag:
|
| 203 |
-
return FeatureStats.load(cache_file)
|
| 204 |
-
|
| 205 |
-
# Initialize.
|
| 206 |
-
num_items = len(dataset)
|
| 207 |
-
if max_items is not None:
|
| 208 |
-
num_items = min(num_items, max_items)
|
| 209 |
-
stats = FeatureStats(max_items=num_items, **stats_kwargs)
|
| 210 |
-
progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
| 211 |
-
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
| 212 |
-
|
| 213 |
-
# Main loop.
|
| 214 |
-
item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
|
| 215 |
-
for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
|
| 216 |
-
if images.shape[1] == 1:
|
| 217 |
-
images = images.repeat([1, 3, 1, 1])
|
| 218 |
-
features = detector(images.to(opts.device), **detector_kwargs)
|
| 219 |
-
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
| 220 |
-
progress.update(stats.num_items)
|
| 221 |
-
|
| 222 |
-
# Save to cache.
|
| 223 |
-
if cache_file is not None and opts.rank == 0:
|
| 224 |
-
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
| 225 |
-
temp_file = cache_file + '.' + uuid.uuid4().hex
|
| 226 |
-
stats.save(temp_file)
|
| 227 |
-
os.replace(temp_file, cache_file) # atomic
|
| 228 |
-
return stats
|
| 229 |
-
|
| 230 |
-
#----------------------------------------------------------------------------
|
| 231 |
-
|
| 232 |
-
def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
|
| 233 |
-
if batch_gen is None:
|
| 234 |
-
batch_gen = min(batch_size, 4)
|
| 235 |
-
assert batch_size % batch_gen == 0
|
| 236 |
-
|
| 237 |
-
# Setup generator and load labels.
|
| 238 |
-
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
|
| 239 |
-
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
| 240 |
-
|
| 241 |
-
# Image generation func.
|
| 242 |
-
def run_generator(z, c):
|
| 243 |
-
img = G(z=z, c=c, **opts.G_kwargs)
|
| 244 |
-
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
| 245 |
-
return img
|
| 246 |
-
|
| 247 |
-
# JIT.
|
| 248 |
-
if jit:
|
| 249 |
-
z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
|
| 250 |
-
c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
|
| 251 |
-
run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
|
| 252 |
-
|
| 253 |
-
# Initialize.
|
| 254 |
-
stats = FeatureStats(**stats_kwargs)
|
| 255 |
-
assert stats.max_items is not None
|
| 256 |
-
progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
| 257 |
-
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
| 258 |
-
|
| 259 |
-
# Main loop.
|
| 260 |
-
while not stats.is_full():
|
| 261 |
-
images = []
|
| 262 |
-
for _i in range(batch_size // batch_gen):
|
| 263 |
-
z = torch.randn([batch_gen, G.z_dim], device=opts.device)
|
| 264 |
-
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
|
| 265 |
-
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
|
| 266 |
-
images.append(run_generator(z, c))
|
| 267 |
-
images = torch.cat(images)
|
| 268 |
-
if images.shape[1] == 1:
|
| 269 |
-
images = images.repeat([1, 3, 1, 1])
|
| 270 |
-
features = detector(images, **detector_kwargs)
|
| 271 |
-
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
| 272 |
-
progress.update(stats.num_items)
|
| 273 |
-
return stats
|
| 274 |
-
|
| 275 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/perceptual_path_length.py
DELETED
|
@@ -1,131 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
-
# and proprietary rights in and to this software, related documentation
|
| 5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
-
# distribution of this software and related documentation without an express
|
| 7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
-
|
| 9 |
-
"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
|
| 10 |
-
Architecture for Generative Adversarial Networks". Matches the original
|
| 11 |
-
implementation by Karras et al. at
|
| 12 |
-
https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
|
| 13 |
-
|
| 14 |
-
import copy
|
| 15 |
-
import numpy as np
|
| 16 |
-
import torch
|
| 17 |
-
import dnnlib
|
| 18 |
-
from . import metric_utils
|
| 19 |
-
|
| 20 |
-
#----------------------------------------------------------------------------
|
| 21 |
-
|
| 22 |
-
# Spherical interpolation of a batch of vectors.
|
| 23 |
-
def slerp(a, b, t):
|
| 24 |
-
a = a / a.norm(dim=-1, keepdim=True)
|
| 25 |
-
b = b / b.norm(dim=-1, keepdim=True)
|
| 26 |
-
d = (a * b).sum(dim=-1, keepdim=True)
|
| 27 |
-
p = t * torch.acos(d)
|
| 28 |
-
c = b - d * a
|
| 29 |
-
c = c / c.norm(dim=-1, keepdim=True)
|
| 30 |
-
d = a * torch.cos(p) + c * torch.sin(p)
|
| 31 |
-
d = d / d.norm(dim=-1, keepdim=True)
|
| 32 |
-
return d
|
| 33 |
-
|
| 34 |
-
#----------------------------------------------------------------------------
|
| 35 |
-
|
| 36 |
-
class PPLSampler(torch.nn.Module):
|
| 37 |
-
def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
|
| 38 |
-
assert space in ['z', 'w']
|
| 39 |
-
assert sampling in ['full', 'end']
|
| 40 |
-
super().__init__()
|
| 41 |
-
self.G = copy.deepcopy(G)
|
| 42 |
-
self.G_kwargs = G_kwargs
|
| 43 |
-
self.epsilon = epsilon
|
| 44 |
-
self.space = space
|
| 45 |
-
self.sampling = sampling
|
| 46 |
-
self.crop = crop
|
| 47 |
-
self.vgg16 = copy.deepcopy(vgg16)
|
| 48 |
-
|
| 49 |
-
def forward(self, c):
|
| 50 |
-
# Generate random latents and interpolation t-values.
|
| 51 |
-
t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
|
| 52 |
-
z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
|
| 53 |
-
|
| 54 |
-
# Interpolate in W or Z.
|
| 55 |
-
if self.space == 'w':
|
| 56 |
-
w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
|
| 57 |
-
wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
|
| 58 |
-
wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
|
| 59 |
-
else: # space == 'z'
|
| 60 |
-
zt0 = slerp(z0, z1, t.unsqueeze(1))
|
| 61 |
-
zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
|
| 62 |
-
wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
|
| 63 |
-
|
| 64 |
-
# Randomize noise buffers.
|
| 65 |
-
for name, buf in self.G.named_buffers():
|
| 66 |
-
if name.endswith('.noise_const'):
|
| 67 |
-
buf.copy_(torch.randn_like(buf))
|
| 68 |
-
|
| 69 |
-
# Generate images.
|
| 70 |
-
img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
|
| 71 |
-
|
| 72 |
-
# Center crop.
|
| 73 |
-
if self.crop:
|
| 74 |
-
assert img.shape[2] == img.shape[3]
|
| 75 |
-
c = img.shape[2] // 8
|
| 76 |
-
img = img[:, :, c*3 : c*7, c*2 : c*6]
|
| 77 |
-
|
| 78 |
-
# Downsample to 256x256.
|
| 79 |
-
factor = self.G.img_resolution // 256
|
| 80 |
-
if factor > 1:
|
| 81 |
-
img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
|
| 82 |
-
|
| 83 |
-
# Scale dynamic range from [-1,1] to [0,255].
|
| 84 |
-
img = (img + 1) * (255 / 2)
|
| 85 |
-
if self.G.img_channels == 1:
|
| 86 |
-
img = img.repeat([1, 3, 1, 1])
|
| 87 |
-
|
| 88 |
-
# Evaluate differential LPIPS.
|
| 89 |
-
lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
|
| 90 |
-
dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
|
| 91 |
-
return dist
|
| 92 |
-
|
| 93 |
-
#----------------------------------------------------------------------------
|
| 94 |
-
|
| 95 |
-
def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False):
|
| 96 |
-
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
| 97 |
-
vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
| 98 |
-
vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
|
| 99 |
-
|
| 100 |
-
# Setup sampler.
|
| 101 |
-
sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
|
| 102 |
-
sampler.eval().requires_grad_(False).to(opts.device)
|
| 103 |
-
if jit:
|
| 104 |
-
c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
|
| 105 |
-
sampler = torch.jit.trace(sampler, [c], check_trace=False)
|
| 106 |
-
|
| 107 |
-
# Sampling loop.
|
| 108 |
-
dist = []
|
| 109 |
-
progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
|
| 110 |
-
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
|
| 111 |
-
progress.update(batch_start)
|
| 112 |
-
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
|
| 113 |
-
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
|
| 114 |
-
x = sampler(c)
|
| 115 |
-
for src in range(opts.num_gpus):
|
| 116 |
-
y = x.clone()
|
| 117 |
-
if opts.num_gpus > 1:
|
| 118 |
-
torch.distributed.broadcast(y, src=src)
|
| 119 |
-
dist.append(y)
|
| 120 |
-
progress.update(num_samples)
|
| 121 |
-
|
| 122 |
-
# Compute PPL.
|
| 123 |
-
if opts.rank != 0:
|
| 124 |
-
return float('nan')
|
| 125 |
-
dist = torch.cat(dist)[:num_samples].cpu().numpy()
|
| 126 |
-
lo = np.percentile(dist, 1, interpolation='lower')
|
| 127 |
-
hi = np.percentile(dist, 99, interpolation='higher')
|
| 128 |
-
ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
|
| 129 |
-
return float(ppl)
|
| 130 |
-
|
| 131 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/precision_recall.py
DELETED
|
@@ -1,62 +0,0 @@
|
|
| 1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
-
#
|
| 3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
| 4 |
-
# and proprietary rights in and to this software, related documentation
|
| 5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
| 6 |
-
# distribution of this software and related documentation without an express
|
| 7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
| 8 |
-
|
| 9 |
-
"""Precision/Recall (PR) from the paper "Improved Precision and Recall
|
| 10 |
-
Metric for Assessing Generative Models". Matches the original implementation
|
| 11 |
-
by Kynkaanniemi et al. at
|
| 12 |
-
https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
from . import metric_utils
|
| 16 |
-
|
| 17 |
-
#----------------------------------------------------------------------------
|
| 18 |
-
|
| 19 |
-
def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
|
| 20 |
-
assert 0 <= rank < num_gpus
|
| 21 |
-
num_cols = col_features.shape[0]
|
| 22 |
-
num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
|
| 23 |
-
col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
|
| 24 |
-
dist_batches = []
|
| 25 |
-
for col_batch in col_batches[rank :: num_gpus]:
|
| 26 |
-
dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
|
| 27 |
-
for src in range(num_gpus):
|
| 28 |
-
dist_broadcast = dist_batch.clone()
|
| 29 |
-
if num_gpus > 1:
|
| 30 |
-
torch.distributed.broadcast(dist_broadcast, src=src)
|
| 31 |
-
dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
|
| 32 |
-
return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
|
| 33 |
-
|
| 34 |
-
#----------------------------------------------------------------------------
|
| 35 |
-
|
| 36 |
-
def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
|
| 37 |
-
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
| 38 |
-
detector_kwargs = dict(return_features=True)
|
| 39 |
-
|
| 40 |
-
real_features = metric_utils.compute_feature_stats_for_dataset(
|
| 41 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
| 42 |
-
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
|
| 43 |
-
|
| 44 |
-
gen_features = metric_utils.compute_feature_stats_for_generator(
|
| 45 |
-
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
| 46 |
-
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
|
| 47 |
-
|
| 48 |
-
results = dict()
|
| 49 |
-
for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
|
| 50 |
-
kth = []
|
| 51 |
-
for manifold_batch in manifold.split(row_batch_size):
|
| 52 |
-
dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
|
| 53 |
-
kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
|
| 54 |
-
kth = torch.cat(kth) if opts.rank == 0 else None
|
| 55 |
-
pred = []
|
| 56 |
-
for probes_batch in probes.split(row_batch_size):
|
| 57 |
-
dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
|
| 58 |
-
pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
|
| 59 |
-
results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
|
| 60 |
-
return results['precision'], results['recall']
|
| 61 |
-
|
| 62 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pretrained/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
pretrained/ffhq.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a205a346e86a9ddaae702e118097d014b7b8bd719491396a162cca438f2f524c
|
| 3 |
+
size 381624121
|
pretrained/metfaces.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:880a460d011a3696c088f58f5844b44271b17903963f2671f96f72dfbce5f76f
|
| 3 |
+
size 381624133
|
model_ir_se50.pth → pretrained/model_ir_se50.pth
RENAMED
|
File without changes
|