DiffSketcher / pytorch_svgrender /pipelines /CLIPascene_pipeline.py
hjc-owo
init repo
966ae59
import shutil
from pathlib import Path
import imageio
import numpy as np
import torch
from PIL import Image
from pytorch_svgrender.libs.engine import ModelState
from pytorch_svgrender.painter.clipascene import Painter, PainterOptimizer, Loss
from pytorch_svgrender.painter.clipascene.lama_utils import apply_inpaint
from pytorch_svgrender.painter.clipascene.scripts_utils import read_svg
from pytorch_svgrender.painter.clipascene.sketch_utils import plot_attn, get_mask_u2net, fix_image_scale
from pytorch_svgrender.plt import plot_img, plot_couple
from skimage.transform import resize
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from tqdm.auto import tqdm
class CLIPascenePipeline(ModelState):
def __init__(self, args):
logdir_ = f"sd{args.seed}" \
f"-im{args.x.image_size}" \
f"-P{args.x.num_paths}W{args.x.width}"
super().__init__(args, log_path_suffix=logdir_)
def painterly_rendering(self, image_path):
foreground_target, background_target = self.preprocess_image(image_path)
background_output_dir = self.run_background(background_target)
foreground_output_dir = self.run_foreground(foreground_target)
self.combine(background_output_dir, foreground_output_dir, self.device)
self.close(msg="painterly rendering complete.")
def preprocess_image(self, image_path):
image_path = Path(image_path)
scene_path = self.result_path / "scene"
background_path = self.result_path / "background"
if self.accelerator.is_main_process:
scene_path.mkdir(parents=True, exist_ok=True)
background_path.mkdir(parents=True, exist_ok=True)
im = Image.open(image_path)
max_size = max(im.size[0], im.size[1])
scaled_path = scene_path / f"{image_path.stem}.png"
if max_size > 512:
im = Image.open(image_path).convert("RGB").resize((512, 512))
im.save(scaled_path)
else:
shutil.copyfile(image_path, scaled_path)
scaled_img = Image.open(scaled_path)
mask = get_mask_u2net(scaled_img, scene_path, self.args.x.u2net_path, preprocess=True, device=self.device)
masked_path = scene_path / f"{image_path.stem}_mask.png"
imageio.imsave(masked_path, mask)
apply_inpaint(scene_path, background_path, self.device)
return scaled_path, background_path / f"{image_path.stem}_mask.png"
def run_background(self, target_file):
print("=====Start background=====")
self.args.x.resize_obj = 0
self.args.x.mask_object = 0
clip_conv_layer_weights_int = [0 for _ in range(12)]
clip_conv_layer_weights_int[self.args.x.background_layer] = 1
clip_conv_layer_weights_str = [str(j) for j in clip_conv_layer_weights_int]
self.args.x.clip_conv_layer_weights = ','.join(clip_conv_layer_weights_str)
output_dir = self.result_path / "background"
if self.accelerator.is_main_process:
output_dir.mkdir(parents=True, exist_ok=True)
self.paint(target_file, output_dir, self.args.x.background_num_iter)
print("=====End background=====")
return output_dir
def run_foreground(self, target_file):
print("=====Start foreground=====")
self.args.x.resize_obj = 1
if self.args.x.foreground_layer != 4:
self.args.x.gradnorm = 1
self.args.x.mask_object = 1
clip_conv_layer_weights_int = [0 for _ in range(12)]
clip_conv_layer_weights_int[4] = 0.5
clip_conv_layer_weights_int[self.args.x.foreground_layer] = 1
clip_conv_layer_weights_str = [str(j) for j in clip_conv_layer_weights_int]
self.args.x.clip_conv_layer_weights = ','.join(clip_conv_layer_weights_str)
output_dir = self.result_path / "object"
if self.accelerator.is_main_process:
output_dir.mkdir(parents=True, exist_ok=True)
self.paint(target_file, output_dir, self.args.x.foreground_num_iter)
print("=====End foreground=====")
return output_dir
def paint(self, target, output_dir, num_iter):
png_log_dir = output_dir / "png_logs"
svg_log_dir = output_dir / "svg_logs"
if self.accelerator.is_main_process:
png_log_dir.mkdir(parents=True, exist_ok=True)
svg_log_dir.mkdir(parents=True, exist_ok=True)
# make video log
self.make_video = self.args.mv
if self.make_video:
self.frame_idx = 0
self.frame_log_dir = output_dir / "frame_logs"
self.frame_log_dir.mkdir(parents=True, exist_ok=True)
# preprocess input image
inputs, mask = self.get_target(target,
self.args.x.image_size,
output_dir,
self.args.x.resize_obj,
self.args.x.u2net_path,
self.args.x.mask_object,
self.args.x.fix_scale,
self.device)
plot_img(inputs, output_dir, fname="target")
loss_func = Loss(self.x_cfg, mask, self.device)
# init renderer
renderer = self.load_renderer(inputs, mask)
# init optimizer
optimizer = PainterOptimizer(self.x_cfg, renderer)
best_loss, best_fc_loss, best_num_strokes = 100, 100, self.args.x.num_paths
best_iter, best_iter_fc = 0, 0
min_delta = 1e-7
renderer.set_random_noise(0)
renderer.init_image(stage=0)
renderer.save_svg(svg_log_dir, "init_svg")
optimizer.init_optimizers()
if self.args.x.switch_loss:
# start with width optim and than switch every switch_loss iterations
renderer.turn_off_points_optim()
optimizer.turn_off_points_optim()
with torch.no_grad():
renderer.get_image("init").to(self.device)
renderer.save_svg(self.result_path, "init")
total_step = num_iter
step = 0
with tqdm(initial=step, total=total_step, disable=not self.accelerator.is_main_process) as pbar:
while step < total_step:
optimizer.zero_grad_()
sketches = renderer.get_image().to(self.device)
if self.make_video and (step % self.args.framefreq == 0 or step == total_step - 1):
plot_img(sketches, self.frame_log_dir, fname=f"iter{self.frame_idx}")
self.frame_idx += 1
losses_dict_weighted, _, _ = loss_func(sketches, inputs.detach(), step,
renderer.get_widths(), renderer,
optimizer, mode="train",
width_opt=renderer.width_optim)
loss = sum(list(losses_dict_weighted.values()))
loss.backward()
optimizer.step_()
if step % self.args.x.save_step == 0:
plot_couple(inputs,
sketches,
self.step,
output_dir=png_log_dir.as_posix(),
fname=f"iter{step}")
renderer.save_svg(svg_log_dir.as_posix(), f"svg_iter{step}")
if step % self.args.x.eval_step == 0:
with torch.no_grad():
losses_dict_weighted_eval, _, _ = loss_func(
sketches,
inputs,
step,
renderer.get_widths(),
renderer=renderer,
mode="eval",
width_opt=renderer.width_optim)
loss_eval = sum(list(losses_dict_weighted_eval.values()))
cur_delta = loss_eval.item() - best_loss
if abs(cur_delta) > min_delta:
if cur_delta < 0:
best_loss = loss_eval.item()
best_iter = step
plot_couple(inputs,
sketches,
best_iter,
output_dir=output_dir.as_posix(),
fname="best_iter")
renderer.save_svg(output_dir.as_posix(), "best_iter")
if step == 0 and self.x_cfg.attention_init and self.accelerator.is_main_process:
plot_attn(renderer.get_attn(),
renderer.get_thresh(),
inputs,
renderer.get_inds(),
(output_dir / "attention_map.png").as_posix(),
self.x_cfg.saliency_model)
if self.args.x.switch_loss:
if step > 0 and step % self.args.x.switch_loss == 0:
renderer.switch_opt()
optimizer.switch_opt()
step += 1
pbar.update(1)
if self.make_video:
from subprocess import call
call([
"ffmpeg",
"-framerate", f"{self.args.framerate}",
"-i", (self.frame_log_dir / "iter%d.png").as_posix(),
"-vb", "20M",
(output_dir / f"clipascene_sketch.mp4").as_posix()
])
def load_renderer(self, target_im=None, mask=None):
renderer = Painter(method_cfg=self.x_cfg,
diffvg_cfg=self.args.diffvg,
num_strokes=self.x_cfg.num_paths,
canvas_size=self.x_cfg.image_size,
device=self.device,
target_im=target_im,
mask=mask)
return renderer
def get_target(self,
target_file,
image_size,
output_dir,
resize_obj,
u2net_path,
mask_object,
fix_scale,
device):
target = Image.open(target_file)
if target.mode == "RGBA":
# Create a white rgba background
new_image = Image.new("RGBA", target.size, "WHITE")
# Paste the image on the background.
new_image.paste(target, (0, 0), target)
target = new_image
target = target.convert("RGB")
# U^2 net mask
masked_im, mask = get_mask_u2net(target, output_dir, u2net_path, resize_obj=resize_obj, device=device)
if mask_object:
target = masked_im
if fix_scale:
target = fix_image_scale(target)
transforms_ = []
if target.size[0] != target.size[1]:
transforms_.append(
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC)
)
else:
transforms_.append(transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC))
transforms_.append(transforms.CenterCrop(image_size))
transforms_.append(transforms.ToTensor())
data_transforms = transforms.Compose(transforms_)
target_ = data_transforms(target).unsqueeze(0).to(self.device)
return target_, mask
def combine(self, background_output_dir, foreground_output_dir, device, output_size=448):
params_path = foreground_output_dir / "resize_params.npy"
params = None
if params_path.exists():
params = np.load(params_path, allow_pickle=True)[()]
mask_path = foreground_output_dir / "mask.png"
mask = imageio.imread(mask_path)
mask = resize(mask, (output_size, output_size), anti_aliasing=False)
object_svg_path = foreground_output_dir / "best_iter.svg"
raster_o = read_svg(object_svg_path, resize_obj=1, params=params, multiply=1.8, device=device)
background_svg_path = background_output_dir / "best_iter.svg"
raster_b = read_svg(background_svg_path, resize_obj=0, params=params, multiply=1.8, device=device)
raster_b[mask == 1] = 1
raster_b[raster_o != 1] = raster_o[raster_o != 1]
raster_b = torch.from_numpy(raster_b).unsqueeze(0).permute(0, 3, 1, 2).to(device)
plot_img(raster_b, self.result_path, fname="combined")