"""Collection of generative models."""

import torch as th
import ttools

import rendering
import modules

LOG = ttools.get_logger(__name__)


class BaseModel(th.nn.Module):
    def sample_z(self, bs, device="cpu"):
        return th.randn(bs, self.zdim).to(device)


class BaseVectorModel(BaseModel):
    def get_vector(self, z):
        _, scenes = self._forward(z)
        return scenes

    def _forward(self, x):
        raise NotImplementedError()

    def forward(self, z):
        # Only return the raster
        return self._forward(z)[0]


class BezierVectorGenerator(BaseVectorModel):
    NUM_SEGMENTS = 2
    def __init__(self, num_strokes=4,
                 zdim=128, width=32, imsize=32,
                 color_output=False,
                 stroke_width=None):
        super(BezierVectorGenerator, self).__init__()

        if stroke_width is None:
            self.stroke_width = (0.5, 3.0)
            LOG.warning("Setting default stroke with %s", self.stroke_width)
        else:
            self.stroke_width = stroke_width

        self.imsize = imsize
        self.num_strokes = num_strokes
        self.zdim = zdim

        self.trunk = th.nn.Sequential(
            th.nn.Linear(zdim, width),
            th.nn.SELU(inplace=True),

            th.nn.Linear(width, 2*width),
            th.nn.SELU(inplace=True),

            th.nn.Linear(2*width, 4*width),
            th.nn.SELU(inplace=True),

            th.nn.Linear(4*width, 8*width),
            th.nn.SELU(inplace=True),
        )

        # 4 points bezier with n_segments -> 3*n_segments + 1 points
        self.point_predictor = th.nn.Sequential(
            th.nn.Linear(8*width, 
                         2*self.num_strokes*(
                             BezierVectorGenerator.NUM_SEGMENTS*3 + 1)),
            th.nn.Tanh()  # bound spatial extent
        )

        self.width_predictor = th.nn.Sequential(
            th.nn.Linear(8*width, self.num_strokes),
            th.nn.Sigmoid()
        )

        self.alpha_predictor = th.nn.Sequential(
            th.nn.Linear(8*width, self.num_strokes),
            th.nn.Sigmoid()
        )

        self.color_predictor = None
        if color_output:
            self.color_predictor = th.nn.Sequential(
                th.nn.Linear(8*width, 3*self.num_strokes),
                th.nn.Sigmoid()
            )

    def _forward(self, z):
        bs = z.shape[0]

        feats = self.trunk(z)
        all_points = self.point_predictor(feats)
        all_alphas = self.alpha_predictor(feats)

        if self.color_predictor:
            all_colors = self.color_predictor(feats)
            all_colors = all_colors.view(bs, self.num_strokes, 3)
        else:
            all_colors = None

        all_widths = self.width_predictor(feats)
        min_width = self.stroke_width[0]
        max_width = self.stroke_width[1]
        all_widths = (max_width - min_width) * all_widths + min_width

        all_points = all_points.view(
            bs, self.num_strokes, BezierVectorGenerator.NUM_SEGMENTS*3+1, 2)

        output, scenes = rendering.bezier_render(all_points, all_widths, all_alphas,
                                         colors=all_colors,
                                         canvas_size=self.imsize)

        # map to [-1, 1]
        output = output*2.0 - 1.0

        return output, scenes


