Upload swap_vae.py with huggingface_hub
Browse files- swap_vae.py +45 -0
swap_vae.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import click
|
| 3 |
+
|
| 4 |
+
def overwrite_first_stage(model_state_dict, vae_state_dict):
|
| 5 |
+
"""
|
| 6 |
+
Overwrite the First Stage Decoders.
|
| 7 |
+
|
| 8 |
+
From the new repo:
|
| 9 |
+
To keep compatibility with existing models,
|
| 10 |
+
only the decoder part was finetuned;
|
| 11 |
+
the checkpoints can be used as a drop-in replacement
|
| 12 |
+
for the existing autoencoder.
|
| 13 |
+
|
| 14 |
+
Sounds like we only need to change the decoder weights.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
target = "first_stage_model."
|
| 18 |
+
for key in model_state_dict.keys():
|
| 19 |
+
if target in key and ("decoder" in key or "encoder" in key):
|
| 20 |
+
matching_name = key.split(target)[1]
|
| 21 |
+
|
| 22 |
+
# double check this weight exists in the new vae
|
| 23 |
+
if matching_name in vae_state_dict:
|
| 24 |
+
model_state_dict[key] = vae_state_dict[matching_name]
|
| 25 |
+
else:
|
| 26 |
+
print(f"{key} Does not exist in the new VAE weights!")
|
| 27 |
+
|
| 28 |
+
return model_state_dict
|
| 29 |
+
|
| 30 |
+
@click.command()
|
| 31 |
+
@click.option("--base-model", type=str, default="sd-v1-5.ckpt")
|
| 32 |
+
@click.option("--vae", type=str, default="new_vae.ckpt")
|
| 33 |
+
@click.option("--output-name", type=str, default="sd-v1-5-new-vae.ckpt")
|
| 34 |
+
def main(base_model, vae, output_name):
|
| 35 |
+
print("hello")
|
| 36 |
+
model = torch.load(base_model)
|
| 37 |
+
new_vae = torch.load(vae)
|
| 38 |
+
|
| 39 |
+
model["state_dict"] = overwrite_first_stage(model["state_dict"], new_vae["state_dict"])
|
| 40 |
+
|
| 41 |
+
print(f"Saving to {output_name}")
|
| 42 |
+
torch.save(model, output_name)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
main()
|