#!/bin/env python
"""Train a GAN.

Usage:

* Train a MNIST model: 

`python train_gan.py`

* Train a Quickdraw model: 

`python train_gan.py --task quickdraw`

"""
import argparse
import os

import numpy as np
import torch as th
from torch.utils.data import DataLoader

import ttools
import ttools.interfaces

import losses
import data
import models

import pydiffvg

LOG = ttools.get_logger(__name__)


BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir)
OUTPUT = os.path.join(BASE_DIR, "results")


class Callback(ttools.callbacks.ImageDisplayCallback):
    """Simple callback that visualize images."""
    def visualized_image(self, batch, step_data, is_val=False):
        if is_val:
            return

        gen = step_data["gen_image"][:16].detach()
        ref = step_data["gt_image"][:16].detach()

        # tensor to visualize, concatenate images
        vizdata = th.cat([ref, gen], 2)

        vector = step_data["vector_image"]
        if vector is not None:
            vector = vector[:16].detach()
            vizdata = th.cat([vizdata, vector], 2)

        vizdata = (vizdata + 1.0 ) * 0.5
        viz = th.clamp(vizdata, 0, 1)
        return viz

    def caption(self, batch, step_data, is_val=False):
        if step_data["vector_image"] is not None:
            s = "top: real, middle: raster, bottom: vector"
        else:
            s = "top: real, bottom: fake"
        return s


