osimeoni commited on
Commit
3317dd8
·
1 Parent(s): 5bf6d36

Image reshape

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -11,7 +11,8 @@ from misc import load_config
11
  from torchvision import transforms as T
12
 
13
  import gradio as gr
14
-
 
15
  NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
16
  CACHE = True
17
 
@@ -48,7 +49,14 @@ def predict(img_input):
48
  img_pil = Image.open(img_input)
49
  img = img_pil.convert("RGB")
50
 
51
- t = T.Compose([T.ToTensor(), NORMALIZE])
 
 
 
 
 
 
 
52
  img_t = t(img)[None,:,:,:]
53
  inputs = img_t
54
 
 
11
  from torchvision import transforms as T
12
 
13
  import gradio as gr
14
+
15
+ MAX_IM_SIZE = 512
16
  NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
17
  CACHE = True
18
 
 
49
  img_pil = Image.open(img_input)
50
  img = img_pil.convert("RGB")
51
 
52
+ # Image transformations
53
+ transforms = [T.ToTensor()]
54
+ # Resize image if needed
55
+ if img.size[0] > MAX_IM_SIZE or img.size[1] > MAX_IM_SIZE:
56
+ transforms.append(T.Resize(max_size=MAX_IM_SIZE))
57
+ transforms.append(NORMALIZE)
58
+ t = T.Compose(transforms)
59
+
60
  img_t = t(img)[None,:,:,:]
61
  inputs = img_t
62