import sys
sys.path.append("flash3d")

from omegaconf import OmegaConf
import gradio as gr
import spaces
import torch
import torchvision.transforms as TT
import torchvision.transforms.functional as TTF
from huggingface_hub import hf_hub_download

from networks.gaussian_predictor import GaussianPredictor
from util.vis3d import save_ply


def main():
    if torch.cuda.is_available():
        device = "cuda:0"
    else:
        device = "cpu"

    model_cfg_path = hf_hub_download(repo_id="einsafutdinov/flash3d", 
                                     filename="config_re10k_v1.yaml")
    model_path = hf_hub_download(repo_id="einsafutdinov/flash3d", 
                                 filename="model_re10k_v1.pth")

    cfg = OmegaConf.load(model_cfg_path)
    model = GaussianPredictor(cfg)
    device = torch.device("cuda:0")
    model.to(device)
    model.load_model(model_path)

    pad_border_fn = TT.Pad((cfg.dataset.pad_border_aug, cfg.dataset.pad_border_aug))
    to_tensor = TT.ToTensor()

    def check_input_image(input_image):
        if input_image is None:
            raise gr.Error("No image uploaded!")

    def preprocess(image):
        image = TTF.resize(
            image, (cfg.dataset.height, cfg.dataset.width), 
            interpolation=TT.InterpolationMode.BICUBIC
        )
        image = pad_border_fn(image)
        return image

    @spaces.GPU()
    def reconstruct_and_export(image):
        """
        Passes image through model, outputs reconstruction in form of a dict of tensors.
        """
        image = to_tensor(image).to(device).unsqueeze(0)
        inputs = {
            ("color_aug", 0, 0): image,
        }

        outputs = model(inputs)

        # export reconstruction to ply
        save_ply(outputs, ply_out_path, num_gauss=2)

        return ply_out_path
    
    ply_out_path = f'./mesh.ply'

    css = """
        h1 {
            text-align: center;
            display:block;
        }
        """

    with gr.Blocks(css=css) as demo:
        gr.Markdown(
            """
            # Flash3D
            """
            )
        with gr.Row(variant="panel"):
            with gr.Column(scale=1):
                with gr.Row():
                    input_image = gr.Image(
                        label="Input Image",
                        image_mode="RGBA",
                        sources="upload",
                        type="pil",
                        elem_id="content_image",
                    )
                with gr.Row():
                    submit = gr.Button("Generate", elem_id="generate", variant="primary")

                with gr.Row(variant="panel"): 
                    gr.Examples(
                        examples=[
                            './demo_examples/bedroom_01.png',
                            './demo_examples/kitti_02.png',
                            './demo_examples/kitti_03.png',
                            './demo_examples/re10k_04.jpg',
                            './demo_examples/re10k_05.jpg',
                            './demo_examples/re10k_06.jpg',
                        ],
                        inputs=[input_image],
                        cache_examples=False,
                        label="Examples",
                        examples_per_page=20,
                    )

                with gr.Row():
                    processed_image = gr.Image(label="Processed Image", interactive=False)

            with gr.Column(scale=2):
                with gr.Row():
                    with gr.Tab("Reconstruction"):
                        output_model = gr.Model3D(
                            height=512,
                            label="Output Model",
                            interactive=False
                        )

        # gr.Markdown(
        # """
        #     ## Comments:
        #     1. If you run the demo online, the first example you upload should take about 4.5 seconds (with preprocessing, saving and overhead), the following take about 1.5s.
        #     2. The 3D viewer shows a .ply mesh extracted from a mix of 3D Gaussians. This is only an approximations and artefacts might show.
        #     3. Known limitations include:
        #     - a black dot appearing on the model from some viewpoints
        #     - see-through parts of objects, especially on the back: this is due to the model performing less well on more complicated shapes
        #     - back of objects are blurry: this is a model limiation due to it being deterministic
        #     4. Our model is of comparable quality to state-of-the-art methods, and is **much** cheaper to train and run.
        #     ## How does it work?
        #     Splatter Image formulates 3D reconstruction as an image-to-image translation task. It maps the input image to another image, 
        #     in which every pixel represents one 3D Gaussian and the channels of the output represent parameters of these Gaussians, including their shapes, colours and locations.
        #     The resulting image thus represents a set of Gaussians (almost like a point cloud) which reconstruct the shape and colour of the object.
        #     The method is very cheap: the reconstruction amounts to a single forward pass of a neural network with only 2D operators (2D convolutions and attention).
        #     The rendering is also very fast, due to using Gaussian Splatting.
        #     Combined, this results in very cheap training and high-quality results.
        #     For more results see the [project page](https://szymanowiczs.github.io/splatter-image) and the [CVPR article](https://arxiv.org/abs/2312.13150).
        #     """
        # )

        submit.click(fn=check_input_image, inputs=[input_image]).success(
            fn=preprocess,
            inputs=[input_image],
            outputs=[processed_image],
        ).success(
            fn=reconstruct_and_export,
            inputs=[processed_image],
            outputs=[output_model],
        )

    demo.queue(max_size=1)
    demo.launch(share=True)


if __name__ == "__main__":
    main()