Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
3dcdf92
1
Parent(s):
e96a195
- memory efficient EMA
Browse files- fix gradient checkpointing
- evaluate using image reward paper
- use attn_resolution in test
- EMA.py +13 -4
- run.py +29 -1
- score_sde/models/ncsnpp_generator_adagn.py +18 -10
- test_ddgan.py +29 -14
- train_ddgan.py +37 -53
EMA.py
CHANGED
|
@@ -15,13 +15,14 @@ from torch.optim import Optimizer
|
|
| 15 |
|
| 16 |
|
| 17 |
class EMA(Optimizer):
|
| 18 |
-
def __init__(self, opt, ema_decay):
|
| 19 |
self.ema_decay = ema_decay
|
| 20 |
self.apply_ema = self.ema_decay > 0.
|
| 21 |
self.optimizer = opt
|
| 22 |
self.state = opt.state
|
| 23 |
self.param_groups = opt.param_groups
|
| 24 |
self.defaults = {}
|
|
|
|
| 25 |
|
| 26 |
def step(self, *args, **kwargs):
|
| 27 |
# for group in self.optimizer.param_groups:
|
|
@@ -53,11 +54,19 @@ class EMA(Optimizer):
|
|
| 53 |
|
| 54 |
params[p.shape]['data'].append(p.data)
|
| 55 |
ema[p.shape].append(state['ema'])
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
for i in params:
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
for p in group['params']:
|
| 63 |
if p.grad is None:
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class EMA(Optimizer):
|
| 18 |
+
def __init__(self, opt, ema_decay, memory_efficient=False):
|
| 19 |
self.ema_decay = ema_decay
|
| 20 |
self.apply_ema = self.ema_decay > 0.
|
| 21 |
self.optimizer = opt
|
| 22 |
self.state = opt.state
|
| 23 |
self.param_groups = opt.param_groups
|
| 24 |
self.defaults = {}
|
| 25 |
+
self.memory_efficient = memory_efficient
|
| 26 |
|
| 27 |
def step(self, *args, **kwargs):
|
| 28 |
# for group in self.optimizer.param_groups:
|
|
|
|
| 54 |
|
| 55 |
params[p.shape]['data'].append(p.data)
|
| 56 |
ema[p.shape].append(state['ema'])
|
| 57 |
+
|
| 58 |
+
# def stack(d, dim=0):
|
| 59 |
+
# return torch.stack([di.cpu() for di in d], dim=dim).cuda()
|
| 60 |
|
| 61 |
for i in params:
|
| 62 |
+
if self.memory_efficient:
|
| 63 |
+
for j in range(len(params[i]['data'])):
|
| 64 |
+
ema[i][j].mul_(self.ema_decay).add_(params[i]['data'][j], alpha=1. - self.ema_decay)
|
| 65 |
+
ema[i] = torch.stack(ema[i], dim=0)
|
| 66 |
+
else:
|
| 67 |
+
params[i]['data'] = torch.stack(params[i]['data'], dim=0)
|
| 68 |
+
ema[i] = torch.stack(ema[i], dim=0)
|
| 69 |
+
ema[i].mul_(self.ema_decay).add_(params[i]['data'], alpha=1. - self.ema_decay)
|
| 70 |
|
| 71 |
for p in group['params']:
|
| 72 |
if p.grad is None:
|
run.py
CHANGED
|
@@ -274,10 +274,30 @@ def ddgan_ddb_v7():
|
|
| 274 |
cfg = ddgan_ddb_v1()
|
| 275 |
return cfg
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
def ddgan_laion_aesthetic_v15():
|
| 278 |
cfg = ddgan_ddb_v3()
|
| 279 |
return cfg
|
| 280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
models = [
|
| 282 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
| 283 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
|
@@ -326,6 +346,10 @@ models = [
|
|
| 326 |
ddgan_ddb_v5,
|
| 327 |
ddgan_ddb_v6,
|
| 328 |
ddgan_ddb_v7,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
]
|
| 330 |
|
| 331 |
def get_model(model_name):
|
|
@@ -334,7 +358,7 @@ def get_model(model_name):
|
|
| 334 |
return model()
|
| 335 |
|
| 336 |
|
| 337 |
-
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False, eval_name="", scale_method="convolutional"):
|
| 338 |
|
| 339 |
cfg = get_model(model_name)
|
| 340 |
model = cfg['model']
|
|
@@ -365,12 +389,16 @@ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guida
|
|
| 365 |
args['scale_factor_w'] = scale_factor_w
|
| 366 |
args['n_mlp'] = model.get("n_mlp")
|
| 367 |
args['scale_method'] = scale_method
|
|
|
|
| 368 |
if fid:
|
| 369 |
args['compute_fid'] = ''
|
| 370 |
args['real_img_dir'] = real_img_dir
|
| 371 |
args['nb_images_for_fid'] = nb_images_for_fid
|
| 372 |
if compute_clip_score:
|
| 373 |
args['compute_clip_score'] = ""
|
|
|
|
|
|
|
|
|
|
| 374 |
if eval_name:
|
| 375 |
args["eval_name"] = eval_name
|
| 376 |
cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
|
|
|
| 274 |
cfg = ddgan_ddb_v1()
|
| 275 |
return cfg
|
| 276 |
|
| 277 |
+
def ddgan_ddb_v9():
|
| 278 |
+
cfg = ddgan_ddb_v3()
|
| 279 |
+
cfg['model']['attn_resolutions'] = '4 8 16 32'
|
| 280 |
+
return cfg
|
| 281 |
+
|
| 282 |
def ddgan_laion_aesthetic_v15():
|
| 283 |
cfg = ddgan_ddb_v3()
|
| 284 |
return cfg
|
| 285 |
|
| 286 |
+
def ddgan_ddb_v10():
|
| 287 |
+
cfg = ddgan_ddb_v9()
|
| 288 |
+
return cfg
|
| 289 |
+
|
| 290 |
+
def ddgan_ddb_v11():
|
| 291 |
+
cfg = ddgan_ddb_v3()
|
| 292 |
+
cfg['model']['text_encoder'] = "openclip/ViT-g-14/laion2B-s12B-b42K"
|
| 293 |
+
return cfg
|
| 294 |
+
|
| 295 |
+
def ddgan_ddb_v12():
|
| 296 |
+
cfg = ddgan_ddb_v3()
|
| 297 |
+
cfg['model']['text_encoder'] = "openclip/ViT-bigG-14/laion2b_s39b_b160k"
|
| 298 |
+
return cfg
|
| 299 |
+
|
| 300 |
+
|
| 301 |
models = [
|
| 302 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
| 303 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
|
|
|
| 346 |
ddgan_ddb_v5,
|
| 347 |
ddgan_ddb_v6,
|
| 348 |
ddgan_ddb_v7,
|
| 349 |
+
ddgan_ddb_v9,
|
| 350 |
+
ddgan_ddb_v10,
|
| 351 |
+
ddgan_ddb_v11,
|
| 352 |
+
ddgan_ddb_v12,
|
| 353 |
]
|
| 354 |
|
| 355 |
def get_model(model_name):
|
|
|
|
| 358 |
return model()
|
| 359 |
|
| 360 |
|
| 361 |
+
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False, eval_name="", scale_method="convolutional", compute_image_reward=False):
|
| 362 |
|
| 363 |
cfg = get_model(model_name)
|
| 364 |
model = cfg['model']
|
|
|
|
| 389 |
args['scale_factor_w'] = scale_factor_w
|
| 390 |
args['n_mlp'] = model.get("n_mlp")
|
| 391 |
args['scale_method'] = scale_method
|
| 392 |
+
args['attn_resolutions'] = model.get("attn_resolutions", "16")
|
| 393 |
if fid:
|
| 394 |
args['compute_fid'] = ''
|
| 395 |
args['real_img_dir'] = real_img_dir
|
| 396 |
args['nb_images_for_fid'] = nb_images_for_fid
|
| 397 |
if compute_clip_score:
|
| 398 |
args['compute_clip_score'] = ""
|
| 399 |
+
|
| 400 |
+
if compute_image_reward:
|
| 401 |
+
args['compute_image_reward'] = ""
|
| 402 |
if eval_name:
|
| 403 |
args["eval_name"] = eval_name
|
| 404 |
cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
score_sde/models/ncsnpp_generator_adagn.py
CHANGED
|
@@ -37,6 +37,11 @@ import functools
|
|
| 37 |
import torch
|
| 38 |
import numpy as np
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp_Adagn
|
| 42 |
ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp_Adagn
|
|
@@ -63,6 +68,7 @@ class NCSNpp(nn.Module):
|
|
| 63 |
def __init__(self, config):
|
| 64 |
super().__init__()
|
| 65 |
self.config = config
|
|
|
|
| 66 |
self.not_use_tanh = config.not_use_tanh
|
| 67 |
self.act = act = nn.SiLU()
|
| 68 |
self.z_emb_dim = z_emb_dim = config.z_emb_dim
|
|
@@ -176,6 +182,8 @@ class NCSNpp(nn.Module):
|
|
| 176 |
raise ValueError(f'resblock type {resblock_type} unrecognized.')
|
| 177 |
|
| 178 |
# Downsampling block
|
|
|
|
|
|
|
| 179 |
|
| 180 |
channels = config.num_channels
|
| 181 |
if progressive_input != 'none':
|
|
@@ -189,18 +197,18 @@ class NCSNpp(nn.Module):
|
|
| 189 |
# Residual blocks for this resolution
|
| 190 |
for i_block in range(num_res_blocks):
|
| 191 |
out_ch = nf * ch_mult[i_level]
|
| 192 |
-
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
| 193 |
in_ch = out_ch
|
| 194 |
|
| 195 |
if all_resolutions[i_level] in attn_resolutions:
|
| 196 |
-
modules.append(AttnBlock(channels=in_ch))
|
| 197 |
hs_c.append(in_ch)
|
| 198 |
|
| 199 |
if i_level != num_resolutions - 1:
|
| 200 |
if resblock_type == 'ddpm':
|
| 201 |
modules.append(Downsample(in_ch=in_ch))
|
| 202 |
else:
|
| 203 |
-
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
| 204 |
|
| 205 |
if progressive_input == 'input_skip':
|
| 206 |
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
|
@@ -214,21 +222,21 @@ class NCSNpp(nn.Module):
|
|
| 214 |
hs_c.append(in_ch)
|
| 215 |
|
| 216 |
in_ch = hs_c[-1]
|
| 217 |
-
modules.append(ResnetBlock(in_ch=in_ch))
|
| 218 |
-
modules.append(AttnBlock(channels=in_ch))
|
| 219 |
-
modules.append(ResnetBlock(in_ch=in_ch))
|
| 220 |
|
| 221 |
pyramid_ch = 0
|
| 222 |
# Upsampling block
|
| 223 |
for i_level in reversed(range(num_resolutions)):
|
| 224 |
for i_block in range(num_res_blocks + 1):
|
| 225 |
out_ch = nf * ch_mult[i_level]
|
| 226 |
-
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(),
|
| 227 |
-
out_ch=out_ch))
|
| 228 |
in_ch = out_ch
|
| 229 |
|
| 230 |
if all_resolutions[i_level] in attn_resolutions:
|
| 231 |
-
modules.append(AttnBlock(channels=in_ch))
|
| 232 |
|
| 233 |
if progressive != 'none':
|
| 234 |
if i_level == num_resolutions - 1:
|
|
@@ -260,7 +268,7 @@ class NCSNpp(nn.Module):
|
|
| 260 |
if resblock_type == 'ddpm':
|
| 261 |
modules.append(Upsample(in_ch=in_ch))
|
| 262 |
else:
|
| 263 |
-
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
| 264 |
|
| 265 |
assert not hs_c
|
| 266 |
|
|
|
|
| 37 |
import torch
|
| 38 |
import numpy as np
|
| 39 |
|
| 40 |
+
try:
|
| 41 |
+
from fairscale.nn.checkpoint import checkpoint_wrapper
|
| 42 |
+
except Exception:
|
| 43 |
+
checkpoint_wrapper = lambda x:x
|
| 44 |
+
|
| 45 |
|
| 46 |
ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp_Adagn
|
| 47 |
ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp_Adagn
|
|
|
|
| 68 |
def __init__(self, config):
|
| 69 |
super().__init__()
|
| 70 |
self.config = config
|
| 71 |
+
self.grad_checkpointing = config.grad_checkpointing if hasattr(config, "grad_checkpointing") else False
|
| 72 |
self.not_use_tanh = config.not_use_tanh
|
| 73 |
self.act = act = nn.SiLU()
|
| 74 |
self.z_emb_dim = z_emb_dim = config.z_emb_dim
|
|
|
|
| 182 |
raise ValueError(f'resblock type {resblock_type} unrecognized.')
|
| 183 |
|
| 184 |
# Downsampling block
|
| 185 |
+
def wrap(block):
|
| 186 |
+
return checkpoint_wrapper(block) if self.grad_checkpointing else block
|
| 187 |
|
| 188 |
channels = config.num_channels
|
| 189 |
if progressive_input != 'none':
|
|
|
|
| 197 |
# Residual blocks for this resolution
|
| 198 |
for i_block in range(num_res_blocks):
|
| 199 |
out_ch = nf * ch_mult[i_level]
|
| 200 |
+
modules.append(wrap(ResnetBlock(in_ch=in_ch, out_ch=out_ch)))
|
| 201 |
in_ch = out_ch
|
| 202 |
|
| 203 |
if all_resolutions[i_level] in attn_resolutions:
|
| 204 |
+
modules.append(wrap(AttnBlock(channels=in_ch)))
|
| 205 |
hs_c.append(in_ch)
|
| 206 |
|
| 207 |
if i_level != num_resolutions - 1:
|
| 208 |
if resblock_type == 'ddpm':
|
| 209 |
modules.append(Downsample(in_ch=in_ch))
|
| 210 |
else:
|
| 211 |
+
modules.append(wrap(ResnetBlock(down=True, in_ch=in_ch)))
|
| 212 |
|
| 213 |
if progressive_input == 'input_skip':
|
| 214 |
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
|
|
|
| 222 |
hs_c.append(in_ch)
|
| 223 |
|
| 224 |
in_ch = hs_c[-1]
|
| 225 |
+
modules.append(wrap(ResnetBlock(in_ch=in_ch)))
|
| 226 |
+
modules.append(wrap(AttnBlock(channels=in_ch)))
|
| 227 |
+
modules.append(wrap(ResnetBlock(in_ch=in_ch)))
|
| 228 |
|
| 229 |
pyramid_ch = 0
|
| 230 |
# Upsampling block
|
| 231 |
for i_level in reversed(range(num_resolutions)):
|
| 232 |
for i_block in range(num_res_blocks + 1):
|
| 233 |
out_ch = nf * ch_mult[i_level]
|
| 234 |
+
modules.append(wrap(ResnetBlock(in_ch=in_ch + hs_c.pop(),
|
| 235 |
+
out_ch=out_ch)))
|
| 236 |
in_ch = out_ch
|
| 237 |
|
| 238 |
if all_resolutions[i_level] in attn_resolutions:
|
| 239 |
+
modules.append(wrap(AttnBlock(channels=in_ch)))
|
| 240 |
|
| 241 |
if progressive != 'none':
|
| 242 |
if i_level == num_resolutions - 1:
|
|
|
|
| 268 |
if resblock_type == 'ddpm':
|
| 269 |
modules.append(Upsample(in_ch=in_ch))
|
| 270 |
else:
|
| 271 |
+
modules.append(wrap(ResnetBlock(in_ch=in_ch, up=True)))
|
| 272 |
|
| 273 |
assert not hs_c
|
| 274 |
|
test_ddgan.py
CHANGED
|
@@ -380,7 +380,11 @@ def sample_and_test(args):
|
|
| 380 |
epochs = range(1000)
|
| 381 |
else:
|
| 382 |
epochs = [args.epoch_id]
|
| 383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
for epoch in epochs:
|
| 385 |
args.epoch_id = epoch
|
| 386 |
path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
|
|
@@ -389,7 +393,7 @@ def sample_and_test(args):
|
|
| 389 |
continue
|
| 390 |
if not os.path.exists(next_next_path):
|
| 391 |
break
|
| 392 |
-
print(path)
|
| 393 |
|
| 394 |
#if not os.path.exists(next_path):
|
| 395 |
# print(f"STOP at {epoch}")
|
|
@@ -400,9 +404,7 @@ def sample_and_test(args):
|
|
| 400 |
continue
|
| 401 |
suffix = '_' + args.eval_name if args.eval_name else ""
|
| 402 |
dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id, suffix)
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
if (args.compute_fid or args.compute_clip_score) and os.path.exists(dest):
|
| 406 |
continue
|
| 407 |
print("Eval Epoch", args.epoch_id)
|
| 408 |
#loading weights from ddp in single gpu
|
|
@@ -424,7 +426,8 @@ def sample_and_test(args):
|
|
| 424 |
if not os.path.exists(save_dir):
|
| 425 |
os.makedirs(save_dir)
|
| 426 |
|
| 427 |
-
|
|
|
|
| 428 |
from torch.nn.functional import adaptive_avg_pool2d
|
| 429 |
from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
|
| 430 |
from pytorch_fid.inception import InceptionV3
|
|
@@ -472,6 +475,8 @@ def sample_and_test(args):
|
|
| 472 |
|
| 473 |
if args.compute_clip_score:
|
| 474 |
clip_scores = []
|
|
|
|
|
|
|
| 475 |
|
| 476 |
for b in range(0, len(texts), args.batch_size):
|
| 477 |
text = texts[b:b+args.batch_size]
|
|
@@ -485,12 +490,7 @@ def sample_and_test(args):
|
|
| 485 |
else:
|
| 486 |
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
|
| 487 |
fake_sample = to_range_0_1(fake_sample)
|
| 488 |
-
|
| 489 |
-
for j, x in enumerate(fake_sample):
|
| 490 |
-
index = i * args.batch_size + j
|
| 491 |
-
torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
|
| 492 |
-
"""
|
| 493 |
-
|
| 494 |
if args.compute_fid:
|
| 495 |
with torch.no_grad():
|
| 496 |
pred = inceptionv3(fake_sample)[0]
|
|
@@ -511,9 +511,18 @@ def sample_and_test(args):
|
|
| 511 |
imf = torch.nn.functional.normalize(imf, dim=1)
|
| 512 |
txtf = torch.nn.functional.normalize(txtf, dim=1)
|
| 513 |
clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
|
| 514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
if i % 10 == 0:
|
| 516 |
print('evaluating batch ', i, time.time() - t0)
|
|
|
|
| 517 |
i += 1
|
| 518 |
|
| 519 |
results = {}
|
|
@@ -526,6 +535,9 @@ def sample_and_test(args):
|
|
| 526 |
if args.compute_clip_score:
|
| 527 |
clip_score = torch.cat(clip_scores).mean().item()
|
| 528 |
results['clip_score'] = clip_score
|
|
|
|
|
|
|
|
|
|
| 529 |
results.update(vars(args))
|
| 530 |
with open(dest, "w") as fd:
|
| 531 |
json.dump(results, fd)
|
|
@@ -591,6 +603,9 @@ if __name__ == '__main__':
|
|
| 591 |
help='whether or not compute FID')
|
| 592 |
parser.add_argument('--compute_clip_score', action='store_true', default=False,
|
| 593 |
help='whether or not compute CLIP score')
|
|
|
|
|
|
|
|
|
|
| 594 |
parser.add_argument('--clip_model', type=str,default="ViT-L/14")
|
| 595 |
parser.add_argument('--eval_name', type=str,default="")
|
| 596 |
|
|
@@ -625,7 +640,7 @@ if __name__ == '__main__':
|
|
| 625 |
|
| 626 |
parser.add_argument('--num_res_blocks', type=int, default=2,
|
| 627 |
help='number of resnet blocks per scale')
|
| 628 |
-
parser.add_argument('--attn_resolutions', default=(16,),
|
| 629 |
help='resolution of applying attention')
|
| 630 |
parser.add_argument('--dropout', type=float, default=0.,
|
| 631 |
help='drop-out rate')
|
|
|
|
| 380 |
epochs = range(1000)
|
| 381 |
else:
|
| 382 |
epochs = [args.epoch_id]
|
| 383 |
+
if args.compute_image_reward:
|
| 384 |
+
import ImageReward as RM
|
| 385 |
+
#image_reward = RM.load("ImageReward-v1.0", download_root=".").to(device)
|
| 386 |
+
image_reward = RM.load("ImageReward.pt", download_root=".").to(device)
|
| 387 |
+
|
| 388 |
for epoch in epochs:
|
| 389 |
args.epoch_id = epoch
|
| 390 |
path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
|
|
|
|
| 393 |
continue
|
| 394 |
if not os.path.exists(next_next_path):
|
| 395 |
break
|
| 396 |
+
print("PATH", path)
|
| 397 |
|
| 398 |
#if not os.path.exists(next_path):
|
| 399 |
# print(f"STOP at {epoch}")
|
|
|
|
| 404 |
continue
|
| 405 |
suffix = '_' + args.eval_name if args.eval_name else ""
|
| 406 |
dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id, suffix)
|
| 407 |
+
if (args.compute_fid or args.compute_clip_score or args.compute_image_reward) and os.path.exists(dest):
|
|
|
|
|
|
|
| 408 |
continue
|
| 409 |
print("Eval Epoch", args.epoch_id)
|
| 410 |
#loading weights from ddp in single gpu
|
|
|
|
| 426 |
if not os.path.exists(save_dir):
|
| 427 |
os.makedirs(save_dir)
|
| 428 |
|
| 429 |
+
|
| 430 |
+
if args.compute_fid or args.compute_clip_score or args.compute_image_reward:
|
| 431 |
from torch.nn.functional import adaptive_avg_pool2d
|
| 432 |
from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
|
| 433 |
from pytorch_fid.inception import InceptionV3
|
|
|
|
| 475 |
|
| 476 |
if args.compute_clip_score:
|
| 477 |
clip_scores = []
|
| 478 |
+
if args.compute_image_reward:
|
| 479 |
+
image_rewards = []
|
| 480 |
|
| 481 |
for b in range(0, len(texts), args.batch_size):
|
| 482 |
text = texts[b:b+args.batch_size]
|
|
|
|
| 490 |
else:
|
| 491 |
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
|
| 492 |
fake_sample = to_range_0_1(fake_sample)
|
| 493 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
if args.compute_fid:
|
| 495 |
with torch.no_grad():
|
| 496 |
pred = inceptionv3(fake_sample)[0]
|
|
|
|
| 511 |
imf = torch.nn.functional.normalize(imf, dim=1)
|
| 512 |
txtf = torch.nn.functional.normalize(txtf, dim=1)
|
| 513 |
clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
|
| 514 |
+
|
| 515 |
+
if args.compute_image_reward:
|
| 516 |
+
for k, sample in enumerate(fake_sample):
|
| 517 |
+
img = sample.cpu().numpy().transpose(1,2,0)
|
| 518 |
+
img = img * 255
|
| 519 |
+
img = img.astype(np.uint8)
|
| 520 |
+
text_k = text[k]
|
| 521 |
+
score = image_reward.score(text_k, img)
|
| 522 |
+
image_rewards.append(score)
|
| 523 |
if i % 10 == 0:
|
| 524 |
print('evaluating batch ', i, time.time() - t0)
|
| 525 |
+
#break
|
| 526 |
i += 1
|
| 527 |
|
| 528 |
results = {}
|
|
|
|
| 535 |
if args.compute_clip_score:
|
| 536 |
clip_score = torch.cat(clip_scores).mean().item()
|
| 537 |
results['clip_score'] = clip_score
|
| 538 |
+
if args.compute_image_reward:
|
| 539 |
+
reward = np.mean(image_rewards)
|
| 540 |
+
results['image_reward'] = reward
|
| 541 |
results.update(vars(args))
|
| 542 |
with open(dest, "w") as fd:
|
| 543 |
json.dump(results, fd)
|
|
|
|
| 603 |
help='whether or not compute FID')
|
| 604 |
parser.add_argument('--compute_clip_score', action='store_true', default=False,
|
| 605 |
help='whether or not compute CLIP score')
|
| 606 |
+
parser.add_argument('--compute_image_reward', action='store_true', default=False,
|
| 607 |
+
help='whether or not compute CLIP score')
|
| 608 |
+
|
| 609 |
parser.add_argument('--clip_model', type=str,default="ViT-L/14")
|
| 610 |
parser.add_argument('--eval_name', type=str,default="")
|
| 611 |
|
|
|
|
| 640 |
|
| 641 |
parser.add_argument('--num_res_blocks', type=int, default=2,
|
| 642 |
help='number of resnet blocks per scale')
|
| 643 |
+
parser.add_argument('--attn_resolutions', default=(16,), nargs='+', type=int,
|
| 644 |
help='resolution of applying attention')
|
| 645 |
parser.add_argument('--dropout', type=float, default=0.,
|
| 646 |
help='drop-out rate')
|
train_ddgan.py
CHANGED
|
@@ -4,14 +4,14 @@
|
|
| 4 |
# This work is licensed under the NVIDIA Source Code License
|
| 5 |
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
|
| 6 |
# ---------------------------------------------------------------
|
|
|
|
| 7 |
|
| 8 |
from glob import glob
|
| 9 |
import argparse
|
| 10 |
-
import torch
|
| 11 |
import numpy as np
|
| 12 |
-
|
| 13 |
import os
|
| 14 |
-
|
| 15 |
import torch.nn as nn
|
| 16 |
import torch.nn.functional as F
|
| 17 |
import torch.optim as optim
|
|
@@ -288,6 +288,15 @@ def train(rank, gpu, args):
|
|
| 288 |
transforms.ToTensor(),
|
| 289 |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
| 290 |
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
shards = glob(os.path.join(args.dataset_root, "*.tar")) if os.path.isdir(args.dataset_root) else args.dataset_root
|
| 292 |
pipeline = [ResampledShards2(shards)]
|
| 293 |
pipeline.extend([
|
|
@@ -312,7 +321,7 @@ def train(rank, gpu, args):
|
|
| 312 |
dataset,
|
| 313 |
batch_size=None,
|
| 314 |
shuffle=False,
|
| 315 |
-
num_workers=
|
| 316 |
)
|
| 317 |
|
| 318 |
if args.dataset != "wds":
|
|
@@ -355,6 +364,7 @@ def train(rank, gpu, args):
|
|
| 355 |
cond_size=text_encoder.output_size,
|
| 356 |
act=nn.LeakyReLU(0.2)).to(device)
|
| 357 |
elif args.discr_type == "large_attn_pool":
|
|
|
|
| 358 |
netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
|
| 359 |
t_emb_dim = args.t_emb_dim,
|
| 360 |
cond_size=text_encoder.output_size,
|
|
@@ -362,6 +372,7 @@ def train(rank, gpu, args):
|
|
| 362 |
act=nn.LeakyReLU(0.2)).to(device)
|
| 363 |
|
| 364 |
elif args.discr_type == "large_cond_attn":
|
|
|
|
| 365 |
netD = CondAttnDiscriminator(
|
| 366 |
nc = 2*args.num_channels,
|
| 367 |
ngf = args.ngf,
|
|
@@ -391,7 +402,7 @@ def train(rank, gpu, args):
|
|
| 391 |
optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
|
| 392 |
|
| 393 |
if args.use_ema:
|
| 394 |
-
optimizerG = EMA(optimizerG, ema_decay=args.ema_decay)
|
| 395 |
|
| 396 |
schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
|
| 397 |
schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
|
|
@@ -403,12 +414,10 @@ def train(rank, gpu, args):
|
|
| 403 |
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu], find_unused_parameters=args.discr_type=="projected_gan")
|
| 404 |
#if args.discr_type == "projected_gan":
|
| 405 |
# netD._set_static_graph()
|
| 406 |
-
|
| 407 |
|
| 408 |
-
if args.grad_checkpointing:
|
| 409 |
-
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
| 410 |
-
netG = checkpoint_wrapper(netG)
|
| 411 |
-
|
| 412 |
exp = args.exp
|
| 413 |
parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
|
| 414 |
|
|
@@ -442,8 +451,9 @@ def train(rank, gpu, args):
|
|
| 442 |
optimizerD.load_state_dict(checkpoint['optimizerD'])
|
| 443 |
schedulerD.load_state_dict(checkpoint['schedulerD'])
|
| 444 |
global_step = checkpoint['global_step']
|
| 445 |
-
|
| 446 |
-
|
|
|
|
| 447 |
else:
|
| 448 |
global_step, epoch, init_epoch = 0, 0, 0
|
| 449 |
use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn", "large_attn_pool", "projected_gan")
|
|
@@ -454,6 +464,7 @@ def train(rank, gpu, args):
|
|
| 454 |
train_sampler.set_epoch(epoch)
|
| 455 |
|
| 456 |
for iteration, (x, y) in enumerate(data_loader):
|
|
|
|
| 457 |
#print(x.shape)
|
| 458 |
if args.dataset != "wds":
|
| 459 |
y = [str(yi) for yi in y.tolist()]
|
|
@@ -631,6 +642,8 @@ def train(rank, gpu, args):
|
|
| 631 |
if rank == 0:
|
| 632 |
print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
|
| 633 |
print('Global step:', global_step)
|
|
|
|
|
|
|
| 634 |
if iteration % 1000 == 0:
|
| 635 |
x_t_1 = torch.randn_like(real_data)
|
| 636 |
with autocast():
|
|
@@ -640,7 +653,8 @@ def train(rank, gpu, args):
|
|
| 640 |
|
| 641 |
if args.save_content:
|
| 642 |
dist.barrier()
|
| 643 |
-
|
|
|
|
| 644 |
def to_cpu(d):
|
| 645 |
for k, v in d.items():
|
| 646 |
d[k] = v.cpu()
|
|
@@ -677,6 +691,9 @@ def train(rank, gpu, args):
|
|
| 677 |
'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
|
| 678 |
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
| 679 |
torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
|
|
|
|
|
|
|
|
|
|
| 680 |
if args.use_ema:
|
| 681 |
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
| 682 |
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
|
@@ -685,40 +702,8 @@ def train(rank, gpu, args):
|
|
| 685 |
|
| 686 |
|
| 687 |
if not args.no_lr_decay:
|
| 688 |
-
|
| 689 |
schedulerG.step()
|
| 690 |
schedulerD.step()
|
| 691 |
-
"""
|
| 692 |
-
if rank == 0:
|
| 693 |
-
if epoch % 10 == 0:
|
| 694 |
-
torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
|
| 695 |
-
|
| 696 |
-
x_t_1 = torch.randn_like(real_data)
|
| 697 |
-
with autocast():
|
| 698 |
-
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
|
| 699 |
-
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
|
| 700 |
-
|
| 701 |
-
if args.save_content:
|
| 702 |
-
if epoch % args.save_content_every == 0:
|
| 703 |
-
print('Saving content.')
|
| 704 |
-
content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
|
| 705 |
-
'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
|
| 706 |
-
'schedulerG': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
|
| 707 |
-
'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
|
| 708 |
-
|
| 709 |
-
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
| 710 |
-
torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
|
| 711 |
-
|
| 712 |
-
if epoch % args.save_ckpt_every == 0:
|
| 713 |
-
if args.use_ema:
|
| 714 |
-
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
| 715 |
-
|
| 716 |
-
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
| 717 |
-
if args.use_ema:
|
| 718 |
-
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
| 719 |
-
dist.barrier()
|
| 720 |
-
"""
|
| 721 |
-
|
| 722 |
|
| 723 |
def init_processes(rank, size, fn, args):
|
| 724 |
""" Initialize the distributed environment. """
|
|
@@ -748,12 +733,12 @@ if __name__ == '__main__':
|
|
| 748 |
help='seed used for initialization')
|
| 749 |
|
| 750 |
parser.add_argument('--resume', action='store_true',default=False)
|
| 751 |
-
parser.add_argument('--masked_mean', action='store_true',default=False)
|
| 752 |
-
parser.add_argument('--mismatch_loss', action='store_true',default=False)
|
| 753 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
| 754 |
-
parser.add_argument('--cross_attention', action='store_true',default=False)
|
| 755 |
-
parser.add_argument('--fsdp', action='store_true',default=False)
|
| 756 |
-
parser.add_argument('--grad_checkpointing', action='store_true',default=False)
|
| 757 |
|
| 758 |
parser.add_argument('--image_size', type=int, default=32,
|
| 759 |
help='size of image')
|
|
@@ -767,9 +752,8 @@ if __name__ == '__main__':
|
|
| 767 |
parser.add_argument('--beta_max', type=float, default=20.,
|
| 768 |
help='beta_max for diffusion')
|
| 769 |
parser.add_argument('--classifier_free_guidance_proba', type=float, default=0.0)
|
| 770 |
-
|
| 771 |
parser.add_argument('--num_channels_dae', type=int, default=128,
|
| 772 |
-
help='number of initial channels in denosing model')
|
| 773 |
parser.add_argument('--n_mlp', type=int, default=3,
|
| 774 |
help='number of mlp layers for z')
|
| 775 |
parser.add_argument('--ch_mult', nargs='+', type=int,
|
|
@@ -825,7 +809,7 @@ if __name__ == '__main__':
|
|
| 825 |
parser.add_argument('--beta2', type=float, default=0.9,
|
| 826 |
help='beta2 for adam')
|
| 827 |
parser.add_argument('--no_lr_decay',action='store_true', default=False)
|
| 828 |
-
parser.add_argument('--grad_penalty_cond', action='store_true',default=False)
|
| 829 |
|
| 830 |
parser.add_argument('--use_ema', action='store_true', default=False,
|
| 831 |
help='use EMA or not')
|
|
|
|
| 4 |
# This work is licensed under the NVIDIA Source Code License
|
| 5 |
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
|
| 6 |
# ---------------------------------------------------------------
|
| 7 |
+
import torch
|
| 8 |
|
| 9 |
from glob import glob
|
| 10 |
import argparse
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
+
import json
|
| 13 |
import os
|
| 14 |
+
import time
|
| 15 |
import torch.nn as nn
|
| 16 |
import torch.nn.functional as F
|
| 17 |
import torch.optim as optim
|
|
|
|
| 288 |
transforms.ToTensor(),
|
| 289 |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
| 290 |
])
|
| 291 |
+
elif args.preprocessing == "simple_random_crop_v2":
|
| 292 |
+
train_transform = transforms.Compose([
|
| 293 |
+
transforms.Resize(args.image_size),
|
| 294 |
+
transforms.RandomCrop(args.image_size, interpolation=3),
|
| 295 |
+
transforms.ToTensor(),
|
| 296 |
+
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
| 297 |
+
])
|
| 298 |
+
else:
|
| 299 |
+
raise ValueError(args.preprocessing)
|
| 300 |
shards = glob(os.path.join(args.dataset_root, "*.tar")) if os.path.isdir(args.dataset_root) else args.dataset_root
|
| 301 |
pipeline = [ResampledShards2(shards)]
|
| 302 |
pipeline.extend([
|
|
|
|
| 321 |
dataset,
|
| 322 |
batch_size=None,
|
| 323 |
shuffle=False,
|
| 324 |
+
num_workers=1,
|
| 325 |
)
|
| 326 |
|
| 327 |
if args.dataset != "wds":
|
|
|
|
| 364 |
cond_size=text_encoder.output_size,
|
| 365 |
act=nn.LeakyReLU(0.2)).to(device)
|
| 366 |
elif args.discr_type == "large_attn_pool":
|
| 367 |
+
# Discriminator with Attention Pool based discriminator for text conditioning
|
| 368 |
netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
|
| 369 |
t_emb_dim = args.t_emb_dim,
|
| 370 |
cond_size=text_encoder.output_size,
|
|
|
|
| 372 |
act=nn.LeakyReLU(0.2)).to(device)
|
| 373 |
|
| 374 |
elif args.discr_type == "large_cond_attn":
|
| 375 |
+
# Discriminator with Cross-Attention based discriminator for text conditioning
|
| 376 |
netD = CondAttnDiscriminator(
|
| 377 |
nc = 2*args.num_channels,
|
| 378 |
ngf = args.ngf,
|
|
|
|
| 402 |
optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
|
| 403 |
|
| 404 |
if args.use_ema:
|
| 405 |
+
optimizerG = EMA(optimizerG, ema_decay=args.ema_decay, memory_efficient=args.grad_checkpointing)
|
| 406 |
|
| 407 |
schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
|
| 408 |
schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
|
|
|
|
| 414 |
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu], find_unused_parameters=args.discr_type=="projected_gan")
|
| 415 |
#if args.discr_type == "projected_gan":
|
| 416 |
# netD._set_static_graph()
|
|
|
|
| 417 |
|
| 418 |
+
#if args.grad_checkpointing:
|
| 419 |
+
#from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
| 420 |
+
#netG = checkpoint_wrapper(netG)
|
|
|
|
| 421 |
exp = args.exp
|
| 422 |
parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
|
| 423 |
|
|
|
|
| 451 |
optimizerD.load_state_dict(checkpoint['optimizerD'])
|
| 452 |
schedulerD.load_state_dict(checkpoint['schedulerD'])
|
| 453 |
global_step = checkpoint['global_step']
|
| 454 |
+
if rank == 0:
|
| 455 |
+
print("=> loaded checkpoint (epoch {})"
|
| 456 |
+
.format(checkpoint['epoch']))
|
| 457 |
else:
|
| 458 |
global_step, epoch, init_epoch = 0, 0, 0
|
| 459 |
use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn", "large_attn_pool", "projected_gan")
|
|
|
|
| 464 |
train_sampler.set_epoch(epoch)
|
| 465 |
|
| 466 |
for iteration, (x, y) in enumerate(data_loader):
|
| 467 |
+
t0 = time.time()
|
| 468 |
#print(x.shape)
|
| 469 |
if args.dataset != "wds":
|
| 470 |
y = [str(yi) for yi in y.tolist()]
|
|
|
|
| 642 |
if rank == 0:
|
| 643 |
print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
|
| 644 |
print('Global step:', global_step)
|
| 645 |
+
dt = time.time() - t0
|
| 646 |
+
print('Time per iteration: ', dt)
|
| 647 |
if iteration % 1000 == 0:
|
| 648 |
x_t_1 = torch.randn_like(real_data)
|
| 649 |
with autocast():
|
|
|
|
| 653 |
|
| 654 |
if args.save_content:
|
| 655 |
dist.barrier()
|
| 656 |
+
if rank == 0:
|
| 657 |
+
print('Saving content.')
|
| 658 |
def to_cpu(d):
|
| 659 |
for k, v in d.items():
|
| 660 |
d[k] = v.cpu()
|
|
|
|
| 691 |
'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
|
| 692 |
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
| 693 |
torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
|
| 694 |
+
state_content = {'epoch': epoch + 1, 'global_step': global_step}
|
| 695 |
+
with open(os.path.join(exp_path, 'netG_{}.json'.format(epoch)), "w") as fd:
|
| 696 |
+
fd.write(json.dumps(state_content))
|
| 697 |
if args.use_ema:
|
| 698 |
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
| 699 |
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
|
|
|
| 702 |
|
| 703 |
|
| 704 |
if not args.no_lr_decay:
|
|
|
|
| 705 |
schedulerG.step()
|
| 706 |
schedulerD.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
|
| 708 |
def init_processes(rank, size, fn, args):
|
| 709 |
""" Initialize the distributed environment. """
|
|
|
|
| 733 |
help='seed used for initialization')
|
| 734 |
|
| 735 |
parser.add_argument('--resume', action='store_true',default=False)
|
| 736 |
+
parser.add_argument('--masked_mean', action='store_true',default=False, help="use masked mean to pool from t5-based text encoder")
|
| 737 |
+
parser.add_argument('--mismatch_loss', action='store_true',default=False, help="use mismatch loss")
|
| 738 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
| 739 |
+
parser.add_argument('--cross_attention', action='store_true',default=False, help="use cross attention in generator")
|
| 740 |
+
parser.add_argument('--fsdp', action='store_true',default=False, help='use FSDP')
|
| 741 |
+
parser.add_argument('--grad_checkpointing', action='store_true',default=False, help='use grad checkpointing')
|
| 742 |
|
| 743 |
parser.add_argument('--image_size', type=int, default=32,
|
| 744 |
help='size of image')
|
|
|
|
| 752 |
parser.add_argument('--beta_max', type=float, default=20.,
|
| 753 |
help='beta_max for diffusion')
|
| 754 |
parser.add_argument('--classifier_free_guidance_proba', type=float, default=0.0)
|
|
|
|
| 755 |
parser.add_argument('--num_channels_dae', type=int, default=128,
|
| 756 |
+
help='number of initial channels in denosing model generator')
|
| 757 |
parser.add_argument('--n_mlp', type=int, default=3,
|
| 758 |
help='number of mlp layers for z')
|
| 759 |
parser.add_argument('--ch_mult', nargs='+', type=int,
|
|
|
|
| 809 |
parser.add_argument('--beta2', type=float, default=0.9,
|
| 810 |
help='beta2 for adam')
|
| 811 |
parser.add_argument('--no_lr_decay',action='store_true', default=False)
|
| 812 |
+
parser.add_argument('--grad_penalty_cond', action='store_true',default=False, help="cond based grad penalty")
|
| 813 |
|
| 814 |
parser.add_argument('--use_ema', action='store_true', default=False,
|
| 815 |
help='use EMA or not')
|