Spaces:
Runtime error
Runtime error
update the checkpoints fine-tuned on TextCaps 5K.
Browse files- app.py +12 -5
- cldm/cldm.py +1 -47
- config_ema_unlock.yaml +88 -0
- laion1M_model_wo_ema.ckpt → textcaps5K_epoch_10_model_wo_ema.ckpt +2 -2
- textcaps5K_epoch_20_model_wo_ema.ckpt +3 -0
- textcaps5K_epoch_40_model_wo_ema.ckpt +3 -0
- transfer.py +7 -4
app.py
CHANGED
|
@@ -82,12 +82,18 @@ def load_ckpt(model_ckpt = "LAION-Glyph-10M-Epoch-5"):
|
|
| 82 |
time.sleep(2)
|
| 83 |
print("empty the cuda cache")
|
| 84 |
|
| 85 |
-
if model_ckpt == "LAION-Glyph-1M":
|
| 86 |
-
|
| 87 |
-
|
| 88 |
model = load_model_ckpt(model, "laion10M_epoch_5_model_wo_ema.ckpt")
|
| 89 |
elif model_ckpt == "LAION-Glyph-10M-Epoch-6":
|
| 90 |
model = load_model_ckpt(model, "laion10M_epoch_6_model_wo_ema.ckpt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
render_tool = Render_Text(model)
|
| 93 |
output_str = f"already change the model checkpoint to {model_ckpt}"
|
|
@@ -126,7 +132,7 @@ with block:
|
|
| 126 |
only_show_rendered_image = gr.Number(value=1, visible=False)
|
| 127 |
default_width = [0.3, 0.3, 0.3, 0.3]
|
| 128 |
default_top_left_x = [0.35, 0.15, 0.15, 0.5]
|
| 129 |
-
default_top_left_y = [0.
|
| 130 |
with gr.Column():
|
| 131 |
|
| 132 |
with gr.Row():
|
|
@@ -154,7 +160,8 @@ with block:
|
|
| 154 |
with gr.Accordion("Model Options", open=False):
|
| 155 |
with gr.Row():
|
| 156 |
# model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M", "Textcaps5K-10"], label="Checkpoint", default = "LAION-Glyph-10M")
|
| 157 |
-
model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M-Epoch-6", "LAION-Glyph-10M-Epoch-5", "LAION-Glyph-1M"], label="Checkpoint", default = "LAION-Glyph-10M-Epoch-6")
|
|
|
|
| 158 |
# load_button = gr.Button(value = "Load Checkpoint")
|
| 159 |
|
| 160 |
with gr.Accordion("Shared Advanced Options", open=False):
|
|
|
|
| 82 |
time.sleep(2)
|
| 83 |
print("empty the cuda cache")
|
| 84 |
|
| 85 |
+
# if model_ckpt == "LAION-Glyph-1M":
|
| 86 |
+
# model = load_model_ckpt(model, "laion1M_model_wo_ema.ckpt")
|
| 87 |
+
if model_ckpt == "LAION-Glyph-10M-Epoch-5":
|
| 88 |
model = load_model_ckpt(model, "laion10M_epoch_5_model_wo_ema.ckpt")
|
| 89 |
elif model_ckpt == "LAION-Glyph-10M-Epoch-6":
|
| 90 |
model = load_model_ckpt(model, "laion10M_epoch_6_model_wo_ema.ckpt")
|
| 91 |
+
elif model_ckpt == "TextCaps-5K-Epoch-10":
|
| 92 |
+
model = load_model_ckpt(model, "textcaps5K_epoch_10_model_wo_ema.ckpt")
|
| 93 |
+
elif model_ckpt == "TextCaps-5K-Epoch-20":
|
| 94 |
+
model = load_model_ckpt(model, "textcaps5K_epoch_20_model_wo_ema.ckpt")
|
| 95 |
+
elif model_ckpt == "TextCaps-5K-Epoch-40":
|
| 96 |
+
model = load_model_ckpt(model, "textcaps5K_epoch_40_model_wo_ema.ckpt")
|
| 97 |
|
| 98 |
render_tool = Render_Text(model)
|
| 99 |
output_str = f"already change the model checkpoint to {model_ckpt}"
|
|
|
|
| 132 |
only_show_rendered_image = gr.Number(value=1, visible=False)
|
| 133 |
default_width = [0.3, 0.3, 0.3, 0.3]
|
| 134 |
default_top_left_x = [0.35, 0.15, 0.15, 0.5]
|
| 135 |
+
default_top_left_y = [0.4, 0.15, 0.65, 0.65]
|
| 136 |
with gr.Column():
|
| 137 |
|
| 138 |
with gr.Row():
|
|
|
|
| 160 |
with gr.Accordion("Model Options", open=False):
|
| 161 |
with gr.Row():
|
| 162 |
# model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M", "Textcaps5K-10"], label="Checkpoint", default = "LAION-Glyph-10M")
|
| 163 |
+
# model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M-Epoch-6", "LAION-Glyph-10M-Epoch-5", "LAION-Glyph-1M"], label="Checkpoint", default = "LAION-Glyph-10M-Epoch-6")
|
| 164 |
+
model_ckpt = gr.inputs.Dropdown(["LAION-Glyph-10M-Epoch-6", "LAION-Glyph-10M-Epoch-5", "TextCaps-5K-Epoch-10", "TextCaps-5K-Epoch-20", "TextCaps-5K-Epoch-40"], label="Checkpoint", default = "LAION-Glyph-10M-Epoch-6")
|
| 165 |
# load_button = gr.Button(value = "Load Checkpoint")
|
| 166 |
|
| 167 |
with gr.Accordion("Shared Advanced Options", open=False):
|
cldm/cldm.py
CHANGED
|
@@ -532,13 +532,6 @@ class ControlLDM(LatentDiffusion):
|
|
| 532 |
self.freeze_glyph_image_encoder = model.freeze_image_encoder #image_encoder.freeze_model
|
| 533 |
self.glyph_control_model = model
|
| 534 |
self.glyph_image_encoder_type = model.image_encoder_type
|
| 535 |
-
# self.glyph_control_optim = torch.optim.AdamW([
|
| 536 |
-
# {"params": gain_or_bias_params, "weight_decay": 0.}, # "lr": self.glycon_lr},
|
| 537 |
-
# {"params": rest_params, "weight_decay": self.glycon_wd} #, "lr": self.glycon_lr},
|
| 538 |
-
# ],
|
| 539 |
-
# lr = self.glycon_lr
|
| 540 |
-
# )
|
| 541 |
-
# params += list(model.image_encoder.parameters())
|
| 542 |
|
| 543 |
|
| 544 |
|
|
@@ -738,16 +731,6 @@ class ControlLDM(LatentDiffusion):
|
|
| 738 |
if decoder_params is not None:
|
| 739 |
params_wlr.append({"params": decoder_params, "lr": self.decoder_lr})
|
| 740 |
|
| 741 |
-
|
| 742 |
-
# if not self.sep_lr:
|
| 743 |
-
# opt = torch.optim.AdamW(params, lr=lr)
|
| 744 |
-
# else:
|
| 745 |
-
# opt = torch.optim.AdamW(
|
| 746 |
-
# [
|
| 747 |
-
# {"params": params},
|
| 748 |
-
# {"params": decoder_params, "lr": self.decoder_lr}
|
| 749 |
-
# ], lr=lr
|
| 750 |
-
# )
|
| 751 |
if not self.freeze_glyph_image_encoder:
|
| 752 |
if self.glyph_image_encoder_type == "CLIP":
|
| 753 |
# assert self.sep_lr
|
|
@@ -866,20 +849,6 @@ class ControlLDM(LatentDiffusion):
|
|
| 866 |
if p.requires_grad and p.grad is not None:
|
| 867 |
grad_norm_v = p.grad.cpu().detach().norm().item()
|
| 868 |
gradnorm_list.append(grad_norm_v)
|
| 869 |
-
# for name, p in self.named_parameters():
|
| 870 |
-
# if p.requires_grad and p.grad is not None:
|
| 871 |
-
# grad_norm_v = p.grad.detach().norm().item()
|
| 872 |
-
# gradnorm_list.append(grad_norm_v)
|
| 873 |
-
# if "textemb_merge_model" in name:
|
| 874 |
-
# self.log("all_gradients/{}_norm".format(name),
|
| 875 |
-
# gradnorm_list[-1],
|
| 876 |
-
# prog_bar=False, logger=True, on_step=True, on_epoch=False
|
| 877 |
-
# )
|
| 878 |
-
# # if grad_norm_v > 0.1:
|
| 879 |
-
# # print("the norm of gradient w.r.t {} > 0.1: {:.2f}".format
|
| 880 |
-
# # (
|
| 881 |
-
# # name, grad_norm_v
|
| 882 |
-
# # ))
|
| 883 |
if len(gradnorm_list):
|
| 884 |
self.log("all_gradients/grad_norm_mean",
|
| 885 |
np.mean(gradnorm_list),
|
|
@@ -943,19 +912,4 @@ class ControlLDM(LatentDiffusion):
|
|
| 943 |
prog_bar=False, logger=True, on_step=True, on_epoch=False
|
| 944 |
)
|
| 945 |
del gradnorm_list
|
| 946 |
-
del zeroconvs
|
| 947 |
-
|
| 948 |
-
# def freeze_unet(self):
|
| 949 |
-
# # Have some bugs
|
| 950 |
-
# self.model.eval()
|
| 951 |
-
# # self.model.train = disabled_train
|
| 952 |
-
# for param in self.model.parameters():
|
| 953 |
-
# param.requires_grad = False
|
| 954 |
-
|
| 955 |
-
# if not self.sd_locked:
|
| 956 |
-
# self.model.diffusion_model.output_blocks.train()
|
| 957 |
-
# self.model.diffusion_model.out.train()
|
| 958 |
-
# for param in self.model.diffusion_model.out.parameters():
|
| 959 |
-
# param.requires_grad = True
|
| 960 |
-
# for param in self.model.diffusion_model.output_blocks.parameters():
|
| 961 |
-
# param.requires_grad = True
|
|
|
|
| 532 |
self.freeze_glyph_image_encoder = model.freeze_image_encoder #image_encoder.freeze_model
|
| 533 |
self.glyph_control_model = model
|
| 534 |
self.glyph_image_encoder_type = model.image_encoder_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
|
| 536 |
|
| 537 |
|
|
|
|
| 731 |
if decoder_params is not None:
|
| 732 |
params_wlr.append({"params": decoder_params, "lr": self.decoder_lr})
|
| 733 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
if not self.freeze_glyph_image_encoder:
|
| 735 |
if self.glyph_image_encoder_type == "CLIP":
|
| 736 |
# assert self.sep_lr
|
|
|
|
| 849 |
if p.requires_grad and p.grad is not None:
|
| 850 |
grad_norm_v = p.grad.cpu().detach().norm().item()
|
| 851 |
gradnorm_list.append(grad_norm_v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 852 |
if len(gradnorm_list):
|
| 853 |
self.log("all_gradients/grad_norm_mean",
|
| 854 |
np.mean(gradnorm_list),
|
|
|
|
| 912 |
prog_bar=False, logger=True, on_step=True, on_epoch=False
|
| 913 |
)
|
| 914 |
del gradnorm_list
|
| 915 |
+
del zeroconvs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config_ema_unlock.yaml
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-6 #1.0e-5 #1.0e-4
|
| 3 |
+
target: cldm.cldm.ControlLDM
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.0120
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: "jpg"
|
| 11 |
+
cond_stage_key: "txt"
|
| 12 |
+
control_key: "hint"
|
| 13 |
+
image_size: 64
|
| 14 |
+
channels: 4
|
| 15 |
+
cond_stage_trainable: false
|
| 16 |
+
conditioning_key: crossattn
|
| 17 |
+
monitor: #val/loss_simple_ema
|
| 18 |
+
scale_factor: 0.18215
|
| 19 |
+
only_mid_control: False
|
| 20 |
+
sd_locked: False #True
|
| 21 |
+
use_ema: True #TODO: specify
|
| 22 |
+
|
| 23 |
+
control_stage_config:
|
| 24 |
+
target: cldm.cldm.ControlNet
|
| 25 |
+
params:
|
| 26 |
+
use_checkpoint: True
|
| 27 |
+
image_size: 32 # unused
|
| 28 |
+
in_channels: 4
|
| 29 |
+
hint_channels: 3
|
| 30 |
+
model_channels: 320
|
| 31 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 32 |
+
num_res_blocks: 2
|
| 33 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 34 |
+
num_head_channels: 64 # need to fix for flash-attn
|
| 35 |
+
use_spatial_transformer: True
|
| 36 |
+
use_linear_in_transformer: True
|
| 37 |
+
transformer_depth: 1
|
| 38 |
+
context_dim: 1024
|
| 39 |
+
legacy: False
|
| 40 |
+
|
| 41 |
+
unet_config:
|
| 42 |
+
target: cldm.cldm.ControlledUnetModel
|
| 43 |
+
params:
|
| 44 |
+
use_checkpoint: True
|
| 45 |
+
image_size: 32 # unused
|
| 46 |
+
in_channels: 4
|
| 47 |
+
out_channels: 4
|
| 48 |
+
model_channels: 320
|
| 49 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 50 |
+
num_res_blocks: 2
|
| 51 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 52 |
+
num_head_channels: 64 # need to fix for flash-attn
|
| 53 |
+
use_spatial_transformer: True
|
| 54 |
+
use_linear_in_transformer: True
|
| 55 |
+
transformer_depth: 1
|
| 56 |
+
context_dim: 1024
|
| 57 |
+
legacy: False
|
| 58 |
+
|
| 59 |
+
first_stage_config:
|
| 60 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 61 |
+
params:
|
| 62 |
+
embed_dim: 4
|
| 63 |
+
monitor: val/rec_loss
|
| 64 |
+
ddconfig:
|
| 65 |
+
#attn_type: "vanilla-xformers"
|
| 66 |
+
double_z: true
|
| 67 |
+
z_channels: 4
|
| 68 |
+
resolution: 256
|
| 69 |
+
in_channels: 3
|
| 70 |
+
out_ch: 3
|
| 71 |
+
ch: 128
|
| 72 |
+
ch_mult:
|
| 73 |
+
- 1
|
| 74 |
+
- 2
|
| 75 |
+
- 4
|
| 76 |
+
- 4
|
| 77 |
+
num_res_blocks: 2
|
| 78 |
+
attn_resolutions: []
|
| 79 |
+
dropout: 0.0
|
| 80 |
+
lossconfig:
|
| 81 |
+
target: torch.nn.Identity
|
| 82 |
+
|
| 83 |
+
cond_stage_config:
|
| 84 |
+
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
| 85 |
+
params:
|
| 86 |
+
freeze: True
|
| 87 |
+
layer: "penultimate"
|
| 88 |
+
# device: "cpu" #TODO: specify
|
laion1M_model_wo_ema.ckpt → textcaps5K_epoch_10_model_wo_ema.ckpt
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c26cd80dcdd8b5563a68f397f291d0d2d7a4bef7a8c2435fd97a36be32ef61be
|
| 3 |
+
size 6671914001
|
textcaps5K_epoch_20_model_wo_ema.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:85c887aa42db7afbed071629bcf5a07cfccdcdc80216475d8a2536fed75cc600
|
| 3 |
+
size 6671914001
|
textcaps5K_epoch_40_model_wo_ema.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:511be806f6e44f9c33af75df181adacfb3a0bb71aac8df8b303fff36e8e97dae
|
| 3 |
+
size 6671914001
|
transfer.py
CHANGED
|
@@ -2,10 +2,13 @@ from omegaconf import OmegaConf
|
|
| 2 |
from scripts.rendertext_tool import Render_Text, load_model_from_config
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
cfg = OmegaConf.load("config_ema.yaml")
|
| 6 |
-
# model = load_model_from_config(cfg, "model_states.pt", verbose=True)
|
| 7 |
-
model = load_model_from_config(cfg, "mp_rank_00_model_states.pt", verbose=True)
|
| 8 |
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 11 |
with model.ema_scope("store ema weights"):
|
|
@@ -18,6 +21,6 @@ with model.ema_scope("store ema weights"):
|
|
| 18 |
file_content = {
|
| 19 |
'state_dict': store_sd
|
| 20 |
}
|
| 21 |
-
torch.save(file_content, "
|
| 22 |
print("has stored the transfered ckpt.")
|
| 23 |
print("trial ends!")
|
|
|
|
| 2 |
from scripts.rendertext_tool import Render_Text, load_model_from_config
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
# cfg = OmegaConf.load("config_ema.yaml")
|
| 6 |
+
# # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
|
| 7 |
+
# model = load_model_from_config(cfg, "mp_rank_00_model_states.pt", verbose=True)
|
| 8 |
|
| 9 |
+
cfg = OmegaConf.load("config_ema_unlock.yaml")
|
| 10 |
+
epoch_idx = 39
|
| 11 |
+
model = load_model_from_config(cfg, "epoch={:0>6d}.ckpt".format(epoch_idx), verbose=True)
|
| 12 |
|
| 13 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 14 |
with model.ema_scope("store ema weights"):
|
|
|
|
| 21 |
file_content = {
|
| 22 |
'state_dict': store_sd
|
| 23 |
}
|
| 24 |
+
torch.save(file_content, f"textcaps5K_epoch_{epoch_idx+1}_model_wo_ema.ckpt")
|
| 25 |
print("has stored the transfered ckpt.")
|
| 26 |
print("trial ends!")
|