AttentionDistillation / train_vae.py
ccchenzc's picture
Init demo.
f2f17f4
raw
history blame
2.57 kB
import argparse
import os
import torch
from diffusers import AutoencoderKL
from torch import nn
from torch.optim import Adam
from utils import load_image, save_image
def main(args):
os.makedirs(args.out_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKL.from_pretrained(args.vae_model_path).to(
device, dtype=torch.float32
)
vae.requires_grad_(False)
image = load_image(args.image_path, size=(512, 512)).to(device, dtype=torch.float32)
image = image * 2 - 1
save_image(image / 2 + 0.5, f"{args.out_dir}/ori_image.png")
latents = vae.encode(image)["latent_dist"].mean
save_image(latents, f"{args.out_dir}/latents.png")
rec_image = vae.decode(latents, return_dict=False)[0]
save_image(rec_image / 2 + 0.5, f"{args.out_dir}/rec_image.png")
for param in vae.decoder.parameters():
param.requires_grad = True
loss_fn = nn.L1Loss()
optimizer = Adam(vae.decoder.parameters(), lr=args.learning_rate)
# Training loop
for epoch in range(args.num_epochs):
reconstructed = vae.decode(latents, return_dict=False)[0]
loss = loss_fn(reconstructed, image)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}/{args.num_epochs}, Loss: {loss.item()}")
rec_image = vae.decode(latents, return_dict=False)[0]
save_image(rec_image / 2 + 0.5, f"{args.out_dir}/trained_rec_image.png")
vae.save_pretrained(
f"{args.out_dir}/trained_vae_{os.path.basename(args.image_path)}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a VAE with given image and settings."
)
# Add arguments
parser.add_argument(
"--out_dir",
type=str,
default="./trained_vae/",
help="Output directory to save results",
)
parser.add_argument(
"--vae_model_path",
type=str,
required=True,
help="Path to the pretrained VAE model",
)
parser.add_argument(
"--image_path", type=str, required=True, help="Path to the input image"
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate for the optimizer",
)
parser.add_argument(
"--num_epochs", type=int, default=75, help="Number of training epochs"
)
args = parser.parse_args()
main(args)