File size: 7,167 Bytes
966ae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# -*- coding: utf-8 -*-
# Copyright (c) XiMing Xing. All rights reserved.
# Author: XiMing Xing
# Description:
from pathlib import Path

from tqdm.auto import tqdm
import torch

from pytorch_svgrender.libs.engine import ModelState
from pytorch_svgrender.painter.wordasimage import Painter, PainterOptimizer
from pytorch_svgrender.painter.wordasimage.losses import ToneLoss, ConformalLoss
from pytorch_svgrender.painter.vectorfusion import LSDSPipeline
from pytorch_svgrender.plt import plot_img, plot_couple
from pytorch_svgrender.diffusers_warp import init_StableDiffusion_pipeline
from pytorch_svgrender.svgtools import FONT_LIST


class WordAsImagePipeline(ModelState):

    def __init__(self, args):
        # assert
        assert args.x.optim_letter in args.x.word
        assert Path(args.x.font_path).exists(), f"{args.x.font_path} is not exist."
        assert args.x.font in FONT_LIST, f"{args.x.font} is not currently supported."

        # make logdir
        logdir_ = f"sd{args.seed}" \
                  f"-im{args.x.image_size}" \
                  f"-{args.x.word}-{args.x.optim_letter}"
        super().__init__(args, log_path_suffix=logdir_)

        # log dir
        self.png_log_dir = self.result_path / "png_logs"
        self.svg_log_dir = self.result_path / "svg_logs"
        # font
        self.font = self.x_cfg.font
        self.font_path = self.x_cfg.font_path
        self.optim_letter = self.x_cfg.optim_letter
        # letter
        self.letter = self.x_cfg.optim_letter
        self.target_letter = self.result_path / f"{self.font}_{self.optim_letter}_scaled.svg"
        # make log dir
        if self.accelerator.is_main_process:
            self.png_log_dir.mkdir(parents=True, exist_ok=True)
            self.svg_log_dir.mkdir(parents=True, exist_ok=True)

        # make video log
        self.make_video = self.args.mv
        if self.make_video:
            self.frame_idx = 0
            self.frame_log_dir = self.result_path / "frame_logs"
            self.frame_log_dir.mkdir(parents=True, exist_ok=True)

        self.diffusion = init_StableDiffusion_pipeline(
            self.x_cfg.model_id,
            custom_pipeline=LSDSPipeline,
            device=self.device,
            local_files_only=not args.diffuser.download,
            force_download=args.diffuser.force_download,
            resume_download=args.diffuser.resume_download,
            ldm_speed_up=self.x_cfg.ldm_speed_up,
            enable_xformers=self.x_cfg.enable_xformers,
            gradient_checkpoint=self.x_cfg.gradient_checkpoint,
            lora_path=self.x_cfg.lora_path
        )

        self.g_device = torch.Generator(device=self.device).manual_seed(args.seed)

    def painterly_rendering(self, word, semantic_concept, optimized_letter):
        prompt = semantic_concept + ". " + self.x_cfg.prompt_suffix
        self.print(f"prompt: {prompt}")

        # load the optimized letter
        renderer = Painter(self.font, canvas_size=self.x_cfg.image_size, device=self.device)

        # font to svg
        self.print(f"font type: {self.font}\n")
        renderer.preprocess_font(word,
                                 optimized_letter,
                                 self.x_cfg.level_of_cc,
                                 self.font_path,
                                 self.result_path.as_posix())

        # init letter shape
        img_init = renderer.init_shape(self.target_letter)
        plot_img(img_init, self.result_path, fname="word_init")

        # save init letter
        renderer.pretty_save_svg(self.result_path / "letter_init.svg")
        init_letter = renderer.get_image()

        n_iter = self.x_cfg.num_iter

        # init optimizer and lr_schedular
        optimizer = PainterOptimizer(renderer, n_iter, self.x_cfg.lr)
        optimizer.init_optimizers()

        # init Tone loss
        if self.x_cfg.tone_loss.use:
            tone_loss = ToneLoss(self.x_cfg.tone_loss)
            tone_loss.set_image_init(img_init)

        # init conformal loss
        if self.x_cfg.conformal.use:
            conformal_loss = ConformalLoss(renderer.get_point_parameters(),
                                           renderer.shape_groups,
                                           optimized_letter, self.device)

        with tqdm(initial=self.step, total=n_iter, disable=not self.accelerator.is_main_process) as pbar:
            for i in range(n_iter):

                raster_img = renderer.get_image(step=i)

                if self.make_video and (i % self.args.framefreq == 0 or i == n_iter - 1):
                    plot_img(raster_img, self.frame_log_dir, fname=f"iter{self.step}")

                L_sds, grad = self.diffusion.score_distillation_sampling(
                    raster_img,
                    im_size=self.x_cfg.sds.im_size,
                    prompt=[prompt],
                    negative_prompt=self.args.neg_prompt,
                    guidance_scale=self.x_cfg.sds.guidance_scale,
                    grad_scale=self.x_cfg.sds.grad_scale,
                    t_range=list(self.x_cfg.sds.t_range),
                )

                loss = L_sds

                if self.x_cfg.tone_loss.use:
                    tone_loss_res = tone_loss(raster_img, step=i)
                    loss = loss + tone_loss_res

                if self.x_cfg.conformal.use:
                    loss_angles = conformal_loss()
                    loss_angles = self.x_cfg.conformal.angeles_w * loss_angles
                    loss = loss + loss_angles

                pbar.set_description(
                    f"n_params: {len(renderer.get_point_parameters())}, "
                    f"lr: {optimizer.get_lr():.4f}, "
                    f"L_total: {loss.item():.4f}, "
                )

                # optimization
                optimizer.zero_grad_()
                loss.backward()
                optimizer.step_()

                if self.x_cfg.lr_schedule:
                    optimizer.update_lr()

                if self.step % self.args.save_step == 0 and self.accelerator.is_main_process:
                    plot_couple(init_letter,
                                raster_img,
                                self.step,
                                output_dir=self.png_log_dir.as_posix(),
                                fname=f"iter{self.step}",
                                prompt=prompt)
                    renderer.pretty_save_svg(self.svg_log_dir / f"svg_iter{self.step}.svg")

                self.step += 1
                pbar.update(1)

        # save final optimized letter
        renderer.pretty_save_svg(self.result_path / "final_letter.svg")

        # combine word
        renderer.combine_word(word, optimized_letter, self.font, self.result_path)

        if self.make_video:
            from subprocess import call
            call([
                "ffmpeg",
                "-framerate", f"{self.args.framerate}",
                "-i", (self.frame_log_dir / "iter%d.png").as_posix(),
                "-vb", "20M",
                (self.result_path / "wordasimg_rendering.mp4").as_posix()
            ])

        self.close(msg="painterly rendering complete.")