import copy
import os
import random
import urllib.request

import torch
import torch.nn.functional as FF
import torch.optim
from torchvision import utils
from tqdm import tqdm

from stylegan2.model import Generator


class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)


def get_path(base_path):
    BASE_DIR = os.path.join('checkpoints')

    save_path = os.path.join(BASE_DIR, base_path)
    if not os.path.exists(save_path):
        url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
        print(f'{base_path} not found')
        print('Try to download from huggingface: ', url)
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        download_url(url, save_path)
        print('Downloaded to ', save_path)
    return save_path


def download_url(url, output_path):
    with DownloadProgressBar(unit='B', unit_scale=True,
                             miniters=1, desc=url.split('/')[-1]) as t:
        urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)


class CustomGenerator(Generator):
    def prepare(
        self,
        styles,
        inject_index=None,
        truncation=1,
        truncation_latent=None,
        input_is_latent=False,
        noise=None,
        randomize_noise=True,
    ):
        if not input_is_latent:
            styles = [self.style(s) for s in styles]

        if noise is None:
            if randomize_noise:
                noise = [None] * self.num_layers
            else:
                noise = [
                    getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
                ]

        if truncation < 1:
            style_t = []

            for style in styles:
                style_t.append(
                    truncation_latent + truncation * (style - truncation_latent)
                )

            styles = style_t

        if len(styles) < 2:
            inject_index = self.n_latent

            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)

            else:
                latent = styles[0]

        else:
            if inject_index is None:
                inject_index = random.randint(1, self.n_latent - 1)

            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)

            latent = torch.cat([latent, latent2], 1)

        return latent, noise

    def generate(
        self,
        latent,
        noise,
    ):
        out = self.input(latent)
        out = self.conv1(out, latent[:, 0], noise=noise[0])

        skip = self.to_rgb1(out, latent[:, 1])
        i = 1
        for conv1, conv2, noise1, noise2, to_rgb in zip(
            self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
        ):
            out = conv1(out, latent[:, i], noise=noise1)
            out = conv2(out, latent[:, i + 1], noise=noise2)
            skip = to_rgb(out, latent[:, i + 2], skip)
            if out.shape[-1] == 256: F = out
            i += 2

        image = skip
        F = FF.interpolate(F, image.shape[-2:], mode='bilinear')
        return image, F


def stylegan2(
    size=512,
    channel_multiplier=2,
    latent=512,
    n_mlp=8,
    ckpt='stylegan2-ffhq-config-f.pt'
):
    g_ema = CustomGenerator(size, latent, n_mlp, channel_multiplier=channel_multiplier)
    checkpoint = torch.load(get_path(ckpt))
    g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
    g_ema.requires_grad_(False)
    g_ema.eval()
    return g_ema


def bilinear_interpolate_torch(im, y, x):
    """
    im : B,C,H,W
    y : 1,numPoints -- pixel location y float
    x : 1,numPOints -- pixel location y float
    """

    x0 = torch.floor(x).long()
    x1 = x0 + 1

    y0 = torch.floor(y).long()
    y1 = y0 + 1

    wa = (x1.float() - x) * (y1.float() - y)
    wb = (x1.float() - x) * (y - y0.float())
    wc = (x - x0.float()) * (y1.float() - y)
    wd = (x - x0.float()) * (y - y0.float())
    # Instead of clamp
    x1 = x1 - torch.floor(x1 / im.shape[3]).int()
    y1 = y1 - torch.floor(y1 / im.shape[2]).int()
    Ia = im[:, :, y0, x0]
    Ib = im[:, :, y1, x0]
    Ic = im[:, :, y0, x1]
    Id = im[:, :, y1, x1]

    return Ia * wa + Ib * wb + Ic * wc + Id * wd


def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points, mask, max_iters=1000):
    handle_points0 = copy.deepcopy(handle_points)
    n = len(handle_points)
    r1, r2, lam, d = 3, 12, 20, 1

    def neighbor(x, y, d):
        points = []
        for i in range(x - d, x + d):
            for j in range(y - d, y + d):
                points.append(torch.tensor([i, j]).float().cuda())
        return points

    F0 = F.detach().clone()

    latent_trainable = latent[:, :6, :].detach().clone().requires_grad_(True)
    latent_untrainable = latent[:, 6:, :].detach().clone().requires_grad_(False)
    optimizer = torch.optim.Adam([latent_trainable], lr=2e-3)
    for iter in range(max_iters):
        for s in range(1):
            optimizer.zero_grad()
            latent = torch.cat([latent_trainable, latent_untrainable], dim=1)
            sample2, F2 = g_ema.generate(latent, noise)

            # motion supervision
            loss = 0
            for i in range(n):
                pi, ti = handle_points[i], target_points[i]
                di = (ti - pi) / torch.sum((ti - pi)**2)

                for qi in neighbor(int(pi[0]), int(pi[1]), r1):
                    # f1 = F[..., int(qi[0]), int(qi[1])]
                    # f2 = F2[..., int(qi[0] + di[0]), int(qi[1] + di[1])]
                    f1 = bilinear_interpolate_torch(F2, qi[0], qi[1]).detach()
                    f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1])
                    loss += FF.l1_loss(f2, f1)

            loss += ((F2 - F0) * (1 - mask)).abs().mean() * lam

            loss.backward()
            optimizer.step()

        # point tracking
        with torch.no_grad():
            sample2, F2 = g_ema.generate(latent, noise)
            for i in range(n):
                pi = handle_points0[i]
                # f = F0[..., int(pi[0]), int(pi[1])]
                f0 = bilinear_interpolate_torch(F0, pi[0], pi[1])
                minv = 1e9
                minx = 1e9
                miny = 1e9
                for qi in neighbor(int(handle_points[i][0]), int(handle_points[i][1]), r2):
                    # f2 = F2[..., int(qi[0]), int(qi[1])]
                    try:
                        f2 = bilinear_interpolate_torch(F2, qi[0], qi[1])
                    except:
                        import ipdb
                        ipdb.set_trace()
                    v = torch.norm(f2 - f0, p=1)
                    if v < minv:
                        minv = v
                        minx = int(qi[0])
                        miny = int(qi[1])
                handle_points[i][0] = minx
                handle_points[i][1] = miny

        F = F2.detach().clone()
        if iter % 1 == 0:
            print(iter, loss.item(), handle_points, target_points)
            # p = handle_points[0].int()
            # sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] = sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] * 0
            # t = target_points[0].int()
            # sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] = sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] * 255

            # sample2[0, :, 210, 134] = sample2[0, :, 210, 134] * 0
            # utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1))

        yield sample2, latent, F2, handle_points