import os

import gradio as gr

from .convert import nifti_to_obj
from .css_style import css
from .inference import run_model
from .logger import flush_logs
from .logger import read_logs
from .logger import setup_logger
from .utils import load_ct_to_numpy
from .utils import load_pred_volume_to_numpy

# setup logging
LOGGER = setup_logger()


class WebUI:
    def __init__(
        self,
        model_name: str = None,
        cwd: str = "/home/user/app/",
        share: int = 1,
    ):
        # global states
        self.images = []
        self.pred_images = []

        # @TODO: This should be dynamically set based on chosen volume size
        self.nb_slider_items = 820

        self.model_name = model_name
        self.cwd = cwd
        self.share = share

        self.filename = None
        self.extension = None

        self.class_name = "airways"  # default
        self.class_names = {
            "airways": "CT_Airways",
            "lungs": "CT_Lungs",
        }

        self.result_names = {
            "airways": "Airways",
            "lungs": "Lungs",
        }

        # define widgets not to be rendered immediantly, but later on
        self.slider = gr.Slider(
            minimum=1,
            maximum=self.nb_slider_items,
            value=1,
            step=1,
            label="Which 2D slice to show",
        )
        self.volume_renderer = gr.Model3D(
            clear_color=[0.0, 0.0, 0.0, 0.0],
            label="3D Model",
            show_label=True,
            visible=True,
            elem_id="model-3d",
            camera_position=[90, 180, 768],
            height=512,
        )

    def set_class_name(self, value):
        LOGGER.info(f"Changed task to: {value}")
        self.class_name = value

    def combine_ct_and_seg(self, img, pred):
        return (img, [(pred, self.class_name)])

    def upload_file(self, file):
        out = file.name
        LOGGER.info(f"File uploaded: {out}")
        return out

    def process(self, mesh_file_name):
        path = mesh_file_name.name
        curr = path.split("/")[-1]
        self.extension = ".".join(curr.split(".")[1:])
        self.filename = (
            curr.split(".")[0] + "-" + self.class_names[self.class_name]
        )
        run_model(
            path,
            model_path=os.path.join(self.cwd, "resources/models/"),
            task=self.class_names[self.class_name],
            name=self.result_names[self.class_name],
            output_filename=self.filename + "." + self.extension,
        )
        LOGGER.info("Converting prediction NIfTI to OBJ...")
        nifti_to_obj(path=self.filename + "." + self.extension)

        LOGGER.info("Loading CT to numpy...")
        self.images = load_ct_to_numpy(path)

        LOGGER.info("Loading prediction volume to numpy..")
        self.pred_images = load_pred_volume_to_numpy(
            self.filename + "." + self.extension
        )

        return "./prediction.obj"

    def download_prediction(self):
        if (self.filename is None) or (self.extension is None):
            LOGGER.error(
                "The prediction is not available or ready to download. Wait until the result is available in the 3D viewer."
            )
            raise ValueError("Run inference before downloading!")
        return self.filename + "." + self.extension

    def get_img_pred_pair(self, k):
        k = int(k)
        out = gr.AnnotatedImage(
            self.combine_ct_and_seg(self.images[k], self.pred_images[k]),
            visible=True,
            elem_id="model-2d",
            color_map={self.class_name: "#ffae00"},
            height=512,
            width=512,
        )
        return out

    def toggle_sidebar(self, state):
        state = not state
        return gr.update(visible=state), state

    def run(self):
        with gr.Blocks(css=css) as demo:
            with gr.Row():
                with gr.Column(visible=True, scale=0.2) as sidebar_left:
                    logs = gr.Textbox(
                        placeholder="\n" * 16,
                        label="Logs",
                        info="Verbose from inference will be displayed below.",
                        lines=36,
                        max_lines=36,
                        autoscroll=True,
                        elem_id="logs",
                        show_copy_button=True,
                        container=True,
                    )
                    demo.load(read_logs, None, logs, every=1)

                with gr.Column():
                    with gr.Row():
                        with gr.Column(scale=1, min_width=150):
                            sidebar_state = gr.State(True)

                            btn_toggle_sidebar = gr.Button(
                                "Toggle Sidebar",
                                elem_id="toggle-button",
                            )
                            btn_toggle_sidebar.click(
                                self.toggle_sidebar,
                                [sidebar_state],
                                [sidebar_left, sidebar_state],
                            )

                            btn_clear_logs = gr.Button(
                                "Clear logs", elem_id="logs-button"
                            )
                            btn_clear_logs.click(flush_logs, [], [])

                        file_output = gr.File(
                            file_count="single",
                            elem_id="upload",
                            scale=3,
                        )
                        file_output.upload(
                            self.upload_file, file_output, file_output
                        )

                        model_selector = gr.Dropdown(
                            list(self.class_names.keys()),
                            label="Task",
                            info="Which structure to segment.",
                            multiselect=False,
                            scale=1,
                        )
                        model_selector.input(
                            fn=lambda x: self.set_class_name(x),
                            inputs=model_selector,
                            outputs=None,
                        )

                        with gr.Column(scale=1, min_width=150):
                            run_btn = gr.Button(
                                "Run analysis",
                                variant="primary",
                                elem_id="run-button",
                            )
                            run_btn.click(
                                fn=lambda x: self.process(x),
                                inputs=file_output,
                                outputs=self.volume_renderer,
                            )

                            download_btn = gr.DownloadButton(
                                "Download prediction",
                                visible=True,
                                variant="secondary",
                                elem_id="download",
                            )
                            download_btn.click(
                                fn=self.download_prediction,
                                inputs=None,
                                outputs=download_btn,
                            )

                    with gr.Row():
                        gr.Examples(
                            examples=[
                                os.path.join(self.cwd, "test_thorax_CT.nii.gz"),
                            ],
                            inputs=file_output,
                            outputs=file_output,
                            fn=self.upload_file,
                            cache_examples=True,
                        )

                        gr.Markdown(
                            """
                            **NOTE:** Inference might take several minutes (Airways: ~8 minutes), see logs to the left. \\
                            The segmentation will be available in the 2D and 3D viewers below when finished.
                            """
                        )

                    with gr.Row():
                        with gr.Group():
                            with gr.Column():
                                # create dummy image to be replaced by loaded images
                                t = gr.AnnotatedImage(
                                    visible=True,
                                    elem_id="model-2d",
                                    color_map={self.class_name: "#ffae00"},
                                    # height=512,
                                    # width=512,
                                )
                                self.slider.input(
                                    self.get_img_pred_pair,
                                    self.slider,
                                    t,
                                )

                                self.slider.render()

                        with gr.Group():  # gr.Box():
                            self.volume_renderer.render()

        # sharing app publicly -> share=True:
        # https://gradio.app/sharing-your-app/
        # inference times > 60 seconds -> need queue():
        # https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
        demo.queue().launch(
            server_name="0.0.0.0", server_port=7860, share=self.share
        )