# -*- coding: utf-8 -*-
#
# @File:   app.py
# @Author: Haozhe Xie
# @Date:   2024-03-02 16:30:00
# @Last Modified by: Haozhe Xie
# @Last Modified at: 2024-10-13 15:36:50
# @Email:  root@haozhexie.com

import gradio as gr
import logging
import numpy as np
import os
import pickle
import ssl
import subprocess
import sys
import urllib.request

from PIL import Image

# Reinstall PyTorch with CUDA 11.8 (Default version is 12.1)
# subprocess.call(
#     [
#         "pip",
#         "install",
#         "torch==2.2.2",
#         "torchvision==0.17.2",
#         "--index-url",
#         "https://download.pytorch.org/whl/cu118",
#     ]
# )
import torch

# Create a dummy decorator for Non-ZeroGPU environments
if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:

    class spaces:
        @staticmethod
        def GPU(func):
            # This is a dummy wrapper that just calls the function.
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)

            return wrapper


# Fix: ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
ssl._create_default_https_context = ssl._create_unverified_context
# Import GaussianCity modules
sys.path.append(os.path.join(os.path.dirname(__file__), "gaussiancity"))


def _get_output(cmd):
    try:
        return subprocess.check_output(cmd).decode("utf-8")
    except Exception as ex:
        logging.exception(ex)

    return None


def install_cuda_toolkit():
    # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
    CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
    CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
    subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
    subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
    subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])

    os.environ["CUDA_HOME"] = "/usr/local/cuda"
    os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
    os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
        os.environ["CUDA_HOME"],
        "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
    )
    # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
    os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"


def setup_runtime_env():
    logging.info("Python Version: %s" % _get_output(["python", "--version"]))
    logging.info("CUDA Version: %s" % _get_output(["nvcc", "--version"]))
    logging.info("GCC Version: %s" % _get_output(["gcc", "--version"]))
    logging.info("CUDA is available: %s" % torch.cuda.is_available())
    logging.info("CUDA Device Capability: %s" % (torch.cuda.get_device_capability(),))

    # Install Pre-compiled CUDA extensions
    # Ref: https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/110
    ext_dir = os.path.join(os.path.dirname(__file__), "wheels")
    for e in os.listdir(ext_dir):
        logging.info("Installing Extensions from %s" % e)
        subprocess.call(
            ["pip", "install", os.path.join(ext_dir, e)], stderr=subprocess.STDOUT
        )
    # Compile CUDA extensions
    # ext_dir = os.path.join(os.path.dirname(__file__), "gaussiancity", "extensions")
    # for e in os.listdir(ext_dir):
    #     if os.path.isdir(os.path.join(ext_dir, e)):
    #         subprocess.call(["pip", "install", "."], cwd=os.path.join(ext_dir, e))

    logging.info("Installed Python Packages: %s" % _get_output(["pip", "list"]))


def get_models(file_name):
    import gaussiancity.generator

    if not os.path.exists(file_name):
        urllib.request.urlretrieve(
            "https://huggingface.co/hzxie/gaussian-city/resolve/main/%s" % file_name,
            file_name,
        )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    ckpt = torch.load(file_name, map_location=torch.device(device), weights_only=False)
    model = gaussiancity.generator.Generator(
        ckpt["cfg"].NETWORK.GAUSSIAN,
        n_classes=ckpt["cfg"].DATASETS.GOOGLE_EARTH.N_CLASSES,
        proj_size=ckpt["cfg"].DATASETS.GOOGLE_EARTH.PROJ_SIZE,
    )
    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda().eval()

    model.load_state_dict(ckpt["gaussian_g"], strict=False)
    return model


def get_city_layout():
    import gaussiancity.inference

    layout = None
    if os.path.exists("assets/NYC.pkl"):
        with open("assets/NYC.pkl", "rb") as fp:
            layout = pickle.load(fp)
    else:
        td_hf = np.array(Image.open("assets/NYC-HghtFld.png")).astype(np.int32)
        # Fix: nonzero is not supported for tensors with more than INT_MAX elements
        td_hf[td_hf > 500] = 500
        bu_hf = np.zeros_like(td_hf)
        seg_map = np.array(Image.open("assets/NYC-SegMap.png").convert("P")).astype(
            np.int32
        )
        ins_map = gaussiancity.inference.get_instance_seg_map(seg_map.copy())
        pts_map = gaussiancity.inference.get_point_map(seg_map)
        layout = {
            "TD_HF": td_hf,
            "BU_HF": bu_hf,
            "SEG": seg_map,
            "INS": ins_map,
            "PTS": pts_map,
        }
        with open("assets/NYC.pkl", "wb") as fp:
            pickle.dump(layout, fp)

    centers = None
    if os.path.exists("assets/CENTERS.pkl"):
        with open("assets/CENTERS.pkl", "rb") as fp:
            centers = pickle.load(fp)
    else:
        centers = gaussiancity.inference.get_centers(layout["INS"], layout["TD_HF"])
        with open("assets/CENTERS.pkl", "wb") as fp:
            pickle.dump(centers, fp)

    layout["CTR"] = centers
    return layout


@spaces.GPU
def get_generated_city(radius, altitude, azimuth, map_center):
    logging.info("CUDA is available: %s" % torch.cuda.is_available())
    logging.info("PyTorch is built with CUDA: %s" % torch.version.cuda)
    # The import must be done after CUDA extension compilation
    import gaussiancity.inference

    return gaussiancity.inference.generate_city(
        get_generated_city.fgm.to("cuda"),
        get_generated_city.bgm.to("cuda"),
        get_generated_city.city_layout,
        map_center,
        map_center,
        radius,
        altitude,
        azimuth,
    )


def main(debug):
    title = "Generative Gaussian Splatting for Unbounded 3D City Generation"
    with open("README.md", "r") as f:
        markdown = f.read()
        desc = markdown[markdown.rfind("---") + 3 :]
    with open("ARTICLE.md", "r") as f:
        arti = f.read()

    app = gr.Interface(
        get_generated_city,
        [
            gr.Slider(256, 960, value=768, step=4, label="Camera Radius (m)"),
            gr.Slider(256, 960, value=768, step=4, label="Camera Altitude (m)"),
            gr.Slider(0, 360, value=210, step=5, label="Camera Azimuth (°)"),
            gr.Slider(1024, 7168, value=3570, step=4, label="Map Center (px)"),
        ],
        [gr.Image(type="numpy", label="Generated City")],
        title=title,
        description=desc,
        article=arti,
        flagging_mode="never",
    )
    app.queue(api_open=False)
    app.launch(debug=debug)


if __name__ == "__main__":
    logging.basicConfig(
        format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO
    )
    logging.info("Environment Variables: %s" % os.environ)
    # if _get_output(["nvcc", "--version"]) is None:
    #     logging.info("Installing CUDA toolkit...")
    #     install_cuda_toolkit()
    # else:
    #     logging.info("Detected CUDA: %s" % _get_output(["nvcc", "--version"]))

    logging.info("Compiling CUDA extensions...")
    setup_runtime_env()

    logging.info("Downloading pretrained models...")
    fgm = get_models("GaussianCity-Fgnd.pth")
    bgm = get_models("GaussianCity-Bgnd.pth")
    get_generated_city.fgm = fgm
    get_generated_city.bgm = bgm

    logging.info("Loading New York city layout to RAM...")
    city_layout = get_city_layout()
    get_generated_city.city_layout = city_layout

    logging.info("Starting the main application...")
    main(os.getenv("DEBUG") == "1")