import gradio as gr
import plotly.graph_objs as go
import trimesh
import numpy as np
from PIL import Image, ImageDraw
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
import io
import matplotlib.pyplot as plt
#import pyrender
#import scipy
import csv
import sys
import os

# Load the Stable Diffusion model for text-to-image generation and inpainting
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)

# Adjusted to handle device compatibility
if device == "cuda":
    pipeline_inpaint = StableDiffusionInpaintPipeline.from_pretrained(
        "runwayml/stable-diffusion-inpainting", 
        torch_dtype=torch.float16
    ).to(device)
else:
    pipeline_inpaint = StableDiffusionInpaintPipeline.from_pretrained(
        "runwayml/stable-diffusion-inpainting"
    ).to(device)

# Get the current working directory
CURRENT_DIR = os.getcwd()

# Define example file paths using the current directory
DEFAULT_OBJ_FILE = os.path.join(CURRENT_DIR, "female.obj")
DEFAULT_GLB_FILE = os.path.join(CURRENT_DIR, "vroid_girl1.glb")
DEFAULT_VRM_FILE = os.path.join(CURRENT_DIR, "fischl.vrm")
DEFAULT_VRM_FILE2 = os.path.join(CURRENT_DIR, "woman.vrm")
DEFAULT_VRM_FILE3 = os.path.join(CURRENT_DIR, "mona.vrm")
DEFAULT_TEXTURE = os.path.join(CURRENT_DIR, "future.png")
DEFAULT_TEXTURE2 = os.path.join(CURRENT_DIR, "woman1.jpeg")
DEFAULT_TEXTURE3 = os.path.join(CURRENT_DIR, "woman2.jpeg")

# Ensure all example files exist
example_files = [
    [DEFAULT_VRM_FILE, DEFAULT_TEXTURE],
    [DEFAULT_OBJ_FILE, None],
    [DEFAULT_GLB_FILE, None],
    [DEFAULT_VRM_FILE2, DEFAULT_TEXTURE2],
    [DEFAULT_VRM_FILE3, DEFAULT_TEXTURE3]
]
for example in example_files:
    for file in example:
        if file and not os.path.exists(file):
            print(f"Warning: Example file {file} does not exist!")

def generate_default_uv(mesh, quality='medium'):
    """
    Generate default UV coordinates for a mesh if UV mapping is missing.
    """
    if quality == 'low':
        bounds = mesh.bounds
        width = bounds[1][0] - bounds[0][0]
        height = bounds[1][1] - bounds[0][1]
        uv_coords = np.zeros((len(mesh.vertices), 2))
        uv_coords[:, 0] = (mesh.vertices[:, 0] - bounds[0][0]) / width
        uv_coords[:, 1] = (mesh.vertices[:, 1] - bounds[0][1]) / height
    
    elif quality == 'medium':
        height_range = mesh.vertices[:, 2].max() - mesh.vertices[:, 2].min()
        radius = np.sqrt(mesh.vertices[:, 0]**2 + mesh.vertices[:, 1]**2)
        uv_coords = np.zeros((len(mesh.vertices), 2))
        uv_coords[:, 0] = np.arctan2(mesh.vertices[:, 1], mesh.vertices[:, 0]) / (2 * np.pi) + 0.5
        uv_coords[:, 1] = (mesh.vertices[:, 2] - mesh.vertices[:, 2].min()) / height_range
    
    elif quality == 'high':
        radius = np.sqrt(np.sum(mesh.vertices**2, axis=1))
        uv_coords = np.zeros((len(mesh.vertices), 2))
        uv_coords[:, 0] = np.arctan2(mesh.vertices[:, 1], mesh.vertices[:, 0]) / (2 * np.pi) + 0.5
        uv_coords[:, 1] = np.arccos(mesh.vertices[:, 2] / radius) / np.pi
    else:
        raise ValueError("Invalid quality parameter. Choose from 'low', 'medium', or 'high'.")

    return uv_coords

