hjc-owo
init repo
966ae59
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
from PIL import Image
from typing import AnyStr
import pathlib
import torch
import torch.nn.functional as F
from torchvision import transforms
from tqdm.auto import tqdm
from svgutils.transform import fromfile
from pytorch_svgrender.libs.engine import ModelState
from pytorch_svgrender.plt import plot_img, plot_couple, plot_img_title
from pytorch_svgrender.painter.clipfont import (imagenet_templates, compose_text_with_templates, Painter,
PainterOptimizer)
from pytorch_svgrender.libs.metric.clip_score import CLIPScoreWrapper
from pytorch_svgrender.libs.metric.piq.perceptual import LPIPS
class CLIPFontPipeline(ModelState):
def __init__(self, args):
logdir_ = f"sd{args.seed}" \
f"-lpips{args.x.lam_lpips}-l2{args.x.lam_l2}" \
f"{f'-{args.x.font.reinit_color}' if args.x.font.reinit else ''}"
super().__init__(args, log_path_suffix=logdir_)
# create log dir
self.png_logs_dir = self.result_path / "png_logs"
self.svg_logs_dir = self.result_path / "svg_logs"
if self.accelerator.is_main_process:
self.png_logs_dir.mkdir(parents=True, exist_ok=True)
self.svg_logs_dir.mkdir(parents=True, exist_ok=True)
# make video log
self.make_video = self.args.mv
if self.make_video:
self.frame_idx = 0
self.frame_log_dir = self.result_path / "frame_logs"
self.frame_log_dir.mkdir(parents=True, exist_ok=True)
# init clip model
self.clip_wrapper = CLIPScoreWrapper(self.x_cfg.clip.model_name, device=self.device)
# init LPIPS
self.lam_lpips = 0 if self.x_cfg.get('lam_lpips', None) is None else self.x_cfg.lam_lpips
self.lpips_fn = LPIPS()
# l2
self.lam_l2 = 0 if self.x_cfg.get('lam_l2', None) is None else self.x_cfg.lam_l2
def load_target_file(self, tar_path: AnyStr, image_size: int = 224):
process_comp = transforms.Compose([
transforms.Resize(size=(image_size, image_size)),
transforms.ToTensor(),
transforms.Lambda(lambda t: t.unsqueeze(0)),
])
tar_pil = Image.open(tar_path).convert("RGB") # open file
target_img = process_comp(tar_pil) # preprocess
return target_img.to(self.device)
def cropper(self, x: torch.Tensor) -> torch.Tensor:
return transforms.RandomCrop(self.x_cfg.crop_size)(x)
def padding_cropper(self, x: torch.Tensor) -> torch.Tensor:
return transforms.RandomCrop(size=500, padding=100, fill=255, padding_mode='constant')(x)
def affine_to512(self, x: torch.Tensor) -> torch.Tensor:
comp = transforms.Compose([
transforms.RandomPerspective(fill=0, p=1, distortion_scale=0.3),
transforms.Resize(512)
])
return comp(x)
def resize224_norm(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.interpolate(x, size=224, mode='bicubic')
return self.clip_wrapper.norm_(x)
def painterly_rendering(self, svg_path, prompt):
svg_path = pathlib.Path(svg_path)
assert svg_path.exists(), f"'{svg_path}' is not exist."
# load renderer
renderer = self.load_renderer()
# rescale svg
fig = fromfile(svg_path.as_posix())
fig.set_size(('512', '512'))
filename = str(svg_path.name).split('.')[0]
svg_path = self.result_path / f'{filename}_scale.svg'
fig.save(svg_path.as_posix())
# init shapes and shape groups
init_img = renderer.init_shapes(svg_path.as_posix(), reinit_cfg=self.x_cfg.font)
self.print("init_image shape: ", init_img.shape)
plot_img(init_img, self.result_path, fname="init_img")
# load init file
with torch.no_grad():
source_image = self.load_target_file(self.result_path / 'init_img.png', image_size=512)
source_image = source_image.detach()
source_image_feats = self.clip_wrapper.encode_image(self.resize224_norm(source_image)).detach()
# build optimizer
optimizer = PainterOptimizer(renderer, self.x_cfg.lr_base)
optimizer.init_optimizers()
# pre-calc
with torch.no_grad():
# encode text prompt and source prompt
template_text = compose_text_with_templates(prompt, imagenet_templates)
text_features = self.clip_wrapper.encode_text(template_text).detach()
source = "A photo"
template_source = compose_text_with_templates(source, imagenet_templates)
text_source = self.clip_wrapper.encode_text(template_source).detach()
total_step = self.x_cfg.num_iter
with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar:
while self.step < total_step:
img_t = renderer.get_image().to(self.device)
if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1):
plot_img(img_t, self.frame_log_dir, fname=f"iter{self.frame_idx}")
self.frame_idx += 1
# style loss
# directional loss 1
img_proc = []
for n in range(self.x_cfg.num_crops):
target_crop = self.cropper(img_t)
target_crop = self.affine_to512(target_crop)
img_proc.append(target_crop)
img_aug = torch.cat(img_proc, dim=0)
image_features = self.clip_wrapper.encode_image(self.resize224_norm(img_aug))
loss_patch = self.x_cfg.lam_patch * self.clip_wrapper.directional_loss(text_source,
source_image_feats,
text_features,
image_features,
self.x_cfg.thresh)
# directional loss 2
img_proc2 = []
for n in range(32):
target_crop = self.padding_cropper(img_t)
target_crop = self.affine_to512(target_crop)
img_proc2.append(target_crop)
img_aug2 = torch.cat(img_proc2, dim=0)
glob_features = self.clip_wrapper.encode_image(self.resize224_norm(img_aug2))
loss_glob = self.x_cfg.lam_dir * self.clip_wrapper.directional_loss(text_source,
source_image_feats,
text_features, glob_features)
# LPIPS
loss_lpips = self.lam_lpips * self.lpips_fn(img_t, source_image)
# L2
loss_l2 = self.lam_l2 * F.mse_loss(img_t, source_image)
# total loss
loss = loss_patch + loss_glob + loss_lpips + loss_l2
# log
p_lr, c_lr = optimizer.get_lr()
pbar.set_description(
f"point_lr: {p_lr}, color_lr: {c_lr}, "
f"L_total: {loss.item():.4f}, "
f"L_patch: {loss_patch.item():.4f}, "
f"L_glob: {loss_glob.item():.4f}, "
f"L_lpips: {loss_lpips.item():.4f}, "
f"L_l2: {loss_l2.item():.4f}."
)
# backward and optimization
optimizer.zero_grad_()
loss.backward()
optimizer.step_()
renderer.clip_curve_shape()
if self.x_cfg.lr_schedule:
optimizer.update_lr(self.step)
if self.step % self.args.save_step == 0 and self.accelerator.is_main_process:
plot_couple(init_img,
img_t,
self.step,
output_dir=self.png_logs_dir.as_posix(),
fname=f"iter{self.step}")
renderer.pretty_save_svg(self.svg_logs_dir / f"svg_iter{self.step}.svg")
self.step += 1
pbar.update(1)
# log final results
renderer.pretty_save_svg(self.result_path / "final_svg.svg")
final_raster_sketch = renderer.get_image().to(self.device)
plot_img_title(final_raster_sketch,
title=f'final result - {self.step} step',
output_dir=self.result_path,
fname='final_render')
if self.make_video:
from subprocess import call
call([
"ffmpeg",
"-framerate", f"{self.args.framerate}",
"-i", (self.frame_log_dir / "iter%d.png").as_posix(),
"-vb", "20M",
(self.result_path / "clipfont_rendering.mp4").as_posix()
])
self.close(msg="painterly rendering complete.")
def load_renderer(self):
renderer = Painter(device=self.device)
return renderer