File size: 2,599 Bytes
429fe88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# app.py
import os, glob
import numpy as np
from PIL import Image
import gradio as gr
import tensorflow as tf
from functools import lru_cache
from huggingface_hub import hf_hub_download

HF_MODEL_ID = "Vedag812/xray_cnn"
CLASS_NAMES = ["NORMAL", "PNEUMONIA"]

@lru_cache(maxsize=1)
def load_model():
    model_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="xray_cnn.keras")
    model = tf.keras.models.load_model(model_path, compile=False)
    return model

def preprocess(pil_img: Image.Image):
    img = pil_img.convert("L").resize((150, 150))
    arr = np.array(img).astype("float32") / 255.0
    arr = np.expand_dims(arr, axis=(0, -1))  # shape (1,150,150,1)
    return arr

def predict_fn(pil_img: Image.Image):
    model = load_model()
    x = preprocess(pil_img)
    prob = float(model.predict(x, verbose=0)[0][0])  # sigmoid
    pred_idx = int(prob > 0.5)
    confidence = prob if pred_idx == 1 else 1 - prob
    probs = {CLASS_NAMES[0]: 1 - prob, CLASS_NAMES[1]: prob}
    msg = f"Prediction: {CLASS_NAMES[pred_idx]} | Confidence: {confidence*100:.2f}%"
    return probs, msg

def list_examples():
    files = []
    for pattern in ["images/*.jpeg", "images/*.jpg", "images/*.png"]:
        files.extend(glob.glob(pattern))
    files = sorted(files)
    return [[p] for p in files]  # gr.Examples expects list of [path]

with gr.Blocks(css="""
.gradio-container {max-width: 980px !important; margin: auto;}
#title {text-align:center;}
.card {border:1px solid #e5e7eb; border-radius:16px; padding:16px;}
""") as demo:
    gr.Markdown("<h1 id='title'>Chest X-Ray Classification</h1>")
    gr.Markdown("Upload an image or click a sample from the gallery. The model predicts NORMAL or PNEUMONIA.")

    with gr.Row():
        with gr.Column(scale=2):
            inp = gr.Image(type="pil", image_mode="L", label="Upload X-ray")
            with gr.Row():
                btn = gr.Button("Predict", variant="primary")
                clr = gr.ClearButton(components=[inp], value="Clear")
            gr.Markdown("### Samples")
            gr.Examples(
                examples=list_examples(),
                inputs=inp,
                examples_per_page=12,
            )
        with gr.Column(scale=1):
            probs = gr.Label(num_top_classes=2, label="Class probabilities")
            out_text = gr.Markdown()

    # Run on click
    btn.click(predict_fn, inputs=inp, outputs=[probs, out_text])
    # Also auto-run when image changes (from upload or example click)
    inp.change(predict_fn, inputs=inp, outputs=[probs, out_text])

if __name__ == "__main__":
    demo.launch()