def apply_texture(mesh, texture_image, uv_scale, uv_quality='medium'):
    """
    Applies the texture to the mesh with UV scaling.
    """
    if not hasattr(mesh.visual, 'uv') or mesh.visual.uv is None:
        # If the mesh does not have UV coordinates, generate them
        print("No UV coordinates found; generating default UV mapping.")
        uv_coords = generate_default_uv(mesh, quality=uv_quality)
    else:
        uv_coords = mesh.visual.uv

    # Ensure UV coordinates exist
    if uv_coords is None:
        raise ValueError("UV coordinates are missing from the mesh.")

    # Apply UV scaling and ensure it is within valid range
    uv_coords = np.clip(uv_coords * uv_scale, 0, 1)

    img_width, img_height = texture_image.size
    texture_array = np.array(texture_image)

    face_colors = []

    for face in mesh.faces:
        uv_face = uv_coords[face]
        pixel_coords = np.round(uv_face * np.array([img_width - 1, img_height - 1])).astype(int)

        valid_coords = np.all((pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] < img_width) &
                              (pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] < img_height))

        if valid_coords:
            face_color = np.mean(texture_array[pixel_coords[:, 1], pixel_coords[:, 0]], axis=0)
            face_colors.append(face_color / 255.0)
        else:
            face_colors.append([0.5, 0.5, 0.5])

    face_colors = np.array(face_colors)
    if len(face_colors) < len(mesh.faces):
        face_colors = np.pad(face_colors, ((0, len(mesh.faces) - len(face_colors)), (0, 0)), 'constant', constant_values=0.5)

    return face_colors

def load_glb_file(filename):
    trimesh_scene = trimesh.load(filename)
    if isinstance(trimesh_scene, trimesh.Scene):
        mesh = trimesh_scene.dump(concatenate=True)
    else:
        mesh = trimesh_scene
    return mesh

def generate_clothing_image(prompt, num_inference_steps):
    """
    Generates the clothing texture based on the provided prompt and number of inference steps.
    """
    image = pipeline(prompt, num_inference_steps=num_inference_steps).images[0]
    return image

def load_vrm_file(filename):
    try:
        vrm_data = trimesh.load(filename, file_type='glb')
        if isinstance(vrm_data, trimesh.Scene):
            mesh = vrm_data.dump(concatenate=True)
        else:
            mesh = vrm_data
    except Exception as e:
        raise ValueError(f"Failed to load VRM file: {e}")

    return mesh

def display_3d_object(obj_file, texture_image, light_intensity, ambient_intensity, color, uv_scale, transparency, uv_quality=None):
    file_extension = obj_file.split('.')[-1].lower()

    if file_extension == 'vrm':
        mesh = load_vrm_file(obj_file)
        try:
            if texture_image:
                face_colors = apply_texture(mesh, texture_image, uv_scale, uv_quality)
            else:
                face_colors = np.array([color] * len(mesh.faces))
        except ValueError as e:
            face_colors = np.array([color] * len(mesh.faces))

        vertices = mesh.vertices
        faces = mesh.faces

        fig = go.Figure(data=[
            go.Mesh3d(
                x=vertices[:, 0],
                y=vertices[:, 1],
                z=vertices[:, 2],
                i=faces[:, 0],
                j=faces[:, 1],
                k=faces[:, 2],
                facecolor=face_colors,
                opacity=transparency,
                lighting=dict(
                    ambient=ambient_intensity,
                    diffuse=light_intensity,
                    specular=0.8,
                    roughness=0.3,
                    fresnel=0.1
                ),
                lightposition=dict(
                    x=100,
                    y=200,
                    z=300
                )
            )
        ])
        fig.update_layout(scene=dict(aspectmode='data'))
        return fig

    else:
        if file_extension == 'obj':
            mesh = trimesh.load(obj_file)
        elif file_extension == 'glb':
            mesh = load_glb_file(obj_file)
        else:
            raise ValueError("Unsupported file format. Please upload a .obj, .glb, or .vrm file.")

        if texture_image:
            face_colors = apply_texture(mesh, texture_image, uv_scale)
        else:
            face_colors = np.array([color] * len(mesh.faces))

        fig = go.Figure(data=[
            go.Mesh3d(
                x=mesh.vertices[:, 0],
                y=mesh.vertices[:, 1],
                z=mesh.vertices[:, 2],
                i=mesh.faces[:, 0],
                j=mesh.faces[:, 1],
                k=mesh.faces[:, 2],
                facecolor=face_colors,
                opacity=transparency,
                lighting=dict(
                    ambient=ambient_intensity,
                    diffuse=light_intensity,
                    specular=0.8,
                    roughness=0.3,
                    fresnel=0.1
                ),
                lightposition=dict(
                    x=100,
                    y=200,
                    z=300
                )
            )
        ])
        fig.update_layout(scene=dict(aspectmode='data'))

        return fig

def clear_texture():
    return None

def restore_original(obj_file):
    return display_3d_object(obj_file, None, 0.8, 0.5, "#D3D3D3", 1.0, 1.0)

def update_texture_display(prompt, texture_file, num_inference_steps):
    if prompt:
        image = generate_clothing_image(prompt, num_inference_steps)
        return image
    elif texture_file:
        return Image.open(texture_file)
    return None

