File size: 2,634 Bytes
15216b5
68ba513
 
58b19a1
 
 
 
 
994b6ed
57e3a39
3f7676c
58b19a1
 
 
 
 
 
 
 
 
 
b38da9b
58b19a1
 
7b7ab95
f489399
58b19a1
 
 
 
 
 
7b7ab95
58b19a1
 
 
 
65e5c64
cd8aa5a
0260030
58b19a1
15216b5
994b6ed
7b7ab95
58b19a1
 
f3dfeae
 
 
7b7ab95
f3dfeae
7b7ab95
f3dfeae
58b19a1
 
f3dfeae
 
f6524bf
7b7ab95
 
f3dfeae
58b19a1
f3dfeae
 
 
 
 
 
994b6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3dfeae
994b6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15216b5
 
994b6ed
 
 
15216b5
994b6ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import monai
import torch
from monai.networks.nets import UNet
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import shutil
import os
import openslide

model = UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=1,
    channels=[16, 32, 64, 128, 256, 512],
    strides=(2, 2, 2, 2, 2),
    num_res_units=4,
    dropout=0.15,
)
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
model.eval()

def process_image(image):
    image = image / 255.0
    image = image.astype(np.float32)

    inference_transforms = A.Compose([
        A.Resize(height=512, width=512),
        ToTensorV2(),
    ])
    
    image = inference_transforms(image=image)["image"]
    image = image.unsqueeze(0)

    with torch.no_grad():
        mask_pred = torch.sigmoid(model(image))

    return mask_pred[0, 0, :, :].numpy()
    

interface_image = gr.Interface(
    fn=process_image,
    title="Histapathology segmentation",
    inputs=[
        gr.Image(
            label="Input image",
            image_mode="RGB",
            height=400,
            type="numpy",
            width=400,
        )
    ],
    outputs=[
        gr.Image(
            label="Model Prediction",
            image_mode="L",
            height=400,
            width=400,
        )
    ],
    # examples=[
    #     os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
    #     os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
    #     os.path.join(os.path.dirname(__file__), "images/logo.png"),
    #     os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
    # ],
)

def process_slide(slide_path):
    if not slide_path.endswith("zip"):
        slide = openslide.OpenSlide(os.path.join(path, image_path))
    else: # mrxs slide files 
        shutil.unpack_archive(slide_path, "cache_mrxs")

        files = os.listdir("cache_mrxs")
        slide_name = [file for file in files if file.endswith("mrxs")][0]
        slide = openslide.OpenSlide(os.path.join("cache_mrxs", slide_name))

    slide.get_thumbnail((512, 512))
    


    return slide


interface_slide = gr.Interface(
    fn=process_slide,
    inputs=[
        gr.File(
            label="Input slide file (input zip for `.mrxs` files)",            
        )
    ],
    outputs=[
        gr.Image(
            label="Model Prediction",
            image_mode="RGB",
            height=400,
            width=400,
        )
    ],
)


demo = gr.TabbedInterface([interface_image, interface_slide], ["Image-to-Mask", "Slide-to-Mask"])

demo.launch()