DiffSketcher / pytorch_svgrender /pipelines /StyleCLIPDraw_pipeline.py
hjc-owo
init repo
966ae59
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
import shutil
from PIL import Image
from pathlib import Path
import torch
from torchvision import transforms
import clip
from tqdm.auto import tqdm
import numpy as np
from pytorch_svgrender.libs.engine import ModelState
from pytorch_svgrender.painter.style_clipdraw import (
Painter, PainterOptimizer, VGG16Extractor, StyleLoss, sample_indices
)
from pytorch_svgrender.plt import plot_img, plot_couple
class StyleCLIPDrawPipeline(ModelState):
def __init__(self, args):
logdir_ = f"sd{args.seed}" \
f"-P{args.x.num_paths}" \
f"-style{args.x.style_strength}" \
f"-n{args.x.num_aug}"
super().__init__(args, log_path_suffix=logdir_)
# create log dir
self.png_logs_dir = self.result_path / "png_logs"
self.svg_logs_dir = self.result_path / "svg_logs"
if self.accelerator.is_main_process:
self.png_logs_dir.mkdir(parents=True, exist_ok=True)
self.svg_logs_dir.mkdir(parents=True, exist_ok=True)
# make video log
self.make_video = self.args.mv
if self.make_video:
self.frame_idx = 0
self.frame_log_dir = self.result_path / "frame_logs"
self.frame_log_dir.mkdir(parents=True, exist_ok=True)
self.clip, self.tokenize_fn = self.init_clip()
self.style_extractor = VGG16Extractor(space="normal").to(self.device)
self.style_loss = StyleLoss()
def init_clip(self):
model, _ = clip.load('ViT-B/32', self.device, jit=False)
return model, clip.tokenize
def drawing_augment(self, image):
augment_trans = transforms.Compose([
transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
# image augmentation transformation
img_augs = []
for n in range(self.x_cfg.num_aug):
img_augs.append(augment_trans(image))
im_batch = torch.cat(img_augs)
# clip visual encoding
image_features = self.clip.encode_image(im_batch)
return image_features
def style_file_preprocess(self, style_file):
process_comp = transforms.Compose([
transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
transforms.Lambda(lambda t: t.unsqueeze(0)),
transforms.Lambda(lambda t: (t + 1) / 2),
])
style_file = process_comp(style_file)
style_file = style_file.to(self.device)
return style_file
def painterly_rendering(self, prompt, style_fpath):
# load style file
style_path = Path(style_fpath)
assert style_path.exists(), f"{style_fpath} is not exist!"
self.print(f"load style file from: {style_path.as_posix()}")
style_pil = Image.open(style_path.as_posix()).convert("RGB")
style_img = self.style_file_preprocess(style_pil)
shutil.copy(style_fpath, self.result_path) # copy style file
# extract style features from style image
feat_style = None
for i in range(5):
with torch.no_grad():
# r is region of interest (mask)
feat_e = self.style_extractor.forward_samples_hypercolumn(style_img, samps=1000)
feat_style = feat_e if feat_style is None else torch.cat((feat_style, feat_e), dim=2)
# text prompt encoding
self.print(f"prompt: {prompt}")
text_tokenize = self.tokenize_fn(prompt).to(self.device)
with torch.no_grad():
text_features = self.clip.encode_text(text_tokenize)
renderer = Painter(self.x_cfg,
self.args.diffvg,
num_strokes=self.x_cfg.num_paths,
canvas_size=self.x_cfg.image_size,
device=self.device)
img = renderer.init_image(stage=0)
self.print("init_image shape: ", img.shape)
plot_img(img, self.result_path, fname="init_img")
optimizer = PainterOptimizer(renderer, self.x_cfg.lr, self.x_cfg.width_lr, self.x_cfg.color_lr)
optimizer.init_optimizers()
style_weight = 4 * (self.x_cfg.style_strength / 100)
self.print(f'style_weight: {style_weight}')
total_step = self.x_cfg.num_iter
with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar:
while self.step < total_step:
rendering = renderer.get_image(self.step).to(self.device)
if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1):
plot_img(rendering, self.frame_log_dir, fname=f"iter{self.frame_idx}")
self.frame_idx += 1
rendering_aug = self.drawing_augment(rendering)
loss = torch.tensor(0., device=self.device)
# do clip optimization
if self.step < 0.9 * total_step:
for n in range(self.x_cfg.num_aug):
loss -= torch.cosine_similarity(text_features, rendering_aug[n:n + 1], dim=1).mean()
# do style optimization
# extract style features based on the approach from STROTSS [Kolkin et al., 2019].
feat_content = self.style_extractor(rendering)
xx, xy = sample_indices(feat_content[0], feat_style)
np.random.shuffle(xx)
np.random.shuffle(xy)
L_style = self.style_loss.forward(feat_content, feat_content, feat_style, [xx, xy], 0)
loss += L_style * style_weight
pbar.set_description(
f"lr: {optimizer.get_lr():.3f}, "
f"L_train: {loss.item():.4f}, "
f"L_style: {L_style.item():.4f}"
)
# optimization
optimizer.zero_grad_()
loss.backward()
optimizer.step_()
renderer.clip_curve_shape()
if self.x_cfg.lr_schedule:
optimizer.update_lr(self.step)
if self.step % self.args.save_step == 0 and self.accelerator.is_main_process:
plot_couple(style_img,
rendering,
self.step,
prompt=prompt,
output_dir=self.png_logs_dir.as_posix(),
fname=f"iter{self.step}")
renderer.save_svg(self.svg_logs_dir.as_posix(), f"svg_iter{self.step}")
self.step += 1
pbar.update(1)
plot_couple(style_img,
rendering,
self.step,
prompt=prompt,
output_dir=self.result_path.as_posix(),
fname=f"final_iter")
renderer.save_svg(self.result_path.as_posix(), "final_svg")
if self.make_video:
from subprocess import call
call([
"ffmpeg",
"-framerate", f"{self.args.framerate}",
"-i", (self.frame_log_dir / "iter%d.png").as_posix(),
"-vb", "20M",
(self.result_path / "styleclipdraw_rendering.mp4").as_posix()
])
self.close(msg="painterly rendering complete.")