Spaces:
Running
Running
File size: 4,742 Bytes
966ae59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# -*- coding: utf-8 -*-
# Author: ximing xing
# Description: the main func of this project.
# Copyright (c) 2023, XiMing Xing.
import os
import sys
from functools import partial
from accelerate.utils import set_seed
import hydra
import omegaconf
sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
from pytorch_svgrender.utils import render_batch_wrap, get_seed_range
METHODS = [
'diffvg',
'live',
'vectorfusion',
'clipasso',
'clipascene',
'diffsketcher',
'stylediffsketcher',
'clipdraw',
'styleclipdraw',
'wordasimage',
'clipfont',
'svgdreamer'
]
@hydra.main(version_base=None, config_path="conf", config_name='config')
def main(cfg: omegaconf.DictConfig):
# print(omegaconf.OmegaConf.to_yaml(cfg))
flag = cfg.x.method
assert flag in METHODS, f"{flag} is not currently supported!"
# seed prepare
set_seed(cfg.seed)
seed_range = get_seed_range(cfg.srange) if cfg.multirun else None
# render function
render_batch_fn = partial(render_batch_wrap, cfg=cfg, seed_range=seed_range)
if flag == "diffvg": # img2svg
from pytorch_svgrender.pipelines.DiffVG_pipeline import DiffVGPipeline
pipe = DiffVGPipeline(cfg)
pipe.painterly_rendering(cfg.target)
elif flag == "live": # img2svg
from pytorch_svgrender.pipelines.LIVE_pipeline import LIVEPipeline
pipe = LIVEPipeline(cfg)
pipe.painterly_rendering(cfg.target)
elif flag == "vectorfusion": # text2svg
from pytorch_svgrender.pipelines.VectorFusion_pipeline import VectorFusionPipeline
if not cfg.multirun:
pipe = VectorFusionPipeline(cfg)
pipe.painterly_rendering(cfg.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=VectorFusionPipeline, text_prompt=cfg.prompt)
elif flag == "svgdreamer": # text2svg
from pytorch_svgrender.pipelines.SVGDreamer_pipeline import SVGDreamerPipeline
if not cfg.multirun:
pipe = SVGDreamerPipeline(cfg)
pipe.painterly_rendering(cfg.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=SVGDreamerPipeline, text_prompt=cfg.prompt, target_file=None)
elif flag == "wordasimage": # text2font
from pytorch_svgrender.pipelines.WordAsImage_pipeline import WordAsImagePipeline
pipe = WordAsImagePipeline(cfg)
pipe.painterly_rendering(cfg.x.word, cfg.prompt, cfg.x.optim_letter)
elif flag == "clipasso": # img2sketch
from pytorch_svgrender.pipelines.CLIPasso_pipeline import CLIPassoPipeline
pipe = CLIPassoPipeline(cfg)
pipe.painterly_rendering(cfg.target)
elif flag == 'clipascene':
from pytorch_svgrender.pipelines.CLIPascene_pipeline import CLIPascenePipeline
pipe = CLIPascenePipeline(cfg)
pipe.painterly_rendering(cfg.target)
elif flag == "clipdraw": # text2svg
from pytorch_svgrender.pipelines.CLIPDraw_pipeline import CLIPDrawPipeline
pipe = CLIPDrawPipeline(cfg)
pipe.painterly_rendering(cfg.prompt)
elif flag == "clipfont": # text and font to font
from pytorch_svgrender.pipelines.CLIPFont_pipeline import CLIPFontPipeline
if not cfg.multirun:
pipe = CLIPFontPipeline(cfg)
pipe.painterly_rendering(svg_path=cfg.target, prompt=cfg.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=CLIPFontPipeline, svg_path=cfg.target, prompt=cfg.prompt)
elif flag == "styleclipdraw": # text to stylized svg
from pytorch_svgrender.pipelines.StyleCLIPDraw_pipeline import StyleCLIPDrawPipeline
pipe = StyleCLIPDrawPipeline(cfg)
pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target)
elif flag == "diffsketcher": # text2sketch
from pytorch_svgrender.pipelines.DiffSketcher_pipeline import DiffSketcherPipeline
if not cfg.multirun:
pipe = DiffSketcherPipeline(cfg)
pipe.painterly_rendering(cfg.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=DiffSketcherPipeline, prompt=cfg.prompt)
elif flag == "stylediffsketcher": # text2sketch + style transfer
from pytorch_svgrender.pipelines.DiffSketcher_stylized_pipeline import StylizedDiffSketcherPipeline
if not cfg.multirun:
pipe = StylizedDiffSketcherPipeline(cfg)
pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target)
else: # generate many SVG at once
render_batch_fn(pipeline=StylizedDiffSketcherPipeline, prompt=cfg.prompt, style_fpath=cfg.style_file)
if __name__ == '__main__':
main()
|