wookimchye's picture
Update app.py
135e43d verified
#!/usr/bin/env python
# coding: utf-8
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image
from ultralytics import YOLO
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
from huggingface_hub import snapshot_download
import os
from torchvision import transforms
# Define the class labels
classes = {0: "Defective", 1: "Good"}
model_path = "ResNet50_Classification.h5" # Trained RestNet50 model
best_yolo_model = "best.pt" # Trained YOLOv8 detection model
classification_model = tf.keras.models.load_model('ResNet50_Classification.h5')
detection_model = YOLO(best_yolo_model, task='detect')
# Define the image preprocessing function
def preprocess_image(pilimg):
img = pilimg.resize((224, 224)) # Resize to the input size of ResNet50
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
return img_array
def classify_image(pilimg):
img_array = preprocess_image(pilimg) # Preprocess the input image
classify_result = classification_model.predict(img_array)[0][0] # Get prediction probability
print(">>> Result : ", classify_result)
predicted_class = "Good" if classify_result >= 0.5 else "Defective"
print(">>> predicted_class : ", predicted_class)
return predicted_class
def detect_defect(img):
detection_result = detection_model.predict(img, conf=0.4, iou=0.5)
return detection_result
def process_image(pilimg):
summary_str = "" # summary variable
# Perform classification first, then perform detection if Defective
classification = classify_image(pilimg)
if classification == "Good":
out_pilimg = pilimg.convert("RGB")
draw = ImageDraw.Draw(out_pilimg)
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
font = ImageFont.truetype(font_path, 30)
#font = ImageFont.truetype("arialbd.ttf", 30) # Use arial.ttf for bold font
draw.text((250, 10), "Good", fill="green", font=font)
#summary_str = "No defect is detected, the cap is GOOD!"
summary_str = f"<span style='font-size:30px; font-weight:bold; color:green'>No defect is detected, the cap is GOOD!</span>"
else: # Defective
detection_result = detect_defect(pilimg)
img_bgr = detection_result[0].plot()
out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # RGB-order PIL image
draw = ImageDraw.Draw(out_pilimg)
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
font = ImageFont.truetype(font_path, 30)
#font = ImageFont.truetype("arialbd.ttf", 30) # Use arial.ttf for bold font
draw.text((300, 10), "Defective", fill="red", font=font)
detections = detection_result[0].boxes.data # Get detections
if len(detections) > 0:
#summary_str = "Defect is detected, the cap is BAD"
summary_str = f"<span style='font-size:30px; font-weight:bold; color:red'>Defect is detected, the cap is BAD!</span>"
else:
#summary_str = "The cap is classifed as Defective but the defect cannot be detected!"
summary_str = f"<span style='font-size:30px; font-weight:bold; color:blue'>The cap is classifed as Defective but the defect cannot be detected!</span>"
#return out_pilimg, f"**{summary_str}**"
return out_pilimg, summary_str
title = "Detect the status of the cap: DEFECTIVE or GOOD"
interface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=[
gr.Image(type="pil", label="Classification/Detection result"),
gr.Markdown(label="Classification/Detection Summary"),
],
title=title,
)
# Launch the interface
interface.launch(share=True)