class VectorGenerator(BaseVectorModel):
    def __init__(self, num_strokes=4,
                 zdim=128, width=32, imsize=32,
                 color_output=False,
                 stroke_width=None):
        super(VectorGenerator, self).__init__()

        if stroke_width is None:
            self.stroke_width = (0.5, 3.0)
            LOG.warning("Setting default stroke with %s", self.stroke_width)
        else:
            self.stroke_width = stroke_width

        self.imsize = imsize
        self.num_strokes = num_strokes
        self.zdim = zdim

        self.trunk = th.nn.Sequential(
            th.nn.Linear(zdim, width),
            th.nn.SELU(inplace=True),

            th.nn.Linear(width, 2*width),
            th.nn.SELU(inplace=True),

            th.nn.Linear(2*width, 4*width),
            th.nn.SELU(inplace=True),

            th.nn.Linear(4*width, 8*width),
            th.nn.SELU(inplace=True),
        )

        # straight lines so n_segments -> n_segments - 1 points
        self.point_predictor = th.nn.Sequential(
            th.nn.Linear(8*width, 2*(self.num_strokes*2)),
            th.nn.Tanh()  # bound spatial extent
        )

        self.width_predictor = th.nn.Sequential(
            th.nn.Linear(8*width, self.num_strokes),
            th.nn.Sigmoid()
        )

        self.alpha_predictor = th.nn.Sequential(
            th.nn.Linear(8*width, self.num_strokes),
            th.nn.Sigmoid()
        )

        self.color_predictor = None
        if color_output:
            self.color_predictor = th.nn.Sequential(
                th.nn.Linear(8*width, 3*self.num_strokes),
                th.nn.Sigmoid()
            )

    def _forward(self, z):
        bs = z.shape[0]

        feats = self.trunk(z)

        all_points = self.point_predictor(feats)

        all_alphas = self.alpha_predictor(feats)

        if self.color_predictor:
            all_colors = self.color_predictor(feats)
            all_colors = all_colors.view(bs, self.num_strokes, 3)
        else:
            all_colors = None

        all_widths = self.width_predictor(feats)
        min_width = self.stroke_width[0]
        max_width = self.stroke_width[1]
        all_widths = (max_width - min_width) * all_widths + min_width

        all_points = all_points.view(bs, self.num_strokes, 2, 2)
        output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
                                       colors=all_colors,
                                       canvas_size=self.imsize)

        # map to [-1, 1]
        output = output*2.0 - 1.0

        return output, scenes


class RNNVectorGenerator(BaseVectorModel):
    def __init__(self, num_strokes=64,
                 zdim=128, width=32, imsize=32,
                 hidden_size=512, dropout=0.9,
                 color_output=False,
                 num_layers=3, stroke_width=None):
        super(RNNVectorGenerator, self).__init__()


        if stroke_width is None:
            self.stroke_width = (0.5, 3.0)
            LOG.warning("Setting default stroke with %s", self.stroke_width)
        else:
            self.stroke_width = stroke_width

        self.num_layers = num_layers
        self.imsize = imsize
        self.num_strokes = num_strokes
        self.hidden_size = hidden_size
        self.zdim = zdim

        self.hidden_cell_predictor = th.nn.Linear(
            zdim, 2*hidden_size*num_layers)

        self.lstm = th.nn.LSTM(
            zdim, hidden_size,
            num_layers=self.num_layers, dropout=dropout,
            batch_first=True)

        # straight lines so n_segments -> n_segments - 1 points
        self.point_predictor = th.nn.Sequential(
            th.nn.Linear(hidden_size, 2*2),  # 2 points, (x,y)
            th.nn.Tanh()  # bound spatial extent
        )

        self.width_predictor = th.nn.Sequential(
            th.nn.Linear(hidden_size, 1),
            th.nn.Sigmoid()
        )

        self.alpha_predictor = th.nn.Sequential(
            th.nn.Linear(hidden_size, 1),
            th.nn.Sigmoid()
        )

    def _forward(self, z, hidden_and_cell=None):
        steps = self.num_strokes

        # z is passed at each step, duplicate it
        bs = z.shape[0]
        expanded_z = z.unsqueeze(1).repeat(1, steps, 1)

        # First step in the RNN
        if hidden_and_cell is None:
            # Initialize from latent vector
            hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
            hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
            hidden = hidden.view(-1, self.num_layers, self.hidden_size)
            hidden = hidden.permute(1, 0, 2).contiguous()
            cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
            cell = cell.view(-1, self.num_layers, self.hidden_size)
            cell = cell.permute(1, 0, 2).contiguous()
            hidden_and_cell = (hidden, cell)

        feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
        hidden, cell = hidden_and_cell

        feats = feats.reshape(bs*steps, self.hidden_size)

        all_points = self.point_predictor(feats).view(bs, steps, 2, 2)
        all_alphas = self.alpha_predictor(feats).view(bs, steps)
        all_widths = self.width_predictor(feats).view(bs, steps)

        min_width = self.stroke_width[0]
        max_width = self.stroke_width[1]
        all_widths = (max_width - min_width) * all_widths + min_width

        output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
                                        canvas_size=self.imsize)

        # map to [-1, 1]
        output = output*2.0 - 1.0

        return output, scenes


