Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright (c) XiMing Xing. All rights reserved. | |
# Author: XiMing Xing | |
# Description: | |
import torch | |
from tqdm.auto import tqdm | |
from torchvision import transforms | |
import clip | |
from pytorch_svgrender.libs.engine import ModelState | |
from pytorch_svgrender.painter.clipdraw import Painter, PainterOptimizer | |
from pytorch_svgrender.plt import plot_img, plot_couple | |
class CLIPDrawPipeline(ModelState): | |
def __init__(self, args): | |
logdir_ = f"sd{args.seed}" \ | |
f"-im{args.x.image_size}" \ | |
f"-P{args.x.num_paths}" | |
super().__init__(args, log_path_suffix=logdir_) | |
# create log dir | |
self.png_logs_dir = self.result_path / "png_logs" | |
self.svg_logs_dir = self.result_path / "svg_logs" | |
if self.accelerator.is_main_process: | |
self.png_logs_dir.mkdir(parents=True, exist_ok=True) | |
self.svg_logs_dir.mkdir(parents=True, exist_ok=True) | |
# make video log | |
self.make_video = self.args.mv | |
if self.make_video: | |
self.frame_idx = 0 | |
self.frame_log_dir = self.result_path / "frame_logs" | |
self.frame_log_dir.mkdir(parents=True, exist_ok=True) | |
self.clip, self.tokenize_fn = self.init_clip() | |
def init_clip(self): | |
model, _ = clip.load('ViT-B/32', self.device, jit=False) | |
return model, clip.tokenize | |
def drawing_augment(self, image): | |
augment_trans = transforms.Compose([ | |
transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5), | |
transforms.RandomResizedCrop(224, scale=(0.7, 0.9)), | |
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
]) | |
# image augmentation transformation | |
img_augs = [] | |
for n in range(self.x_cfg.num_aug): | |
img_augs.append(augment_trans(image)) | |
im_batch = torch.cat(img_augs) | |
# clip visual encoding | |
image_features = self.clip.encode_image(im_batch) | |
return image_features | |
def painterly_rendering(self, prompt): | |
self.print(f"prompt: {prompt}") | |
# text prompt encoding | |
text_tokenize = self.tokenize_fn(prompt).to(self.device) | |
with torch.no_grad(): | |
text_features = self.clip.encode_text(text_tokenize) | |
# init SVG Painter | |
renderer = Painter(self.x_cfg, | |
self.args.diffvg, | |
num_strokes=self.x_cfg.num_paths, | |
canvas_size=self.x_cfg.image_size, | |
device=self.device) | |
img = renderer.init_image(stage=0) | |
self.print("init_image shape: ", img.shape) | |
plot_img(img, self.result_path, fname="init_img") | |
# init painter optimizer | |
optimizer = PainterOptimizer(renderer, self.x_cfg.lr, self.x_cfg.width_lr, self.x_cfg.color_lr) | |
optimizer.init_optimizers() | |
total_step = self.x_cfg.num_iter | |
with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: | |
while self.step < total_step: | |
rendering = renderer.get_image(self.step).to(self.device) | |
if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1): | |
plot_img(rendering, self.frame_log_dir, fname=f"iter{self.frame_idx}") | |
self.frame_idx += 1 | |
# data augmentation | |
aug_svg_batch = self.drawing_augment(rendering) | |
loss = torch.tensor(0., device=self.device) | |
for n in range(self.x_cfg.num_aug): | |
loss -= torch.cosine_similarity(text_features, aug_svg_batch[n:n + 1], dim=1).mean() | |
pbar.set_description( | |
f"lr: {optimizer.get_lr():.3f}, " | |
f"L_train: {loss.item():.4f}" | |
) | |
# optimization | |
optimizer.zero_grad_() | |
loss.backward() | |
optimizer.step_() | |
renderer.clip_curve_shape() | |
if self.x_cfg.lr_schedule: | |
optimizer.update_lr(self.step) | |
if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: | |
plot_couple(img, | |
rendering, | |
self.step, | |
prompt=prompt, | |
output_dir=self.png_logs_dir.as_posix(), | |
fname=f"iter{self.step}") | |
renderer.save_svg(self.svg_logs_dir.as_posix(), f"svg_iter{self.step}") | |
self.step += 1 | |
pbar.update(1) | |
renderer.save_svg(self.result_path.as_posix(), "final_svg") | |
if self.make_video: | |
from subprocess import call | |
call([ | |
"ffmpeg", | |
"-framerate", f"{self.args.framerate}", | |
"-i", (self.frame_log_dir / "iter%d.png").as_posix(), | |
"-vb", "20M", | |
(self.result_path / "clipdraw_rendering.mp4").as_posix() | |
]) | |
self.close(msg="painterly rendering complete.") | |