osbm commited on
Commit
58b19a1
·
verified ·
1 Parent(s): 8bf1548

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -5
app.py CHANGED
@@ -1,15 +1,57 @@
1
  import gradio as gr
2
  import monai
3
  import torch
 
 
 
 
 
4
 
5
- def greet(name, intensity):
6
- return "Hello "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  demo = gr.Interface(
9
  fn=greet,
10
- title="test-title",
11
- inputs=["file"],
12
- outputs=["text"],
 
 
 
 
13
  )
14
 
15
  demo.launch()
 
1
  import gradio as gr
2
  import monai
3
  import torch
4
+ from monai.networks.nets import UNet
5
+ from PIL import Image
6
+ import albumentations as A
7
+ from albumentations.pytorch import ToTensorV2
8
+ import numpy as np
9
 
10
+
11
+ model = UNet(
12
+ spatial_dims=2,
13
+ in_channels=3,
14
+ out_channels=1,
15
+ channels=[16, 32, 64, 128, 256, 512],
16
+ strides=(2, 2, 2, 2, 2),
17
+ num_res_units=4,
18
+ dropout=0.15,
19
+ )
20
+ model.load_state_dict(torch.load("best_model.pth"))
21
+ model.eval()
22
+
23
+ def greet(image_path):
24
+
25
+
26
+ image = Image.open(image_path).convert("RGB")
27
+ image = np.array(image) / 255.0
28
+ image = image.astype(np.float32)
29
+
30
+ inference_transforms = A.Compose([
31
+ A.Resize(height=512, width=512),
32
+ ToTensorV2(),
33
+ ])
34
+
35
+ image = inference_transforms(image=image)["image"]
36
+
37
+ image = image.unsqueeze(0)
38
+
39
+
40
+ with torch.no_grad():
41
+ mask_pred = model(image)
42
+
43
+ return mask_pred[0]
44
+
45
 
46
  demo = gr.Interface(
47
  fn=greet,
48
+ title="Histapathology segmentation",
49
+ inputs=[
50
+ gr.File(label="Input image (512x512)")
51
+ ],
52
+ outputs=[
53
+ gr.File(label="Model Prediction")
54
+ ],
55
  )
56
 
57
  demo.launch()