import os
import sys
import logging

__workdir__ = '/'.join(os.path.abspath(__file__).split('/')[:-2])
sys.path.insert(0, __workdir__)

print(__workdir__)

import numpy as np
import torch
from torchvision.transforms import v2
from PIL import Image
import rembg

from models.lrm.online_render.render_single import load_mipmap
from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
from models.lrm.utils.render_utils import rotate_x, rotate_y
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
from models.lrm.utils.infer_util import remove_background, resize_foreground

from models.ISOMER.reconstruction_func import reconstruction
from models.ISOMER.projection_func import projection

from utils.tool import NormalTransfer, get_render_cameras_frames, get_background, get_render_cameras_video, render_frames, mask_fix

logging.basicConfig(
    level = logging.INFO
)
logger = logging.getLogger('kiss3d_wrapper')

OUT_DIR = './outputs'
TMP_DIR = './outputs/tmp'

os.makedirs(TMP_DIR, exist_ok=True)

def lrm_reconstruct(model, infer_config, images, 
                    name='', export_texmap=False,
                    input_camera_type='zero123',
                    render_3d_bundle_image=True,
                    render_azimuths=[270, 0, 90, 180], 
                    render_elevations=[5, 5, 5, 5], 
                    render_radius=4.15):
    """
    image: Tensor, shape (1, c, h, w)
    """
    
    mesh_path_idx = os.path.join(TMP_DIR, f'{name}_recon_from_{input_camera_type}.obj')

    device = images.device
    if input_camera_type == 'zero123':
        input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
    elif input_camera_type == 'kiss3d':
        input_cameras = get_flux_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
    else:
        raise NotImplementedError(f'Unexpected input camera type: {input_camera_type}')
    
    images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)

    logger.info(f"==> Runing LRM reconstruction ...")
    planes = model.forward_planes(images, input_cameras)
    mesh_out = model.extract_mesh(
            planes,
            use_texture_map=export_texmap,
            **infer_config,
        )
    if export_texmap:
        vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
        save_obj_with_mtl(
            vertices.data.cpu().numpy(),
            uvs.data.cpu().numpy(),
            faces.data.cpu().numpy(),
            mesh_tex_idx.data.cpu().numpy(),
            tex_map.permute(1, 2, 0).data.cpu().numpy(),
            mesh_path_idx,
        )
    else:
        vertices, faces, vertex_colors = mesh_out
        save_obj(vertices, faces, vertex_colors, mesh_path_idx)
    logger.info(f"Mesh saved to {mesh_path_idx}")

    if render_3d_bundle_image:
        assert render_azimuths is not None and render_elevations is not None and render_radius is not None
        render_azimuths = torch.Tensor(render_azimuths).to(device)
        render_elevations = torch.Tensor(render_elevations).to(device)
        
        render_size = infer_config.render_resolution
        ENV = load_mipmap("models/lrm/env_mipmap/6")
        materials = (0.0,0.9)
        all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames(
            batch_size=1, 
            radius=render_radius, 
            azimuths=render_azimuths, 
            elevations=render_elevations,
            fov=30
        )
        frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
            model, 
            planes, 
            render_cameras=all_mvp,
            camera_pos=all_campos,
            env=ENV,
            materials=materials,
            render_size=render_size, 
            render_mv = all_mv,
            local_normal=True,
            identity_mv=identity_mv,
        )
    else:
        normals = None
        frames = None
        albedos = None


    vertices = torch.from_numpy(vertices).to(device)
    faces = torch.from_numpy(faces).to(device)
    vertices = vertices @ rotate_x(np.pi / 2, device=device)[:3, :3]
    vertices = vertices @ rotate_y(np.pi / 2, device=device)[:3, :3]

    return vertices.cpu(), faces.cpu(), normals, frames, albedos

normal_transfer = NormalTransfer()

def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg):
    if local_normal_images.min() >= 0:
        local_normal = local_normal_images.float() * 2 - 1
    else:
        local_normal = local_normal_images.float()
    global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
    global_normal[...,0] *= -1
    global_normal = (global_normal + 1) / 2
    global_normal = global_normal.permute(0, 3, 1, 2)
    return global_normal


def isomer_reconstruct(
        rgb_multi_view,
        normal_multi_view,
        multi_view_mask,
        vertices,
        faces,
        save_path=None,
        azimuths=[0, 90, 180, 270],
        elevations=[5, 5, 5, 5],
        geo_weights=[1, 0.9, 1, 0.9],
        color_weights=[1, 0.5, 1, 0.5],
        reconstruction_stage1_steps=10,
        reconstruction_stage2_steps=50,
        radius=4.5):

    device = rgb_multi_view.device
    to_tensor_ = lambda x: torch.Tensor(x).float().to(device)

    # local normal to global normal
    global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), to_tensor_(azimuths), to_tensor_(elevations)).to(device)
    global_normal = global_normal * multi_view_mask + (1-multi_view_mask)

    global_normal = global_normal.permute(0,2,3,1)
    multi_view_mask = multi_view_mask.squeeze(1)
    rgb_multi_view = rgb_multi_view.permute(0,2,3,1)

    logger.info(f"==> Runing ISOMER reconstruction ...")
    meshes = reconstruction(
        normal_pils=global_normal, 
        masks=multi_view_mask, 
        weights=to_tensor_(geo_weights), 
        fov=30, 
        radius=radius, 
        camera_angles_azi=to_tensor_(azimuths), 
        camera_angles_ele=to_tensor_(elevations), 
        expansion_weight_stage1=0.1,
        init_type="file",
        init_verts=vertices,
        init_faces=faces,
        stage1_steps=reconstruction_stage1_steps,
        stage2_steps=reconstruction_stage2_steps,
        start_edge_len_stage1=0.1,
        end_edge_len_stage1=0.02,
        start_edge_len_stage2=0.02,
        end_edge_len_stage2=0.005,
    )

    multi_view_mask_proj = mask_fix(multi_view_mask, erode_dilate=-10, blur=5)


    logger.info(f"==> Runing ISOMER projection ...")
    save_glb_addr = projection(
        meshes,
        masks=multi_view_mask_proj.to(device),
        images=rgb_multi_view.to(device),
        azimuths=to_tensor_(azimuths), 
        elevations=to_tensor_(elevations), 
        weights=to_tensor_(color_weights),
        fov=30,
        radius=radius,
        save_dir=TMP_DIR,
        save_glb_addr=save_path
    )

    logger.info(f"==> Save mesh to {save_glb_addr} ...")
    return save_glb_addr


def to_rgb_image(maybe_rgba):
    assert isinstance(maybe_rgba, Image.Image)
    if maybe_rgba.mode == 'RGB':
        return maybe_rgba, None
    elif maybe_rgba.mode == 'RGBA':
        rgba = maybe_rgba
        img = np.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8)
        img = Image.fromarray(img, 'RGB')
        img.paste(rgba, mask=rgba.getchannel('A'))
        return img, rgba.getchannel('A')
    else:
        raise ValueError("Unsupported image type.", maybe_rgba.mode)
    
rembg_session = rembg.new_session("u2net")
def preprocess_input_image(input_image):
    """
    input_image: PIL.Image
    output_image: PIL.Image, (3, 512, 512), mode = RGB, background = white
    """
    image = remove_background(to_rgb_image(input_image)[0], rembg_session, bgcolor=(255, 255, 255, 255))
    image = resize_foreground(image, ratio=0.85, pad_value=255)
    return to_rgb_image(image)[0]