import os import gradio as gr import numpy as np from tensorflow.keras.preprocessing.image import img_to_array from tensorflow.keras.applications.mobilenet_v2 import preprocess_input from tensorflow import keras from huggingface_hub import hf_hub_download from PIL import Image # Disable GPU (forces CPU usage) os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Define class labels class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'} # Function to load and use the model def load(): # Step 1: Download the model from Hugging Face Model Hub model_path = hf_hub_download(repo_id="D963934/bact_model", filename="bct_model.keras") # Step 2: Load the model using Keras model = keras.models.load_model(model_path) # Return the model for inference return model # Function for image preprocessing and prediction def predict(image): if image is None: print("Error: Image not loaded. Check the file path.") else: # Step 1: Preprocess image for the model image = image.resize((224, 224)) # Resize to the expected input size image = img_to_array(image) # Convert the image to an array image = np.expand_dims(image, axis=0) # Add batch dimension image = preprocess_input(image) # Preprocess for MobileNetV2 # Load the model (assuming it is already loaded globally) model = load() # Step 2: Make predictions predictions = model.predict(image) predicted_class_idx = predictions.argmax() # Class index with the highest probability confidence_score = predictions[0][predicted_class_idx] # Confidence score #predicted_class = np.argmax(predictions, axis=1)[0] # Get class label predicted_class = class_labels[predicted_class_idx] #predicted_label = class_labels.get(predicted_class, 'Unknown') #return predicted_Class return f"Predicted Class: {predicted_class} (Confidence: {confidence_score:.2f})" # Step 3: Define examples for Gradio examples = [ ["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%203001.jpg"], ["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%205000.jpg"], ["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%20137.jpg"], ["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%20352.jpg"], ["https://huggingface.co/spaces/D963934/BactSpace/resolve/main/img%20458.jpg"], ] # Define the Gradio interface iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Text(label="Prediction Results:"), title="Bacterial Identification", description="Upload an image of bacteria to identify - cocci, bacilli, or spirilla", examples=examples, allow_flagging="never", live=True) # Launch the Gradio app if __name__ == "__main__": iface.launch()