class ChainRNNVectorGenerator(BaseVectorModel):
    """Strokes form a single long chain."""
    def __init__(self, num_strokes=64,
                 zdim=128, width=32, imsize=32,
                 hidden_size=512, dropout=0.9,
                 color_output=False,
                 num_layers=3, stroke_width=None):
        super(ChainRNNVectorGenerator, self).__init__()

        if stroke_width is None:
            self.stroke_width = (0.5, 3.0)
            LOG.warning("Setting default stroke with %s", self.stroke_width)
        else:
            self.stroke_width = stroke_width

        self.num_layers = num_layers
        self.imsize = imsize
        self.num_strokes = num_strokes
        self.hidden_size = hidden_size
        self.zdim = zdim

        self.hidden_cell_predictor = th.nn.Linear(
            zdim, 2*hidden_size*num_layers)

        self.lstm = th.nn.LSTM(
            zdim, hidden_size,
            num_layers=self.num_layers, dropout=dropout,
            batch_first=True)

        # straight lines so n_segments -> n_segments - 1 points
        self.point_predictor = th.nn.Sequential(
            th.nn.Linear(hidden_size, 2),  # 1 point, (x,y)
            th.nn.Tanh()  # bound spatial extent
        )

        self.width_predictor = th.nn.Sequential(
            th.nn.Linear(hidden_size, 1),
            th.nn.Sigmoid()
        )

        self.alpha_predictor = th.nn.Sequential(
            th.nn.Linear(hidden_size, 1),
            th.nn.Sigmoid()
        )

    def _forward(self, z, hidden_and_cell=None):
        steps = self.num_strokes

        # z is passed at each step, duplicate it
        bs = z.shape[0]
        expanded_z = z.unsqueeze(1).repeat(1, steps, 1)

        # First step in the RNN
        if hidden_and_cell is None:
            # Initialize from latent vector
            hidden_and_cell = self.hidden_cell_predictor(th.tanh(z))
            hidden = hidden_and_cell[:, :self.hidden_size*self.num_layers]
            hidden = hidden.view(-1, self.num_layers, self.hidden_size)
            hidden = hidden.permute(1, 0, 2).contiguous()
            cell = hidden_and_cell[:, self.hidden_size*self.num_layers:]
            cell = cell.view(-1, self.num_layers, self.hidden_size)
            cell = cell.permute(1, 0, 2).contiguous()
            hidden_and_cell = (hidden, cell)

        feats, hidden_and_cell = self.lstm(expanded_z, hidden_and_cell)
        hidden, cell = hidden_and_cell

        feats = feats.reshape(bs*steps, self.hidden_size)

        # Construct the chain
        end_points = self.point_predictor(feats).view(bs, steps, 1, 2)
        start_points = th.cat([
            # first point is canvas center
            th.zeros(bs, 1, 1, 2, device=feats.device),
            end_points[:, 1:, :, :]], 1)
        all_points = th.cat([start_points, end_points], 2)

        all_alphas = self.alpha_predictor(feats).view(bs, steps)
        all_widths = self.width_predictor(feats).view(bs, steps)

        min_width = self.stroke_width[0]
        max_width = self.stroke_width[1]
        all_widths = (max_width - min_width) * all_widths + min_width

        output, scenes = rendering.line_render(all_points, all_widths, all_alphas,
                                        canvas_size=self.imsize)

        # map to [-1, 1]
        output = output*2.0 - 1.0

        return output, scenes


