File size: 6,998 Bytes
31726e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os, sys
import pydiffvg
import argparse
import torch
# import torch as th
import scipy.ndimage.filters as filters
# import numba
import numpy as np
from skimage import io
sys.path.append('./textureSyn')
from patchBasedTextureSynthesis import *
from make_gif import make_gif
import random
import ttools.modules

from svgpathtools import svg2paths2, Path, is_path_segment
"""
python texture_synthesis.py textureSyn/traced_1.png  --svg-path textureSyn/traced_1.svg --case 1
"""

def texture_syn(img_path):
    ## get the width and height first
    # input_img = io.imread(img_path)  # returns an MxNx3 array
    # output_size = [input_img.shape[1], input_img.shape[0]]
    # output_path = "textureSyn/1/"
    output_path = "results/texture_synthesis/%d"%(args.case)
    patch_size = 40  # size of the patch (without the overlap)
    overlap_size = 10  # the width of the overlap region
    output_size = [300, 300]
    pbts = patchBasedTextureSynthesis(img_path, output_path, output_size, patch_size, overlap_size, in_windowStep=5,
                                      in_mirror_hor=True, in_mirror_vert=True, in_shapshots=False)
    target_img = pbts.resolveAll()
    return np.array(target_img)


def render(canvas_width, canvas_height, shapes, shape_groups, samples=2):
    _render = pydiffvg.RenderFunction.apply
    scene_args = pydiffvg.RenderFunction.serialize_scene(\
        canvas_width, canvas_height, shapes, shape_groups)
    img = _render(canvas_width, # width
                 canvas_height, # height
                 samples,   # num_samples_x
                 samples,   # num_samples_y
                 0,   # seed
                 None,
                 *scene_args)
    return img

def big_bounding_box(paths_n_stuff):
    """Finds a BB containing a collection of paths, Bezier path segments, and
    points (given as complex numbers)."""
    bbs = []
    for thing in paths_n_stuff:
        if is_path_segment(thing) or isinstance(thing, Path):
            bbs.append(thing.bbox())
        elif isinstance(thing, complex):
            bbs.append((thing.real, thing.real, thing.imag, thing.imag))
        else:
            try:
                complexthing = complex(thing)
                bbs.append((complexthing.real, complexthing.real,
                            complexthing.imag, complexthing.imag))
            except ValueError:
                raise TypeError(
                    "paths_n_stuff can only contains Path, CubicBezier, "
                    "QuadraticBezier, Line, and complex objects.")
    xmins, xmaxs, ymins, ymaxs = list(zip(*bbs))
    xmin = min(xmins)
    xmax = max(xmaxs)
    ymin = min(ymins)
    ymax = max(ymaxs)
    return xmin, xmax, ymin, ymax


def main(args):
    ## set device -> use cpu now since I haven't solved the nvcc issue
    pydiffvg.set_use_gpu(False)
    # pydiffvg.set_device(torch.device('cuda:1'))
    ## use L2 for now
    # perception_loss = ttools.modules.LPIPS().to(pydiffvg.get_device())

    ## generate a texture synthesized
    target_img = texture_syn(args.target)
    tar_h, tar_w = target_img.shape[1], target_img.shape[0]
    canvas_width, canvas_height, shapes, shape_groups = \
        pydiffvg.svg_to_scene(args.svg_path)


    ## svgpathtools for checking the bounding box
    # paths, _, _ = svg2paths2(args.svg_path)
    # print(len(paths))
    # xmin, xmax, ymin, ymax = big_bounding_box(paths)
    # print(xmin, xmax, ymin, ymax)
    # input("check")


    print('tar h : %d tar w : %d'%(tar_h, tar_w))
    print('canvas h : %d canvas w : %d' % (canvas_height, canvas_width))
    scale_ratio = tar_h / canvas_height
    print("scale ratio : ", scale_ratio)
    # input("check")
    for path in shapes:
        path.points[..., 0] = path.points[..., 0] * scale_ratio
        path.points[..., 1] = path.points[..., 1] * scale_ratio

    init_img = render(tar_w, tar_h, shapes, shape_groups)
    pydiffvg.imwrite(init_img.cpu(), 'results/texture_synthesis/%d/init.png'%(args.case), gamma=2.2)
    # input("check")
    random.seed(1234)
    torch.manual_seed(1234)

    points_vars = []
    for path in shapes:
        path.points.requires_grad = True
        points_vars.append(path.points)
    color_vars = []
    for group in shape_groups:
        group.fill_color.requires_grad = True
        color_vars.append(group.fill_color)
    # Optimize
    points_optim = torch.optim.Adam(points_vars, lr=1.0)
    color_optim = torch.optim.Adam(color_vars, lr=0.01)

    target = torch.from_numpy(target_img).to(torch.float32) / 255.0
    target = target.pow(2.2)
    target = target.to(pydiffvg.get_device())
    target = target.unsqueeze(0)
    target = target.permute(0, 3, 1, 2) # NHWC -> NCHW
    canvas_width, canvas_height = target.shape[3], target.shape[2]
    # print('canvas h : %d canvas w : %d' % (canvas_height, canvas_width))
    # input("check")

    for t in range(args.max_iter):
        print('iteration:', t)
        points_optim.zero_grad()
        color_optim.zero_grad()
        cur_img = render(canvas_width, canvas_height, shapes, shape_groups)
        pydiffvg.imwrite(cur_img.cpu(), 'results/texture_synthesis/%d/iter_%d.png'%(args.case, t), gamma=2.2)
        cur_img = cur_img[:, :, :3]
        cur_img = cur_img.unsqueeze(0)
        cur_img = cur_img.permute(0, 3, 1, 2) # NHWC -> NCHW

        ## perceptual loss
        # loss = perception_loss(cur_img, target)
        ## l2 loss
        loss = (cur_img - target).pow(2).mean()
        print('render loss:', loss.item())
        loss.backward()

        points_optim.step()
        color_optim.step()

        for group in shape_groups:
            group.fill_color.data.clamp_(0.0, 1.0)
        ## write svg
        if t % 10 == 0 or t == args.max_iter - 1:
            pydiffvg.save_svg('results/texture_synthesis/%d/iter_%d.svg'%(args.case, t),
                              canvas_width, canvas_height, shapes, shape_groups)

    ## render final result
    final_img = render(tar_h, tar_w, shapes, shape_groups)
    pydiffvg.imwrite(final_img.cpu(), 'results/texture_synthesis/%d/final.png'%(args.case), gamma=2.2)


    from subprocess import call
    call(["ffmpeg", "-framerate", "24", "-i",
        "results/texture_synthesis/%d/iter_%d.png"%(args.case), "-vb", "20M",
        "results/texture_synthesis/%d/out.mp4"%(args.case)])
    ## make gif
    make_gif("results/texture_synthesis/%d"%(args.case), "results/texture_synthesis/%d/out.gif"%(args.case), frame_every_X_steps=1, repeat_ending=3, total_iter=args.max_iter)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    ## target image path
    parser.add_argument("target", help="target image path")
    parser.add_argument("--svg-path", type=str, help="the corresponding svg file path")
    parser.add_argument("--max-iter", type=int, default=500, help="the max optimization iterations")
    parser.add_argument("--case", type=int, default=1, help="just the case id for a separate result folder")
    args = parser.parse_args()
    main(args)