import os
import shlex
import subprocess

import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import DiffusionPipeline

subprocess.run(
    shlex.split(
        "pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"
    )
)

TMP_DIR = "/tmp"
os.makedirs(TMP_DIR, exist_ok=True)


image_pipeline = DiffusionPipeline.from_pretrained(
    "dylanebert/imagedream",
    custom_pipeline="dylanebert/multi-view-diffusion",
    torch_dtype=torch.float16,
    trust_remote_code=True,
).to("cuda")


splat_pipeline = DiffusionPipeline.from_pretrained(
    "dylanebert/LGM",
    custom_pipeline="dylanebert/LGM",
    torch_dtype=torch.float16,
    trust_remote_code=True,
).to("cuda")


@spaces.GPU
def run(input_image, seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    input_image = input_image.astype("float32") / 255.0
    images = image_pipeline(
        "", input_image, guidance_scale=5, num_inference_steps=30, elevation=0
    )
    gaussians = splat_pipeline(images)
    output_ply_path = os.path.join(TMP_DIR, "output.ply")
    splat_pipeline.save_ply(gaussians, output_ply_path)
    return output_ply_path


_TITLE = """LGM Mini"""

_DESCRIPTION = """
<div>
A lightweight version of <a href="https://huggingface.co/spaces/ashawkey/LGM">LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation</a>.

To convert to mesh, download the output splat and visit [splat-to-mesh](https://huggingface.co/spaces/dylanebert/splat-to-mesh).
</div>
"""

css = """
#duplicate-button {
    margin: auto;
    color: white;
    background: #1565c0;
    border-radius: 100vh;
}
"""

block = gr.Blocks(title=_TITLE, css=css)
with block:
    gr.DuplicateButton(
        value="Duplicate Space for private use", elem_id="duplicate-button"
    )

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("# " + _TITLE)
    gr.Markdown(_DESCRIPTION)

    with gr.Row(variant="panel"):
        with gr.Column(scale=1):
            input_image = gr.Image(label="image", type="numpy")
            seed_input = gr.Number(label="seed", value=42)
            button_gen = gr.Button("Generate")

        with gr.Column(scale=1):
            output_splat = gr.Model3D(label="3D Gaussians")

        button_gen.click(
            fn=run, inputs=[input_image, seed_input], outputs=[output_splat]
        )

    gr.Examples(
        examples=[
            "https://huggingface.co/datasets/dylanebert/iso3d/resolve/main/jpg@512/a_cat_statue.jpg",
            "https://huggingface.co/datasets/dylanebert/iso3d/resolve/main/jpg@512/a_baby_penguin.jpg",
            "https://huggingface.co/datasets/dylanebert/iso3d/resolve/main/jpg@512/A_cartoon_house_with_red_roof.jpg",
            "https://huggingface.co/datasets/dylanebert/iso3d/resolve/main/jpg@512/a_hat.jpg",
            "https://huggingface.co/datasets/dylanebert/iso3d/resolve/main/jpg@512/an_antique_chest.jpg",
            "https://huggingface.co/datasets/dylanebert/iso3d/resolve/main/jpg@512/metal.jpg",
        ],
        inputs=[input_image],
        outputs=[output_splat],
        fn=lambda x: run(input_image=x, seed=42),
        cache_examples=True,
        label="Image-to-3D Examples",
    )

block.queue().launch(debug=True, share=True)