|
|
|
import os, glob, traceback |
|
import numpy as np |
|
from PIL import Image |
|
import gradio as gr |
|
import tensorflow as tf |
|
|
|
|
|
KERAS3_AVAILABLE = False |
|
try: |
|
import keras |
|
KERAS3_AVAILABLE = int(keras.__version__.split(".")[0]) >= 3 |
|
except Exception: |
|
keras = None |
|
|
|
HF_MODEL_ID = "Vedag812/xray_cnn" |
|
CLASS_NAMES = ["NORMAL", "PNEUMONIA"] |
|
|
|
|
|
def load_model(): |
|
from huggingface_hub import hf_hub_download |
|
model_path = hf_hub_download(HF_MODEL_ID, filename="xray_cnn.keras") |
|
|
|
if KERAS3_AVAILABLE: |
|
os.environ.setdefault("KERAS_BACKEND", "tensorflow") |
|
try: |
|
return keras.saving.load_model(model_path, compile=False, safe_mode=False) |
|
except Exception: |
|
|
|
pass |
|
|
|
return tf.keras.models.load_model(model_path, compile=False) |
|
|
|
def _infer_input_shape(model): |
|
"""returns (H, W, C) with integers if available, else defaults to (150,150,1)""" |
|
shape = None |
|
try: |
|
shape = tuple(model.inputs[0].shape.as_list()) |
|
except Exception: |
|
try: |
|
shape = tuple(model.input_shape) |
|
except Exception: |
|
pass |
|
if not shape or len(shape) < 4: |
|
return 150, 150, 1 |
|
H = int(shape[1]) if shape[1] else 150 |
|
W = int(shape[2]) if shape[2] else 150 |
|
C = int(shape[3]) if shape[3] else 1 |
|
return H, W, C |
|
|
|
def preprocess(pil_img: Image.Image, target_hw_c): |
|
H, W, C = target_hw_c |
|
|
|
g = pil_img.convert("L").resize((W, H)) |
|
g_arr = np.array(g).astype("float32") / 255.0 |
|
if C == 1: |
|
x = np.expand_dims(g_arr, axis=(0, -1)) |
|
elif C == 3: |
|
x3 = np.stack([g_arr, g_arr, g_arr], axis=-1) |
|
x = np.expand_dims(x3, axis=0) |
|
else: |
|
|
|
xC = np.repeat(g_arr[..., None], C, axis=-1) |
|
x = np.expand_dims(xC, axis=0) |
|
return x |
|
|
|
def predict_fn(pil_img: Image.Image): |
|
try: |
|
model = load_model() |
|
H, W, C = _infer_input_shape(model) |
|
x = preprocess(pil_img, (H, W, C)) |
|
preds = model.predict(x, verbose=0) |
|
|
|
prob = float(preds.ravel()[0]) |
|
pred_idx = int(prob > 0.5) |
|
confidence = prob if pred_idx == 1 else 1 - prob |
|
probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob} |
|
msg = f"Prediction: {CLASS_NAMES[pred_idx]} | Confidence: {confidence*100:.2f}%" |
|
return probs, msg |
|
except Exception as e: |
|
|
|
tip = ( |
|
"Tip: if this keeps happening, the Space may need keras>=3 to load a model " |
|
"saved with newer Keras. I handled both paths here, but if your model was saved " |
|
"with a very new version, updating the Space deps can help." |
|
) |
|
err_text = "⚠️ Error during prediction:\n\n" + str(e) + "\n\n" + tip |
|
return {"NORMAL": 0.0, "PNEUMONIA": 0.0}, err_text |
|
|
|
def list_examples(): |
|
files = [] |
|
for pattern in ["images/*.jpeg", "images/*.jpg", "images/*.png"]: |
|
files.extend(glob.glob(pattern)) |
|
files = sorted(files) |
|
return [[p] for p in files] |
|
|
|
with gr.Blocks(css=""" |
|
.gradio-container {max-width: 980px !important; margin: auto;} |
|
#title {text-align:center;} |
|
.card {border:1px solid #e5e7eb; border-radius:16px; padding:16px;} |
|
""") as demo: |
|
gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>") |
|
gr.Markdown("Upload an image or click a sample from the gallery. The model predicts NORMAL or PNEUMONIA.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray") |
|
with gr.Row(): |
|
btn = gr.Button("Predict", variant="primary") |
|
gr.ClearButton(components=[inp], value="Clear") |
|
gr.Markdown("### Samples") |
|
gr.Examples( |
|
examples=list_examples(), |
|
inputs=inp, |
|
examples_per_page=12, |
|
) |
|
with gr.Column(scale=1): |
|
probs = gr.Label(num_top_classes=2, label="Class probabilities") |
|
out_text = gr.Markdown() |
|
|
|
btn.click(predict_fn, inputs=inp, outputs=[probs, out_text]) |
|
inp.change(predict_fn, inputs=inp, outputs=[probs, out_text]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|