M3000j's picture
Upload folder using huggingface_hub
31726e5 verified
"""Evaluate a pretrained GAN model.
Usage:
`python eval_gan.py <path/to/model/folder>`, e.g.
`../results/quickdraw_gan_vector_bezier_fc_wgan`.
"""
import os
import argparse
import torch as th
import numpy as np
import ttools
import imageio
from subprocess import call
import pydiffvg
import models
LOG = ttools.get_logger(__name__)
def postprocess(im, invert=False):
im = th.clamp((im + 1.0) / 2.0, 0, 1)
if invert:
im = (1.0 - im)
im = ttools.tensor2image(im)
return im
def imsave(im, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.imwrite(path, im)
def save_scene(scn, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
pydiffvg.save_svg(path, *scn, use_gamma=False)
def run(args):
th.manual_seed(0)
np.random.seed(0)
meta = ttools.Checkpointer.load_meta(args.model, "vect_g_")
if meta is None:
LOG.warning("Could not load metadata at %s, aborting.", args.model)
return
LOG.info("Loaded model %s with metadata:\n %s", args.model, meta)
if args.output_dir is None:
outdir = os.path.join(args.model, "eval")
else:
outdir = args.output_dir
os.makedirs(outdir, exist_ok=True)
model_params = meta["model_params"]
if args.imsize is not None:
LOG.info("Overriding output image size to: %dx%d", args.imsize,
args.imsize)
old_size = model_params["imsize"]
scale = args.imsize * 1.0 / old_size
model_params["imsize"] = args.imsize
model_params["stroke_width"] = [w*scale for w in
model_params["stroke_width"]]
LOG.info("Overriding width to: %s", model_params["stroke_width"])
# task = meta["task"]
generator = meta["generator"]
if generator == "fc":
model = models.VectorGenerator(**model_params)
elif generator == "bezier_fc":
model = models.BezierVectorGenerator(**model_params)
elif generator in ["rnn"]:
model = models.RNNVectorGenerator(**model_params)
elif generator in ["chain_rnn"]:
model = models.ChainRNNVectorGenerator(**model_params)
else:
raise NotImplementedError()
model.eval()
device = "cpu"
if th.cuda.is_available():
device = "cuda"
model.to(device)
checkpointer = ttools.Checkpointer(
args.model, model, meta=meta, prefix="vect_g_")
checkpointer.load_latest()
LOG.info("Computing latent space interpolation")
for i in range(args.nsamples):
z0 = model.sample_z(1)
z1 = model.sample_z(1)
# interpolation
alpha = th.linspace(0, 1, args.nsteps).view(args.nsteps, 1).to(device)
alpha_video = th.linspace(0, 1, args.nframes).view(args.nframes, 1)
alpha_video = alpha_video.to(device)
length = [args.nsteps, args.nframes]
for idx, a in enumerate([alpha, alpha_video]):
_z0 = z0.repeat(length[idx], 1).to(device)
_z1 = z1.repeat(length[idx], 1).to(device)
batch = _z0*(1-a) + _z1*a
out = model(batch)
if idx == 0: # image viz
n, c, h, w = out.shape
out = out.permute(1, 2, 0, 3)
out = out.contiguous().view(1, c, h, w*n)
out = postprocess(out, invert=args.invert)
imsave(out, os.path.join(outdir,
"latent_interp", "%03d.png" % i))
scenes = model.get_vector(batch)
for scn_idx, scn in enumerate(scenes):
save_scene(scn, os.path.join(outdir, "latent_interp_svg",
"%03d" % i, "%03d.svg" %
scn_idx))
else: # video viz
anim_root = os.path.join(outdir,
"latent_interp_video", "%03d" % i)
LOG.info("Rendering animation %d", i)
for frame_idx, frame in enumerate(out):
LOG.info("frame %d", frame_idx)
frame = frame.unsqueeze(0)
frame = postprocess(frame, invert=args.invert)
imsave(frame, os.path.join(anim_root,
"frame%04d.png" % frame_idx))
call(["ffmpeg", "-framerate", "30", "-i",
os.path.join(anim_root, "frame%04d.png"), "-vb", "20M",
os.path.join(outdir,
"latent_interp_video", "%03d.mp4" % i)])
LOG.info(" saved %d", i)
LOG.info("Sampling latent space")
for i in range(args.nsamples):
n = 8
bs = n*n
z = model.sample_z(bs).to(device)
out = model(z)
_, c, h, w = out.shape
out = out.view(n, n, c, h, w).permute(2, 0, 3, 1, 4)
out = out.contiguous().view(1, c, h*n, w*n)
out = postprocess(out)
imsave(out, os.path.join(outdir, "samples_%03d.png" % i))
LOG.info(" saved %d", i)
LOG.info("output images saved to %s", outdir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("model")
parser.add_argument("--output_dir", help="output directory for "
" the samples. Defaults to the model's path")
parser.add_argument("--nsamples", default=16, type=int,
help="number of output to compute")
parser.add_argument("--imsize", type=int,
help="if provided, override the raster output "
"resolution")
parser.add_argument("--nsteps", default=9, type=int, help="number of "
"interpolation steps for the interpolation")
parser.add_argument("--nframes", default=120, type=int, help="number of "
"frames for the interpolation video")
parser.add_argument("--invert", default=False, action="store_true",
help="if True, render black on white rather than the"
" opposite")
args = parser.parse_args()
pydiffvg.set_use_gpu(False)
ttools.set_logger(False)
run(args)