Vizuara commited on
Commit
354cddb
·
verified ·
1 Parent(s): 5301c86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -25,12 +25,18 @@ transform = transforms.Compose([
25
  ])
26
 
27
  def predict(image):
28
- img = transform(image).unsqueeze(0) # shape: (1,3,128,128)
 
29
  with torch.no_grad():
30
  pred = model(img)
 
31
  mask = pred.squeeze(0).squeeze(0).cpu().numpy()
32
- mask = (mask * 255).astype(np.uint8) # keep gray levels
33
- return Image.fromarray(mask)
 
 
 
 
34
 
35
  # Gradio interface
36
  demo = gr.Interface(
 
25
  ])
26
 
27
  def predict(image):
28
+ orig_w, orig_h = image.size # original size of uploaded image
29
+ img = transform(image).unsqueeze(0) # (1,3,128,128)
30
  with torch.no_grad():
31
  pred = model(img)
32
+
33
  mask = pred.squeeze(0).squeeze(0).cpu().numpy()
34
+ mask = (mask * 255).astype(np.uint8) # grayscale mask
35
+
36
+ # Resize back to original size
37
+ mask_img = Image.fromarray(mask).resize((orig_w, orig_h), Image.NEAREST)
38
+ return mask_img
39
+
40
 
41
  # Gradio interface
42
  demo = gr.Interface(