model-inference / app.py
osbm's picture
Update app.py
65e5c64 verified
raw
history blame
1.74 kB
import gradio as gr
import monai
import torch
from monai.networks.nets import UNet
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
model = UNet(
spatial_dims=2,
in_channels=3,
out_channels=1,
channels=[16, 32, 64, 128, 256, 512],
strides=(2, 2, 2, 2, 2),
num_res_units=4,
dropout=0.15,
)
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
model.eval()
def greet(image):
# image = Image.open(image_path).convert("RGB")
# image = np.array(image) / 255.0
image = image / 255.0
image = image.astype(np.float32)
inference_transforms = A.Compose([
A.Resize(height=512, width=512),
ToTensorV2(),
])
image = inference_transforms(image=image)["image"]
image = image.unsqueeze(0)
with torch.no_grad():
mask_pred = torch.sigmoid(model(image))
return mask_pred[0].numpy()
demo = gr.Interface(
fn=greet,
title="Histapathology segmentation",
inputs=[
gr.Image(
label="Input image",
image_mode="RGB",
# height=400,
type="numpy",
# width=400,
)
],
outputs=[
gr.Image(
label="Model Prediction",
image_mode="L",
# height=400,
# width=400,
)
],
# examples=[
# os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
# os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
# os.path.join(os.path.dirname(__file__), "images/logo.png"),
# os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
# ],
)
demo.launch()