# final one
import spaces
import torch
import gradio as gr
import os
import numpy as np
import trimesh
import mcubes
import imageio
from torchvision.utils import save_image
from PIL import Image
from transformers import AutoModel, AutoConfig
from rembg import remove, new_session
from functools import partial
from kiui.op import recenter
import kiui
from gradio_litmodel3d import LitModel3D

# we load the pre-trained model from HF
class LRMGeneratorWrapper:
    def __init__(self):
        self.config = AutoConfig.from_pretrained("jadechoghari/vfusion3d", trust_remote_code=True)
        self.model = AutoModel.from_pretrained("jadechoghari/vfusion3d", trust_remote_code=True)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()

    def forward(self, image, camera):
        return self.model(image, camera)

model_wrapper = LRMGeneratorWrapper()

# we preprocess the input image
def preprocess_image(image, source_size):
    session = new_session("isnet-general-use")
    rembg_remove = partial(remove, session=session)
    image = np.array(image)
    image = rembg_remove(image)
    mask = rembg_remove(image, only_mask=True)
    image = recenter(image, mask, border_ratio=0.20)
    image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0
    if image.shape[1] == 4:
        image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
    image = torch.nn.functional.interpolate(image, size=(source_size, source_size), mode='bicubic', align_corners=True)
    image = torch.clamp(image, 0, 1)
    return image

# Copied from https://github.com/facebookresearch/vfusion3d/blob/main/lrm/cam_utils.py and
# https://github.com/facebookresearch/vfusion3d/blob/main/lrm/inferrer.py
def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
    fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
    cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
    width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
    fx, fy = fx / width, fy / height
    cx, cy = cx / width, cy / height
    return fx, fy, cx, cy

def build_camera_principle(RT: torch.Tensor, intrinsics: torch.Tensor):
    fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
    return torch.cat([
        RT.reshape(-1, 12),
        fx.unsqueeze(-1), fy.unsqueeze(-1), cx.unsqueeze(-1), cy.unsqueeze(-1),
    ], dim=-1)

def _default_intrinsics():
    fx = fy = 384
    cx = cy = 256
    w = h = 512
    intrinsics = torch.tensor([
        [fx, fy],
        [cx, cy],
        [w, h],
    ], dtype=torch.float32)
    return intrinsics

def _default_source_camera(batch_size: int = 1):
    canonical_camera_extrinsics = torch.tensor([[
        [0, 0, 1, 1],
        [1, 0, 0, 0],
        [0, 1, 0, 0],
    ]], dtype=torch.float32)
    canonical_camera_intrinsics = _default_intrinsics().unsqueeze(0)
    source_camera = build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics)
    return source_camera.repeat(batch_size, 1)

def _center_looking_at_camera_pose(camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
    """
    camera_position: (M, 3)
    look_at: (3)
    up_world: (3)
    return: (M, 3, 4)
    """
    # by default, looking at the origin and world up is pos-z
    if look_at is None:
        look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
    if up_world is None:
        up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
    look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
    up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)

    z_axis = camera_position - look_at
    z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
    x_axis = torch.cross(up_world, z_axis)
    x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True)
    y_axis = torch.cross(z_axis, x_axis)
    y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True)
    extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
    return extrinsics

def compose_extrinsic_RT(RT: torch.Tensor):
    """
    Compose the standard form extrinsic matrix from RT.
    Batched I/O.
    """
    return torch.cat([
        RT,
        torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat(RT.shape[0], 1, 1).to(RT.device)
        ], dim=1)

def _build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
    """
    RT: (N, 3, 4)
    intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
    """
    E = compose_extrinsic_RT(RT)
    fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
    I = torch.stack([
        torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
        torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
        torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1),
    ], dim=1)
    return torch.cat([
        E.reshape(-1, 16),
        I.reshape(-1, 9),
    ], dim=-1)

def _default_render_cameras(batch_size: int = 1):
    M = 160
    radius = 1.5
    elevation = 0
    camera_positions = []
    rand_theta = np.random.uniform(0, np.pi/180)
    elevation = np.radians(elevation)
    for i in range(M):
        theta = 2 * np.pi * i / M + rand_theta
        x = radius * np.cos(theta) * np.cos(elevation)
        y = radius * np.sin(theta) * np.cos(elevation)
        z = radius * np.sin(elevation)
        camera_positions.append([x, y, z])
    camera_positions = torch.tensor(camera_positions, dtype=torch.float32)
    extrinsics = _center_looking_at_camera_pose(camera_positions)

    render_camera_intrinsics = _default_intrinsics().unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
    render_cameras = _build_camera_standard(extrinsics, render_camera_intrinsics)
    return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1)
