"""Evaluate a pretrained GAN model.
Usage:

`python eval_gan.py <path/to/model/folder>`, e.g. 
`../results/quickdraw_gan_vector_bezier_fc_wgan`.

"""
import os
import argparse
import torch as th
import numpy as np
import ttools
import imageio
from subprocess import call

import pydiffvg

import models


LOG = ttools.get_logger(__name__)


def postprocess(im, invert=False):
    im = th.clamp((im + 1.0) / 2.0, 0, 1)
    if invert:
        im = (1.0 - im)
    im = ttools.tensor2image(im)
    return im


def imsave(im, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    imageio.imwrite(path, im)


def save_scene(scn, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    pydiffvg.save_svg(path, *scn, use_gamma=False)


def run(args):
    th.manual_seed(0)
    np.random.seed(0)

    meta = ttools.Checkpointer.load_meta(args.model, "vect_g_")

    if meta is None:
        LOG.warning("Could not load metadata at %s, aborting.", args.model)
        return

    LOG.info("Loaded model %s with metadata:\n %s", args.model, meta)

    if args.output_dir is None:
        outdir = os.path.join(args.model, "eval")
    else:
        outdir = args.output_dir
    os.makedirs(outdir, exist_ok=True)

    model_params = meta["model_params"]
    if args.imsize is not None:
        LOG.info("Overriding output image size to: %dx%d", args.imsize,
                 args.imsize)
        old_size = model_params["imsize"]
        scale = args.imsize * 1.0 / old_size
        model_params["imsize"] = args.imsize
        model_params["stroke_width"] = [w*scale for w in
                                        model_params["stroke_width"]]
        LOG.info("Overriding width to: %s", model_params["stroke_width"])

    # task = meta["task"]
    generator = meta["generator"]
    if generator == "fc":
        model = models.VectorGenerator(**model_params)
    elif generator == "bezier_fc":
        model = models.BezierVectorGenerator(**model_params)
    elif generator in ["rnn"]:
        model = models.RNNVectorGenerator(**model_params)
    elif generator in ["chain_rnn"]:
        model = models.ChainRNNVectorGenerator(**model_params)
    else:
        raise NotImplementedError()
    model.eval()

    device = "cpu"
    if th.cuda.is_available():
        device = "cuda"

    model.to(device)

    checkpointer = ttools.Checkpointer(
        args.model, model, meta=meta, prefix="vect_g_")
    checkpointer.load_latest()

    LOG.info("Computing latent space interpolation")
    for i in range(args.nsamples):
        z0 = model.sample_z(1)
        z1 = model.sample_z(1)

        # interpolation
        alpha = th.linspace(0, 1, args.nsteps).view(args.nsteps, 1).to(device)
        alpha_video = th.linspace(0, 1, args.nframes).view(args.nframes, 1)
        alpha_video = alpha_video.to(device)

        length = [args.nsteps, args.nframes]
        for idx, a in enumerate([alpha, alpha_video]):
            _z0 = z0.repeat(length[idx], 1).to(device)
            _z1 = z1.repeat(length[idx], 1).to(device)
            batch = _z0*(1-a) + _z1*a
            out = model(batch)
            if idx == 0:  # image viz
                n, c, h, w = out.shape
                out = out.permute(1, 2, 0, 3)
                out = out.contiguous().view(1, c, h, w*n)
                out = postprocess(out, invert=args.invert)
                imsave(out, os.path.join(outdir,
                                         "latent_interp", "%03d.png" % i))

                scenes = model.get_vector(batch)
                for scn_idx, scn in enumerate(scenes):
                    save_scene(scn, os.path.join(outdir, "latent_interp_svg",
                                                 "%03d" % i, "%03d.svg" %
                                                 scn_idx))
            else:  # video viz
                anim_root = os.path.join(outdir,
                                         "latent_interp_video", "%03d" % i)
                LOG.info("Rendering animation %d", i)
                for frame_idx, frame in enumerate(out):
                    LOG.info("frame %d", frame_idx)
                    frame = frame.unsqueeze(0)
                    frame = postprocess(frame, invert=args.invert)
                    imsave(frame, os.path.join(anim_root,
                                               "frame%04d.png" % frame_idx))
                call(["ffmpeg", "-framerate", "30", "-i",
                      os.path.join(anim_root, "frame%04d.png"), "-vb", "20M",
                     os.path.join(outdir,
                                  "latent_interp_video", "%03d.mp4" % i)])
        LOG.info("  saved %d", i)

    LOG.info("Sampling latent space")

    for i in range(args.nsamples):
        n = 8
        bs = n*n
        z = model.sample_z(bs).to(device)
        out = model(z)
        _, c, h, w = out.shape
        out = out.view(n, n, c, h, w).permute(2, 0, 3, 1, 4)
        out = out.contiguous().view(1, c, h*n, w*n)
        out = postprocess(out)
        imsave(out, os.path.join(outdir, "samples_%03d.png" % i))
        LOG.info("  saved %d", i)

    LOG.info("output images saved to %s", outdir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("model")
    parser.add_argument("--output_dir", help="output directory for "
                        " the samples. Defaults to the model's path")
    parser.add_argument("--nsamples", default=16, type=int, 
                        help="number of output to compute")
    parser.add_argument("--imsize", type=int,
                        help="if provided, override the raster output "
                        "resolution")
    parser.add_argument("--nsteps", default=9, type=int, help="number of "
                        "interpolation steps for the interpolation")
    parser.add_argument("--nframes", default=120, type=int, help="number of "
                        "frames for the interpolation video")
    parser.add_argument("--invert", default=False, action="store_true",
                        help="if True, render black on white rather than the"
                        " opposite")

    args = parser.parse_args()

    pydiffvg.set_use_gpu(False)

    ttools.set_logger(False)

    run(args)