Spaces:
Sleeping
Sleeping
File size: 1,827 Bytes
15216b5 68ba513 58b19a1 15216b5 58b19a1 b38da9b 58b19a1 f3dfeae 58b19a1 49684df f3dfeae f489399 58b19a1 cd8aa5a 58b19a1 cd8aa5a 58b19a1 65e5c64 cd8aa5a f6524bf 58b19a1 15216b5 58b19a1 f3dfeae 034449b f3dfeae 034449b f3dfeae 58b19a1 f3dfeae f6524bf 034449b f3dfeae 58b19a1 f3dfeae 15216b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
import monai
import torch
from monai.networks.nets import UNet
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
model = UNet(
spatial_dims=2,
in_channels=3,
out_channels=1,
channels=[16, 32, 64, 128, 256, 512],
strides=(2, 2, 2, 2, 2),
num_res_units=4,
dropout=0.15,
)
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
model.eval()
def greet(image):
# image = Image.open(image_path).convert("RGB")
# image = np.array(image) / 255.0
image = image / 255.0
image = image.astype(np.float32)
inference_transforms = A.Compose([
A.Resize(height=512, width=512),
ToTensorV2(),
])
print(image.shape)
image = inference_transforms(image=image)["image"]
print(image.shape)
image = image.unsqueeze(0)
with torch.no_grad():
mask_pred = torch.sigmoid(model(image))
print(image.shape)
print(mask_pred.shape)
return mask_pred[0].numpy()
demo = gr.Interface(
fn=greet,
title="Histapathology segmentation",
inputs=[
gr.Image(
label="Input image",
image_mode="RGB",
# height=400,
type="numpy",
# width=400,
)
],
outputs=[
gr.Image(
label="Model Prediction",
image_mode="L",
# height=400,
# width=400,
)
],
# examples=[
# os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
# os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
# os.path.join(os.path.dirname(__file__), "images/logo.png"),
# os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
# ],
)
demo.launch()
|