Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright (c) XiMing Xing. All rights reserved. | |
# Author: XiMing Xing | |
# Description: | |
from pathlib import Path | |
from tqdm.auto import tqdm | |
import torch | |
from pytorch_svgrender.libs.engine import ModelState | |
from pytorch_svgrender.painter.wordasimage import Painter, PainterOptimizer | |
from pytorch_svgrender.painter.wordasimage.losses import ToneLoss, ConformalLoss | |
from pytorch_svgrender.painter.vectorfusion import LSDSPipeline | |
from pytorch_svgrender.plt import plot_img, plot_couple | |
from pytorch_svgrender.diffusers_warp import init_StableDiffusion_pipeline | |
from pytorch_svgrender.svgtools import FONT_LIST | |
class WordAsImagePipeline(ModelState): | |
def __init__(self, args): | |
# assert | |
assert args.x.optim_letter in args.x.word | |
assert Path(args.x.font_path).exists(), f"{args.x.font_path} is not exist." | |
assert args.x.font in FONT_LIST, f"{args.x.font} is not currently supported." | |
# make logdir | |
logdir_ = f"sd{args.seed}" \ | |
f"-im{args.x.image_size}" \ | |
f"-{args.x.word}-{args.x.optim_letter}" | |
super().__init__(args, log_path_suffix=logdir_) | |
# log dir | |
self.png_log_dir = self.result_path / "png_logs" | |
self.svg_log_dir = self.result_path / "svg_logs" | |
# font | |
self.font = self.x_cfg.font | |
self.font_path = self.x_cfg.font_path | |
self.optim_letter = self.x_cfg.optim_letter | |
# letter | |
self.letter = self.x_cfg.optim_letter | |
self.target_letter = self.result_path / f"{self.font}_{self.optim_letter}_scaled.svg" | |
# make log dir | |
if self.accelerator.is_main_process: | |
self.png_log_dir.mkdir(parents=True, exist_ok=True) | |
self.svg_log_dir.mkdir(parents=True, exist_ok=True) | |
# make video log | |
self.make_video = self.args.mv | |
if self.make_video: | |
self.frame_idx = 0 | |
self.frame_log_dir = self.result_path / "frame_logs" | |
self.frame_log_dir.mkdir(parents=True, exist_ok=True) | |
self.diffusion = init_StableDiffusion_pipeline( | |
self.x_cfg.model_id, | |
custom_pipeline=LSDSPipeline, | |
device=self.device, | |
local_files_only=not args.diffuser.download, | |
force_download=args.diffuser.force_download, | |
resume_download=args.diffuser.resume_download, | |
ldm_speed_up=self.x_cfg.ldm_speed_up, | |
enable_xformers=self.x_cfg.enable_xformers, | |
gradient_checkpoint=self.x_cfg.gradient_checkpoint, | |
lora_path=self.x_cfg.lora_path | |
) | |
self.g_device = torch.Generator(device=self.device).manual_seed(args.seed) | |
def painterly_rendering(self, word, semantic_concept, optimized_letter): | |
prompt = semantic_concept + ". " + self.x_cfg.prompt_suffix | |
self.print(f"prompt: {prompt}") | |
# load the optimized letter | |
renderer = Painter(self.font, canvas_size=self.x_cfg.image_size, device=self.device) | |
# font to svg | |
self.print(f"font type: {self.font}\n") | |
renderer.preprocess_font(word, | |
optimized_letter, | |
self.x_cfg.level_of_cc, | |
self.font_path, | |
self.result_path.as_posix()) | |
# init letter shape | |
img_init = renderer.init_shape(self.target_letter) | |
plot_img(img_init, self.result_path, fname="word_init") | |
# save init letter | |
renderer.pretty_save_svg(self.result_path / "letter_init.svg") | |
init_letter = renderer.get_image() | |
n_iter = self.x_cfg.num_iter | |
# init optimizer and lr_schedular | |
optimizer = PainterOptimizer(renderer, n_iter, self.x_cfg.lr) | |
optimizer.init_optimizers() | |
# init Tone loss | |
if self.x_cfg.tone_loss.use: | |
tone_loss = ToneLoss(self.x_cfg.tone_loss) | |
tone_loss.set_image_init(img_init) | |
# init conformal loss | |
if self.x_cfg.conformal.use: | |
conformal_loss = ConformalLoss(renderer.get_point_parameters(), | |
renderer.shape_groups, | |
optimized_letter, self.device) | |
with tqdm(initial=self.step, total=n_iter, disable=not self.accelerator.is_main_process) as pbar: | |
for i in range(n_iter): | |
raster_img = renderer.get_image(step=i) | |
if self.make_video and (i % self.args.framefreq == 0 or i == n_iter - 1): | |
plot_img(raster_img, self.frame_log_dir, fname=f"iter{self.step}") | |
L_sds, grad = self.diffusion.score_distillation_sampling( | |
raster_img, | |
im_size=self.x_cfg.sds.im_size, | |
prompt=[prompt], | |
negative_prompt=self.args.neg_prompt, | |
guidance_scale=self.x_cfg.sds.guidance_scale, | |
grad_scale=self.x_cfg.sds.grad_scale, | |
t_range=list(self.x_cfg.sds.t_range), | |
) | |
loss = L_sds | |
if self.x_cfg.tone_loss.use: | |
tone_loss_res = tone_loss(raster_img, step=i) | |
loss = loss + tone_loss_res | |
if self.x_cfg.conformal.use: | |
loss_angles = conformal_loss() | |
loss_angles = self.x_cfg.conformal.angeles_w * loss_angles | |
loss = loss + loss_angles | |
pbar.set_description( | |
f"n_params: {len(renderer.get_point_parameters())}, " | |
f"lr: {optimizer.get_lr():.4f}, " | |
f"L_total: {loss.item():.4f}, " | |
) | |
# optimization | |
optimizer.zero_grad_() | |
loss.backward() | |
optimizer.step_() | |
if self.x_cfg.lr_schedule: | |
optimizer.update_lr() | |
if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: | |
plot_couple(init_letter, | |
raster_img, | |
self.step, | |
output_dir=self.png_log_dir.as_posix(), | |
fname=f"iter{self.step}", | |
prompt=prompt) | |
renderer.pretty_save_svg(self.svg_log_dir / f"svg_iter{self.step}.svg") | |
self.step += 1 | |
pbar.update(1) | |
# save final optimized letter | |
renderer.pretty_save_svg(self.result_path / "final_letter.svg") | |
# combine word | |
renderer.combine_word(word, optimized_letter, self.font, self.result_path) | |
if self.make_video: | |
from subprocess import call | |
call([ | |
"ffmpeg", | |
"-framerate", f"{self.args.framerate}", | |
"-i", (self.frame_log_dir / "iter%d.png").as_posix(), | |
"-vb", "20M", | |
(self.result_path / "wordasimg_rendering.mp4").as_posix() | |
]) | |
self.close(msg="painterly rendering complete.") | |