import spaces
import tempfile
import os
from pathlib import Path
import SimpleITK as sitk
import numpy as np
import nibabel as nib
from totalsegmentator.python_api import totalsegmentator
import gradio as gr
from segmap import seg_map
import logging

# Logging configuration
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

sample_files = ["ct1.nii.gz", "ct2.nii.gz", "ct3.nii.gz"]


def map_labels(seg_array):
    labels = []
    count = 0
    logger.debug("unique segs:")
    logger.debug(str(len(np.unique(seg_array))))
    for seg_class in np.unique(seg_array):
        if seg_class == 0:
            continue
        labels.append((seg_array == seg_class, seg_map[seg_class]))
        count += 1

    return labels

def sitk_to_numpy(img_sitk, norm=False):
    img_sitk = sitk.DICOMOrient(img_sitk, "LPS")
    img_np = sitk.GetArrayFromImage(img_sitk)
    if norm:
        min_val, max_val = np.min(img_np), np.max(img_np)
        img_np = ((img_np - min_val) / (max_val - min_val)).clip(0, 1) * 255
    img_np = img_np.astype(np.uint8)
    return img_np


def load_image(path, norm=False):
    img_sitk = sitk.ReadImage(path)
    return sitk_to_numpy(img_sitk, norm)


def show_img_seg(img_np, seg_np=None, slice_idx=50):
    if img_np is None or (isinstance(img_np, list) and len(img_np) == 0):
        return None
    if isinstance(img_np, list):
        img_np = img_np[-1]
    slice_pos = int(slice_idx * (img_np.shape[0] / 100))
    img_slice = img_np[slice_pos, :, :]

    if seg_np is None or (isinstance(seg_np, list) and len(seg_np) == 0):
        seg_np = []
    else:
        if isinstance(seg_np, list):
            seg_np = seg_np[-1]
        seg_np = map_labels(seg_np[slice_pos, :, :])

    return img_slice, seg_np


def load_img_to_state(path, img_state, seg_state):
    img_state.clear()
    seg_state.clear()

    if path:
        img_np = load_image(path, norm=True)
        img_state.append(img_np)
        return None, img_state, seg_state
    else:
        return None, img_state, seg_state
    

def save_seg(seg, path):
    if Path(path).name in sample_files:
        path = os.path.join("output_examples", f"{Path(Path(path).stem).stem}_seg.nii.gz")
    else:
        sitk.WriteImage(seg, path)

    return path


@spaces.GPU(duration=150)
def run_inference(path):
    with tempfile.TemporaryDirectory() as temp_dir:
        input_nib = nib.load(path)
        output_nib = totalsegmentator(input_nib, fast=True)
        output_path = os.path.join(temp_dir, "totalseg_output.nii.gz")
        nib.save(output_nib, output_path)
        seg_sitk = sitk.ReadImage(output_path)
    return seg_sitk


def inference_wrapper(input_file, img_state, seg_state, slice_slider=50):
    file_name = Path(input_file).name

    if file_name in sample_files:
        seg_sitk = sitk.ReadImage(os.path.join("output_examples", f"{Path(Path(file_name).stem).stem}_seg.nii.gz"))
    else:
        seg_sitk = run_inference(input_file.name)

    seg_path = save_seg(seg_sitk, input_file.name)
    seg_state.append(sitk_to_numpy(seg_sitk))

    if not img_state:
        img_sitk = sitk.ReadImage(input_file.name)
        img_state.append(sitk_to_numpy(img_sitk))

    return show_img_seg(img_state[-1], seg_state[-1], slice_slider), seg_state, seg_path


with gr.Blocks(title="TotalSegmentator") as interface:

    gr.Markdown("# TotalSegmentator: Segmentation of 117 Classes in CT and MR Images")
    gr.Markdown("""
- **GitHub:** https://github.com/wasserth/TotalSegmentator
- **Please Note:** This tool is intended for research purposes only and can segment 117 classes in CT/MRI images
- Supports both CT and MR imaging modalities
- Credit: adapted from `DiGuaQiu/MRSegmentator-Gradio`
""")

    img_state = gr.State([])
    seg_state = gr.State([])

    with gr.Accordion(label='Upload CT Scan (nifti file) then click on Generate Segmentation to run TotalSegmentator', open=True):
        with gr.Row():
            with gr.Column():

                file_input = gr.File(
                    type="filepath", label="Upload a CT or MR Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"]
                )
                gr.Examples(["input_examples/" + example for example in sample_files], file_input)

                with gr.Row():
                    infer_button = gr.Button("Generate Segmentations", variant="primary")
                    clear_button = gr.ClearButton()

            with gr.Column():
                slice_slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice")
                img_viewer = gr.AnnotatedImage(label="Image Viewer")
                download_seg = gr.File(label="Download Segmentation", interactive=False)

    file_input.change(
        load_img_to_state,
        inputs=[file_input, img_state, seg_state],
        outputs=[img_viewer, img_state, seg_state],
    )
    slice_slider.change(show_img_seg, inputs=[img_state, seg_state, slice_slider], outputs=[img_viewer])

    infer_button.click(
        inference_wrapper,
        inputs=[file_input, img_state, seg_state, slice_slider],
        outputs=[img_viewer, seg_state, download_seg],
    )

    clear_button.add([file_input, img_viewer, img_state, seg_state, download_seg])


if __name__ == "__main__":
    interface.queue()
    interface.launch(debug=True)