File size: 4,742 Bytes
966ae59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# -*- coding: utf-8 -*-
# Author: ximing xing
# Description: the main func of this project.
# Copyright (c) 2023, XiMing Xing.

import os
import sys
from functools import partial

from accelerate.utils import set_seed
import hydra
import omegaconf

sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])

from pytorch_svgrender.utils import render_batch_wrap, get_seed_range

METHODS = [
    'diffvg',
    'live',
    'vectorfusion',
    'clipasso',
    'clipascene',
    'diffsketcher',
    'stylediffsketcher',
    'clipdraw',
    'styleclipdraw',
    'wordasimage',
    'clipfont',
    'svgdreamer'
]


@hydra.main(version_base=None, config_path="conf", config_name='config')
def main(cfg: omegaconf.DictConfig):
    # print(omegaconf.OmegaConf.to_yaml(cfg))
    flag = cfg.x.method
    assert flag in METHODS, f"{flag} is not currently supported!"

    # seed prepare
    set_seed(cfg.seed)
    seed_range = get_seed_range(cfg.srange) if cfg.multirun else None

    # render function
    render_batch_fn = partial(render_batch_wrap, cfg=cfg, seed_range=seed_range)

    if flag == "diffvg":  # img2svg
        from pytorch_svgrender.pipelines.DiffVG_pipeline import DiffVGPipeline

        pipe = DiffVGPipeline(cfg)
        pipe.painterly_rendering(cfg.target)

    elif flag == "live":  # img2svg
        from pytorch_svgrender.pipelines.LIVE_pipeline import LIVEPipeline

        pipe = LIVEPipeline(cfg)
        pipe.painterly_rendering(cfg.target)

    elif flag == "vectorfusion":  # text2svg
        from pytorch_svgrender.pipelines.VectorFusion_pipeline import VectorFusionPipeline

        if not cfg.multirun:
            pipe = VectorFusionPipeline(cfg)
            pipe.painterly_rendering(cfg.prompt)
        else:  # generate many SVG at once
            render_batch_fn(pipeline=VectorFusionPipeline, text_prompt=cfg.prompt)

    elif flag == "svgdreamer":  # text2svg
        from pytorch_svgrender.pipelines.SVGDreamer_pipeline import SVGDreamerPipeline

        if not cfg.multirun:
            pipe = SVGDreamerPipeline(cfg)
            pipe.painterly_rendering(cfg.prompt)
        else:  # generate many SVG at once
            render_batch_fn(pipeline=SVGDreamerPipeline, text_prompt=cfg.prompt, target_file=None)

    elif flag == "wordasimage":  # text2font
        from pytorch_svgrender.pipelines.WordAsImage_pipeline import WordAsImagePipeline

        pipe = WordAsImagePipeline(cfg)
        pipe.painterly_rendering(cfg.x.word, cfg.prompt, cfg.x.optim_letter)

    elif flag == "clipasso":  # img2sketch
        from pytorch_svgrender.pipelines.CLIPasso_pipeline import CLIPassoPipeline

        pipe = CLIPassoPipeline(cfg)
        pipe.painterly_rendering(cfg.target)

    elif flag == 'clipascene':
        from pytorch_svgrender.pipelines.CLIPascene_pipeline import CLIPascenePipeline

        pipe = CLIPascenePipeline(cfg)
        pipe.painterly_rendering(cfg.target)

    elif flag == "clipdraw":  # text2svg
        from pytorch_svgrender.pipelines.CLIPDraw_pipeline import CLIPDrawPipeline

        pipe = CLIPDrawPipeline(cfg)
        pipe.painterly_rendering(cfg.prompt)

    elif flag == "clipfont":  # text and font to font
        from pytorch_svgrender.pipelines.CLIPFont_pipeline import CLIPFontPipeline

        if not cfg.multirun:
            pipe = CLIPFontPipeline(cfg)
            pipe.painterly_rendering(svg_path=cfg.target, prompt=cfg.prompt)
        else:  # generate many SVG at once
            render_batch_fn(pipeline=CLIPFontPipeline, svg_path=cfg.target, prompt=cfg.prompt)

    elif flag == "styleclipdraw":  # text to stylized svg
        from pytorch_svgrender.pipelines.StyleCLIPDraw_pipeline import StyleCLIPDrawPipeline

        pipe = StyleCLIPDrawPipeline(cfg)
        pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target)

    elif flag == "diffsketcher":  # text2sketch
        from pytorch_svgrender.pipelines.DiffSketcher_pipeline import DiffSketcherPipeline

        if not cfg.multirun:
            pipe = DiffSketcherPipeline(cfg)
            pipe.painterly_rendering(cfg.prompt)
        else:  # generate many SVG at once
            render_batch_fn(pipeline=DiffSketcherPipeline, prompt=cfg.prompt)

    elif flag == "stylediffsketcher":  # text2sketch + style transfer
        from pytorch_svgrender.pipelines.DiffSketcher_stylized_pipeline import StylizedDiffSketcherPipeline

        if not cfg.multirun:
            pipe = StylizedDiffSketcherPipeline(cfg)
            pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target)
        else:  # generate many SVG at once
            render_batch_fn(pipeline=StylizedDiffSketcherPipeline, prompt=cfg.prompt, style_fpath=cfg.style_file)


if __name__ == '__main__':
    main()