Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Author: ximing xing | |
# Description: the main func of this project. | |
# Copyright (c) 2023, XiMing Xing. | |
import os | |
import sys | |
from functools import partial | |
from accelerate.utils import set_seed | |
import hydra | |
import omegaconf | |
sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0]) | |
from pytorch_svgrender.utils import render_batch_wrap, get_seed_range | |
METHODS = [ | |
'diffvg', | |
'live', | |
'vectorfusion', | |
'clipasso', | |
'clipascene', | |
'diffsketcher', | |
'stylediffsketcher', | |
'clipdraw', | |
'styleclipdraw', | |
'wordasimage', | |
'clipfont', | |
'svgdreamer' | |
] | |
def main(cfg: omegaconf.DictConfig): | |
# print(omegaconf.OmegaConf.to_yaml(cfg)) | |
flag = cfg.x.method | |
assert flag in METHODS, f"{flag} is not currently supported!" | |
# seed prepare | |
set_seed(cfg.seed) | |
seed_range = get_seed_range(cfg.srange) if cfg.multirun else None | |
# render function | |
render_batch_fn = partial(render_batch_wrap, cfg=cfg, seed_range=seed_range) | |
if flag == "diffvg": # img2svg | |
from pytorch_svgrender.pipelines.DiffVG_pipeline import DiffVGPipeline | |
pipe = DiffVGPipeline(cfg) | |
pipe.painterly_rendering(cfg.target) | |
elif flag == "live": # img2svg | |
from pytorch_svgrender.pipelines.LIVE_pipeline import LIVEPipeline | |
pipe = LIVEPipeline(cfg) | |
pipe.painterly_rendering(cfg.target) | |
elif flag == "vectorfusion": # text2svg | |
from pytorch_svgrender.pipelines.VectorFusion_pipeline import VectorFusionPipeline | |
if not cfg.multirun: | |
pipe = VectorFusionPipeline(cfg) | |
pipe.painterly_rendering(cfg.prompt) | |
else: # generate many SVG at once | |
render_batch_fn(pipeline=VectorFusionPipeline, text_prompt=cfg.prompt) | |
elif flag == "svgdreamer": # text2svg | |
from pytorch_svgrender.pipelines.SVGDreamer_pipeline import SVGDreamerPipeline | |
if not cfg.multirun: | |
pipe = SVGDreamerPipeline(cfg) | |
pipe.painterly_rendering(cfg.prompt) | |
else: # generate many SVG at once | |
render_batch_fn(pipeline=SVGDreamerPipeline, text_prompt=cfg.prompt, target_file=None) | |
elif flag == "wordasimage": # text2font | |
from pytorch_svgrender.pipelines.WordAsImage_pipeline import WordAsImagePipeline | |
pipe = WordAsImagePipeline(cfg) | |
pipe.painterly_rendering(cfg.x.word, cfg.prompt, cfg.x.optim_letter) | |
elif flag == "clipasso": # img2sketch | |
from pytorch_svgrender.pipelines.CLIPasso_pipeline import CLIPassoPipeline | |
pipe = CLIPassoPipeline(cfg) | |
pipe.painterly_rendering(cfg.target) | |
elif flag == 'clipascene': | |
from pytorch_svgrender.pipelines.CLIPascene_pipeline import CLIPascenePipeline | |
pipe = CLIPascenePipeline(cfg) | |
pipe.painterly_rendering(cfg.target) | |
elif flag == "clipdraw": # text2svg | |
from pytorch_svgrender.pipelines.CLIPDraw_pipeline import CLIPDrawPipeline | |
pipe = CLIPDrawPipeline(cfg) | |
pipe.painterly_rendering(cfg.prompt) | |
elif flag == "clipfont": # text and font to font | |
from pytorch_svgrender.pipelines.CLIPFont_pipeline import CLIPFontPipeline | |
if not cfg.multirun: | |
pipe = CLIPFontPipeline(cfg) | |
pipe.painterly_rendering(svg_path=cfg.target, prompt=cfg.prompt) | |
else: # generate many SVG at once | |
render_batch_fn(pipeline=CLIPFontPipeline, svg_path=cfg.target, prompt=cfg.prompt) | |
elif flag == "styleclipdraw": # text to stylized svg | |
from pytorch_svgrender.pipelines.StyleCLIPDraw_pipeline import StyleCLIPDrawPipeline | |
pipe = StyleCLIPDrawPipeline(cfg) | |
pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target) | |
elif flag == "diffsketcher": # text2sketch | |
from pytorch_svgrender.pipelines.DiffSketcher_pipeline import DiffSketcherPipeline | |
if not cfg.multirun: | |
pipe = DiffSketcherPipeline(cfg) | |
pipe.painterly_rendering(cfg.prompt) | |
else: # generate many SVG at once | |
render_batch_fn(pipeline=DiffSketcherPipeline, prompt=cfg.prompt) | |
elif flag == "stylediffsketcher": # text2sketch + style transfer | |
from pytorch_svgrender.pipelines.DiffSketcher_stylized_pipeline import StylizedDiffSketcherPipeline | |
if not cfg.multirun: | |
pipe = StylizedDiffSketcherPipeline(cfg) | |
pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target) | |
else: # generate many SVG at once | |
render_batch_fn(pipeline=StylizedDiffSketcherPipeline, prompt=cfg.prompt, style_fpath=cfg.style_file) | |
if __name__ == '__main__': | |
main() | |