|
|
|
import os, glob |
|
import numpy as np |
|
from PIL import Image |
|
import gradio as gr |
|
import tensorflow as tf |
|
from functools import lru_cache |
|
from huggingface_hub import hf_hub_download |
|
|
|
HF_MODEL_ID = "Vedag812/xray_cnn" |
|
CLASS_NAMES = ["NORMAL", "PNEUMONIA"] |
|
|
|
@lru_cache(maxsize=1) |
|
def load_model(): |
|
model_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="xray_cnn.keras") |
|
model = tf.keras.models.load_model(model_path, compile=False) |
|
return model |
|
|
|
def preprocess(pil_img: Image.Image): |
|
img = pil_img.convert("L").resize((150, 150)) |
|
arr = np.array(img).astype("float32") / 255.0 |
|
arr = np.expand_dims(arr, axis=(0, -1)) |
|
return arr |
|
|
|
def predict_fn(pil_img: Image.Image): |
|
model = load_model() |
|
x = preprocess(pil_img) |
|
prob = float(model.predict(x, verbose=0)[0][0]) |
|
pred_idx = int(prob > 0.5) |
|
confidence = prob if pred_idx == 1 else 1 - prob |
|
probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob} |
|
msg = f"Prediction: {CLASS_NAMES[pred_idx]} | Confidence: {confidence*100:.2f}%" |
|
return probs, msg |
|
|
|
def list_examples(): |
|
files = [] |
|
for pattern in ["images/*.jpeg", "images/*.jpg", "images/*.png"]: |
|
files.extend(glob.glob(pattern)) |
|
files = sorted(files) |
|
return [[p] for p in files] |
|
|
|
with gr.Blocks(css=""" |
|
.gradio-container {max-width: 980px !important; margin: auto;} |
|
#title {text-align:center;} |
|
.card {border:1px solid #e5e7eb; border-radius:16px; padding:16px;} |
|
""") as demo: |
|
gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>") |
|
gr.Markdown("Upload an image or click a sample from the gallery. The model predicts NORMAL or PNEUMONIA.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray") |
|
with gr.Row(): |
|
btn = gr.Button("Predict", variant="primary") |
|
clr = gr.ClearButton(components=[inp], value="Clear") |
|
gr.Markdown("### Samples") |
|
gr.Examples( |
|
examples=list_examples(), |
|
inputs=inp, |
|
examples_per_page=12, |
|
) |
|
with gr.Column(scale=1): |
|
probs = gr.Label(num_top_classes=2, label="Class probabilities") |
|
out_text = gr.Markdown() |
|
|
|
|
|
btn.click(predict_fn, inputs=inp, outputs=[probs, out_text]) |
|
|
|
inp.change(predict_fn, inputs=inp, outputs=[probs, out_text]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|