# -*- coding: utf-8 -*- # Author: ximing xing # Description: the main func of this project. # Copyright (c) 2023, XiMing Xing. import os import sys from functools import partial from accelerate.utils import set_seed import hydra import omegaconf sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0]) from pytorch_svgrender.utils import render_batch_wrap, get_seed_range METHODS = [ 'diffvg', 'live', 'vectorfusion', 'clipasso', 'clipascene', 'diffsketcher', 'stylediffsketcher', 'clipdraw', 'styleclipdraw', 'wordasimage', 'clipfont', 'svgdreamer' ] @hydra.main(version_base=None, config_path="conf", config_name='config') def main(cfg: omegaconf.DictConfig): # print(omegaconf.OmegaConf.to_yaml(cfg)) flag = cfg.x.method assert flag in METHODS, f"{flag} is not currently supported!" # seed prepare set_seed(cfg.seed) seed_range = get_seed_range(cfg.srange) if cfg.multirun else None # render function render_batch_fn = partial(render_batch_wrap, cfg=cfg, seed_range=seed_range) if flag == "diffvg": # img2svg from pytorch_svgrender.pipelines.DiffVG_pipeline import DiffVGPipeline pipe = DiffVGPipeline(cfg) pipe.painterly_rendering(cfg.target) elif flag == "live": # img2svg from pytorch_svgrender.pipelines.LIVE_pipeline import LIVEPipeline pipe = LIVEPipeline(cfg) pipe.painterly_rendering(cfg.target) elif flag == "vectorfusion": # text2svg from pytorch_svgrender.pipelines.VectorFusion_pipeline import VectorFusionPipeline if not cfg.multirun: pipe = VectorFusionPipeline(cfg) pipe.painterly_rendering(cfg.prompt) else: # generate many SVG at once render_batch_fn(pipeline=VectorFusionPipeline, text_prompt=cfg.prompt) elif flag == "svgdreamer": # text2svg from pytorch_svgrender.pipelines.SVGDreamer_pipeline import SVGDreamerPipeline if not cfg.multirun: pipe = SVGDreamerPipeline(cfg) pipe.painterly_rendering(cfg.prompt) else: # generate many SVG at once render_batch_fn(pipeline=SVGDreamerPipeline, text_prompt=cfg.prompt, target_file=None) elif flag == "wordasimage": # text2font from pytorch_svgrender.pipelines.WordAsImage_pipeline import WordAsImagePipeline pipe = WordAsImagePipeline(cfg) pipe.painterly_rendering(cfg.x.word, cfg.prompt, cfg.x.optim_letter) elif flag == "clipasso": # img2sketch from pytorch_svgrender.pipelines.CLIPasso_pipeline import CLIPassoPipeline pipe = CLIPassoPipeline(cfg) pipe.painterly_rendering(cfg.target) elif flag == 'clipascene': from pytorch_svgrender.pipelines.CLIPascene_pipeline import CLIPascenePipeline pipe = CLIPascenePipeline(cfg) pipe.painterly_rendering(cfg.target) elif flag == "clipdraw": # text2svg from pytorch_svgrender.pipelines.CLIPDraw_pipeline import CLIPDrawPipeline pipe = CLIPDrawPipeline(cfg) pipe.painterly_rendering(cfg.prompt) elif flag == "clipfont": # text and font to font from pytorch_svgrender.pipelines.CLIPFont_pipeline import CLIPFontPipeline if not cfg.multirun: pipe = CLIPFontPipeline(cfg) pipe.painterly_rendering(svg_path=cfg.target, prompt=cfg.prompt) else: # generate many SVG at once render_batch_fn(pipeline=CLIPFontPipeline, svg_path=cfg.target, prompt=cfg.prompt) elif flag == "styleclipdraw": # text to stylized svg from pytorch_svgrender.pipelines.StyleCLIPDraw_pipeline import StyleCLIPDrawPipeline pipe = StyleCLIPDrawPipeline(cfg) pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target) elif flag == "diffsketcher": # text2sketch from pytorch_svgrender.pipelines.DiffSketcher_pipeline import DiffSketcherPipeline if not cfg.multirun: pipe = DiffSketcherPipeline(cfg) pipe.painterly_rendering(cfg.prompt) else: # generate many SVG at once render_batch_fn(pipeline=DiffSketcherPipeline, prompt=cfg.prompt) elif flag == "stylediffsketcher": # text2sketch + style transfer from pytorch_svgrender.pipelines.DiffSketcher_stylized_pipeline import StylizedDiffSketcherPipeline if not cfg.multirun: pipe = StylizedDiffSketcherPipeline(cfg) pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target) else: # generate many SVG at once render_batch_fn(pipeline=StylizedDiffSketcherPipeline, prompt=cfg.prompt, style_fpath=cfg.style_file) if __name__ == '__main__': main()