Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
572f947
1
Parent(s):
06c5f0c
support clip score and higher resolution at test time
Browse files- test_ddgan.py +45 -13
test_ddgan.py
CHANGED
|
@@ -12,7 +12,7 @@ import os
|
|
| 12 |
import json
|
| 13 |
import torchvision
|
| 14 |
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
|
| 15 |
-
import
|
| 16 |
|
| 17 |
#%% Diffusion coefficients
|
| 18 |
def var_func_vp(t, beta_min, beta_max):
|
|
@@ -130,13 +130,13 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None
|
|
| 130 |
def sample_from_model_classifier_free_guidance(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0):
|
| 131 |
x = x_init
|
| 132 |
null = text_encoder([""] * len(x_init), return_only_pooled=False)
|
| 133 |
-
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
| 134 |
with torch.no_grad():
|
| 135 |
for i in reversed(range(n_time)):
|
| 136 |
t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
|
| 137 |
t_time = t
|
| 138 |
|
| 139 |
-
|
| 140 |
|
| 141 |
x_0_uncond = generator(x, t_time, latent_z, cond=null)
|
| 142 |
|
|
@@ -184,10 +184,8 @@ def sample_from_model_classifier_free_guidance(coefficients, generator, n_time,
|
|
| 184 |
def sample_and_test(args):
|
| 185 |
torch.manual_seed(args.seed)
|
| 186 |
device = 'cuda:0'
|
| 187 |
-
text_encoder
|
| 188 |
args.cond_size = text_encoder.output_size
|
| 189 |
-
# cond = text_encoder([str(yi%10) for yi in range(args.batch_size)])
|
| 190 |
-
|
| 191 |
if args.dataset == 'cifar10':
|
| 192 |
real_img_dir = 'pytorch_fid/cifar10_train_stat.npy'
|
| 193 |
elif args.dataset == 'celeba_256':
|
|
@@ -201,7 +199,7 @@ def sample_and_test(args):
|
|
| 201 |
|
| 202 |
|
| 203 |
netG = NCSNpp(args).to(device)
|
| 204 |
-
|
| 205 |
|
| 206 |
if args.epoch_id == -1:
|
| 207 |
epochs = range(1000)
|
|
@@ -214,7 +212,7 @@ def sample_and_test(args):
|
|
| 214 |
if not os.path.exists(path):
|
| 215 |
continue
|
| 216 |
ckpt = torch.load(path, map_location=device)
|
| 217 |
-
dest = './saved_info/dd_gan/{}/{}/
|
| 218 |
|
| 219 |
if args.compute_fid and os.path.exists(dest):
|
| 220 |
continue
|
|
@@ -258,6 +256,15 @@ def sample_and_test(args):
|
|
| 258 |
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
| 259 |
inceptionv3 = InceptionV3([block_idx]).to(device)
|
| 260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
if not args.real_img_dir.endswith("npz"):
|
| 262 |
real_mu, real_sigma = compute_statistics_of_path(
|
| 263 |
args.real_img_dir, inceptionv3, args.batch_size, dims, device,
|
|
@@ -270,6 +277,9 @@ def sample_and_test(args):
|
|
| 270 |
real_sigma = stats['sigma']
|
| 271 |
|
| 272 |
fake_features = []
|
|
|
|
|
|
|
|
|
|
| 273 |
for b in range(0, len(texts), args.batch_size):
|
| 274 |
text = texts[b:b+args.batch_size]
|
| 275 |
with torch.no_grad():
|
|
@@ -277,6 +287,7 @@ def sample_and_test(args):
|
|
| 277 |
bs = len(text)
|
| 278 |
t0 = time.time()
|
| 279 |
x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
|
|
|
|
| 280 |
if args.guidance_scale:
|
| 281 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
| 282 |
else:
|
|
@@ -295,6 +306,17 @@ def sample_and_test(args):
|
|
| 295 |
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
| 296 |
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
| 297 |
fake_features.append(pred)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
if i % 10 == 0:
|
| 299 |
print('generating batch ', i, time.time() - t0)
|
| 300 |
"""
|
|
@@ -311,14 +333,17 @@ def sample_and_test(args):
|
|
| 311 |
fake_mu = np.mean(fake_features, axis=0)
|
| 312 |
fake_sigma = np.cov(fake_features, rowvar=False)
|
| 313 |
fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
|
| 314 |
-
dest = './saved_info/dd_gan/{}/{}/
|
| 315 |
results = {
|
| 316 |
"fid": fid,
|
| 317 |
}
|
|
|
|
|
|
|
|
|
|
| 318 |
results.update(vars(args))
|
| 319 |
with open(dest, "w") as fd:
|
| 320 |
json.dump(results, fd)
|
| 321 |
-
print(
|
| 322 |
else:
|
| 323 |
if args.cond_text.endswith(".txt"):
|
| 324 |
texts = open(args.cond_text).readlines()
|
|
@@ -326,11 +351,13 @@ def sample_and_test(args):
|
|
| 326 |
else:
|
| 327 |
texts = [args.cond_text] * args.batch_size
|
| 328 |
cond = text_encoder(texts, return_only_pooled=False)
|
| 329 |
-
x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size, args.image_size).to(device)
|
|
|
|
| 330 |
if args.guidance_scale:
|
| 331 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
| 332 |
else:
|
| 333 |
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
|
|
|
|
| 334 |
fake_sample = to_range_0_1(fake_sample)
|
| 335 |
torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
|
| 336 |
|
|
@@ -344,11 +371,16 @@ if __name__ == '__main__':
|
|
| 344 |
help='seed used for initialization')
|
| 345 |
parser.add_argument('--compute_fid', action='store_true', default=False,
|
| 346 |
help='whether or not compute FID')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
parser.add_argument('--epoch_id', type=int,default=1000)
|
| 348 |
parser.add_argument('--guidance_scale', type=float,default=0)
|
| 349 |
parser.add_argument('--dynamic_thresholding_quantile', type=float,default=0)
|
| 350 |
parser.add_argument('--cond_text', type=str,default="0")
|
| 351 |
-
|
|
|
|
| 352 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
| 353 |
|
| 354 |
|
|
@@ -419,7 +451,7 @@ if __name__ == '__main__':
|
|
| 419 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
| 420 |
parser.add_argument('--masked_mean', action='store_true',default=False)
|
| 421 |
parser.add_argument('--nb_images_for_fid', type=int, default=0)
|
| 422 |
-
|
| 423 |
|
| 424 |
|
| 425 |
|
|
|
|
| 12 |
import json
|
| 13 |
import torchvision
|
| 14 |
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
|
| 15 |
+
from encoder import build_encoder
|
| 16 |
|
| 17 |
#%% Diffusion coefficients
|
| 18 |
def var_func_vp(t, beta_min, beta_max):
|
|
|
|
| 130 |
def sample_from_model_classifier_free_guidance(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0):
|
| 131 |
x = x_init
|
| 132 |
null = text_encoder([""] * len(x_init), return_only_pooled=False)
|
| 133 |
+
#latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
| 134 |
with torch.no_grad():
|
| 135 |
for i in reversed(range(n_time)):
|
| 136 |
t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
|
| 137 |
t_time = t
|
| 138 |
|
| 139 |
+
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
| 140 |
|
| 141 |
x_0_uncond = generator(x, t_time, latent_z, cond=null)
|
| 142 |
|
|
|
|
| 184 |
def sample_and_test(args):
|
| 185 |
torch.manual_seed(args.seed)
|
| 186 |
device = 'cuda:0'
|
| 187 |
+
text_encoder =build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
|
| 188 |
args.cond_size = text_encoder.output_size
|
|
|
|
|
|
|
| 189 |
if args.dataset == 'cifar10':
|
| 190 |
real_img_dir = 'pytorch_fid/cifar10_train_stat.npy'
|
| 191 |
elif args.dataset == 'celeba_256':
|
|
|
|
| 199 |
|
| 200 |
|
| 201 |
netG = NCSNpp(args).to(device)
|
| 202 |
+
netG.attn_resolutions = [r * args.scale_factor_w for r in netG.attn_resolutions]
|
| 203 |
|
| 204 |
if args.epoch_id == -1:
|
| 205 |
epochs = range(1000)
|
|
|
|
| 212 |
if not os.path.exists(path):
|
| 213 |
continue
|
| 214 |
ckpt = torch.load(path, map_location=device)
|
| 215 |
+
dest = './saved_info/dd_gan/{}/{}/eval_{}.json'.format(args.dataset, args.exp, args.epoch_id)
|
| 216 |
|
| 217 |
if args.compute_fid and os.path.exists(dest):
|
| 218 |
continue
|
|
|
|
| 256 |
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
| 257 |
inceptionv3 = InceptionV3([block_idx]).to(device)
|
| 258 |
|
| 259 |
+
if args.compute_clip_score:
|
| 260 |
+
import clip
|
| 261 |
+
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
| 262 |
+
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
| 263 |
+
clip_model, preprocess = clip.load(args.clip_model, device)
|
| 264 |
+
clip_mean = torch.Tensor(CLIP_MEAN).view(1,-1,1,1).to(device)
|
| 265 |
+
clip_std = torch.Tensor(CLIP_STD).view(1,-1,1,1).to(device)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
if not args.real_img_dir.endswith("npz"):
|
| 269 |
real_mu, real_sigma = compute_statistics_of_path(
|
| 270 |
args.real_img_dir, inceptionv3, args.batch_size, dims, device,
|
|
|
|
| 277 |
real_sigma = stats['sigma']
|
| 278 |
|
| 279 |
fake_features = []
|
| 280 |
+
if args.compute_clip_score:
|
| 281 |
+
clip_scores = []
|
| 282 |
+
|
| 283 |
for b in range(0, len(texts), args.batch_size):
|
| 284 |
text = texts[b:b+args.batch_size]
|
| 285 |
with torch.no_grad():
|
|
|
|
| 287 |
bs = len(text)
|
| 288 |
t0 = time.time()
|
| 289 |
x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
|
| 290 |
+
#print(x_t_1.shape)
|
| 291 |
if args.guidance_scale:
|
| 292 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
| 293 |
else:
|
|
|
|
| 306 |
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
| 307 |
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
| 308 |
fake_features.append(pred)
|
| 309 |
+
|
| 310 |
+
if args.compute_clip_score:
|
| 311 |
+
with torch.no_grad():
|
| 312 |
+
clip_ims = torch.nn.functional.interpolate(fake_sample, (224, 224), mode="bicubic")
|
| 313 |
+
clip_txt = clip.tokenize(text).to(device)
|
| 314 |
+
imf = clip_model.encode_image(clip_ims)
|
| 315 |
+
txtf = clip_model.encode_text(clip_txt)
|
| 316 |
+
imf = torch.nn.functional.normalize(imf, dim=1)
|
| 317 |
+
txtf = torch.nn.functional.normalize(txtf, dim=1)
|
| 318 |
+
clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
|
| 319 |
+
break
|
| 320 |
if i % 10 == 0:
|
| 321 |
print('generating batch ', i, time.time() - t0)
|
| 322 |
"""
|
|
|
|
| 333 |
fake_mu = np.mean(fake_features, axis=0)
|
| 334 |
fake_sigma = np.cov(fake_features, rowvar=False)
|
| 335 |
fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
|
| 336 |
+
dest = './saved_info/dd_gan/{}/{}/eval_{}.json'.format(args.dataset, args.exp, args.epoch_id)
|
| 337 |
results = {
|
| 338 |
"fid": fid,
|
| 339 |
}
|
| 340 |
+
if args.compute_clip_score:
|
| 341 |
+
clip_score = torch.cat(clip_scores).mean().item()
|
| 342 |
+
results['clip_score'] = clip_score
|
| 343 |
results.update(vars(args))
|
| 344 |
with open(dest, "w") as fd:
|
| 345 |
json.dump(results, fd)
|
| 346 |
+
print(results)
|
| 347 |
else:
|
| 348 |
if args.cond_text.endswith(".txt"):
|
| 349 |
texts = open(args.cond_text).readlines()
|
|
|
|
| 351 |
else:
|
| 352 |
texts = [args.cond_text] * args.batch_size
|
| 353 |
cond = text_encoder(texts, return_only_pooled=False)
|
| 354 |
+
x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
|
| 355 |
+
t0 = time.time()
|
| 356 |
if args.guidance_scale:
|
| 357 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
| 358 |
else:
|
| 359 |
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
|
| 360 |
+
print(time.time() - t0)
|
| 361 |
fake_sample = to_range_0_1(fake_sample)
|
| 362 |
torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
|
| 363 |
|
|
|
|
| 371 |
help='seed used for initialization')
|
| 372 |
parser.add_argument('--compute_fid', action='store_true', default=False,
|
| 373 |
help='whether or not compute FID')
|
| 374 |
+
parser.add_argument('--compute_clip_score', action='store_true', default=False,
|
| 375 |
+
help='whether or not compute CLIP score')
|
| 376 |
+
parser.add_argument('--clip_model', type=str,default="ViT-L/14")
|
| 377 |
+
|
| 378 |
parser.add_argument('--epoch_id', type=int,default=1000)
|
| 379 |
parser.add_argument('--guidance_scale', type=float,default=0)
|
| 380 |
parser.add_argument('--dynamic_thresholding_quantile', type=float,default=0)
|
| 381 |
parser.add_argument('--cond_text', type=str,default="0")
|
| 382 |
+
parser.add_argument('--scale_factor_h', type=int,default=1)
|
| 383 |
+
parser.add_argument('--scale_factor_w', type=int,default=1)
|
| 384 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
| 385 |
|
| 386 |
|
|
|
|
| 451 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
| 452 |
parser.add_argument('--masked_mean', action='store_true',default=False)
|
| 453 |
parser.add_argument('--nb_images_for_fid', type=int, default=0)
|
| 454 |
+
|
| 455 |
|
| 456 |
|
| 457 |
|