class Generator(BaseModel):
    def __init__(self, width=64, imsize=32, zdim=128,
                 stroke_width=None,
                 color_output=False,
                 num_strokes=4):
        super(Generator, self).__init__()
        assert imsize == 32

        self.imsize = imsize
        self.zdim = zdim

        num_in_chans = self.zdim // (2*2)
        num_out_chans = 3 if color_output else 1

        self.net = th.nn.Sequential(
            th.nn.ConvTranspose2d(num_in_chans, width*8, 4, padding=1,
                                  stride=2),
            th.nn.LeakyReLU(0.2, inplace=True),
            th.nn.Conv2d(width*8, width*8, 3, padding=1),
            th.nn.BatchNorm2d(width*8),
            th.nn.LeakyReLU(0.2, inplace=True),
            # 4x4

            th.nn.ConvTranspose2d(8*width, 4*width, 4, padding=1, stride=2),
            th.nn.LeakyReLU(0.2, inplace=True),
            th.nn.Conv2d(4*width, 4*width, 3, padding=1),
            th.nn.BatchNorm2d(width*4),
            th.nn.LeakyReLU(0.2, inplace=True),
            # 8x8

            th.nn.ConvTranspose2d(4*width, 2*width, 4, padding=1, stride=2),
            th.nn.LeakyReLU(0.2, inplace=True),
            th.nn.Conv2d(2*width, 2*width, 3, padding=1),
            th.nn.BatchNorm2d(width*2),
            th.nn.LeakyReLU(0.2, inplace=True),
            # 16x16

            th.nn.ConvTranspose2d(2*width, width, 4, padding=1, stride=2),
            th.nn.LeakyReLU(0.2, inplace=True),
            th.nn.Conv2d(width, width, 3, padding=1),
            th.nn.BatchNorm2d(width),
            th.nn.LeakyReLU(0.2, inplace=True),
            # 32x32

            th.nn.Conv2d(width, width, 3, padding=1),
            th.nn.BatchNorm2d(width),
            th.nn.LeakyReLU(0.2, inplace=True),
            th.nn.Conv2d(width, width, 3, padding=1),
            th.nn.LeakyReLU(0.2, inplace=True),
            th.nn.Conv2d(width, num_out_chans, 1),

            th.nn.Tanh(),
        )

    def forward(self, z):
        bs = z.shape[0]
        num_in_chans = self.zdim // (2*2)
        raster = self.net(z.view(bs, num_in_chans, 2, 2))
        return raster


class Discriminator(th.nn.Module):
    def __init__(self, conditional=False, width=64, color_output=False):
        super(Discriminator, self).__init__()

        self.conditional = conditional

        sn = th.nn.utils.spectral_norm

        num_chan_in = 3 if color_output else 1

        self.net = th.nn.Sequential(
            th.nn.Conv2d(num_chan_in, width, 3, padding=1),
            th.nn.LeakyReLU(0.2, inplace=True),
            th.nn.Conv2d(width, 2*width, 4, padding=1, stride=2),
            th.nn.LeakyReLU(0.2, inplace=True),
            # 16x16

            sn(th.nn.Conv2d(2*width, 2*width, 3, padding=1)),
            th.nn.LeakyReLU(0.2, inplace=True),
            sn(th.nn.Conv2d(2*width, 4*width, 4, padding=1, stride=2)),
            th.nn.LeakyReLU(0.2, inplace=True),
            # 8x8

            sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)),
            th.nn.LeakyReLU(0.2, inplace=True),
            sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)),
            th.nn.LeakyReLU(0.2, inplace=True),
            # 4x4

            sn(th.nn.Conv2d(4*width, 4*width, 3, padding=1)),
            th.nn.LeakyReLU(0.2, inplace=True),
            sn(th.nn.Conv2d(4*width, width*4, 4, padding=1, stride=2)),
            th.nn.LeakyReLU(0.2, inplace=True),
            # 2x2

            modules.Flatten(),
            th.nn.Linear(width*4*2*2, 1),
        )

    def forward(self, x):
        out = self.net(x)
        return out