Image reshape
Browse files
app.py
CHANGED
@@ -11,7 +11,8 @@ from misc import load_config
|
|
11 |
from torchvision import transforms as T
|
12 |
|
13 |
import gradio as gr
|
14 |
-
|
|
|
15 |
NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
16 |
CACHE = True
|
17 |
|
@@ -48,7 +49,14 @@ def predict(img_input):
|
|
48 |
img_pil = Image.open(img_input)
|
49 |
img = img_pil.convert("RGB")
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
img_t = t(img)[None,:,:,:]
|
53 |
inputs = img_t
|
54 |
|
|
|
11 |
from torchvision import transforms as T
|
12 |
|
13 |
import gradio as gr
|
14 |
+
|
15 |
+
MAX_IM_SIZE = 512
|
16 |
NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
17 |
CACHE = True
|
18 |
|
|
|
49 |
img_pil = Image.open(img_input)
|
50 |
img = img_pil.convert("RGB")
|
51 |
|
52 |
+
# Image transformations
|
53 |
+
transforms = [T.ToTensor()]
|
54 |
+
# Resize image if needed
|
55 |
+
if img.size[0] > MAX_IM_SIZE or img.size[1] > MAX_IM_SIZE:
|
56 |
+
transforms.append(T.Resize(max_size=MAX_IM_SIZE))
|
57 |
+
transforms.append(NORMALIZE)
|
58 |
+
t = T.Compose(transforms)
|
59 |
+
|
60 |
img_t = t(img)[None,:,:,:]
|
61 |
inputs = img_t
|
62 |
|