@spaces.GPU()
def generate_mesh(image, source_size=512, render_size=384, mesh_size=512, export_mesh=False, export_video=True, fps=30):
    image = preprocess_image(image, source_size).to(model_wrapper.device)
    source_camera = _default_source_camera(batch_size=1).to(model_wrapper.device)

    with torch.no_grad():
        planes = model_wrapper.forward(image, source_camera)

        if export_mesh:
            grid_out = model_wrapper.model.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
            vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), 1.0)
            vtx = vtx / (mesh_size - 1) * 2 - 1
            vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=model_wrapper.device).unsqueeze(0)
            vtx_colors = model_wrapper.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
            vtx_colors = (vtx_colors * 255).astype(np.uint8)
            mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)

            mesh_path = "awesome_mesh.obj"
            mesh.export(mesh_path, 'obj')

            return mesh_path, mesh_path

        if export_video:
            render_cameras = _default_render_cameras(batch_size=1).to(model_wrapper.device)
            frames = []
            chunk_size = 2
            for i in range(0, render_cameras.shape[1], chunk_size):
                frame_chunk = model_wrapper.model.synthesizer(
                    planes,
                    render_cameras[:, i:i + chunk_size],
                    render_size,
                    render_size,
                    0,
                    0
                )
                frames.append(frame_chunk['images_rgb'])

            frames = torch.cat(frames, dim=1)
            frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)

            video_path = "awesome_video.mp4"
            imageio.mimwrite(video_path, frames, fps=fps)

            return None, video_path

    return None, None

def step_1_generate_obj(image):
    mesh_path, _ = generate_mesh(image, export_mesh=True)
    return mesh_path, mesh_path

def step_2_generate_video(image):
    _, video_path = generate_mesh(image, export_video=True)
    return video_path

def step_3_display_3d_model(mesh_file):
    return mesh_file

# set up the example files from assets folder, we limit to 10
example_folder = "assets"
examples = [os.path.join(example_folder, f) for f in os.listdir(example_folder) if f.endswith(('.png', '.jpg', '.jpeg'))][:10]

with gr.Blocks() as demo:
    with gr.Row():

        with gr.Column():
            gr.Markdown("**VFusion3D is now officially on Facebook repo: you can use it with A100s GPU here: [https://huggingface.co/spaces/facebook/VFusion3D](https://huggingface.co/spaces/facebook/VFusion3D)**")
            img_input = gr.Image(type="pil", label="Input Image")
            
            examples_component = gr.Examples(examples=examples, inputs=img_input, outputs=None, examples_per_page=3)
            generate_mesh_button = gr.Button("Generate and Download Mesh")
            generate_video_button = gr.Button("Generate and Download Video")
            obj_file_output = gr.File(label="Download .obj File")
            video_file_output = gr.File(label="Download Video")

        with gr.Column():
            model_output = LitModel3D(
                clear_color=[0.1, 0.1, 0.1, 0],  # can adjust background color for better contrast
                label="3D Model Visualization",
                scale=1.0,
                tonemapping="aces",  # can use aces tonemapping for more realistic lighting
                exposure=1.0,        # can adjust exposure to control brightness
                contrast=1.1,        # can slightly increase contrast for better depth
                camera_position=(0, 0, 2),  # will set initial camera position to center the model
                zoom_speed=0.5,      # will adjust zoom speed for better control
                pan_speed=0.5,       # will adjust pan speed for better control
                interactive=True     # this allow users to interact with the model
            )
            
        
    # clear outputs
    def clear_model_viewer():
        """Reset the Model3D component before loading a new model."""
        return gr.update(value=None)
    
    def generate_and_visualize(image):
        mesh_path = step_1_generate_obj(image)
        return mesh_path, mesh_path

    # first we clear the existing 3D model
    img_input.change(clear_model_viewer, inputs=None, outputs=model_output)

    # then, generate the mesh and video
    generate_mesh_button.click(step_1_generate_obj, inputs=img_input, outputs=[obj_file_output, model_output])
    generate_video_button.click(step_2_generate_video, inputs=img_input, outputs=video_file_output)

demo.launch(debug=True)