# Copyright 2021 by Haozhe Wu, Tsinghua University, Department of Computer Science and Technology.
# All rights reserved.
# This file is part of the pytorch-nicp,
# and is released under the "MIT License Agreement". Please see the LICENSE
# file that should have been included as part of this package.

import torch
import torch.nn as nn
import trimesh
from pytorch3d.loss import chamfer_distance
from pytorch3d.structures import Meshes
from tqdm import tqdm

from lib.common.train_util import init_loss
from lib.dataset.mesh_util import update_mesh_shape_prior_losses


# reference: https://github.com/wuhaozhe/pytorch-nicp
class LocalAffine(nn.Module):
    def __init__(self, num_points, batch_size=1, edges=None):
        '''
            specify the number of points, the number of points should be constant across the batch
            and the edges torch.Longtensor() with shape N * 2
            the local affine operator supports batch operation
            batch size must be constant
            add additional pooling on top of w matrix
        '''
        super(LocalAffine, self).__init__()
        self.A = nn.Parameter(
            torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1)
        )
        self.b = nn.Parameter(
            torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat(
                batch_size, num_points, 1, 1
            )
        )
        self.edges = edges
        self.num_points = num_points

    def stiffness(self):
        '''
            calculate the stiffness of local affine transformation
            f norm get infinity gradient when w is zero matrix, 
        '''
        if self.edges is None:
            raise Exception("edges cannot be none when calculate stiff")
        affine_weight = torch.cat((self.A, self.b), dim=3)
        w1 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 0])
        w2 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 1])
        w_diff = (w1 - w2)**2
        w_rigid = (torch.linalg.det(self.A) - 1.0)**2
        return w_diff, w_rigid

    def forward(self, x):
        '''
            x should have shape of B * N * 3 * 1
        '''
        x = x.unsqueeze(3)
        out_x = torch.matmul(self.A, x)
        out_x = out_x + self.b
        out_x.squeeze_(3)
        stiffness, rigid = self.stiffness()

        return out_x, stiffness, rigid


def trimesh2meshes(mesh):
    '''
        convert trimesh mesh to pytorch3d mesh
    '''
    verts = torch.from_numpy(mesh.vertices).float()
    faces = torch.from_numpy(mesh.faces).long()
    mesh = Meshes(verts.unsqueeze(0), faces.unsqueeze(0))
    return mesh


def register(target_mesh, src_mesh, device, verbose=True):

    # define local_affine deform verts
    tgt_mesh = trimesh2meshes(target_mesh).to(device)
    src_verts = src_mesh.verts_padded().clone()

    local_affine_model = LocalAffine(
        src_mesh.verts_padded().shape[1],
        src_mesh.verts_padded().shape[0], src_mesh.edges_packed()
    ).to(device)

    optimizer_cloth = torch.optim.Adam([{'params': local_affine_model.parameters()}],
                                       lr=1e-2,
                                       amsgrad=True)
    scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer_cloth,
        mode="min",
        factor=0.1,
        verbose=0,
        min_lr=1e-5,
        patience=5,
    )

    losses = init_loss()

    if verbose:
        loop_cloth = tqdm(range(100))
    else:
        loop_cloth = range(100)

    for i in loop_cloth:

        optimizer_cloth.zero_grad()

        deformed_verts, stiffness, rigid = local_affine_model(x=src_verts)
        src_mesh = src_mesh.update_padded(deformed_verts)

        # losses for laplacian, edge, normal consistency
        update_mesh_shape_prior_losses(src_mesh, losses)

        losses["cloth"]["value"] = chamfer_distance(
            x=src_mesh.verts_padded(), y=tgt_mesh.verts_padded()
        )[0]
        losses["stiff"]["value"] = torch.mean(stiffness)
        losses["rigid"]["value"] = torch.mean(rigid)

        # Weighted sum of the losses
        cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
        pbar_desc = "Register SMPL-X -> d-BiNI -- "

        for k in losses.keys():
            if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0:
                cloth_loss = cloth_loss + \
                    losses[k]["value"] * losses[k]["weight"]
                pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | "

        if verbose:
            pbar_desc += f"TOTAL: {cloth_loss:.3f}"
            loop_cloth.set_description(pbar_desc)

        # update params
        cloth_loss.backward(retain_graph=True)
        optimizer_cloth.step()
        scheduler_cloth.step(cloth_loss)
        
        print(pbar_desc)

    final = trimesh.Trimesh(
        src_mesh.verts_packed().detach().squeeze(0).cpu(),
        src_mesh.faces_packed().detach().squeeze(0).cpu(),
        process=False,
        maintains_order=True
    )

    return final