class Interface(ttools.ModelInterface):
    def __init__(self, generator, vect_generator,
                 discriminator, vect_discriminator,
                 lr=1e-4, lr_decay=0.9999,
                 gradient_penalty=10,
                 wgan_gp=False,
                 raster_resolution=32, device="cpu", grad_clip=1.0):
        super(Interface, self).__init__()

        self.wgan_gp = wgan_gp
        self.w_gradient_penalty = gradient_penalty

        self.n_critic = 1
        if self.wgan_gp:
            self.n_critic = 5

        self.grad_clip = grad_clip
        self.raster_resolution = raster_resolution

        self.gen = generator
        self.vect_gen = vect_generator
        self.discrim = discriminator
        self.vect_discrim = vect_discriminator

        self.device = device
        self.gen.to(self.device)
        self.discrim.to(self.device)

        beta1 = 0.5
        beta2 = 0.9

        self.gen_opt = th.optim.Adam(
            self.gen.parameters(), lr=lr, betas=(beta1, beta2))
        self.discrim_opt = th.optim.Adam(
            self.discrim.parameters(), lr=lr, betas=(beta1, beta2))

        self.schedulers = [
            th.optim.lr_scheduler.ExponentialLR(self.gen_opt, lr_decay),
            th.optim.lr_scheduler.ExponentialLR(self.discrim_opt, lr_decay),
        ]

        self.optimizers = [self.gen_opt, self.discrim_opt]

        if self.vect_gen is not None:
            assert self.vect_discrim is not None

            self.vect_gen.to(self.device)
            self.vect_discrim.to(self.device)

            self.vect_gen_opt = th.optim.Adam(
                self.vect_gen.parameters(), lr=lr, betas=(beta1, beta2))
            self.vect_discrim_opt = th.optim.Adam(
                self.vect_discrim.parameters(), lr=lr, betas=(beta1, beta2))

            self.schedulers += [
                th.optim.lr_scheduler.ExponentialLR(self.vect_gen_opt,
                                                    lr_decay),
                th.optim.lr_scheduler.ExponentialLR(self.vect_discrim_opt,
                                                    lr_decay),
            ]

            self.optimizers += [self.vect_gen_opt, self.vect_discrim_opt]

        # include loss on alpha
        self.im_loss = losses.MultiscaleMSELoss(channels=4).to(self.device)

        self.iter = 0
        
        self.cross_entropy = th.nn.BCEWithLogitsLoss()
        self.mse = th.nn.MSELoss()

    def _gradient_penalty(self, discrim, fake, real):
        bs = real.size(0)
        epsilon = th.rand(bs, 1, 1, 1, device=real.device)
        epsilon = epsilon.expand_as(real)

        interpolation = epsilon * real.data + (1 - epsilon) * fake.data
        interpolation = th.autograd.Variable(interpolation, requires_grad=True)

        interpolation_logits = discrim(interpolation)
        grad_outputs = th.ones(interpolation_logits.size(), device=real.device)

        gradients = th.autograd.grad(outputs=interpolation_logits,
                                     inputs=interpolation,
                                     grad_outputs=grad_outputs,
                                     create_graph=True, retain_graph=True)[0]

        gradients = gradients.view(bs, -1)
        gradients_norm = th.sqrt(th.sum(gradients ** 2, dim=1) + 1e-12)

        # [Tanh-Tung 2019] https://openreview.net/pdf?id=ByxPYjC5KQ
        return self.w_gradient_penalty * ((gradients_norm - 0) ** 2).mean()

        # return self.w_gradient_penalty * ((gradients_norm - 1) ** 2).mean()

    def _discriminator_step(self, discrim, opt, fake, real):
        """Try to classify fake as 0 and real as 1."""

        opt.zero_grad()

        # no backprop to gen
        fake = fake.detach()

        fake_pred = discrim(fake)
        real_pred = discrim(real)

        if self.wgan_gp:
            gradient_penalty = self._gradient_penalty(discrim, fake, real)
            loss_d = fake_pred.mean() - real_pred.mean() + gradient_penalty
            gradient_penalty = gradient_penalty.item()
        else:
            fake_loss = self.cross_entropy(fake_pred, th.zeros_like(fake_pred))
            real_loss = self.cross_entropy(real_pred, th.ones_like(real_pred))
            # fake_loss = self.mse(fake_pred, th.zeros_like(fake_pred))
            # real_loss = self.mse(real_pred, th.ones_like(real_pred))
            loss_d = 0.5*(fake_loss + real_loss)
            gradient_penalty = None

        loss_d.backward()
        nrm = th.nn.utils.clip_grad_norm_(
            discrim.parameters(), self.grad_clip)
        if nrm > self.grad_clip:
            LOG.debug("Clipped discriminator gradient (%.5f) to %.2f",
                      nrm, self.grad_clip)

        opt.step()

        return loss_d.item(), gradient_penalty

    def _generator_step(self, gen, discrim, opt, fake):
        """Try to classify fake as 1."""

        opt.zero_grad()

        fake_pred = discrim(fake)

        if self.wgan_gp:
            loss_g = -fake_pred.mean()
        else:
            loss_g = self.cross_entropy(fake_pred, th.ones_like(fake_pred))
            # loss_g = self.mse(fake_pred, th.ones_like(fake_pred))

        loss_g.backward()

        # clip gradients
        nrm = th.nn.utils.clip_grad_norm_(
            gen.parameters(), self.grad_clip)
        if nrm > self.grad_clip:
            LOG.debug("Clipped generator gradient (%.5f) to %.2f",
                      nrm, self.grad_clip)

        opt.step()

        return loss_g.item()

    def training_step(self, batch):
        im = batch
        im = im.to(self.device)

        z = self.gen.sample_z(im.shape[0], device=self.device)

        generated = self.gen(z)

        vect_generated = None
        if self.vect_gen is not None:
            vect_generated = self.vect_gen(z)

        loss_g = None
        loss_d = None
        loss_g_vect = None
        loss_d_vect = None

        gp = None
        gp_vect = None

        if self.iter < self.n_critic:  # Discriminator update
            self.iter += 1

            loss_d, gp = self._discriminator_step(
                self.discrim, self.discrim_opt, generated, im)

            if vect_generated is not None:
                loss_d_vect, gp_vect = self._discriminator_step(
                    self.vect_discrim, self.vect_discrim_opt, vect_generated, im)

        else:  # Generator update
            self.iter = 0

            loss_g = self._generator_step(
                self.gen, self.discrim, self.gen_opt, generated)

            if vect_generated is not None:
                loss_g_vect = self._generator_step(
                    self.vect_gen, self.vect_discrim, self.vect_gen_opt, vect_generated)

        return {
            "loss_g": loss_g,
            "loss_d": loss_d,
            "loss_g_vect": loss_g_vect,
            "loss_d_vect": loss_d_vect,
            "gp": gp,
            "gp_vect": gp_vect,
            "gt_image": im,
            "gen_image": generated,
            "vector_image": vect_generated,
            "lr": self.gen_opt.param_groups[0]["lr"],
        }

    def init_validation(self):
        return dict(sample=None)

    def validation_step(self, batch, running_data):
        # Switch to eval mode for dropout, batchnorm, etc
        self.model.eval()
        return running_data


