hjc-owo
init repo
966ae59
# -*- coding: utf-8 -*-
# Author: ximing
# Description: LIVE pipeline
# Copyright (c) 2023, XiMing Xing.
# License: MIT License
import shutil
from pathlib import Path
from functools import partial
from typing import AnyStr
from PIL import Image
from tqdm.auto import tqdm
import torch
from torchvision import transforms
from pytorch_svgrender.libs.engine import ModelState
from pytorch_svgrender.painter.diffvg import Painter, PainterOptimizer
from pytorch_svgrender.plt import plot_img, plot_couple
from pytorch_svgrender.libs.metric.lpips_origin import LPIPS
class DiffVGPipeline(ModelState):
def __init__(self, args):
logdir_ = f"sd{args.seed}" \
f"-{args.x.path_type}" \
f"-P{args.x.num_paths}"
super().__init__(args, log_path_suffix=logdir_)
assert self.x_cfg.path_type in ['unclosed', 'closed']
# create log dir
self.png_logs_dir = self.result_path / "png_logs"
self.svg_logs_dir = self.result_path / "svg_logs"
if self.accelerator.is_main_process:
self.png_logs_dir.mkdir(parents=True, exist_ok=True)
self.svg_logs_dir.mkdir(parents=True, exist_ok=True)
# make video log
self.make_video = self.args.mv
if self.make_video:
self.frame_idx = 0
self.frame_log_dir = self.result_path / "frame_logs"
self.frame_log_dir.mkdir(parents=True, exist_ok=True)
def target_file_preprocess(self, tar_path):
process_comp = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda t: t.unsqueeze(0)),
])
tar_pil = Image.open(tar_path).convert("RGB") # open file
target_img = process_comp(tar_pil) # preprocess
target_img = target_img.to(self.device)
return target_img
def painterly_rendering(self, img_path: AnyStr):
# load target file
target_file = Path(img_path)
assert target_file.exists(), f"{target_file} is not exist!"
shutil.copy(target_file, self.result_path) # copy target file
target_img = self.target_file_preprocess(target_file.as_posix())
self.print(f"load image from: '{target_file.as_posix()}'")
# init Painter
renderer = Painter(target_img,
self.args.diffvg,
canvas_size=[target_img.shape[3], target_img.shape[2]],
path_type=self.x_cfg.path_type,
max_width=self.x_cfg.max_width,
device=self.device)
init_img = renderer.init_image(num_paths=self.x_cfg.num_paths)
self.print("init_image shape: ", init_img.shape)
plot_img(init_img, self.result_path, fname="init_img")
# init Painter Optimizer
num_iter = self.x_cfg.num_iter
optimizer = PainterOptimizer(renderer,
num_iter,
self.x_cfg.lr_base,
trainable_stroke=self.x_cfg.path_type == 'unclosed')
optimizer.init_optimizer()
# Set Loss
if self.x_cfg.loss_type in ['lpips', 'l2+lpips']:
lpips_loss_fn = LPIPS(net=self.x_cfg.perceptual.lpips_net).to(self.device)
perceptual_loss_fn = partial(lpips_loss_fn.forward, return_per_layer=False, normalize=False)
with tqdm(initial=self.step, total=num_iter, disable=not self.accelerator.is_main_process) as pbar:
while self.step < num_iter:
raster_img = renderer.get_image(self.step).to(self.device)
if self.make_video and (self.step % self.args.framefreq == 0 or self.step == num_iter - 1):
plot_img(raster_img, self.frame_log_dir, fname=f"iter{self.frame_idx}")
self.frame_idx += 1
# Reconstruction Loss
if self.x_cfg.loss_type == 'l1':
loss_recon = torch.nn.functional.l1_loss(raster_img, target_img)
elif self.x_cfg.loss_type == 'lpips':
loss_recon = perceptual_loss_fn(raster_img, target_img).mean()
elif self.x_cfg.loss_type == 'l2': # default: MSE loss
loss_recon = torch.nn.functional.mse_loss(raster_img, target_img)
elif self.x_cfg.loss_type == 'l2+lpips': # default: MSE loss
lpips = perceptual_loss_fn(raster_img, target_img).mean()
loss_mse = torch.nn.functional.mse_loss(raster_img, target_img)
loss_recon = loss_mse + lpips
# total loss
loss = loss_recon
pbar.set_description(
f"lr: {optimizer.get_lr():.4f}, "
f"L_recon: {loss_recon.item():.4f}"
)
# optimization
optimizer.zero_grad_()
loss.backward()
optimizer.step_()
renderer.clip_curve_shape()
if self.x_cfg.lr_schedule:
optimizer.update_lr()
if self.step % self.args.save_step == 0 and self.accelerator.is_main_process:
plot_couple(target_img,
raster_img,
self.step,
output_dir=self.png_logs_dir.as_posix(),
fname=f"iter{self.step}")
renderer.save_svg(self.svg_logs_dir / f"svg_iter{self.step}.svg")
self.step += 1
pbar.update(1)
# end rendering
renderer.save_svg(self.result_path / "final_svg.svg")
if self.make_video:
from subprocess import call
call([
"ffmpeg",
"-framerate", f"{self.args.framerate}",
"-i", (self.frame_log_dir / "iter%d.png").as_posix(),
"-vb", "20M",
(self.result_path / "live_rendering.mp4").as_posix()
])
self.close(msg="painterly rendering complete.")