model-inference / app.py
osbm's picture
Update app.py
58b19a1 verified
raw
history blame
1.12 kB
import gradio as gr
import monai
import torch
from monai.networks.nets import UNet
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
model = UNet(
spatial_dims=2,
in_channels=3,
out_channels=1,
channels=[16, 32, 64, 128, 256, 512],
strides=(2, 2, 2, 2, 2),
num_res_units=4,
dropout=0.15,
)
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
def greet(image_path):
image = Image.open(image_path).convert("RGB")
image = np.array(image) / 255.0
image = image.astype(np.float32)
inference_transforms = A.Compose([
A.Resize(height=512, width=512),
ToTensorV2(),
])
image = inference_transforms(image=image)["image"]
image = image.unsqueeze(0)
with torch.no_grad():
mask_pred = model(image)
return mask_pred[0]
demo = gr.Interface(
fn=greet,
title="Histapathology segmentation",
inputs=[
gr.File(label="Input image (512x512)")
],
outputs=[
gr.File(label="Model Prediction")
],
)
demo.launch()