import gradio as gr
import monai
import torch
from monai.networks.nets import UNet
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import shutil
import os
import openslide
from project_utils.preprocessing import expand2square

model = UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=1,
    channels=[16, 32, 64, 128, 256, 512],
    strides=(2, 2, 2, 2, 2),
    num_res_units=4,
    dropout=0.15,
)
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
model.eval()

def process_image(image):
    image = image / 255.0
    image = image.astype(np.float32)

    inference_transforms = A.Compose([
        A.Resize(height=512, width=512),
        ToTensorV2(),
    ])
    
    image = inference_transforms(image=image)["image"]
    image = image.unsqueeze(0)

    with torch.no_grad():
        mask_pred = torch.sigmoid(model(image))

    return mask_pred[0, 0, :, :].numpy()
    

interface_image = gr.Interface(
    fn=process_image,
    title="Histapathology segmentation",
    inputs=[
        gr.Image(
            label="Input image",
            image_mode="RGB",
            height=400,
            type="numpy",
            width=400,
        )
    ],
    outputs=[
        gr.Image(
            label="Model Prediction",
            image_mode="L",
            height=400,
            width=400,
        )
    ],
    # examples=[
    #     os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
    #     os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
    #     os.path.join(os.path.dirname(__file__), "images/logo.png"),
    #     os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
    # ],
)

def process_slide(slide_path):
    if not slide_path.endswith("zip"):
        slide = openslide.OpenSlide(os.path.join(slide_path))
    else: # mrxs slide files 
        shutil.unpack_archive(slide_path, "cache_mrxs")

        files = os.listdir("cache_mrxs")
        slide_name = [file for file in files if file.endswith("mrxs")][0]
        slide = openslide.OpenSlide(os.path.join("cache_mrxs", slide_name))

    thumbnail = slide.get_thumbnail((512, 512))

    image = expand2square(thumbnail, "white")

    return image, process_image(np.array(image))


interface_slide = gr.Interface(
    fn=process_slide,
    inputs=[
        gr.File(
            label="Input slide file (input zip for `.mrxs` files)",            
        )
    ],
    outputs=[
        gr.Image(
            label="Input Image",
            image_mode="RGB",
            height=400,
            width=400,
        ),
        gr.Image(
            label="Model Prediction",
            image_mode="L",
            height=400,
            width=400,
        )
    ],
)


demo = gr.TabbedInterface([interface_image, interface_slide], ["Image-to-Mask", "Slide-to-Mask"])

demo.launch()