DiffSketcher / svg_render.py
hjc-owo
init repo
966ae59
# -*- coding: utf-8 -*-
# Author: ximing xing
# Description: the main func of this project.
# Copyright (c) 2023, XiMing Xing.
import os
import sys
from functools import partial
from accelerate.utils import set_seed
import hydra
import omegaconf
sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
from pytorch_svgrender.utils import render_batch_wrap, get_seed_range
METHODS = [
'diffvg',
'live',
'vectorfusion',
'clipasso',
'clipascene',
'diffsketcher',
'stylediffsketcher',
'clipdraw',
'styleclipdraw',
'wordasimage',
'clipfont',
'svgdreamer'
]
@hydra.main(version_base=None, config_path="conf", config_name='config')
def main(cfg: omegaconf.DictConfig):
# print(omegaconf.OmegaConf.to_yaml(cfg))
flag = cfg.x.method
assert flag in METHODS, f"{flag} is not currently supported!"
# seed prepare
set_seed(cfg.seed)
seed_range = get_seed_range(cfg.srange) if cfg.multirun else None
# render function
render_batch_fn = partial(render_batch_wrap, cfg=cfg, seed_range=seed_range)
if flag == "diffvg": # img2svg
from pytorch_svgrender.pipelines.DiffVG_pipeline import DiffVGPipeline
pipe = DiffVGPipeline(cfg)
pipe.painterly_rendering(cfg.target)
elif flag == "live": # img2svg
from pytorch_svgrender.pipelines.LIVE_pipeline import LIVEPipeline
pipe = LIVEPipeline(cfg)
pipe.painterly_rendering(cfg.target)
elif flag == "vectorfusion": # text2svg
from pytorch_svgrender.pipelines.VectorFusion_pipeline import VectorFusionPipeline
if not cfg.multirun:
pipe = VectorFusionPipeline(cfg)
pipe.painterly_rendering(cfg.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=VectorFusionPipeline, text_prompt=cfg.prompt)
elif flag == "svgdreamer": # text2svg
from pytorch_svgrender.pipelines.SVGDreamer_pipeline import SVGDreamerPipeline
if not cfg.multirun:
pipe = SVGDreamerPipeline(cfg)
pipe.painterly_rendering(cfg.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=SVGDreamerPipeline, text_prompt=cfg.prompt, target_file=None)
elif flag == "wordasimage": # text2font
from pytorch_svgrender.pipelines.WordAsImage_pipeline import WordAsImagePipeline
pipe = WordAsImagePipeline(cfg)
pipe.painterly_rendering(cfg.x.word, cfg.prompt, cfg.x.optim_letter)
elif flag == "clipasso": # img2sketch
from pytorch_svgrender.pipelines.CLIPasso_pipeline import CLIPassoPipeline
pipe = CLIPassoPipeline(cfg)
pipe.painterly_rendering(cfg.target)
elif flag == 'clipascene':
from pytorch_svgrender.pipelines.CLIPascene_pipeline import CLIPascenePipeline
pipe = CLIPascenePipeline(cfg)
pipe.painterly_rendering(cfg.target)
elif flag == "clipdraw": # text2svg
from pytorch_svgrender.pipelines.CLIPDraw_pipeline import CLIPDrawPipeline
pipe = CLIPDrawPipeline(cfg)
pipe.painterly_rendering(cfg.prompt)
elif flag == "clipfont": # text and font to font
from pytorch_svgrender.pipelines.CLIPFont_pipeline import CLIPFontPipeline
if not cfg.multirun:
pipe = CLIPFontPipeline(cfg)
pipe.painterly_rendering(svg_path=cfg.target, prompt=cfg.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=CLIPFontPipeline, svg_path=cfg.target, prompt=cfg.prompt)
elif flag == "styleclipdraw": # text to stylized svg
from pytorch_svgrender.pipelines.StyleCLIPDraw_pipeline import StyleCLIPDrawPipeline
pipe = StyleCLIPDrawPipeline(cfg)
pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target)
elif flag == "diffsketcher": # text2sketch
from pytorch_svgrender.pipelines.DiffSketcher_pipeline import DiffSketcherPipeline
if not cfg.multirun:
pipe = DiffSketcherPipeline(cfg)
pipe.painterly_rendering(cfg.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=DiffSketcherPipeline, prompt=cfg.prompt)
elif flag == "stylediffsketcher": # text2sketch + style transfer
from pytorch_svgrender.pipelines.DiffSketcher_stylized_pipeline import StylizedDiffSketcherPipeline
if not cfg.multirun:
pipe = StylizedDiffSketcherPipeline(cfg)
pipe.painterly_rendering(cfg.prompt, style_fpath=cfg.target)
else: # generate many SVG at once
render_batch_fn(pipeline=StylizedDiffSketcherPipeline, prompt=cfg.prompt, style_fpath=cfg.style_file)
if __name__ == '__main__':
main()