garbage-classifier / model2.py
Shiva Subhan S
Add garbage classifier app
29eec23
import yolov5
import random
from pathlib import Path
import paho.mqtt.client as mqtt
import mysql.connector
from datetime import datetime
# MySQL Configuration (TiDB Serverless)
db_config = {
'host': 'gateway01.ap-southeast-1.prod.aws.tidbcloud.com',
'user': '5ztcqT1EBcgYB5u.root',
'password': '92t7gF7zq3Oz9eKb',
'database': 'satwa',
'port': 4000,
'ssl_ca': '/etc/ssl/certs/ca-certificates.crt', # Works on Spaces
'ssl_verify_cert': True,
'ssl_verify_identity': True
}
# Connect to MySQL
try:
db = mysql.connector.connect(**db_config)
cursor = db.cursor()
print("Connected to MySQL database successfully.")
except mysql.connector.Error as err:
print(f"Failed to connect to MySQL: {err}")
exit()
# Create table
create_table_query = """
CREATE TABLE IF NOT EXISTS garbage_classification (
id INT AUTO_INCREMENT PRIMARY KEY,
timestamp DATETIME,
payload_value INT,
predicted_class VARCHAR(255)
)
"""
try:
cursor.execute(create_table_query)
db.commit()
print("Table checked/created.")
except mysql.connector.Error as err:
print(f"Error creating table: {err}")
exit()
# Force CPU (Spaces free tier is CPU-only)
device = 'cpu'
print(f"Using device: {device}")
# Load model
model = yolov5.load('keremberke/yolov5m-garbage')
model.to(device)
# Model parameters
model.conf = 0.25
model.iou = 0.45
model.agnostic = False
model.multi_label = False
model.max_det = 1000
class_names = model.names
# Base directory in Spaces
base_dir = "/data/garbage_classification"
base_path = Path(base_dir)
if not base_path.exists() or not base_path.is_dir():
print(f"Error: Directory '{base_dir}' not found.")
exit()
subfolders = sorted([f for f in base_path.rglob('*') if f.is_dir()])
if not subfolders:
print(f"Error: No subfolders found in '{base_dir}'.")
exit()
print(f"Found {len(subfolders)} subfolders.")
def get_folder_index(value):
if 1 <= value <= 11: return 0
elif 12 <= value <= 22: return 1
elif 23 <= value <= 33: return 2
elif 34 <= value <= 44: return 3
elif 45 <= value <= 55: return 4
elif 56 <= value <= 66: return 5
elif 67 <= value <= 77: return 6
elif 78 <= value <= 88: return 7
elif 89 <= value <= 99: return 8
else: return -1
def process_random_image(subfolder, payload_value):
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
images = [f for f in subfolder.iterdir() if f.is_file() and f.suffix.lower() in image_extensions]
if not images:
print(f"No images in {subfolder}")
return None
random_image = random.choice(images)
print(f"Selected image: {random_image}")
results = model(str(random_image), size=640)
predictions = results.pred[0]
predicted_class = "None"
if len(predictions) > 0:
scores = predictions[:, 4]
top_idx = scores.argmax()
top_score = scores[top_idx]
top_cat = int(predictions[top_idx, 5])
predicted_class = class_names[top_cat] if top_cat < len(class_names) else "Unknown"
print(f"Class: {predicted_class}, Confidence: {top_score:.2f}")
else:
print("No objects detected.")
timestamp = datetime.now()
insert_query = "INSERT INTO garbage_classification (timestamp, payload_value, predicted_class) VALUES (%s, %s, %s)"
try:
cursor.execute(insert_query, (timestamp, payload_value, predicted_class))
db.commit()
print("Data saved to database.")
except mysql.connector.Error as err:
print(f"Error saving to database: {err}")
return predicted_class
def on_message(client, userdata, msg):
try:
value = int(msg.payload.decode())
print(f"Received MQTT value: {value}")
folder_idx = get_folder_index(value)
if folder_idx != -1 and folder_idx < len(subfolders):
selected_folder = subfolders[folder_idx]
print(f"Mapped to folder: {selected_folder}")
process_random_image(selected_folder, value)
else:
timestamp = datetime.now()
insert_query = "INSERT INTO garbage_classification (timestamp, payload_value, predicted_class) VALUES (%s, %s, %s)"
cursor.execute(insert_query, (timestamp, value, "Out of Range"))
db.commit()
print("Value out of range. Saved to database.")
except ValueError:
print("Invalid payload - must be an integer")
# MQTT Setup
broker = "test.mosquitto.org"
topic = "garbage/index"
client = mqtt.Client()
client.on_message = on_message
client.connect(broker, 1883, 60)
client.subscribe(topic)
print(f"Subscribed to {topic}, waiting for messages...")
client.loop_forever() # Runs indefinitely