Update app.py (#3)
Browse files- Update app.py (fe73a0757085ed388a81afce43206377986ef5c5)
Co-authored-by: MUHAMMAD IBRAHIM <[email protected]>
app.py
CHANGED
@@ -1,43 +1,98 @@
|
|
1 |
# app.py
|
2 |
-
import os, glob
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
import gradio as gr
|
6 |
import tensorflow as tf
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
HF_MODEL_ID = "Vedag812/xray_cnn"
|
11 |
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
|
12 |
|
13 |
-
@
|
14 |
def load_model():
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
def preprocess(pil_img: Image.Image):
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def predict_fn(pil_img: Image.Image):
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
def list_examples():
|
36 |
files = []
|
37 |
for pattern in ["images/*.jpeg", "images/*.jpg", "images/*.png"]:
|
38 |
files.extend(glob.glob(pattern))
|
39 |
files = sorted(files)
|
40 |
-
return [[p] for p in files]
|
41 |
|
42 |
with gr.Blocks(css="""
|
43 |
.gradio-container {max-width: 980px !important; margin: auto;}
|
@@ -52,7 +107,7 @@ with gr.Blocks(css="""
|
|
52 |
inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray")
|
53 |
with gr.Row():
|
54 |
btn = gr.Button("Predict", variant="primary")
|
55 |
-
|
56 |
gr.Markdown("### Samples")
|
57 |
gr.Examples(
|
58 |
examples=list_examples(),
|
@@ -63,9 +118,7 @@ with gr.Blocks(css="""
|
|
63 |
probs = gr.Label(num_top_classes=2, label="Class probabilities")
|
64 |
out_text = gr.Markdown()
|
65 |
|
66 |
-
# Run on click
|
67 |
btn.click(predict_fn, inputs=inp, outputs=[probs, out_text])
|
68 |
-
# Also auto-run when image changes (from upload or example click)
|
69 |
inp.change(predict_fn, inputs=inp, outputs=[probs, out_text])
|
70 |
|
71 |
if __name__ == "__main__":
|
|
|
1 |
# app.py
|
2 |
+
import os, glob, traceback
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
import gradio as gr
|
6 |
import tensorflow as tf
|
7 |
+
|
8 |
+
# try to use Keras 3 loader if present (many models saved with Keras 3 need this)
|
9 |
+
KERAS3_AVAILABLE = False
|
10 |
+
try:
|
11 |
+
import keras # pip package "keras" (v3.x)
|
12 |
+
KERAS3_AVAILABLE = int(keras.__version__.split(".")[0]) >= 3
|
13 |
+
except Exception:
|
14 |
+
keras = None
|
15 |
|
16 |
HF_MODEL_ID = "Vedag812/xray_cnn"
|
17 |
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
|
18 |
|
19 |
+
@gr.cache_resource
|
20 |
def load_model():
|
21 |
+
from huggingface_hub import hf_hub_download
|
22 |
+
model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras")
|
23 |
+
# Keras 3 path
|
24 |
+
if KERAS3_AVAILABLE:
|
25 |
+
os.environ.setdefault("KERAS_BACKEND", "tensorflow")
|
26 |
+
try:
|
27 |
+
return keras.saving.load_model(model_path, compile=False, safe_mode=False)
|
28 |
+
except Exception:
|
29 |
+
# fall back to tf.keras if that fails for any reason
|
30 |
+
pass
|
31 |
+
# tf.keras path
|
32 |
+
return tf.keras.models.load_model(model_path, compile=False)
|
33 |
+
|
34 |
+
def _infer_input_shape(model):
|
35 |
+
"""returns (H, W, C) with integers if available, else defaults to (150,150,1)"""
|
36 |
+
shape = None
|
37 |
+
try:
|
38 |
+
shape = tuple(model.inputs[0].shape.as_list()) # works on many TF models
|
39 |
+
except Exception:
|
40 |
+
try:
|
41 |
+
shape = tuple(model.input_shape)
|
42 |
+
except Exception:
|
43 |
+
pass
|
44 |
+
if not shape or len(shape) < 4:
|
45 |
+
return 150, 150, 1
|
46 |
+
H = int(shape[1]) if shape[1] else 150
|
47 |
+
W = int(shape[2]) if shape[2] else 150
|
48 |
+
C = int(shape[3]) if shape[3] else 1
|
49 |
+
return H, W, C
|
50 |
|
51 |
+
def preprocess(pil_img: Image.Image, target_hw_c):
|
52 |
+
H, W, C = target_hw_c
|
53 |
+
# always start from grayscale so intensity stays consistent
|
54 |
+
g = pil_img.convert("L").resize((W, H))
|
55 |
+
g_arr = np.array(g).astype("float32") / 255.0 # (H,W)
|
56 |
+
if C == 1:
|
57 |
+
x = np.expand_dims(g_arr, axis=(0, -1)) # (1,H,W,1)
|
58 |
+
elif C == 3:
|
59 |
+
x3 = np.stack([g_arr, g_arr, g_arr], axis=-1) # (H,W,3)
|
60 |
+
x = np.expand_dims(x3, axis=0) # (1,H,W,3)
|
61 |
+
else:
|
62 |
+
# unexpected channel count. tile to that count safely
|
63 |
+
xC = np.repeat(g_arr[..., None], C, axis=-1)
|
64 |
+
x = np.expand_dims(xC, axis=0)
|
65 |
+
return x
|
66 |
|
67 |
def predict_fn(pil_img: Image.Image):
|
68 |
+
try:
|
69 |
+
model = load_model()
|
70 |
+
H, W, C = _infer_input_shape(model)
|
71 |
+
x = preprocess(pil_img, (H, W, C))
|
72 |
+
preds = model.predict(x, verbose=0)
|
73 |
+
# handle models that output shape (1,1) or (1,)
|
74 |
+
prob = float(preds.ravel()[0])
|
75 |
+
pred_idx = int(prob > 0.5)
|
76 |
+
confidence = prob if pred_idx == 1 else 1 - prob
|
77 |
+
probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob}
|
78 |
+
msg = f"Prediction: {CLASS_NAMES[pred_idx]} | Confidence: {confidence*100:.2f}%"
|
79 |
+
return probs, msg
|
80 |
+
except Exception as e:
|
81 |
+
# show a readable error with a tip
|
82 |
+
tip = (
|
83 |
+
"Tip: if this keeps happening, the Space may need keras>=3 to load a model "
|
84 |
+
"saved with newer Keras. I handled both paths here, but if your model was saved "
|
85 |
+
"with a very new version, updating the Space deps can help."
|
86 |
+
)
|
87 |
+
err_text = "⚠️ Error during prediction:\n\n" + str(e) + "\n\n" + tip
|
88 |
+
return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err_text
|
89 |
|
90 |
def list_examples():
|
91 |
files = []
|
92 |
for pattern in ["images/*.jpeg", "images/*.jpg", "images/*.png"]:
|
93 |
files.extend(glob.glob(pattern))
|
94 |
files = sorted(files)
|
95 |
+
return [[p] for p in files]
|
96 |
|
97 |
with gr.Blocks(css="""
|
98 |
.gradio-container {max-width: 980px !important; margin: auto;}
|
|
|
107 |
inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray")
|
108 |
with gr.Row():
|
109 |
btn = gr.Button("Predict", variant="primary")
|
110 |
+
gr.ClearButton(components=[inp], value="Clear")
|
111 |
gr.Markdown("### Samples")
|
112 |
gr.Examples(
|
113 |
examples=list_examples(),
|
|
|
118 |
probs = gr.Label(num_top_classes=2, label="Class probabilities")
|
119 |
out_text = gr.Markdown()
|
120 |
|
|
|
121 |
btn.click(predict_fn, inputs=inp, outputs=[probs, out_text])
|
|
|
122 |
inp.change(predict_fn, inputs=inp, outputs=[probs, out_text])
|
123 |
|
124 |
if __name__ == "__main__":
|