def train(args):
    th.manual_seed(0)
    np.random.seed(0)

    color_output = False
    if args.task == "mnist":
        dataset = data.MNISTDataset(args.raster_resolution, train=True)
    elif args.task == "quickdraw":
        dataset = data.QuickDrawImageDataset(
            args.raster_resolution, train=True)
    else:
        raise NotImplementedError()

    dataloader = DataLoader(
        dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True)

    val_dataloader = None

    model_params = {
        "zdim": args.zdim,
        "num_strokes": args.num_strokes,
        "imsize": args.raster_resolution,
        "stroke_width": args.stroke_width,
        "color_output": color_output,
    }
    gen = models.Generator(**model_params)
    gen.train()

    discrim = models.Discriminator(color_output=color_output)
    discrim.train()

    if args.raster_only:
        vect_gen = None
        vect_discrim = None
    else:
        if args.generator == "fc":
            vect_gen = models.VectorGenerator(**model_params)
        elif args.generator == "bezier_fc":
            vect_gen = models.BezierVectorGenerator(**model_params)
        elif args.generator in ["rnn"]:
            vect_gen = models.RNNVectorGenerator(**model_params)
        elif args.generator in ["chain_rnn"]:
            vect_gen = models.ChainRNNVectorGenerator(**model_params)
        else:
            raise NotImplementedError()
        vect_gen.train()

        vect_discrim = models.Discriminator(color_output=color_output)
        vect_discrim.train()

    LOG.info("Model parameters:\n%s", model_params)

    device = "cpu"
    if th.cuda.is_available():
        device = "cuda"
        LOG.info("Using CUDA")

    interface = Interface(gen, vect_gen, discrim, vect_discrim,
                          raster_resolution=args.raster_resolution, lr=args.lr,
                          wgan_gp=args.wgan_gp,
                          lr_decay=args.lr_decay, device=device)

    env_name = args.task + "_gan"

    if args.raster_only:
        env_name += "_raster"
    else:
        env_name += "_vector"

    env_name += "_" + args.generator

    if args.wgan_gp:
        env_name += "_wgan"

    chkpt = os.path.join(OUTPUT, env_name)

    meta = {
        "model_params": model_params,
        "task": args.task,
        "generator": args.generator,
    }
    checkpointer = ttools.Checkpointer(
        chkpt, gen, meta=meta,
        optimizers=interface.optimizers,
        schedulers=interface.schedulers,
        prefix="g_")
    checkpointer_d = ttools.Checkpointer(
        chkpt, discrim, 
        prefix="d_")

    # Resume from checkpoint, if any
    extras, _ = checkpointer.load_latest()
    checkpointer_d.load_latest()

    if not args.raster_only:
        checkpointer_vect = ttools.Checkpointer(
            chkpt, vect_gen, meta=meta,
            optimizers=interface.optimizers,
            schedulers=interface.schedulers,
            prefix="vect_g_")
        checkpointer_d_vect = ttools.Checkpointer(
            chkpt, vect_discrim, 
            prefix="vect_d_")
        extras, _ = checkpointer_vect.load_latest()
        checkpointer_d_vect.load_latest()

    epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0

    # if meta is not None and meta["model_parameters"] != model_params:
    #     LOG.info("Checkpoint's metaparams differ "
    #              "from CLI, aborting: %s and %s", meta, model_params)

    trainer = ttools.Trainer(interface)

    # Add callbacks
    losses = ["loss_g", "loss_d", "loss_g_vect", "loss_d_vect", "gp",
              "gp_vect"]
    training_debug = ["lr"]

    trainer.add_callback(Callback(
        env=env_name, win="samples", port=args.port, frequency=args.freq))
    trainer.add_callback(ttools.callbacks.ProgressBarCallback(
        keys=losses, val_keys=None))
    trainer.add_callback(ttools.callbacks.MultiPlotCallback(
        keys=losses, val_keys=None, env=env_name, port=args.port,
        server=args.server, base_url=args.base_url,
        win="losses", frequency=args.freq))
    trainer.add_callback(ttools.callbacks.VisdomLoggingCallback(
        keys=training_debug, smoothing=0, val_keys=None, env=env_name,
        server=args.server, base_url=args.base_url,
        port=args.port))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(
        checkpointer, max_files=2, interval=600, max_epochs=10))
    trainer.add_callback(ttools.callbacks.CheckpointingCallback(
        checkpointer_d, max_files=2, interval=600, max_epochs=10))

    if not args.raster_only:
        trainer.add_callback(ttools.callbacks.CheckpointingCallback(
            checkpointer_vect, max_files=2, interval=600, max_epochs=10))
        trainer.add_callback(ttools.callbacks.CheckpointingCallback(
            checkpointer_d_vect, max_files=2, interval=600, max_epochs=10))

    trainer.add_callback(
        ttools.callbacks.LRSchedulerCallback(interface.schedulers))

    # Start training
    trainer.train(dataloader, starting_epoch=epoch,
                  val_dataloader=val_dataloader,
                  num_epochs=args.num_epochs)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task",
                        default="mnist",
                        choices=["mnist", "quickdraw"])
    parser.add_argument("--generator", 
                        default="bezier_fc",
                        choices=["bezier_fc", "fc", "rnn", "chain_rnn"],
                        help="model to use as generator")

    parser.add_argument("--raster_only", action="store_true", default=False,
                        help="if true only train the raster baseline")

    parser.add_argument("--standard_gan", dest="wgan_gp", action="store_false",
                        default=True,
                        help="if true, use regular GAN instead of WGAN")

    # Training params
    parser.add_argument("--bs", type=int, default=4, help="batch size")
    parser.add_argument("--workers", type=int, default=4,
                        help="number of dataloader threads")
    parser.add_argument("--num_epochs", type=int, default=200,
                        help="number of epochs to train for")
    parser.add_argument("--lr", type=float, default=1e-4,
                        help="learning rate")
    parser.add_argument("--lr_decay", type=float, default=0.9999,
                        help="exponential learning rate decay rate")

    # Model configuration
    parser.add_argument("--zdim", type=int, default=32,
                        help="latent space dimension")
    parser.add_argument("--stroke_width", type=float, nargs=2,
                        default=(0.5, 1.5),
                        help="min and max stroke width")
    parser.add_argument("--num_strokes", type=int, default=16,
                        help="number of strokes to generate")
    parser.add_argument("--raster_resolution", type=int, default=32,
                        help="raster canvas resolution on each side")

    # Viz params
    parser.add_argument("--freq", type=int, default=10,
                        help="visualization frequency")
    parser.add_argument("--port", type=int, default=8097,
                        help="visdom port")
    parser.add_argument("--server", default=None,
                        help="visdom server if not local.")
    parser.add_argument("--base_url", default="", help="visdom entrypoint URL")

    args = parser.parse_args()

    pydiffvg.set_use_gpu(False)

    ttools.set_logger(False)

    train(args)