# -*- coding: utf-8 -*- # Copyright (c) XiMing Xing. All rights reserved. # Author: XiMing Xing # Description: from PIL import Image from typing import AnyStr import pathlib import torch import torch.nn.functional as F from torchvision import transforms from tqdm.auto import tqdm from svgutils.transform import fromfile from pytorch_svgrender.libs.engine import ModelState from pytorch_svgrender.plt import plot_img, plot_couple, plot_img_title from pytorch_svgrender.painter.clipfont import (imagenet_templates, compose_text_with_templates, Painter, PainterOptimizer) from pytorch_svgrender.libs.metric.clip_score import CLIPScoreWrapper from pytorch_svgrender.libs.metric.piq.perceptual import LPIPS class CLIPFontPipeline(ModelState): def __init__(self, args): logdir_ = f"sd{args.seed}" \ f"-lpips{args.x.lam_lpips}-l2{args.x.lam_l2}" \ f"{f'-{args.x.font.reinit_color}' if args.x.font.reinit else ''}" super().__init__(args, log_path_suffix=logdir_) # create log dir self.png_logs_dir = self.result_path / "png_logs" self.svg_logs_dir = self.result_path / "svg_logs" if self.accelerator.is_main_process: self.png_logs_dir.mkdir(parents=True, exist_ok=True) self.svg_logs_dir.mkdir(parents=True, exist_ok=True) # make video log self.make_video = self.args.mv if self.make_video: self.frame_idx = 0 self.frame_log_dir = self.result_path / "frame_logs" self.frame_log_dir.mkdir(parents=True, exist_ok=True) # init clip model self.clip_wrapper = CLIPScoreWrapper(self.x_cfg.clip.model_name, device=self.device) # init LPIPS self.lam_lpips = 0 if self.x_cfg.get('lam_lpips', None) is None else self.x_cfg.lam_lpips self.lpips_fn = LPIPS() # l2 self.lam_l2 = 0 if self.x_cfg.get('lam_l2', None) is None else self.x_cfg.lam_l2 def load_target_file(self, tar_path: AnyStr, image_size: int = 224): process_comp = transforms.Compose([ transforms.Resize(size=(image_size, image_size)), transforms.ToTensor(), transforms.Lambda(lambda t: t.unsqueeze(0)), ]) tar_pil = Image.open(tar_path).convert("RGB") # open file target_img = process_comp(tar_pil) # preprocess return target_img.to(self.device) def cropper(self, x: torch.Tensor) -> torch.Tensor: return transforms.RandomCrop(self.x_cfg.crop_size)(x) def padding_cropper(self, x: torch.Tensor) -> torch.Tensor: return transforms.RandomCrop(size=500, padding=100, fill=255, padding_mode='constant')(x) def affine_to512(self, x: torch.Tensor) -> torch.Tensor: comp = transforms.Compose([ transforms.RandomPerspective(fill=0, p=1, distortion_scale=0.3), transforms.Resize(512) ]) return comp(x) def resize224_norm(self, x: torch.Tensor) -> torch.Tensor: x = torch.nn.functional.interpolate(x, size=224, mode='bicubic') return self.clip_wrapper.norm_(x) def painterly_rendering(self, svg_path, prompt): svg_path = pathlib.Path(svg_path) assert svg_path.exists(), f"'{svg_path}' is not exist." # load renderer renderer = self.load_renderer() # rescale svg fig = fromfile(svg_path.as_posix()) fig.set_size(('512', '512')) filename = str(svg_path.name).split('.')[0] svg_path = self.result_path / f'{filename}_scale.svg' fig.save(svg_path.as_posix()) # init shapes and shape groups init_img = renderer.init_shapes(svg_path.as_posix(), reinit_cfg=self.x_cfg.font) self.print("init_image shape: ", init_img.shape) plot_img(init_img, self.result_path, fname="init_img") # load init file with torch.no_grad(): source_image = self.load_target_file(self.result_path / 'init_img.png', image_size=512) source_image = source_image.detach() source_image_feats = self.clip_wrapper.encode_image(self.resize224_norm(source_image)).detach() # build optimizer optimizer = PainterOptimizer(renderer, self.x_cfg.lr_base) optimizer.init_optimizers() # pre-calc with torch.no_grad(): # encode text prompt and source prompt template_text = compose_text_with_templates(prompt, imagenet_templates) text_features = self.clip_wrapper.encode_text(template_text).detach() source = "A photo" template_source = compose_text_with_templates(source, imagenet_templates) text_source = self.clip_wrapper.encode_text(template_source).detach() total_step = self.x_cfg.num_iter with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: while self.step < total_step: img_t = renderer.get_image().to(self.device) if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1): plot_img(img_t, self.frame_log_dir, fname=f"iter{self.frame_idx}") self.frame_idx += 1 # style loss # directional loss 1 img_proc = [] for n in range(self.x_cfg.num_crops): target_crop = self.cropper(img_t) target_crop = self.affine_to512(target_crop) img_proc.append(target_crop) img_aug = torch.cat(img_proc, dim=0) image_features = self.clip_wrapper.encode_image(self.resize224_norm(img_aug)) loss_patch = self.x_cfg.lam_patch * self.clip_wrapper.directional_loss(text_source, source_image_feats, text_features, image_features, self.x_cfg.thresh) # directional loss 2 img_proc2 = [] for n in range(32): target_crop = self.padding_cropper(img_t) target_crop = self.affine_to512(target_crop) img_proc2.append(target_crop) img_aug2 = torch.cat(img_proc2, dim=0) glob_features = self.clip_wrapper.encode_image(self.resize224_norm(img_aug2)) loss_glob = self.x_cfg.lam_dir * self.clip_wrapper.directional_loss(text_source, source_image_feats, text_features, glob_features) # LPIPS loss_lpips = self.lam_lpips * self.lpips_fn(img_t, source_image) # L2 loss_l2 = self.lam_l2 * F.mse_loss(img_t, source_image) # total loss loss = loss_patch + loss_glob + loss_lpips + loss_l2 # log p_lr, c_lr = optimizer.get_lr() pbar.set_description( f"point_lr: {p_lr}, color_lr: {c_lr}, " f"L_total: {loss.item():.4f}, " f"L_patch: {loss_patch.item():.4f}, " f"L_glob: {loss_glob.item():.4f}, " f"L_lpips: {loss_lpips.item():.4f}, " f"L_l2: {loss_l2.item():.4f}." ) # backward and optimization optimizer.zero_grad_() loss.backward() optimizer.step_() renderer.clip_curve_shape() if self.x_cfg.lr_schedule: optimizer.update_lr(self.step) if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: plot_couple(init_img, img_t, self.step, output_dir=self.png_logs_dir.as_posix(), fname=f"iter{self.step}") renderer.pretty_save_svg(self.svg_logs_dir / f"svg_iter{self.step}.svg") self.step += 1 pbar.update(1) # log final results renderer.pretty_save_svg(self.result_path / "final_svg.svg") final_raster_sketch = renderer.get_image().to(self.device) plot_img_title(final_raster_sketch, title=f'final result - {self.step} step', output_dir=self.result_path, fname='final_render') if self.make_video: from subprocess import call call([ "ffmpeg", "-framerate", f"{self.args.framerate}", "-i", (self.frame_log_dir / "iter%d.png").as_posix(), "-vb", "20M", (self.result_path / "clipfont_rendering.mp4").as_posix() ]) self.close(msg="painterly rendering complete.") def load_renderer(self): renderer = Painter(device=self.device) return renderer