Spaces:
Sleeping
Sleeping
import gradio as gr | |
import monai | |
import torch | |
from monai.networks.nets import UNet | |
from PIL import Image | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
import numpy as np | |
model = UNet( | |
spatial_dims=2, | |
in_channels=3, | |
out_channels=1, | |
channels=[16, 32, 64, 128, 256, 512], | |
strides=(2, 2, 2, 2, 2), | |
num_res_units=4, | |
dropout=0.15, | |
) | |
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu'))) | |
model.eval() | |
def greet(image): | |
# image = Image.open(image_path).convert("RGB") | |
# image = np.array(image) / 255.0 | |
image = image / 255.0 | |
image = image.astype(np.float32) | |
inference_transforms = A.Compose([ | |
A.Resize(height=512, width=512), | |
ToTensorV2(), | |
]) | |
image = inference_transforms(image=image)["image"] | |
image = image.unsqueeze(0) | |
with torch.no_grad(): | |
mask_pred = model(image) | |
return mask_pred[0].numpy() | |
demo = gr.Interface( | |
fn=greet, | |
title="Histapathology segmentation", | |
inputs=[ | |
gr.Image( | |
label="Input image", | |
image_mode="RGB", | |
# height=400, | |
type="numpy", | |
# width=400, | |
) | |
], | |
outputs=[ | |
gr.Image( | |
label="Model Prediction", | |
image_mode="L", | |
# height=400, | |
# width=400, | |
) | |
], | |
# examples=[ | |
# os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"), | |
# os.path.join(os.path.dirname(__file__), "images/lion.jpg"), | |
# os.path.join(os.path.dirname(__file__), "images/logo.png"), | |
# os.path.join(os.path.dirname(__file__), "images/tower.jpg"), | |
# ], | |
) | |
demo.launch() | |