import os
import json
import pickle as pkl
import random
import argparse
import torch
from TADA import smplx
import imageio
import numpy as np
from tqdm import tqdm
from PIL import Image
from TADA.lib.common.remesh import subdivide_inorder
from TADA.lib.common.utils import SMPLXSeg
from TADA.lib.common.lbs import warp_points
from TADA.lib.common.obj import compute_normal
import trimesh
import pyrender
from shapely import geometry
import moviepy.editor as mpy
os.environ['PYOPENGL_PLATFORM'] = "egl"

def build_new_mesh(v, f, vt, ft):
    # build a correspondences dictionary from the original mesh indices to the (possibly multiple) texture map indices
    f_flat = f.flatten()
    ft_flat = ft.flatten()
    correspondences = {}

    # traverse and find the corresponding indices in f and ft
    for i in range(len(f_flat)):
        if f_flat[i] not in correspondences:
            correspondences[f_flat[i]] = [ft_flat[i]]
        else:
            if ft_flat[i] not in correspondences[f_flat[i]]:
                correspondences[f_flat[i]].append(ft_flat[i])

    # build a mesh using the texture map vertices
    new_v = np.zeros((v.shape[0], vt.shape[0], 3))
    for old_index, new_indices in correspondences.items():
        for new_index in new_indices:
            new_v[:, new_index] = v[:, old_index]

    # define new faces using the texture map faces
    f_new = ft
    return new_v, f_new

class Animation:
    def __init__(self, ckpt_path, workspace_dir, device="cuda"):
        self.device = device
        self.SMPLXSeg = SMPLXSeg(workspace_dir)
        # load data
        init_data = np.load(os.path.join(workspace_dir, "init_body/data.npz"))
        self.dense_faces = torch.as_tensor(init_data['dense_faces'], device=self.device)
        self.dense_lbs_weights = torch.as_tensor(init_data['dense_lbs_weights'], device=self.device)
        self.unique = init_data['unique']
        self.vt = init_data['vt']
        self.ft = init_data['ft']

        model_params = dict(
            model_path=os.path.join(workspace_dir, "smplx/SMPLX_NEUTRAL_2020.npz"),
            model_type='smplx',
            create_global_orient=True,
            create_body_pose=True,
            create_betas=True,
            create_left_hand_pose=True,
            create_right_hand_pose=True,
            create_jaw_pose=True,
            create_leye_pose=True,
            create_reye_pose=True,
            create_expression=True,
            create_transl=False,
            use_pca=False,
            flat_hand_mean=False,
            num_betas=300,
            num_expression_coeffs=100,
            num_pca_comps=12,
            dtype=torch.float32,
            batch_size=1,
        )
        self.body_model = smplx.create(**model_params).to(device=self.device)
        self.smplx_face = self.body_model.faces.astype(np.int32)

        ckpt_file = os.path.join(workspace_dir, "MESH", ckpt_path, "params.pt")
        albedo_path = os.path.join(workspace_dir, "MESH", ckpt_path, "mesh_albedo.png")
        self.load_ckpt_data(ckpt_file, albedo_path)


    def load_ckpt_data(self, ckpt_file, albedo_path):
        model_data = torch.load(ckpt_file, map_location=self.device)
        self.expression = model_data["expression"] if "expression" in model_data else None
        self.jaw_pose = model_data["jaw_pose"] if "jaw_pose" in model_data else None

        self.betas = model_data['betas']
        self.v_offsets = model_data['v_offsets']
        self.v_offsets[self.SMPLXSeg.eyeball_ids] = 0.
        self.v_offsets[self.SMPLXSeg.hands_ids] = 0.

        # tex to trimesh texture
        vt = self.vt.copy()
        vt[:, 1] = 1 - vt[:, 1]
        albedo = Image.open(albedo_path)
        
        self.raw_albedo = torch.from_numpy(np.array(albedo))
        self.raw_albedo = self.raw_albedo / 255.0
        self.raw_albedo = self.raw_albedo.permute(2, 0, 1)
        
        self.trimesh_visual = trimesh.visual.TextureVisuals(
            uv=vt,
            image=albedo,
            material=trimesh.visual.texture.SimpleMaterial(
                image=albedo,
                diffuse=[255, 255, 255, 255],
                ambient=[255, 255, 255, 255],
                specular=[0, 0, 0, 255],
                glossiness=0)
        )

    def forward_mdm(self, motion):
        try:
            mdm_body_pose = motion["poses"]
            translate = torch.from_numpy(motion["trans"])
        except:
            translate = torch.from_numpy(motion[:, -3:])
            mdm_body_pose = motion[:, :-3]
            mdm_body_pose = mdm_body_pose.reshape(mdm_body_pose.shape[0], -1, 3)

        translate = translate.to(self.device)
        scan_v_posed = []
        for i, (pose, t) in tqdm(enumerate(zip(mdm_body_pose, translate))):
            body_pose = torch.as_tensor(pose[None, 1:22, :], device=self.device)
            global_orient = torch.as_tensor(pose[None, :1, :], device=self.device)
            output = self.body_model(
                betas=self.betas,
                global_orient=global_orient,
                jaw_pose=self.jaw_pose,
                body_pose=body_pose,
                expression=self.expression,
                return_verts=True
            )

            v_cano = output.v_posed[0]
            # re-mesh
            v_cano_dense = subdivide_inorder(v_cano, self.smplx_face[self.SMPLXSeg.remesh_mask], self.unique).squeeze(0)
            # add offsets
            vn = compute_normal(v_cano_dense, self.dense_faces)[0]
            v_cano_dense += self.v_offsets * vn
            # do LBS
            v_posed_dense = warp_points(v_cano_dense, self.dense_lbs_weights, output.joints_transform[:, :55])
            # translate
            v_posed_dense += t - translate[0]

            scan_v_posed.append(v_posed_dense)
        
        scan_v_posed = torch.cat(scan_v_posed).detach().cpu().numpy()
        new_scan_v_posed, new_face = build_new_mesh(scan_v_posed, self.dense_faces, self.vt, self.ft)
        new_scan_v_posed = new_scan_v_posed.astype(np.float32)

        return new_scan_v_posed, new_face