import yolov5 import random from pathlib import Path import paho.mqtt.client as mqtt import mysql.connector from datetime import datetime # MySQL Configuration (TiDB Serverless) db_config = { 'host': 'gateway01.ap-southeast-1.prod.aws.tidbcloud.com', 'user': '5ztcqT1EBcgYB5u.root', 'password': '92t7gF7zq3Oz9eKb', 'database': 'satwa', 'port': 4000, 'ssl_ca': '/etc/ssl/certs/ca-certificates.crt', # Works on Spaces 'ssl_verify_cert': True, 'ssl_verify_identity': True } # Connect to MySQL try: db = mysql.connector.connect(**db_config) cursor = db.cursor() print("Connected to MySQL database successfully.") except mysql.connector.Error as err: print(f"Failed to connect to MySQL: {err}") exit() # Create table create_table_query = """ CREATE TABLE IF NOT EXISTS garbage_classification ( id INT AUTO_INCREMENT PRIMARY KEY, timestamp DATETIME, payload_value INT, predicted_class VARCHAR(255) ) """ try: cursor.execute(create_table_query) db.commit() print("Table checked/created.") except mysql.connector.Error as err: print(f"Error creating table: {err}") exit() # Force CPU (Spaces free tier is CPU-only) device = 'cpu' print(f"Using device: {device}") # Load model model = yolov5.load('keremberke/yolov5m-garbage') model.to(device) # Model parameters model.conf = 0.25 model.iou = 0.45 model.agnostic = False model.multi_label = False model.max_det = 1000 class_names = model.names # Base directory in Spaces base_dir = "/data/garbage_classification" base_path = Path(base_dir) if not base_path.exists() or not base_path.is_dir(): print(f"Error: Directory '{base_dir}' not found.") exit() subfolders = sorted([f for f in base_path.rglob('*') if f.is_dir()]) if not subfolders: print(f"Error: No subfolders found in '{base_dir}'.") exit() print(f"Found {len(subfolders)} subfolders.") def get_folder_index(value): if 1 <= value <= 11: return 0 elif 12 <= value <= 22: return 1 elif 23 <= value <= 33: return 2 elif 34 <= value <= 44: return 3 elif 45 <= value <= 55: return 4 elif 56 <= value <= 66: return 5 elif 67 <= value <= 77: return 6 elif 78 <= value <= 88: return 7 elif 89 <= value <= 99: return 8 else: return -1 def process_random_image(subfolder, payload_value): image_extensions = ('.jpg', '.jpeg', '.png', '.bmp') images = [f for f in subfolder.iterdir() if f.is_file() and f.suffix.lower() in image_extensions] if not images: print(f"No images in {subfolder}") return None random_image = random.choice(images) print(f"Selected image: {random_image}") results = model(str(random_image), size=640) predictions = results.pred[0] predicted_class = "None" if len(predictions) > 0: scores = predictions[:, 4] top_idx = scores.argmax() top_score = scores[top_idx] top_cat = int(predictions[top_idx, 5]) predicted_class = class_names[top_cat] if top_cat < len(class_names) else "Unknown" print(f"Class: {predicted_class}, Confidence: {top_score:.2f}") else: print("No objects detected.") timestamp = datetime.now() insert_query = "INSERT INTO garbage_classification (timestamp, payload_value, predicted_class) VALUES (%s, %s, %s)" try: cursor.execute(insert_query, (timestamp, payload_value, predicted_class)) db.commit() print("Data saved to database.") except mysql.connector.Error as err: print(f"Error saving to database: {err}") return predicted_class def on_message(client, userdata, msg): try: value = int(msg.payload.decode()) print(f"Received MQTT value: {value}") folder_idx = get_folder_index(value) if folder_idx != -1 and folder_idx < len(subfolders): selected_folder = subfolders[folder_idx] print(f"Mapped to folder: {selected_folder}") process_random_image(selected_folder, value) else: timestamp = datetime.now() insert_query = "INSERT INTO garbage_classification (timestamp, payload_value, predicted_class) VALUES (%s, %s, %s)" cursor.execute(insert_query, (timestamp, value, "Out of Range")) db.commit() print("Value out of range. Saved to database.") except ValueError: print("Invalid payload - must be an integer") # MQTT Setup broker = "test.mosquitto.org" topic = "garbage/index" client = mqtt.Client() client.on_message = on_message client.connect(broker, 1883, 60) client.subscribe(topic) print(f"Subscribed to {topic}, waiting for messages...") client.loop_forever() # Runs indefinitely