BactSpace / app.py
D963934's picture
Update app.py
1036510 verified
import os
import gradio as gr
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tensorflow import keras
from huggingface_hub import hf_hub_download
from PIL import Image
# Disable GPU (forces CPU usage)
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Define class labels
class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'}
# Function to load and use the model
def load():
# Step 1: Download the model from Hugging Face Model Hub
model_path = hf_hub_download(repo_id="D963934/bact_model", filename="bct_model.keras")
# Step 2: Load the model using Keras
model = keras.models.load_model(model_path)
# Return the model for inference
return model
# Function for image preprocessing and prediction
def predict(image):
if image is None:
print("Error: Image not loaded. Check the file path.")
else:
# Step 1: Preprocess image for the model
image = image.resize((224, 224)) # Resize to the expected input size
image = img_to_array(image) # Convert the image to an array
image = np.expand_dims(image, axis=0) # Add batch dimension
image = preprocess_input(image) # Preprocess for MobileNetV2
# Load the model (assuming it is already loaded globally)
model = load()
# Step 2: Make predictions
predictions = model.predict(image)
predicted_class_idx = predictions.argmax() # Class index with the highest probability
confidence_score = predictions[0][predicted_class_idx] # Confidence score
#predicted_class = np.argmax(predictions, axis=1)[0]
# Get class label
predicted_class = class_labels[predicted_class_idx]
#predicted_label = class_labels.get(predicted_class, 'Unknown')
#return predicted_Class
return f"Predicted Class: {predicted_class} (Confidence: {confidence_score:.2f})"
# Step 3: Define examples for Gradio
examples = [
["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%203001.jpg"],
["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%205000.jpg"],
["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%20137.jpg"],
["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%20352.jpg"],
["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%20458.jpg"],
]
# Define the Gradio interface
iface = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Text(label="Prediction Results:"),
title="Bacterial Identification",
description="Upload an image of bacteria to identify - cocci, bacilli, or spirilla",
examples=examples,
allow_flagging="never",
live=True)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()