# -*- coding: utf-8 -*- # Copyright (c) XiMing Xing. All rights reserved. # Author: XiMing Xing # Description: from PIL import Image import torch from tqdm.auto import tqdm from torchvision import transforms from torchvision.transforms import InterpolationMode from torchvision.datasets.folder import is_image_file from pytorch_svgrender.libs.engine import ModelState from pytorch_svgrender.painter.clipasso import Painter, PainterOptimizer, Loss from pytorch_svgrender.painter.clipasso.sketch_utils import plot_attn, get_mask_u2net, fix_image_scale from pytorch_svgrender.plt import plot_img, plot_couple, plot_img_title class CLIPassoPipeline(ModelState): def __init__(self, args): logdir_ = f"sd{args.seed}" \ f"-im{args.x.image_size}" \ f"{'-mask' if args.x.mask_object else ''}" \ f"{'-XDoG' if args.x.xdog_intersec else ''}" \ f"-P{args.x.num_paths}W{args.x.width}{'OP' if args.x.force_sparse else 'BL'}" \ f"-tau{args.x.softmax_temp}" super().__init__(args, log_path_suffix=logdir_) # create log dir self.png_logs_dir = self.result_path / "png_logs" self.svg_logs_dir = self.result_path / "svg_logs" if self.accelerator.is_main_process: self.png_logs_dir.mkdir(parents=True, exist_ok=True) self.svg_logs_dir.mkdir(parents=True, exist_ok=True) # make video log self.make_video = self.args.mv if self.make_video: self.frame_idx = 0 self.frame_log_dir = self.result_path / "frame_logs" self.frame_log_dir.mkdir(parents=True, exist_ok=True) def painterly_rendering(self, image_path): loss_func = Loss(self.x_cfg, self.device) # preprocess input image inputs, mask = self.get_target(image_path, self.x_cfg.image_size, self.result_path, self.x_cfg.u2net_path, self.x_cfg.mask_object, self.x_cfg.fix_scale, self.device) plot_img(inputs, self.result_path, fname="input") # init renderer renderer = self.load_renderer(inputs, mask) img = renderer.init_image(stage=0) self.print("init_image shape: ", img.shape) plot_img(img, self.result_path, fname="init_img") # init optimizer optimizer = PainterOptimizer(renderer, self.x_cfg.num_iter, self.x_cfg.lr, self.x_cfg.force_sparse, self.x_cfg.color_lr) optimizer.init_optimizers() best_loss, best_fc_loss = 100, 100 min_delta = 1e-5 total_step = self.x_cfg.num_iter with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: while self.step < total_step: sketches = renderer.get_image().to(self.device) if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1): plot_img(sketches, self.frame_log_dir, fname=f"iter{self.frame_idx}") self.frame_idx += 1 losses_dict = loss_func(sketches, inputs.detach(), renderer.get_color_parameters(), renderer, self.step, optimizer) loss = sum(list(losses_dict.values())) optimizer.zero_grad_() loss.backward() optimizer.step_() if self.x_cfg.lr_schedule: optimizer.update_lr() pbar.set_description(f"L_train: {loss.item():.5f}") if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: plot_couple(inputs, sketches, self.step, output_dir=self.png_logs_dir.as_posix(), fname=f"iter{self.step}") renderer.save_svg(self.svg_logs_dir.as_posix(), f"svg_iter{self.step}") if self.step % self.args.eval_step == 0 and self.accelerator.is_main_process: with torch.no_grad(): losses_dict_eval = loss_func( sketches, inputs, renderer.get_color_parameters(), renderer.get_point_parameters(), self.step, optimizer, mode="eval" ) loss_eval = sum(list(losses_dict_eval.values())) cur_delta = loss_eval.item() - best_loss if abs(cur_delta) > min_delta and cur_delta < 0: best_loss = loss_eval.item() best_iter = self.step plot_couple(inputs, sketches, best_iter, output_dir=self.result_path.as_posix(), fname="best_iter") renderer.save_svg(self.result_path.as_posix(), "best_iter") if self.step == 0 and self.x_cfg.attention_init and self.accelerator.is_main_process: plot_attn(renderer.get_attn(), renderer.get_thresh(), inputs, renderer.get_inds(), (self.result_path / "attention_map.png").as_posix(), self.x_cfg.saliency_model) self.step += 1 pbar.update(1) # log final results renderer.save_svg(self.result_path.as_posix(), "final_svg") final_raster_sketch = renderer.get_image().to(self.device) plot_img_title(final_raster_sketch, title=f'final result - {self.step} step', output_dir=self.result_path, fname='final_render') if self.make_video: from subprocess import call call([ "ffmpeg", "-framerate", f"{self.args.framerate}", "-i", (self.frame_log_dir / "iter%d.png").as_posix(), "-vb", "20M", (self.result_path / "clipasso_rendering.mp4").as_posix() ]) self.close(msg="painterly rendering complete.") def load_renderer(self, target_im=None, mask=None): renderer = Painter(method_cfg=self.x_cfg, diffvg_cfg=self.args.diffvg, num_strokes=self.x_cfg.num_paths, canvas_size=self.x_cfg.image_size, device=self.device, target_im=target_im, mask=mask) return renderer def get_target(self, target_file, image_size, output_dir, u2net_path, mask_object, fix_scale, device): if not is_image_file(target_file): raise TypeError(f"{target_file} is not image file.") target = Image.open(target_file) if target.mode == "RGBA": # Create a white rgba background new_image = Image.new("RGBA", target.size, "WHITE") # Paste the image on the background. new_image.paste(target, (0, 0), target) target = new_image target = target.convert("RGB") # U^2 net mask masked_im, mask = get_mask_u2net(target, output_dir, u2net_path, device) if mask_object: target = masked_im if fix_scale: target = fix_image_scale(target) transforms_ = [] if target.size[0] != target.size[1]: transforms_.append( transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC) ) else: transforms_.append(transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC)) transforms_.append(transforms.CenterCrop(image_size)) transforms_.append(transforms.ToTensor()) data_transforms = transforms.Compose(transforms_) target_ = data_transforms(target).unsqueeze(0).to(self.device) return target_, mask