def load_example(obj_file, texture_file):
    """Loads and displays an example 3D object with texture."""
    file_extension = obj_file.split('.')[-1].lower()
    texture_image = None
    if texture_file:
        texture_image = Image.open(texture_file)

    if file_extension == 'vrm':
        return display_3d_object(obj_file, texture_image, 0.8, 0.5, "#D3D3D3", 1.0, 1.0, 'medium')  # Using default values for other parameters
    else:
        return display_3d_object(obj_file, texture_image, 0.8, 0.5, "#D3D3D3", 1.0, 1.0)  # Using default values for other parameters



with gr.Blocks() as demo:
    gr.Markdown("## 3D Object Viewer with Custom Texture, UV Scale, Transparency, Color, and Adjustable Lighting")

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Texture Options")
            prompt_input = gr.Textbox(label="Enter a Prompt to Generate Texture", placeholder="Type a prompt...")
            num_inference_steps_slider = gr.Slider(minimum=5, maximum=100, step=1, value=10, label="Num Inference Steps")
            generate_button = gr.Button("Generate Texture")
            texture_file = gr.File(label="Upload Texture file (PNG or JPG, optional)", type="filepath")
            texture_preview = gr.Image(label="Texture Preview", visible=True)

            gr.Markdown("### Mapping, Lighting & Color Settings")
            uv_scale_slider = gr.Slider(minimum=0.1, maximum=5, step=0.1, value=1.0, label="UV Mapping Scale")
            uv_quality_dropdown = gr.Dropdown(label="UV Quality (for VRM files)", choices=['low', 'medium', 'high'], value='medium')
            light_intensity_slider = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Light Intensity")
            ambient_intensity_slider = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Ambient Intensity")
            transparency_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=1.0, label="Transparency")
            color_picker = gr.ColorPicker(value="#D3D3D3", label="Object Color")
            submit_button = gr.Button("Submit")
            restore_button = gr.Button("Restore")
            clear_button = gr.Button("Clear")
            obj_file = gr.File(label="Upload OBJ, GLB, or VRM file", value=DEFAULT_OBJ_FILE, type='filepath')

        with gr.Column(scale=2):
            display = gr.Plot(label="3D Viewer")

    def update_display(file, texture, uv_scale, uv_quality, light_intensity, ambient_intensity, transparency, color, num_inference_steps):
        file_extension = file.split('.')[-1].lower()
        texture_image = None

        if texture:
            texture_image = Image.open(texture)

        if file_extension == 'vrm':
            return display_3d_object(file, texture_image, light_intensity, ambient_intensity, color, uv_scale, transparency, uv_quality)
        else:
            return display_3d_object(file, texture_image, light_intensity, ambient_intensity, color, uv_scale, transparency)

    def toggle_uv_quality_dropdown(file):
        if file is None:
            return gr.update(visible=False)

        file_extension = file.split('.')[-1].lower()
        return gr.update(visible=(file_extension == 'vrm'))

    submit_button.click(
        fn=update_display,
        inputs=[obj_file, texture_file, uv_scale_slider, uv_quality_dropdown, light_intensity_slider, ambient_intensity_slider, transparency_slider, color_picker, num_inference_steps_slider],
        outputs=display
    )

    obj_file.change(fn=toggle_uv_quality_dropdown, inputs=[obj_file], outputs=uv_quality_dropdown)
    generate_button.click(fn=update_texture_display, inputs=[prompt_input, texture_file, num_inference_steps_slider], outputs=texture_preview)
    restore_button.click(fn=restore_original, inputs=[obj_file], outputs=display)
    clear_button.click(fn=clear_texture, outputs=texture_preview)
    texture_file.change(fn=update_texture_display, inputs=[prompt_input, texture_file, num_inference_steps_slider], outputs=texture_preview)

    demo.load(fn=update_display, inputs=[obj_file, texture_file, uv_scale_slider, uv_quality_dropdown, light_intensity_slider, ambient_intensity_slider, transparency_slider, color_picker, num_inference_steps_slider], outputs=display)
    gr.Examples(
        examples=[
            [DEFAULT_VRM_FILE, DEFAULT_TEXTURE],
            [DEFAULT_OBJ_FILE, None],
            [DEFAULT_GLB_FILE, None],
            [DEFAULT_VRM_FILE2, DEFAULT_TEXTURE2],
            [DEFAULT_VRM_FILE3, DEFAULT_TEXTURE3]
        ],
        inputs=[obj_file, texture_file],
        outputs=display,  # Specify the output component
        fn=load_example,  # Specify the function to load the example
        label="Example Files",
        cache_examples=False  # Disable caching
    )

demo.launch(debug=True)