File size: 4,641 Bytes
429fe88
ae73f85
429fe88
 
 
 
ae73f85
 
 
 
 
 
 
 
429fe88
 
 
 
de17834
429fe88
ae73f85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429fe88
ae73f85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429fe88
 
ae73f85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429fe88
 
 
 
 
 
ae73f85
429fe88
 
 
 
 
 
 
 
 
 
 
 
 
 
ae73f85
429fe88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# app.py
import os, glob, traceback
import numpy as np
from PIL import Image
import gradio as gr
import tensorflow as tf

# try to use Keras 3 loader if present (many models saved with Keras 3 need this)
KERAS3_AVAILABLE = False
try:
    import keras  # pip package "keras" (v3.x)
    KERAS3_AVAILABLE = int(keras.__version__.split(".")[0]) >= 3
except Exception:
    keras = None

HF_MODEL_ID = "Vedag812/xray_cnn"
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]


def load_model():
    from huggingface_hub import hf_hub_download
    model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras")
    # Keras 3 path
    if KERAS3_AVAILABLE:
        os.environ.setdefault("KERAS_BACKEND", "tensorflow")
        try:
            return keras.saving.load_model(model_path, compile=False, safe_mode=False)
        except Exception:
            # fall back to tf.keras if that fails for any reason
            pass
    # tf.keras path
    return tf.keras.models.load_model(model_path, compile=False)

def _infer_input_shape(model):
    """returns (H, W, C) with integers if available, else defaults to (150,150,1)"""
    shape = None
    try:
        shape = tuple(model.inputs[0].shape.as_list())  # works on many TF models
    except Exception:
        try:
            shape = tuple(model.input_shape)
        except Exception:
            pass
    if not shape or len(shape) < 4:
        return 150, 150, 1
    H = int(shape[1]) if shape[1] else 150
    W = int(shape[2]) if shape[2] else 150
    C = int(shape[3]) if shape[3] else 1
    return H, W, C

def preprocess(pil_img: Image.Image, target_hw_c):
    H, W, C = target_hw_c
    # always start from grayscale so intensity stays consistent
    g = pil_img.convert("L").resize((W, H))
    g_arr = np.array(g).astype("float32") / 255.0  # (H,W)
    if C == 1:
        x = np.expand_dims(g_arr, axis=(0, -1))       # (1,H,W,1)
    elif C == 3:
        x3 = np.stack([g_arr, g_arr, g_arr], axis=-1) # (H,W,3)
        x = np.expand_dims(x3, axis=0)                # (1,H,W,3)
    else:
        # unexpected channel count. tile to that count safely
        xC = np.repeat(g_arr[..., None], C, axis=-1)
        x = np.expand_dims(xC, axis=0)
    return x

def predict_fn(pil_img: Image.Image):
    try:
        model = load_model()
        H, W, C = _infer_input_shape(model)
        x = preprocess(pil_img, (H, W, C))
        preds = model.predict(x, verbose=0)
        # handle models that output shape (1,1) or (1,)
        prob = float(preds.ravel()[0])
        pred_idx = int(prob > 0.5)
        confidence = prob if pred_idx == 1 else 1 - prob
        probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob}
        msg = f"Prediction: {CLASS_NAMES[pred_idx]} | Confidence: {confidence*100:.2f}%"
        return probs, msg
    except Exception as e:
        # show a readable error with a tip
        tip = (
            "Tip: if this keeps happening, the Space may need keras>=3 to load a model "
            "saved with newer Keras. I handled both paths here, but if your model was saved "
            "with a very new version, updating the Space deps can help."
        )
        err_text = "⚠️ Error during prediction:\n\n" + str(e) + "\n\n" + tip
        return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err_text

def list_examples():
    files = []
    for pattern in ["images/*.jpeg", "images/*.jpg", "images/*.png"]:
        files.extend(glob.glob(pattern))
    files = sorted(files)
    return [[p] for p in files]

with gr.Blocks(css="""
.gradio-container {max-width: 980px !important; margin: auto;}
#title {text-align:center;}
.card {border:1px solid #e5e7eb; border-radius:16px; padding:16px;}
""") as demo:
    gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>")
    gr.Markdown("Upload an image or click a sample from the gallery. The model predicts NORMAL or PNEUMONIA.")

    with gr.Row():
        with gr.Column(scale=2):
            inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray")
            with gr.Row():
                btn = gr.Button("Predict", variant="primary")
                gr.ClearButton(components=[inp], value="Clear")
            gr.Markdown("### Samples")
            gr.Examples(
                examples=list_examples(),
                inputs=inp,
                examples_per_page=12,
            )
        with gr.Column(scale=1):
            probs = gr.Label(num_top_classes=2, label="Class probabilities")
            out_text = gr.Markdown()

    btn.click(predict_fn, inputs=inp, outputs=[probs, out_text])
    inp.change(predict_fn, inputs=inp, outputs=[probs, out_text])

if __name__ == "__main__":
    demo.launch()