DiffSketcher / pytorch_svgrender /pipelines /SVGDreamer_pipeline.py
hjc-owo
init repo
966ae59
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
import pathlib
from PIL import Image
from typing import AnyStr
import numpy as np
from tqdm.auto import tqdm
import torch
from torch.optim.lr_scheduler import LambdaLR
import torchvision
from torchvision import transforms
from pytorch_svgrender.libs.engine import ModelState
from pytorch_svgrender.libs.solver.optim import get_optimizer
from pytorch_svgrender.painter.svgdreamer import Painter, PainterOptimizer
from pytorch_svgrender.painter.svgdreamer.painter_params import CosineWithWarmupLRLambda
from pytorch_svgrender.painter.live import xing_loss_fn
from pytorch_svgrender.painter.svgdreamer import VectorizedParticleSDSPipeline
from pytorch_svgrender.plt import plot_img
from pytorch_svgrender.utils.color_attrs import init_tensor_with_color
from pytorch_svgrender.token2attn.ptp_utils import view_images
from pytorch_svgrender.diffusers_warp import model2res
import ImageReward as RM
class SVGDreamerPipeline(ModelState):
def __init__(self, args):
assert args.x.style in ["iconography", "pixelart", "low-poly", "painting", "sketch", "ink"]
assert args.x.guidance.n_particle >= args.x.guidance.vsd_n_particle
assert args.x.guidance.n_particle >= args.x.guidance.phi_n_particle
assert args.x.guidance.n_phi_sample >= 1
logdir_ = f"sd{args.seed}" \
f"-{'vpsd' if args.x.skip_sive else 'sive'}" \
f"-{args.x.model_id}" \
f"-{args.x.style}" \
f"-P{args.x.num_paths}" \
f"{'-RePath' if args.x.path_reinit.use else ''}"
super().__init__(args, log_path_suffix=logdir_)
# create log dir
self.png_logs_dir = self.result_path / "png_logs"
self.svg_logs_dir = self.result_path / "svg_logs"
self.ft_png_logs_dir = self.result_path / "ft_png_logs"
self.ft_svg_logs_dir = self.result_path / "ft_svg_logs"
self.sd_sample_dir = self.result_path / 'sd_samples'
self.reinit_dir = self.result_path / "reinit_logs"
self.init_stage_two_dir = self.result_path / "stage_two_init_logs"
self.phi_samples_dir = self.result_path / "phi_sampling_logs"
if self.accelerator.is_main_process:
self.png_logs_dir.mkdir(parents=True, exist_ok=True)
self.svg_logs_dir.mkdir(parents=True, exist_ok=True)
self.ft_png_logs_dir.mkdir(parents=True, exist_ok=True)
self.ft_svg_logs_dir.mkdir(parents=True, exist_ok=True)
self.sd_sample_dir.mkdir(parents=True, exist_ok=True)
self.reinit_dir.mkdir(parents=True, exist_ok=True)
self.init_stage_two_dir.mkdir(parents=True, exist_ok=True)
self.phi_samples_dir.mkdir(parents=True, exist_ok=True)
self.select_fpth = self.result_path / 'select_sample.png'
# make video log
self.make_video = self.args.mv
if self.make_video:
self.frame_idx = 0
self.frame_log_dir = self.result_path / "frame_logs"
self.frame_log_dir.mkdir(parents=True, exist_ok=True)
self.g_device = torch.Generator(device=self.device).manual_seed(args.seed)
self.pipeline = VectorizedParticleSDSPipeline(args, args.diffuser, self.x_cfg.guidance, self.device)
# load reward model
self.reward_model = None
if self.x_cfg.guidance.phi_ReFL:
self.reward_model = RM.load("ImageReward-v1.0", device=self.device, download_root=self.x_cfg.reward_path)
self.style = self.x_cfg.style
if self.style == "pixelart":
self.x_cfg.lr_stage_one.lr_schedule = False
self.x_cfg.lr_stage_two.lr_schedule = False
def target_file_preprocess(self, tar_path: AnyStr):
process_comp = transforms.Compose([
transforms.Resize(size=(self.x_cfg.image_size, self.x_cfg.image_size)),
transforms.ToTensor(),
transforms.Lambda(lambda t: t.unsqueeze(0)),
])
tar_pil = Image.open(tar_path).convert("RGB") # open file
target_img = process_comp(tar_pil) # preprocess
target_img = target_img.to(self.device)
return target_img
def SIVE_stage(self, text_prompt: str):
# TODO: SIVE implementation
pass
def painterly_rendering(self, text_prompt: str, target_file: AnyStr = None):
# log prompts
self.print(f"prompt: {text_prompt}")
self.print(f"neg_prompt: {self.args.neg_prompt}\n")
# for convenience
im_size = self.x_cfg.image_size
guidance_cfg = self.x_cfg.guidance
n_particle = self.x_cfg.guidance.n_particle
total_step = self.x_cfg.guidance.num_iter
path_reinit = self.x_cfg.path_reinit
init_from_target = True if (target_file and pathlib.Path(target_file).exists()) else False
# switch mode
if self.x_cfg.skip_sive and not init_from_target:
# mode 1: optimization with VPSD from scratch
# randomly init
self.print("optimization with VPSD from scratch...")
if self.x_cfg.color_init == 'rand':
target_img = torch.randn(1, 3, im_size, im_size)
self.print("color: randomly init")
else:
target_img = init_tensor_with_color(self.x_cfg.color_init, 1, im_size, im_size)
self.print(f"color: {self.x_cfg.color_init}")
# log init target_img
plot_img(target_img, self.result_path, fname='init_target_img')
final_svg_path = None
elif init_from_target:
# mode 2: load the SVG file and finetune it
self.print(f"load svg from {target_file} ...")
self.print(f"SVG fine-tuning via VPSD...")
final_svg_path = target_file
if self.x_cfg.color_init == 'target_randn':
# special order: init newly paths color use random color
target_img = torch.randn(1, 3, im_size, im_size)
self.print("color: randomly init")
else:
# load the SVG and init newly paths color use target_img
# note: the target will be converted to png via pydiffvg when load_renderer called
target_img = None
else:
# mode 3: text-to-img-to-svg (two stage)
target_img, final_svg_path = self.SIVE_stage(text_prompt)
self.x_cfg.path_svg = final_svg_path
self.print("\n SVG fine-tuning via VPSD...")
plot_img(target_img, self.result_path, fname='init_target_img')
# create svg renderer
renderers = [self.load_renderer(final_svg_path) for _ in range(n_particle)]
# randomly initialize the particles
if self.x_cfg.skip_sive or init_from_target:
if target_img is None:
target_img = self.target_file_preprocess(self.result_path / 'target_img.png')
for render in renderers:
render.component_wise_path_init(gt=target_img, pred=None, init_type='random')
# log init images
for i, r in enumerate(renderers):
init_imgs = r.init_image(stage=0, num_paths=self.x_cfg.num_paths)
plot_img(init_imgs, self.init_stage_two_dir, fname=f"init_img_stage_two_{i}")
# init renderer optimizer
optimizers = []
for renderer in renderers:
optim_ = PainterOptimizer(renderer,
self.style,
guidance_cfg.num_iter,
self.x_cfg.lr_stage_two,
self.x_cfg.trainable_bg)
optim_.init_optimizers()
optimizers.append(optim_)
# init phi_model optimizer
phi_optimizer = get_optimizer('adamW',
self.pipeline.phi_params,
guidance_cfg.phi_lr,
guidance_cfg.phi_optim)
# init phi_model lr scheduler
phi_scheduler = None
schedule_cfg = guidance_cfg.phi_schedule
if schedule_cfg.use:
phi_lr_lambda = CosineWithWarmupLRLambda(num_steps=schedule_cfg.total_step,
warmup_steps=schedule_cfg.warmup_steps,
warmup_start_lr=schedule_cfg.warmup_start_lr,
warmup_end_lr=schedule_cfg.warmup_end_lr,
cosine_end_lr=schedule_cfg.cosine_end_lr)
phi_scheduler = LambdaLR(phi_optimizer, lr_lambda=phi_lr_lambda, last_epoch=-1)
self.print(f"-> Painter point Params: {len(renderers[0].get_point_parameters())}")
self.print(f"-> Painter color Params: {len(renderers[0].get_color_parameters())}")
self.print(f"-> Painter width Params: {len(renderers[0].get_width_parameters())}")
L_reward = torch.tensor(0.)
self.step = 0 # reset global step
self.print(f"\ntotal VPSD optimization steps: {total_step}")
with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar:
while self.step < total_step:
# set particles
particles = [renderer.get_image() for renderer in renderers]
raster_imgs = torch.cat(particles, dim=0)
if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1):
plot_img(raster_imgs, self.frame_log_dir, fname=f"iter{self.frame_idx}")
self.frame_idx += 1
L_guide, grad, latents, t_step = self.pipeline.variational_score_distillation(
raster_imgs,
self.step,
prompt=[text_prompt],
negative_prompt=self.args.neg_prompt,
grad_scale=guidance_cfg.grad_scale,
enhance_particle=guidance_cfg.particle_aug,
im_size=model2res(self.x_cfg.model_id)
)
# Xing Loss for Self-Interaction Problem
L_add = torch.tensor(0.)
if self.style == "iconography" or self.x_cfg.xing_loss.use:
for r in renderers:
L_add += xing_loss_fn(r.get_point_parameters()) * self.x_cfg.xing_loss.weight
loss = L_guide + L_add
# optimization
for opt_ in optimizers:
opt_.zero_grad_()
loss.backward()
for opt_ in optimizers:
opt_.step_()
# phi_model optimization
for _ in range(guidance_cfg.phi_update_step):
L_lora = self.pipeline.train_phi_model(latents, guidance_cfg.phi_t, as_latent=True)
phi_optimizer.zero_grad()
L_lora.backward()
phi_optimizer.step()
# reward learning
if guidance_cfg.phi_ReFL and self.step % guidance_cfg.phi_sample_step == 0:
with torch.no_grad():
phi_outputs = []
phi_sample_paths = []
for idx in range(guidance_cfg.n_phi_sample):
phi_output = self.pipeline.sample(text_prompt,
num_inference_steps=guidance_cfg.phi_infer_step,
generator=self.g_device)
sample_path = (self.phi_samples_dir / f'iter{idx}.png').as_posix()
phi_output.images[0].save(sample_path)
phi_sample_paths.append(sample_path)
phi_output_np = np.array(phi_output.images[0])
phi_outputs.append(phi_output_np)
# save all samples
view_images(phi_outputs, save_image=True,
num_rows=max(len(phi_outputs) // 6, 1),
fp=self.phi_samples_dir / f'samples_iter{self.step}.png')
ranking, rewards = self.reward_model.inference_rank(text_prompt, phi_sample_paths)
self.print(f"ranking: {ranking}, reward score: {rewards}")
for k in range(guidance_cfg.n_phi_sample):
phi = self.target_file_preprocess(phi_sample_paths[ranking[k] - 1])
L_reward = self.pipeline.train_phi_model_refl(phi, weight=rewards[k])
phi_optimizer.zero_grad()
L_reward.backward()
phi_optimizer.step()
# update the learning rate of the phi_model
if phi_scheduler is not None:
phi_scheduler.step()
# curve regularization
for r in renderers:
r.clip_curve_shape()
# re-init paths
if self.step % path_reinit.freq == 0 and self.step < path_reinit.stop_step and self.step != 0:
for i, r in enumerate(renderers):
r.reinitialize_paths(path_reinit.use, # on-off
path_reinit.opacity_threshold,
path_reinit.area_threshold,
fpath=self.reinit_dir / f"reinit-{self.step}_p{i}.svg")
# update lr
if self.x_cfg.lr_stage_two.lr_schedule:
for opt_ in optimizers:
opt_.update_lr()
# log pretrained model lr
lr_str = ""
for k, lr in optimizers[0].get_lr().items():
lr_str += f"{k}_lr: {lr:.4f}, "
# log phi model lr
cur_phi_lr = phi_optimizer.param_groups[0]['lr']
lr_str += f"phi_lr: {cur_phi_lr:.3e}, "
pbar.set_description(
lr_str +
f"t: {t_step.item():.2f}, "
f"L_total: {loss.item():.4f}, "
f"L_add: {L_add.item():.4e}, "
f"L_lora: {L_lora.item():.4f}, "
f"L_reward: {L_reward.item():.4f}, "
f"vpsd: {grad.item():.4e}"
)
if self.step % self.args.save_step == 0 and self.accelerator.is_main_process:
# save png
torchvision.utils.save_image(raster_imgs,
fp=self.ft_png_logs_dir / f'iter{self.step}.png')
# save svg
for i, r in enumerate(renderers):
r.pretty_save_svg(self.ft_svg_logs_dir / f"svg_iter{self.step}_p{i}.svg")
self.step += 1
pbar.update(1)
# save final
for i, r in enumerate(renderers):
final_svg_path = self.result_path / f"finetune_final_p_{i}.svg"
r.pretty_save_svg(final_svg_path)
# save SVGs
torchvision.utils.save_image(raster_imgs, fp=self.result_path / f'all_particles.png')
if self.make_video:
from subprocess import call
call([
"ffmpeg",
"-framerate", f"{self.args.framerate}",
"-i", (self.frame_log_dir / "iter%d.png").as_posix(),
"-vb", "20M",
(self.result_path / "svgdreamer_rendering.mp4").as_posix()
])
self.close(msg="painterly rendering complete.")
def load_renderer(self, path_svg=None):
renderer = Painter(self.args.diffvg,
self.style,
self.x_cfg.num_segments,
self.x_cfg.segment_init,
self.x_cfg.radius,
self.x_cfg.image_size,
self.x_cfg.grid,
self.x_cfg.trainable_bg,
self.x_cfg.width,
path_svg=path_svg,
device=self.device)
# if load a svg file, then rasterize it
save_path = self.result_path / 'target_img.png'
if path_svg is not None and (not save_path.exists()):
canvas_width, canvas_height, shapes, shape_groups = renderer.load_svg(path_svg)
render_img = renderer.render_image(canvas_width, canvas_height, shapes, shape_groups)
torchvision.utils.save_image(render_img, fp=save_path)
return renderer