import gradio as gr import monai import torch from monai.networks.nets import UNet from PIL import Image import albumentations as A from albumentations.pytorch import ToTensorV2 import numpy as np model = UNet( spatial_dims=2, in_channels=3, out_channels=1, channels=[16, 32, 64, 128, 256, 512], strides=(2, 2, 2, 2, 2), num_res_units=4, dropout=0.15, ) model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu'))) model.eval() def greet(image): # image = Image.open(image_path).convert("RGB") # image = np.array(image) / 255.0 image = image / image = image.astype(np.float32) inference_transforms = A.Compose([ A.Resize(height=512, width=512), ToTensorV2(), ]) image = inference_transforms(image=image)["image"] image = image.unsqueeze(0) with torch.no_grad(): mask_pred = model(image) return mask_pred[0] demo = gr.Interface( fn=greet, title="Histapathology segmentation", inputs=[ gr.Image( label="Input image", image_mode="RGB", height=400, type="numpy", witdh=400, ) ], outputs=[ gr.Image( label="Model Prediction", image_mode="RGB", height=400, witdh=400, ) ], # examples=[ # os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"), # os.path.join(os.path.dirname(__file__), "images/lion.jpg"), # os.path.join(os.path.dirname(__file__), "images/logo.png"), # os.path.join(os.path.dirname(__file__), "images/tower.jpg"), # ], ) demo.launch()