|
|
|
|
|
|
|
|
|
import tensorflow as tf |
|
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions |
|
from tensorflow.keras.preprocessing import image |
|
from ultralytics import YOLO |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFont |
|
import gradio as gr |
|
from huggingface_hub import snapshot_download |
|
import os |
|
from torchvision import transforms |
|
|
|
|
|
classes = {0: "Defective", 1: "Good"} |
|
|
|
model_path = "ResNet50_Classification.h5" |
|
|
|
best_yolo_model = "best.pt" |
|
|
|
classification_model = tf.keras.models.load_model('ResNet50_Classification.h5') |
|
|
|
detection_model = YOLO(best_yolo_model, task='detect') |
|
|
|
|
|
|
|
def preprocess_image(pilimg): |
|
img = pilimg.resize((224, 224)) |
|
img_array = image.img_to_array(img) |
|
img_array = np.expand_dims(img_array, axis=0) |
|
return img_array |
|
|
|
def classify_image(pilimg): |
|
img_array = preprocess_image(pilimg) |
|
classify_result = classification_model.predict(img_array)[0][0] |
|
print(">>> Result : ", classify_result) |
|
|
|
predicted_class = "Good" if classify_result >= 0.5 else "Defective" |
|
print(">>> predicted_class : ", predicted_class) |
|
|
|
return predicted_class |
|
|
|
def detect_defect(img): |
|
detection_result = detection_model.predict(img, conf=0.4, iou=0.5) |
|
|
|
return detection_result |
|
|
|
|
|
def process_image(pilimg): |
|
summary_str = "" |
|
|
|
|
|
classification = classify_image(pilimg) |
|
|
|
if classification == "Good": |
|
out_pilimg = pilimg.convert("RGB") |
|
draw = ImageDraw.Draw(out_pilimg) |
|
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" |
|
font = ImageFont.truetype(font_path, 30) |
|
|
|
|
|
draw.text((250, 10), "Good", fill="green", font=font) |
|
|
|
summary_str = f"<span style='font-size:30px; font-weight:bold; color:green'>No defect is detected, the cap is GOOD!</span>" |
|
|
|
else: |
|
detection_result = detect_defect(pilimg) |
|
img_bgr = detection_result[0].plot() |
|
out_pilimg = Image.fromarray(img_bgr[..., ::-1]) |
|
|
|
draw = ImageDraw.Draw(out_pilimg) |
|
font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" |
|
font = ImageFont.truetype(font_path, 30) |
|
|
|
|
|
draw.text((300, 10), "Defective", fill="red", font=font) |
|
|
|
detections = detection_result[0].boxes.data |
|
if len(detections) > 0: |
|
|
|
summary_str = f"<span style='font-size:30px; font-weight:bold; color:red'>Defect is detected, the cap is BAD!</span>" |
|
else: |
|
|
|
summary_str = f"<span style='font-size:30px; font-weight:bold; color:blue'>The cap is classifed as Defective but the defect cannot be detected!</span>" |
|
|
|
|
|
return out_pilimg, summary_str |
|
|
|
title = "Detect the status of the cap: DEFECTIVE or GOOD" |
|
interface = gr.Interface( |
|
fn=process_image, |
|
inputs=gr.Image(type="pil", label="Input Image"), |
|
outputs=[ |
|
gr.Image(type="pil", label="Classification/Detection result"), |
|
gr.Markdown(label="Classification/Detection Summary"), |
|
], |
|
title=title, |
|
) |
|
|
|
|
|
interface.launch(share=True) |
|
|