AttentionDistillation / train_vae.py
ccchenzc's picture
Init demo.
f2f17f4
import argparse
import os
import torch
from diffusers import AutoencoderKL
from torch import nn
from torch.optim import Adam
from utils import load_image, save_image
def main(args):
os.makedirs(args.out_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = AutoencoderKL.from_pretrained(args.vae_model_path).to(
device, dtype=torch.float32
)
vae.requires_grad_(False)
image = load_image(args.image_path, size=(512, 512)).to(device, dtype=torch.float32)
image = image * 2 - 1
save_image(image / 2 + 0.5, f"{args.out_dir}/ori_image.png")
latents = vae.encode(image)["latent_dist"].mean
save_image(latents, f"{args.out_dir}/latents.png")
rec_image = vae.decode(latents, return_dict=False)[0]
save_image(rec_image / 2 + 0.5, f"{args.out_dir}/rec_image.png")
for param in vae.decoder.parameters():
param.requires_grad = True
loss_fn = nn.L1Loss()
optimizer = Adam(vae.decoder.parameters(), lr=args.learning_rate)
# Training loop
for epoch in range(args.num_epochs):
reconstructed = vae.decode(latents, return_dict=False)[0]
loss = loss_fn(reconstructed, image)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}/{args.num_epochs}, Loss: {loss.item()}")
rec_image = vae.decode(latents, return_dict=False)[0]
save_image(rec_image / 2 + 0.5, f"{args.out_dir}/trained_rec_image.png")
vae.save_pretrained(
f"{args.out_dir}/trained_vae_{os.path.basename(args.image_path)}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a VAE with given image and settings."
)
# Add arguments
parser.add_argument(
"--out_dir",
type=str,
default="./trained_vae/",
help="Output directory to save results",
)
parser.add_argument(
"--vae_model_path",
type=str,
required=True,
help="Path to the pretrained VAE model",
)
parser.add_argument(
"--image_path", type=str, required=True, help="Path to the input image"
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate for the optimizer",
)
parser.add_argument(
"--num_epochs", type=int, default=75, help="Number of training epochs"
)
args = parser.parse_args()
main(args)