DiffSketcher / pytorch_svgrender /pipelines /WordAsImage_pipeline.py
hjc-owo
init repo
966ae59
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
from pathlib import Path
from tqdm.auto import tqdm
import torch
from pytorch_svgrender.libs.engine import ModelState
from pytorch_svgrender.painter.wordasimage import Painter, PainterOptimizer
from pytorch_svgrender.painter.wordasimage.losses import ToneLoss, ConformalLoss
from pytorch_svgrender.painter.vectorfusion import LSDSPipeline
from pytorch_svgrender.plt import plot_img, plot_couple
from pytorch_svgrender.diffusers_warp import init_StableDiffusion_pipeline
from pytorch_svgrender.svgtools import FONT_LIST
class WordAsImagePipeline(ModelState):
def __init__(self, args):
# assert
assert args.x.optim_letter in args.x.word
assert Path(args.x.font_path).exists(), f"{args.x.font_path} is not exist."
assert args.x.font in FONT_LIST, f"{args.x.font} is not currently supported."
# make logdir
logdir_ = f"sd{args.seed}" \
f"-im{args.x.image_size}" \
f"-{args.x.word}-{args.x.optim_letter}"
super().__init__(args, log_path_suffix=logdir_)
# log dir
self.png_log_dir = self.result_path / "png_logs"
self.svg_log_dir = self.result_path / "svg_logs"
# font
self.font = self.x_cfg.font
self.font_path = self.x_cfg.font_path
self.optim_letter = self.x_cfg.optim_letter
# letter
self.letter = self.x_cfg.optim_letter
self.target_letter = self.result_path / f"{self.font}_{self.optim_letter}_scaled.svg"
# make log dir
if self.accelerator.is_main_process:
self.png_log_dir.mkdir(parents=True, exist_ok=True)
self.svg_log_dir.mkdir(parents=True, exist_ok=True)
# make video log
self.make_video = self.args.mv
if self.make_video:
self.frame_idx = 0
self.frame_log_dir = self.result_path / "frame_logs"
self.frame_log_dir.mkdir(parents=True, exist_ok=True)
self.diffusion = init_StableDiffusion_pipeline(
self.x_cfg.model_id,
custom_pipeline=LSDSPipeline,
device=self.device,
local_files_only=not args.diffuser.download,
force_download=args.diffuser.force_download,
resume_download=args.diffuser.resume_download,
ldm_speed_up=self.x_cfg.ldm_speed_up,
enable_xformers=self.x_cfg.enable_xformers,
gradient_checkpoint=self.x_cfg.gradient_checkpoint,
lora_path=self.x_cfg.lora_path
)
self.g_device = torch.Generator(device=self.device).manual_seed(args.seed)
def painterly_rendering(self, word, semantic_concept, optimized_letter):
prompt = semantic_concept + ". " + self.x_cfg.prompt_suffix
self.print(f"prompt: {prompt}")
# load the optimized letter
renderer = Painter(self.font, canvas_size=self.x_cfg.image_size, device=self.device)
# font to svg
self.print(f"font type: {self.font}\n")
renderer.preprocess_font(word,
optimized_letter,
self.x_cfg.level_of_cc,
self.font_path,
self.result_path.as_posix())
# init letter shape
img_init = renderer.init_shape(self.target_letter)
plot_img(img_init, self.result_path, fname="word_init")
# save init letter
renderer.pretty_save_svg(self.result_path / "letter_init.svg")
init_letter = renderer.get_image()
n_iter = self.x_cfg.num_iter
# init optimizer and lr_schedular
optimizer = PainterOptimizer(renderer, n_iter, self.x_cfg.lr)
optimizer.init_optimizers()
# init Tone loss
if self.x_cfg.tone_loss.use:
tone_loss = ToneLoss(self.x_cfg.tone_loss)
tone_loss.set_image_init(img_init)
# init conformal loss
if self.x_cfg.conformal.use:
conformal_loss = ConformalLoss(renderer.get_point_parameters(),
renderer.shape_groups,
optimized_letter, self.device)
with tqdm(initial=self.step, total=n_iter, disable=not self.accelerator.is_main_process) as pbar:
for i in range(n_iter):
raster_img = renderer.get_image(step=i)
if self.make_video and (i % self.args.framefreq == 0 or i == n_iter - 1):
plot_img(raster_img, self.frame_log_dir, fname=f"iter{self.step}")
L_sds, grad = self.diffusion.score_distillation_sampling(
raster_img,
im_size=self.x_cfg.sds.im_size,
prompt=[prompt],
negative_prompt=self.args.neg_prompt,
guidance_scale=self.x_cfg.sds.guidance_scale,
grad_scale=self.x_cfg.sds.grad_scale,
t_range=list(self.x_cfg.sds.t_range),
)
loss = L_sds
if self.x_cfg.tone_loss.use:
tone_loss_res = tone_loss(raster_img, step=i)
loss = loss + tone_loss_res
if self.x_cfg.conformal.use:
loss_angles = conformal_loss()
loss_angles = self.x_cfg.conformal.angeles_w * loss_angles
loss = loss + loss_angles
pbar.set_description(
f"n_params: {len(renderer.get_point_parameters())}, "
f"lr: {optimizer.get_lr():.4f}, "
f"L_total: {loss.item():.4f}, "
)
# optimization
optimizer.zero_grad_()
loss.backward()
optimizer.step_()
if self.x_cfg.lr_schedule:
optimizer.update_lr()
if self.step % self.args.save_step == 0 and self.accelerator.is_main_process:
plot_couple(init_letter,
raster_img,
self.step,
output_dir=self.png_log_dir.as_posix(),
fname=f"iter{self.step}",
prompt=prompt)
renderer.pretty_save_svg(self.svg_log_dir / f"svg_iter{self.step}.svg")
self.step += 1
pbar.update(1)
# save final optimized letter
renderer.pretty_save_svg(self.result_path / "final_letter.svg")
# combine word
renderer.combine_word(word, optimized_letter, self.font, self.result_path)
if self.make_video:
from subprocess import call
call([
"ffmpeg",
"-framerate", f"{self.args.framerate}",
"-i", (self.frame_log_dir / "iter%d.png").as_posix(),
"-vb", "20M",
(self.result_path / "wordasimg_rendering.mp4").as_posix()
])
self.close(msg="painterly rendering complete.")