File size: 1,707 Bytes
15216b5
68ba513
 
58b19a1
 
 
 
 
15216b5
58b19a1
 
 
 
 
 
 
 
 
 
b38da9b
58b19a1
 
f3dfeae
58b19a1
 
49684df
f3dfeae
 
58b19a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15216b5
 
 
58b19a1
 
f3dfeae
 
 
 
 
 
 
58b19a1
 
f3dfeae
 
 
 
 
 
58b19a1
f3dfeae
 
 
 
 
 
 
15216b5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
import monai
import torch
from monai.networks.nets import UNet
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np


model = UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=1,
    channels=[16, 32, 64, 128, 256, 512],
    strides=(2, 2, 2, 2, 2),
    num_res_units=4,
    dropout=0.15,
)
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
model.eval()

def greet(image):
    

    # image = Image.open(image_path).convert("RGB")
    # image = np.array(image) / 255.0
    image = image / 
    image = image.astype(np.float32)

    inference_transforms = A.Compose([
        A.Resize(height=512, width=512),
        ToTensorV2(),
    ])

    image = inference_transforms(image=image)["image"]

    image = image.unsqueeze(0)


    with torch.no_grad():
        mask_pred = model(image)
        
    return mask_pred[0]
    

demo = gr.Interface(
    fn=greet,
    title="Histapathology segmentation",
    inputs=[
        gr.Image(
            label="Input image",
            image_mode="RGB",
            height=400,
            type="numpy",
            witdh=400,
        )
    ],
    outputs=[
        gr.Image(
            label="Model Prediction",
            image_mode="RGB",
            height=400,
            witdh=400,
        )
    ],
    # examples=[
    #     os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
    #     os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
    #     os.path.join(os.path.dirname(__file__), "images/logo.png"),
    #     os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
    # ],

)

demo.launch()