Vedag812 ibrahim313 commited on
Commit
ae73f85
·
verified ·
1 Parent(s): 7c76a73

Update app.py (#3)

Browse files

- Update app.py (fe73a0757085ed388a81afce43206377986ef5c5)


Co-authored-by: MUHAMMAD IBRAHIM <[email protected]>

Files changed (1) hide show
  1. app.py +77 -24
app.py CHANGED
@@ -1,43 +1,98 @@
1
  # app.py
2
- import os, glob
3
  import numpy as np
4
  from PIL import Image
5
  import gradio as gr
6
  import tensorflow as tf
7
- from functools import lru_cache
8
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
9
 
10
  HF_MODEL_ID = "Vedag812/xray_cnn"
11
  CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
12
 
13
- @lru_cache(maxsize=1)
14
  def load_model():
15
- model_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="xray_cnn.keras")
16
- model = tf.keras.models.load_model(model_path, compile=False)
17
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- def preprocess(pil_img: Image.Image):
20
- img = pil_img.convert("L").resize((150, 150))
21
- arr = np.array(img).astype("float32") / 255.0
22
- arr = np.expand_dims(arr, axis=(0, -1)) # shape (1,150,150,1)
23
- return arr
 
 
 
 
 
 
 
 
 
 
24
 
25
  def predict_fn(pil_img: Image.Image):
26
- model = load_model()
27
- x = preprocess(pil_img)
28
- prob = float(model.predict(x, verbose=0)[0][0]) # sigmoid
29
- pred_idx = int(prob > 0.5)
30
- confidence = prob if pred_idx == 1 else 1 - prob
31
- probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob}
32
- msg = f"Prediction: {CLASS_NAMES[pred_idx]} | Confidence: {confidence*100:.2f}%"
33
- return probs, msg
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def list_examples():
36
  files = []
37
  for pattern in ["images/*.jpeg", "images/*.jpg", "images/*.png"]:
38
  files.extend(glob.glob(pattern))
39
  files = sorted(files)
40
- return [[p] for p in files] # gr.Examples expects list of [path]
41
 
