Spaces:
Sleeping
Sleeping
File size: 2,904 Bytes
15216b5 68ba513 58b19a1 994b6ed 57e3a39 3f7676c f2c95e9 58b19a1 b38da9b 58b19a1 7b7ab95 f489399 58b19a1 7b7ab95 58b19a1 65e5c64 cd8aa5a 0260030 58b19a1 15216b5 994b6ed 7b7ab95 58b19a1 f3dfeae 7b7ab95 f3dfeae 7b7ab95 f3dfeae 58b19a1 f3dfeae f6524bf 7b7ab95 f3dfeae 58b19a1 f3dfeae 994b6ed 68cc1c8 994b6ed 7c4ff58 994b6ed c3fc582 f2c95e9 4cc4c48 994b6ed 4cc4c48 994b6ed 4cc4c48 994b6ed 15216b5 994b6ed 15216b5 994b6ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import monai
import torch
from monai.networks.nets import UNet
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import shutil
import os
import openslide
from project_utils.preprocessing import expand2square
model = UNet(
spatial_dims=2,
in_channels=3,
out_channels=1,
channels=[16, 32, 64, 128, 256, 512],
strides=(2, 2, 2, 2, 2),
num_res_units=4,
dropout=0.15,
)
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
model.eval()
def process_image(image):
image = image / 255.0
image = image.astype(np.float32)
inference_transforms = A.Compose([
A.Resize(height=512, width=512),
ToTensorV2(),
])
image = inference_transforms(image=image)["image"]
image = image.unsqueeze(0)
with torch.no_grad():
mask_pred = torch.sigmoid(model(image))
return mask_pred[0, 0, :, :].numpy()
interface_image = gr.Interface(
fn=process_image,
title="Histapathology segmentation",
inputs=[
gr.Image(
label="Input image",
image_mode="RGB",
height=400,
type="numpy",
width=400,
)
],
outputs=[
gr.Image(
label="Model Prediction",
image_mode="L",
height=400,
width=400,
)
],
# examples=[
# os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
# os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
# os.path.join(os.path.dirname(__file__), "images/logo.png"),
# os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
# ],
)
def process_slide(slide_path):
if not slide_path.endswith("zip"):
slide = openslide.OpenSlide(os.path.join(slide_path))
else: # mrxs slide files
shutil.unpack_archive(slide_path, "cache_mrxs")
files = os.listdir("cache_mrxs")
slide_name = [file for file in files if file.endswith("mrxs")][0]
slide = openslide.OpenSlide(os.path.join("cache_mrxs", slide_name))
thumbnail = slide.get_thumbnail((512, 512))
image = expand2square(thumbnail, "white")
return image, process_image(np.array(image))
interface_slide = gr.Interface(
fn=process_slide,
inputs=[
gr.File(
label="Input slide file (input zip for `.mrxs` files)",
)
],
outputs=[
gr.Image(
label="Input Image",
image_mode="RGB",
height=400,
width=400,
),
gr.Image(
label="Model Prediction",
image_mode="L",
height=400,
width=400,
)
],
)
demo = gr.TabbedInterface([interface_image, interface_slide], ["Image-to-Mask", "Slide-to-Mask"])
demo.launch()
|