|
import os |
|
import gradio as gr |
|
import numpy as np |
|
from tensorflow.keras.preprocessing.image import img_to_array |
|
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input |
|
from tensorflow import keras |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|
|
|
|
|
class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'} |
|
|
|
|
|
def load(): |
|
|
|
model_path = hf_hub_download(repo_id="D963934/bact_model", filename="bct_model.keras") |
|
|
|
|
|
model = keras.models.load_model(model_path) |
|
|
|
|
|
return model |
|
|
|
|
|
def predict(image): |
|
if image is None: |
|
print("Error: Image not loaded. Check the file path.") |
|
else: |
|
|
|
image = image.resize((224, 224)) |
|
image = img_to_array(image) |
|
image = np.expand_dims(image, axis=0) |
|
image = preprocess_input(image) |
|
|
|
|
|
model = load() |
|
|
|
|
|
predictions = model.predict(image) |
|
predicted_class_idx = predictions.argmax() |
|
confidence_score = predictions[0][predicted_class_idx] |
|
|
|
|
|
|
|
predicted_class = class_labels[predicted_class_idx] |
|
|
|
|
|
|
|
return f"Predicted Class: {predicted_class} (Confidence: {confidence_score:.2f})" |
|
|
|
|
|
|
|
examples = [ |
|
["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%203001.jpg"], |
|
["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%205000.jpg"], |
|
["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%20137.jpg"], |
|
["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%20352.jpg"], |
|
["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%20458.jpg"], |
|
] |
|
|
|
iface = gr.Interface(fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Text(label="Prediction Results:"), |
|
title="Bacterial Identification", |
|
description="Upload an image of bacteria to identify - cocci, bacilli, or spirilla", |
|
examples=examples, |
|
allow_flagging="never", |
|
live=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|