Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
ae26d48
1
Parent(s):
1a02524
update run.py
Browse files
run.py
CHANGED
|
@@ -91,6 +91,11 @@ def ddgan_cc12m_v14():
|
|
| 91 |
cfg['model']['num_channels_dae'] = 192
|
| 92 |
return cfg
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
def ddgan_cifar10_cond17():
|
| 96 |
cfg = base()
|
|
@@ -107,6 +112,13 @@ def ddgan_cifar10_cond18():
|
|
| 107 |
cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
|
| 108 |
return cfg
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
def ddgan_laion_aesthetic_v1():
|
| 111 |
cfg = ddgan_cc12m_v11()
|
| 112 |
cfg['model']['dataset_root'] = '"/p/scratch/ccstdl/cherti1/LAION-aesthetic/output/{00000..05038}.tar"'
|
|
@@ -122,10 +134,23 @@ def ddgan_laion_aesthetic_v3():
|
|
| 122 |
cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
|
| 123 |
return cfg
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
models = [
|
| 127 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
| 128 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
|
|
|
|
|
|
| 129 |
ddgan_cc12m_v2, # baseline (no large text encoder, no classifier guidance)
|
| 130 |
ddgan_cc12m_v6, # like v2 but using large T5 text encoder
|
| 131 |
ddgan_cc12m_v7, # like v2 but with classifier guidance
|
|
@@ -135,9 +160,12 @@ models = [
|
|
| 135 |
ddgan_cc12m_v12, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1
|
| 136 |
ddgan_cc12m_v13, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + cond attn
|
| 137 |
ddgan_cc12m_v14, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + 300M model
|
|
|
|
| 138 |
ddgan_laion_aesthetic_v1, # like ddgan_cc12m_v11 but fine-tuned on laion aesthetic
|
| 139 |
ddgan_laion_aesthetic_v2, # like ddgan_laion_aesthetic_v1 but trained from scratch with the new cross attn discr
|
| 140 |
-
ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL
|
|
|
|
|
|
|
| 141 |
]
|
| 142 |
|
| 143 |
def get_model(model_name):
|
|
@@ -146,7 +174,7 @@ def get_model(model_name):
|
|
| 146 |
return model()
|
| 147 |
|
| 148 |
|
| 149 |
-
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0):
|
| 150 |
|
| 151 |
cfg = get_model(model_name)
|
| 152 |
model = cfg['model']
|
|
@@ -173,12 +201,16 @@ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guida
|
|
| 173 |
args['guidance_scale'] = guidance_scale
|
| 174 |
args['masked_mean'] = model.get("masked_mean")
|
| 175 |
args['dynamic_thresholding_quantile'] = q
|
|
|
|
|
|
|
| 176 |
args['n_mlp'] = model.get("n_mlp")
|
| 177 |
-
|
| 178 |
if fid:
|
| 179 |
args['compute_fid'] = ''
|
| 180 |
args['real_img_dir'] = real_img_dir
|
| 181 |
args['nb_images_for_fid'] = nb_images_for_fid
|
|
|
|
|
|
|
|
|
|
| 182 |
cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
| 183 |
print(cmd)
|
| 184 |
call(cmd, shell=True)
|
|
|
|
| 91 |
cfg['model']['num_channels_dae'] = 192
|
| 92 |
return cfg
|
| 93 |
|
| 94 |
+
def ddgan_cc12m_v15():
|
| 95 |
+
cfg = ddgan_cc12m_v11()
|
| 96 |
+
cfg['model']['mismatch_loss'] = ''
|
| 97 |
+
cfg['model']['grad_penalty_cond'] = ''
|
| 98 |
+
return cfg
|
| 99 |
|
| 100 |
def ddgan_cifar10_cond17():
|
| 101 |
cfg = base()
|
|
|
|
| 112 |
cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
|
| 113 |
return cfg
|
| 114 |
|
| 115 |
+
def ddgan_cifar10_cond19():
|
| 116 |
+
cfg = ddgan_cifar10_cond17()
|
| 117 |
+
cfg['model']['discr_type'] = 'small_cond_attn'
|
| 118 |
+
cfg['model']['mismatch_loss'] = ''
|
| 119 |
+
cfg['model']['grad_penalty_cond'] = ''
|
| 120 |
+
return cfg
|
| 121 |
+
|
| 122 |
def ddgan_laion_aesthetic_v1():
|
| 123 |
cfg = ddgan_cc12m_v11()
|
| 124 |
cfg['model']['dataset_root'] = '"/p/scratch/ccstdl/cherti1/LAION-aesthetic/output/{00000..05038}.tar"'
|
|
|
|
| 134 |
cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
|
| 135 |
return cfg
|
| 136 |
|
| 137 |
+
def ddgan_laion_aesthetic_v4():
|
| 138 |
+
cfg = ddgan_laion_aesthetic_v1()
|
| 139 |
+
cfg['model']['text_encoder'] = "openclip/ViT-L-14-336/openai"
|
| 140 |
+
return cfg
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def ddgan_laion_aesthetic_v5():
|
| 144 |
+
cfg = ddgan_laion_aesthetic_v1()
|
| 145 |
+
cfg['model']['mismatch_loss'] = ''
|
| 146 |
+
cfg['model']['grad_penalty_cond'] = ''
|
| 147 |
+
return cfg
|
| 148 |
|
| 149 |
models = [
|
| 150 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
| 151 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
| 152 |
+
ddgan_cifar10_cond19, # cifar10, xl encoder
|
| 153 |
+
|
| 154 |
ddgan_cc12m_v2, # baseline (no large text encoder, no classifier guidance)
|
| 155 |
ddgan_cc12m_v6, # like v2 but using large T5 text encoder
|
| 156 |
ddgan_cc12m_v7, # like v2 but with classifier guidance
|
|
|
|
| 160 |
ddgan_cc12m_v12, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1
|
| 161 |
ddgan_cc12m_v13, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + cond attn
|
| 162 |
ddgan_cc12m_v14, # T5-XL + cross attention + classifier free guidance + random_resized_crop_v1 + 300M model
|
| 163 |
+
ddgan_cc12m_v15, # fine-tune v11 with --mismatch_loss and --grad_penalty_cond
|
| 164 |
ddgan_laion_aesthetic_v1, # like ddgan_cc12m_v11 but fine-tuned on laion aesthetic
|
| 165 |
ddgan_laion_aesthetic_v2, # like ddgan_laion_aesthetic_v1 but trained from scratch with the new cross attn discr
|
| 166 |
+
ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL (continue from 23aug with mismatch and grad penalty and random_resized_crop_v1)
|
| 167 |
+
ddgan_laion_aesthetic_v4, # like ddgan_laion_aesthetic_v1 but trained from scratch with OpenAI's ClipEncoder
|
| 168 |
+
ddgan_laion_aesthetic_v5, # fine-tune ddgan_laion_aesthetic_v1 with mismatch and cond grad penalty losses
|
| 169 |
]
|
| 170 |
|
| 171 |
def get_model(model_name):
|
|
|
|
| 174 |
return model()
|
| 175 |
|
| 176 |
|
| 177 |
+
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False):
|
| 178 |
|
| 179 |
cfg = get_model(model_name)
|
| 180 |
model = cfg['model']
|
|
|
|
| 201 |
args['guidance_scale'] = guidance_scale
|
| 202 |
args['masked_mean'] = model.get("masked_mean")
|
| 203 |
args['dynamic_thresholding_quantile'] = q
|
| 204 |
+
args['scale_factor_h'] = scale_factor_h
|
| 205 |
+
args['scale_factor_w'] = scale_factor_w
|
| 206 |
args['n_mlp'] = model.get("n_mlp")
|
|
|
|
| 207 |
if fid:
|
| 208 |
args['compute_fid'] = ''
|
| 209 |
args['real_img_dir'] = real_img_dir
|
| 210 |
args['nb_images_for_fid'] = nb_images_for_fid
|
| 211 |
+
if compute_clip_score:
|
| 212 |
+
args['compute_clip_score'] = ""
|
| 213 |
+
|
| 214 |
cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
| 215 |
print(cmd)
|
| 216 |
call(cmd, shell=True)
|