TrashDetection / app.py
Mnjar
Edit title
735b755
import gradio as gr
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
# Load YOLO model
model = YOLO('trash_detection.pt')
def predict(image):
"""
Function to make predictions using YOLO model.
Args:
image (PIL.Image): Input image.
Returns:
Tuple: Processed image with bounding boxes and a list of predictions with labels, confidence, and bounding boxes.
"""
# Convert PIL image to numpy array
img = np.array(image)
results = model(img) # Run the model on the image
if isinstance(results, list):
results = results[0] # Take the first result if it's a list
# Access the bounding boxes and other relevant information
boxes = results[0].boxes # Get the boxes from the first result
# Convert the boxes to a Pandas DataFrame
df = pd.DataFrame(boxes.xyxy.cpu().numpy(), columns=['x1', 'y1', 'x2', 'y2'])
df['confidence'] = boxes.conf.cpu().numpy()
df['class'] = boxes.cls.cpu().numpy()
# Get the class names
class_names = results.names # Class names dictionary
df['label'] = df['class'].apply(lambda x: class_names[int(x)])
# Select the necessary columns for the output
df_output = df[['label', 'confidence', 'x1', 'y1', 'x2', 'y2']]
# Convert DataFrame to list of lists
output = df_output.values.tolist()
# Draw bounding boxes on the image
pil_img = Image.fromarray(img) # Convert numpy array back to PIL image
draw = ImageDraw.Draw(pil_img)
# Load a font with a larger size
try:
font = ImageFont.truetype("/Library/Fonts/Arial.ttf", 24) # Adjust the font size as needed
except IOError:
font = ImageFont.load_default() # Use default font if the specified one is not found
for _, row in df.iterrows():
x1, y1, x2, y2 = row['x1'], row['y1'], row['x2'], row['y2']
label = row['label']
confidence = row['confidence']
# Draw bounding box
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
# Draw label and confidence with larger font
text = f"{label} ({confidence:.2f})"
draw.text((x1, y1 - 30), text, font=font, fill="red")
# Return the image with bounding boxes and the prediction results
return pil_img, output
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"), # Input image as PIL
outputs=[gr.Image(type="pil"), # Output image with bounding boxes
gr.Dataframe(
headers=["Label", "Confidence", "Xmin", "Ymin", "Xmax", "Ymax"],
label="Predictions"
)],
title="Garbage Detection",
description="Upload an image to detect objects."
)
# Launch the app
if __name__ == "__main__":
iface.launch(share=True)