42
  with gr.Blocks(css="""
43
  .gradio-container {max-width: 980px !important; margin: auto;}
@@ -52,7 +107,7 @@ with gr.Blocks(css="""
52
  inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray")
53
  with gr.Row():
54
  btn = gr.Button("Predict", variant="primary")
55
- clr = gr.ClearButton(components=[inp], value="Clear")
56
  gr.Markdown("### Samples")
57
  gr.Examples(
58
  examples=list_examples(),
@@ -63,9 +118,7 @@ with gr.Blocks(css="""
63
  probs = gr.Label(num_top_classes=2, label="Class probabilities")
64
  out_text = gr.Markdown()
65
 
66
- # Run on click
67
  btn.click(predict_fn, inputs=inp, outputs=[probs, out_text])
68
- # Also auto-run when image changes (from upload or example click)
69
  inp.change(predict_fn, inputs=inp, outputs=[probs, out_text])
70
 
71
  if __name__ == "__main__":
 
1
  # app.py
2
+ import os, glob, traceback
3
  import numpy as np
4
  from PIL import Image
5
  import gradio as gr
6
  import tensorflow as tf
7
+
8
+ # try to use Keras 3 loader if present (many models saved with Keras 3 need this)
9
+ KERAS3_AVAILABLE = False
10
+ try:
11
+ import keras # pip package "keras" (v3.x)
12
+ KERAS3_AVAILABLE = int(keras.__version__.split(".")[0]) >= 3
13
+ except Exception:
14
+ keras = None
15
 
16
  HF_MODEL_ID = "Vedag812/xray_cnn"
17
  CLASS_NAMES = ["NORMAL", "PNEUMONIA"]
18
 
19
+ @gr.cache_resource
20
  def load_model():
21
+ from huggingface_hub import hf_hub_download
22
+ model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras")
23
+ # Keras 3 path
24
+ if KERAS3_AVAILABLE:
25
+ os.environ.setdefault("KERAS_BACKEND", "tensorflow")
26
+ try:
27
+ return keras.saving.load_model(model_path, compile=False, safe_mode=False)
28
+ except Exception:
29
+ # fall back to tf.keras if that fails for any reason
30
+ pass
31
+ # tf.keras path
32
+ return tf.keras.models.load_model(model_path, compile=False)
33
+
34
+ def _infer_input_shape(model):
35
+ """returns (H, W, C) with integers if available, else defaults to (150,150,1)"""
36
+ shape = None
37
+ try:
38
+ shape = tuple(model.inputs[0].shape.as_list()) # works on many TF models
39
+ except Exception:
40
+ try:
41
+ shape = tuple(model.input_shape)
42
+ except Exception:
43
+ pass
44
+ if not shape or len(shape) < 4:
45
+ return 150, 150, 1
46
+ H = int(shape[1]) if shape[1] else 150
47
+ W = int(shape[2]) if shape[2] else 150
48
+ C = int(shape[3]) if shape[3] else 1
49
+ return H, W, C
50
 
51
+ def preprocess(pil_img: Image.Image, target_hw_c):
52
+ H, W, C = target_hw_c
53
+ # always start from grayscale so intensity stays consistent
54
+ g = pil_img.convert("L").resize((W, H))
55
+ g_arr = np.array(g).astype("float32") / 255.0 # (H,W)
56
+ if C == 1:
57
+ x = np.expand_dims(g_arr, axis=(0, -1)) # (1,H,W,1)
58
+ elif C == 3:
59
+ x3 = np.stack([g_arr, g_arr, g_arr], axis=-1) # (H,W,3)
60
+ x = np.expand_dims(x3, axis=0) # (1,H,W,3)
61
+ else:
62
+ # unexpected channel count. tile to that count safely
63
+ xC = np.repeat(g_arr[..., None], C, axis=-1)
64
+ x = np.expand_dims(xC, axis=0)
65
+ return x
66
 
67
  def predict_fn(pil_img: Image.Image):
68
+ try:
69
+ model = load_model()
70
+ H, W, C = _infer_input_shape(model)
71
+ x = preprocess(pil_img, (H, W, C))
72
+ preds = model.predict(x, verbose=0)
73
+ # handle models that output shape (1,1) or (1,)
74
+ prob = float(preds.ravel()[0])
75
+ pred_idx = int(prob > 0.5)
76
+ confidence = prob if pred_idx == 1 else 1 - prob
77
+ probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob}
78
+ msg = f"Prediction: {CLASS_NAMES[pred_idx]} | Confidence: {confidence*100:.2f}%"
79
+ return probs, msg
80
+ except Exception as e:
81
+ # show a readable error with a tip
82
+ tip = (
83
+ "Tip: if this keeps happening, the Space may need keras>=3 to load a model "
84
+ "saved with newer Keras. I handled both paths here, but if your model was saved "
85
+ "with a very new version, updating the Space deps can help."
86
+ )
87
+ err_text = "⚠️ Error during prediction:\n\n" + str(e) + "\n\n" + tip
88
+ return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err_text
89
 
90
  def list_examples():
91
  files = []
92
  for pattern in ["images/*.jpeg", "images/*.jpg", "images/*.png"]:
93
  files.extend(glob.glob(pattern))
94
  files = sorted(files)
95
+ return [[p] for p in files]
96
 
97
  with gr.Blocks(css="""
98
  .gradio-container {max-width: 980px !important; margin: auto;}
 
107
  inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray")
108
  with gr.Row():
109
  btn = gr.Button("Predict", variant="primary")
110
+ gr.ClearButton(components=[inp], value="Clear")
111
  gr.Markdown("### Samples")
112
  gr.Examples(
113
  examples=list_examples(),
 
118
  probs = gr.Label(num_top_classes=2, label="Class probabilities")
119
  out_text = gr.Markdown()
120
 
 
121
  btn.click(predict_fn, inputs=inp, outputs=[probs, out_text])
 
122
  inp.change(predict_fn, inputs=inp, outputs=[probs, out_text])
123
 
124
  if __name__ == "__main__":