File size: 4,012 Bytes
1b803a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c6ee43
 
1b803a3
 
 
 
 
 
30a4b5d
 
 
1b803a3
 
135e43d
 
1b803a3
 
 
 
 
 
 
30a4b5d
 
 
1b803a3
9d4d41a
1b803a3
 
 
135e43d
 
1b803a3
135e43d
 
1b803a3
135e43d
 
1b803a3
4c6ee43
1b803a3
 
 
4c6ee43
 
fb2321f
4c6ee43
1b803a3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python
# coding: utf-8


import tensorflow as tf
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image
from ultralytics import YOLO
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
from huggingface_hub import snapshot_download
import os
from torchvision import transforms

# Define the class labels
classes = {0: "Defective", 1: "Good"}

model_path = "ResNet50_Classification.h5" # Trained RestNet50 model

best_yolo_model = "best.pt" # Trained YOLOv8 detection model

classification_model = tf.keras.models.load_model('ResNet50_Classification.h5') 

detection_model = YOLO(best_yolo_model, task='detect')


# Define the image preprocessing function
def preprocess_image(pilimg):
    img = pilimg.resize((224, 224))  # Resize to the input size of ResNet50    
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension    
    return img_array

def classify_image(pilimg):
    img_array = preprocess_image(pilimg)  # Preprocess the input image
    classify_result = classification_model.predict(img_array)[0][0]  # Get prediction probability 
    print(">>> Result : ", classify_result)

    predicted_class = "Good" if classify_result >= 0.5 else "Defective"
    print(">>> predicted_class : ", predicted_class)

    return predicted_class

def detect_defect(img):
    detection_result = detection_model.predict(img, conf=0.4, iou=0.5)
     
    return detection_result

   
def process_image(pilimg):
    summary_str = "" # summary variable
    
    # Perform classification first, then perform detection if Defective
    classification = classify_image(pilimg)

    if classification == "Good":
        out_pilimg = pilimg.convert("RGB")
        draw = ImageDraw.Draw(out_pilimg)
        font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
        font = ImageFont.truetype(font_path, 30)
        #font = ImageFont.truetype("arialbd.ttf", 30)  # Use arial.ttf for bold font

        draw.text((250, 10), "Good", fill="green", font=font)
        #summary_str = "No defect is detected, the cap is GOOD!"
        summary_str = f"<span style='font-size:30px; font-weight:bold; color:green'>No defect is detected, the cap is GOOD!</span>"

    else:  # Defective
        detection_result = detect_defect(pilimg)
        img_bgr = detection_result[0].plot()
        out_pilimg = Image.fromarray(img_bgr[..., ::-1])  # RGB-order PIL image

        draw = ImageDraw.Draw(out_pilimg)
        font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
        font = ImageFont.truetype(font_path, 30)
        #font = ImageFont.truetype("arialbd.ttf", 30)  # Use arial.ttf for bold font

        draw.text((300, 10), "Defective", fill="red", font=font)
        
        detections = detection_result[0].boxes.data  # Get detections
        if len(detections) > 0:
           #summary_str = "Defect is detected, the cap is BAD"
           summary_str = f"<span style='font-size:30px; font-weight:bold; color:red'>Defect is detected, the cap is BAD!</span>" 
        else:
           #summary_str = "The cap is classifed as Defective but the defect cannot be detected!"
           summary_str = f"<span style='font-size:30px; font-weight:bold; color:blue'>The cap is classifed as Defective but the defect cannot be detected!</span>"  
            
    #return out_pilimg, f"**{summary_str}**"
    return out_pilimg, summary_str

title = "Detect the status of the cap: DEFECTIVE or GOOD"
interface = gr.Interface(
                fn=process_image,
                inputs=gr.Image(type="pil", label="Input Image"),
                outputs=[
                    gr.Image(type="pil", label="Classification/Detection result"),
                    gr.Markdown(label="Classification/Detection Summary"),
                ],
                title=title,
             )

# Launch the interface
interface.launch(share=True)