Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Copyright (c) XiMing Xing. All rights reserved. | |
# Author: XiMing Xing | |
# Description: | |
from PIL import Image | |
from typing import Union, AnyStr, List | |
from omegaconf.listconfig import ListConfig | |
import diffusers | |
import numpy as np | |
from tqdm.auto import tqdm | |
import torch | |
from torchvision import transforms | |
import clip | |
from pytorch_svgrender.libs.engine import ModelState | |
from pytorch_svgrender.painter.vectorfusion import LSDSPipeline, LSDSSDXLPipeline, Painter, PainterOptimizer | |
from pytorch_svgrender.painter.vectorfusion import channel_saturation_penalty_loss as pixel_penalty_loss | |
from pytorch_svgrender.painter.live import xing_loss_fn | |
from pytorch_svgrender.plt import plot_img, plot_couple | |
from pytorch_svgrender.token2attn.ptp_utils import view_images | |
from pytorch_svgrender.diffusers_warp import init_StableDiffusion_pipeline, model2res | |
class VectorFusionPipeline(ModelState): | |
def __init__(self, args): | |
assert args.x.style in ["iconography", "pixelart", "low-poly", "painting", "sketch", "ink"] | |
logdir_ = f"sd{args.seed}-" \ | |
f"{'scratch' if args.x.skip_live else 'baseline'}" \ | |
f"-{args.x.model_id}" \ | |
f"-{args.x.style}" \ | |
f"-P{args.x.num_paths}" \ | |
f"{'-RePath' if args.x.path_reinit.use else ''}" | |
super().__init__(args, log_path_suffix=logdir_) | |
# create log dir | |
self.png_logs_dir = self.result_path / "png_logs" | |
self.svg_logs_dir = self.result_path / "svg_logs" | |
self.ft_png_logs_dir = self.result_path / "ft_png_logs" | |
self.ft_svg_logs_dir = self.result_path / "ft_svg_logs" | |
self.sd_sample_dir = self.result_path / 'sd_samples' | |
self.reinit_dir = self.result_path / "reinit_logs" | |
if self.accelerator.is_main_process: | |
self.png_logs_dir.mkdir(parents=True, exist_ok=True) | |
self.svg_logs_dir.mkdir(parents=True, exist_ok=True) | |
self.ft_png_logs_dir.mkdir(parents=True, exist_ok=True) | |
self.ft_svg_logs_dir.mkdir(parents=True, exist_ok=True) | |
self.sd_sample_dir.mkdir(parents=True, exist_ok=True) | |
self.reinit_dir.mkdir(parents=True, exist_ok=True) | |
self.select_fpth = self.result_path / 'select_sample.png' | |
# make video log | |
self.make_video = self.args.mv | |
if self.make_video: | |
self.frame_idx = 0 | |
self.frame_log_dir = self.result_path / "frame_logs" | |
self.frame_log_dir.mkdir(parents=True, exist_ok=True) | |
if self.x_cfg.model_id == "sdxl": | |
# default LSDSSDXLPipeline scheduler is EulerDiscreteScheduler | |
# when LSDSSDXLPipeline calls, scheduler.timesteps will change in step 4 | |
# which causes problem in sds add_noise() function | |
# because the random t may not in scheduler.timesteps | |
custom_pipeline = LSDSSDXLPipeline | |
custom_scheduler = diffusers.DPMSolverMultistepScheduler | |
elif self.x_cfg.model_id == 'sd21': | |
custom_pipeline = LSDSPipeline | |
custom_scheduler = diffusers.DDIMScheduler | |
else: # sd14, sd15 | |
custom_pipeline = LSDSPipeline | |
custom_scheduler = diffusers.PNDMScheduler | |
self.diffusion = init_StableDiffusion_pipeline( | |
self.x_cfg.model_id, | |
custom_pipeline=custom_pipeline, | |
custom_scheduler=custom_scheduler, | |
device=self.device, | |
local_files_only=not args.diffuser.download, | |
force_download=args.diffuser.force_download, | |
resume_download=args.diffuser.resume_download, | |
ldm_speed_up=self.x_cfg.ldm_speed_up, | |
enable_xformers=self.x_cfg.enable_xformers, | |
gradient_checkpoint=self.x_cfg.gradient_checkpoint, | |
lora_path=self.x_cfg.lora_path | |
) | |
self.g_device = torch.Generator(device=self.device).manual_seed(args.seed) | |
self.style = self.x_cfg.style | |
if self.style in ["pixelart", "low-poly"]: | |
self.x_cfg.path_schedule = 'list' | |
self.x_cfg.schedule_each = list([args.x.grid]) | |
if self.style == "pixelart": | |
self.x_cfg.lr_stage_one.lr_schedule = False | |
self.x_cfg.lr_stage_two.lr_schedule = False | |
def get_path_schedule(self, schedule_each: Union[int, List]): | |
if self.x_cfg.path_schedule == 'repeat': | |
return int(self.x_cfg.num_paths / schedule_each) * [schedule_each] | |
elif self.x_cfg.path_schedule == 'list': | |
assert isinstance(self.x_cfg.schedule_each, list) or \ | |
isinstance(self.x_cfg.schedule_each, ListConfig) | |
return schedule_each | |
else: | |
raise NotImplementedError | |
def target_file_preprocess(self, tar_path: AnyStr): | |
process_comp = transforms.Compose([ | |
transforms.Resize(size=(self.x_cfg.image_size, self.x_cfg.image_size)), | |
transforms.ToTensor(), | |
transforms.Lambda(lambda t: t.unsqueeze(0)), | |
]) | |
tar_pil = Image.open(tar_path).convert("RGB") # open file | |
target_img = process_comp(tar_pil) # preprocess | |
target_img = target_img.to(self.device) | |
return target_img | |
def rejection_sampling(self, img_caption: Union[AnyStr, List], diffusion_samples: List): | |
clip_model, preprocess = clip.load("ViT-B/32", device=self.device) | |
text_input = clip.tokenize([img_caption]).to(self.device) | |
text_features = clip_model.encode_text(text_input) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
clip_images = torch.stack([ | |
preprocess(sample) for sample in diffusion_samples] | |
).to(self.device) | |
image_features = clip_model.encode_image(clip_images) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
# clip score | |
similarity_scores = (text_features @ image_features.T).squeeze(0) | |
selected_image_index = similarity_scores.argmax().item() | |
selected_image = diffusion_samples[selected_image_index] | |
return selected_image | |
def diffusion_sampling(self, text_prompt: AnyStr): | |
"""sampling K images""" | |
diffusion_samples = [] | |
for i in range(self.x_cfg.K): | |
height = width = model2res(self.x_cfg.model_id) | |
outputs = self.diffusion(prompt=[text_prompt], | |
negative_prompt=self.args.neg_prompt, | |
height=height, | |
width=width, | |
num_images_per_prompt=1, | |
num_inference_steps=self.x_cfg.num_inference_steps, | |
guidance_scale=self.x_cfg.guidance_scale, | |
generator=self.g_device) | |
outputs_np = [np.array(img) for img in outputs.images] | |
view_images(outputs_np, save_image=True, fp=self.sd_sample_dir / f'samples_{i}.png') | |
diffusion_samples.extend(outputs.images) | |
self.print(f"num_generated_samples: {len(diffusion_samples)}, shape: {outputs_np[0].shape}") | |
return diffusion_samples | |
def LIVE_rendering(self, text_prompt: AnyStr): | |
select_fpth = self.select_fpth | |
# sampling K images | |
diffusion_samples = self.diffusion_sampling(text_prompt) | |
# rejection sampling | |
select_target = self.rejection_sampling(text_prompt, diffusion_samples) | |
select_target_pil = Image.fromarray(np.asarray(select_target)) # numpy to PIL | |
select_target_pil.save(select_fpth) | |
# load target file | |
assert select_fpth.exists(), f"{select_fpth} is not exist!" | |
target_img = self.target_file_preprocess(select_fpth.as_posix()) | |
self.print(f"load target file from: {select_fpth.as_posix()}") | |
# log path_schedule | |
path_schedule = self.get_path_schedule(self.x_cfg.schedule_each) | |
self.print(f"path_schedule: {path_schedule}") | |
renderer = self.load_renderer() | |
# first init center | |
renderer.component_wise_path_init(target_img, pred=None, init_type=self.x_cfg.coord_init) | |
optimizer_list = [PainterOptimizer(renderer, self.style, self.x_cfg.num_iter, | |
self.x_cfg.lr_stage_one, self.x_cfg.trainable_bg) | |
for _ in range(len(path_schedule))] | |
pathn_record = [] | |
loss_weight_keep = 0 | |
total_step = len(path_schedule) * self.x_cfg.num_iter | |
with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: | |
for path_idx, pathn in enumerate(path_schedule): | |
# record path | |
pathn_record.append(pathn) | |
# init graphic | |
img = renderer.init_image(stage=0, num_paths=pathn) | |
plot_img(img, self.result_path, fname=f"init_img_{path_idx}") | |
# rebuild optimizer | |
optimizer_list[path_idx].init_optimizers(pid_delta=int(path_idx * pathn)) | |
pbar.write(f"=> adding {pathn} paths, n_path: {sum(pathn_record)}, " | |
f"n_points: {len(renderer.get_point_parameters())}, " | |
f"n_colors: {len(renderer.get_color_parameters())}") | |
for t in range(self.x_cfg.num_iter): | |
raster_img = renderer.get_image(step=t).to(self.device) | |
if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1): | |
plot_img(raster_img, self.frame_log_dir, fname=f"iter{self.frame_idx}") | |
self.frame_idx += 1 | |
if self.x_cfg.use_distance_weighted_loss and not (self.style == "pixelart"): | |
loss_weight = renderer.calc_distance_weight(loss_weight_keep) | |
# reconstruction loss | |
if self.style == "pixelart": | |
loss_recon = torch.nn.functional.l1_loss(raster_img, target_img) | |
else: # UDF loss | |
loss_recon = ((raster_img - target_img) ** 2) | |
loss_recon = (loss_recon.sum(1) * loss_weight).mean() | |
# Xing Loss for Self-Interaction Problem | |
loss_xing = torch.tensor(0.) | |
if self.style == "iconography": | |
loss_xing = xing_loss_fn(renderer.get_point_parameters()) * self.x_cfg.xing_loss_weight | |
# total loss | |
loss = loss_recon + loss_xing | |
lr_str = "" | |
for k, lr in optimizer_list[path_idx].get_lr().items(): | |
lr_str += f"{k}_lr: {lr:.4f}, " | |
pbar.set_description( | |
lr_str + | |
f"L_total: {loss.item():.4f}, " | |
f"L_recon: {loss_recon.item():.4f}, " | |
f"L_xing: {loss_xing.item()}" | |
) | |
# optimization | |
for i in range(path_idx + 1): | |
optimizer_list[i].zero_grad_() | |
loss.backward() | |
for i in range(path_idx + 1): | |
optimizer_list[i].step_() | |
renderer.clip_curve_shape() | |
if self.x_cfg.lr_stage_one.lr_schedule: | |
for i in range(path_idx + 1): | |
optimizer_list[i].update_lr() | |
if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: | |
plot_couple(target_img, | |
raster_img, | |
self.step, | |
prompt=text_prompt, | |
output_dir=self.png_logs_dir.as_posix(), | |
fname=f"iter{self.step}") | |
renderer.pretty_save_svg(self.svg_logs_dir / f"svg_iter{self.step}.svg") | |
self.step += 1 | |
pbar.update(1) | |
# end a set of path optimization | |
if self.x_cfg.use_distance_weighted_loss and not (self.style == "pixelart"): | |
loss_weight_keep = loss_weight.detach().cpu().numpy() * 1 | |
# recalculate the coordinates for the new join path | |
renderer.component_wise_path_init(target_img, raster_img) | |
# end LIVE | |
final_svg_fpth = self.result_path / "live_stage_one_final.svg" | |
renderer.pretty_save_svg(final_svg_fpth) | |
if self.make_video: | |
from subprocess import call | |
call([ | |
"ffmpeg", | |
"-framerate", f"{self.args.framerate}", | |
"-i", (self.frame_log_dir / "iter%d.png").as_posix(), | |
"-vb", "20M", | |
(self.result_path / "VF_rendering_stage1.mp4").as_posix() | |
]) | |
return target_img, final_svg_fpth | |
def painterly_rendering(self, text_prompt: AnyStr): | |
# log prompts | |
self.print(f"prompt: {text_prompt}") | |
self.print(f"negative_prompt: {self.args.neg_prompt}\n") | |
if self.x_cfg.skip_live: | |
target_img = torch.randn(1, 3, self.x_cfg.image_size, self.x_cfg.image_size) | |
final_svg_fpth = None | |
self.print("from scratch with Score Distillation Sampling...") | |
else: | |
# text-to-img-to-svg | |
target_img, final_svg_fpth = self.LIVE_rendering(text_prompt) | |
torch.cuda.empty_cache() | |
self.x_cfg.path_svg = final_svg_fpth | |
self.print("\nfine-tune SVG via Score Distillation Sampling...") | |
renderer = self.load_renderer(path_svg=final_svg_fpth) | |
if self.x_cfg.skip_live: | |
renderer.component_wise_path_init(target_img, pred=None, init_type='random') | |
img = renderer.init_image(stage=0, num_paths=self.x_cfg.num_paths) | |
plot_img(img, self.result_path, fname=f"init_img_stage_two") | |
optimizer = PainterOptimizer(renderer, self.style, | |
self.x_cfg.sds.num_iter, | |
self.x_cfg.lr_stage_two, | |
self.x_cfg.trainable_bg) | |
optimizer.init_optimizers() | |
self.print(f"-> Painter point Params: {len(renderer.get_point_parameters())}") | |
self.print(f"-> Painter color Params: {len(renderer.get_color_parameters())}") | |
self.print(f"-> Painter width Params: {len(renderer.get_width_parameters())}") | |
self.step = 0 # reset global step | |
total_step = self.x_cfg.sds.num_iter | |
path_reinit = self.x_cfg.path_reinit | |
self.print(f"\ntotal sds optimization steps: {total_step}") | |
with tqdm(initial=self.step, total=total_step, disable=not self.accelerator.is_main_process) as pbar: | |
while self.step < total_step: | |
raster_img = renderer.get_image(step=self.step).to(self.device) | |
if self.make_video and (self.step % self.args.framefreq == 0 or self.step == total_step - 1): | |
plot_img(raster_img, self.frame_log_dir, fname=f"iter{self.frame_idx}") | |
self.frame_idx += 1 | |
L_sds, grad = self.diffusion.score_distillation_sampling( | |
raster_img, | |
im_size=self.x_cfg.sds.im_size, | |
prompt=[text_prompt], | |
negative_prompt=self.args.neg_prompt, | |
guidance_scale=self.x_cfg.sds.guidance_scale, | |
grad_scale=self.x_cfg.sds.grad_scale, | |
t_range=list(self.x_cfg.sds.t_range), | |
) | |
# Xing Loss for Self-Interaction Problem | |
L_add = torch.tensor(0.) | |
if self.style == "iconography": | |
L_add = xing_loss_fn(renderer.get_point_parameters()) * self.x_cfg.xing_loss_weight | |
# pixel_penalty_loss to combat oversaturation | |
if self.style in ["pixelart", "low-poly"]: | |
L_add = pixel_penalty_loss(raster_img) * self.x_cfg.penalty_weight | |
loss = L_sds + L_add | |
# optimization | |
optimizer.zero_grad_() | |
loss.backward() | |
optimizer.step_() | |
renderer.clip_curve_shape() | |
# re-init paths | |
if self.step % path_reinit.freq == 0 and self.step < path_reinit.stop_step and self.step != 0: | |
renderer.reinitialize_paths(path_reinit.use, # on-off | |
path_reinit.opacity_threshold, | |
path_reinit.area_threshold, | |
fpath=self.reinit_dir / f"reinit-{self.step}.svg") | |
# update lr | |
if self.x_cfg.lr_stage_two.lr_schedule: | |
optimizer.update_lr() | |
lr_str = "" | |
for k, lr in optimizer.get_lr().items(): | |
lr_str += f"{k}_lr: {lr:.4f}, " | |
pbar.set_description( | |
lr_str + | |
f"L_total: {loss.item():.4f}, " | |
f"L_add: {L_add.item():.4e}, " | |
f"sds: {grad.item():.5e}" | |
) | |
if self.step % self.args.save_step == 0 and self.accelerator.is_main_process: | |
plot_couple(target_img, | |
raster_img, | |
self.step, | |
prompt=text_prompt, | |
output_dir=self.ft_png_logs_dir.as_posix(), | |
fname=f"iter{self.step}") | |
renderer.pretty_save_svg(self.ft_svg_logs_dir / f"svg_iter{self.step}.svg") | |
self.step += 1 | |
pbar.update(1) | |
final_svg_fpth = self.result_path / "finetune_final.svg" | |
renderer.pretty_save_svg(final_svg_fpth) | |
if self.make_video: | |
from subprocess import call | |
call([ | |
"ffmpeg", | |
"-framerate", f"{self.args.framerate}", | |
"-i", (self.frame_log_dir / "iter%d.png").as_posix(), | |
"-vb", "20M", | |
(self.result_path / "VF_rendering_stage2.mp4").as_posix() | |
]) | |
self.close(msg="painterly rendering complete.") | |
def load_renderer(self, path_svg=None): | |
renderer = Painter(self.args.diffvg, | |
self.style, | |
self.x_cfg.num_segments, | |
self.x_cfg.segment_init, | |
self.x_cfg.radius, | |
self.x_cfg.image_size, | |
self.x_cfg.grid, | |
self.x_cfg.trainable_bg, | |
self.x_cfg.width, | |
path_svg=path_svg, | |
device=self.device) | |
return renderer | |