Spaces:
Runtime error
Runtime error
import yolov5 | |
import random | |
from pathlib import Path | |
import paho.mqtt.client as mqtt | |
import mysql.connector | |
from datetime import datetime | |
# MySQL Configuration (TiDB Serverless) | |
db_config = { | |
'host': 'gateway01.ap-southeast-1.prod.aws.tidbcloud.com', | |
'user': '5ztcqT1EBcgYB5u.root', | |
'password': '92t7gF7zq3Oz9eKb', | |
'database': 'satwa', | |
'port': 4000, | |
'ssl_ca': '/etc/ssl/certs/ca-certificates.crt', # Works on Spaces | |
'ssl_verify_cert': True, | |
'ssl_verify_identity': True | |
} | |
# Connect to MySQL | |
try: | |
db = mysql.connector.connect(**db_config) | |
cursor = db.cursor() | |
print("Connected to MySQL database successfully.") | |
except mysql.connector.Error as err: | |
print(f"Failed to connect to MySQL: {err}") | |
exit() | |
# Create table | |
create_table_query = """ | |
CREATE TABLE IF NOT EXISTS garbage_classification ( | |
id INT AUTO_INCREMENT PRIMARY KEY, | |
timestamp DATETIME, | |
payload_value INT, | |
predicted_class VARCHAR(255) | |
) | |
""" | |
try: | |
cursor.execute(create_table_query) | |
db.commit() | |
print("Table checked/created.") | |
except mysql.connector.Error as err: | |
print(f"Error creating table: {err}") | |
exit() | |
# Force CPU (Spaces free tier is CPU-only) | |
device = 'cpu' | |
print(f"Using device: {device}") | |
# Load model | |
model = yolov5.load('keremberke/yolov5m-garbage') | |
model.to(device) | |
# Model parameters | |
model.conf = 0.25 | |
model.iou = 0.45 | |
model.agnostic = False | |
model.multi_label = False | |
model.max_det = 1000 | |
class_names = model.names | |
# Base directory in Spaces | |
base_dir = "/data/garbage_classification" | |
base_path = Path(base_dir) | |
if not base_path.exists() or not base_path.is_dir(): | |
print(f"Error: Directory '{base_dir}' not found.") | |
exit() | |
subfolders = sorted([f for f in base_path.rglob('*') if f.is_dir()]) | |
if not subfolders: | |
print(f"Error: No subfolders found in '{base_dir}'.") | |
exit() | |
print(f"Found {len(subfolders)} subfolders.") | |
def get_folder_index(value): | |
if 1 <= value <= 11: return 0 | |
elif 12 <= value <= 22: return 1 | |
elif 23 <= value <= 33: return 2 | |
elif 34 <= value <= 44: return 3 | |
elif 45 <= value <= 55: return 4 | |
elif 56 <= value <= 66: return 5 | |
elif 67 <= value <= 77: return 6 | |
elif 78 <= value <= 88: return 7 | |
elif 89 <= value <= 99: return 8 | |
else: return -1 | |
def process_random_image(subfolder, payload_value): | |
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp') | |
images = [f for f in subfolder.iterdir() if f.is_file() and f.suffix.lower() in image_extensions] | |
if not images: | |
print(f"No images in {subfolder}") | |
return None | |
random_image = random.choice(images) | |
print(f"Selected image: {random_image}") | |
results = model(str(random_image), size=640) | |
predictions = results.pred[0] | |
predicted_class = "None" | |
if len(predictions) > 0: | |
scores = predictions[:, 4] | |
top_idx = scores.argmax() | |
top_score = scores[top_idx] | |
top_cat = int(predictions[top_idx, 5]) | |
predicted_class = class_names[top_cat] if top_cat < len(class_names) else "Unknown" | |
print(f"Class: {predicted_class}, Confidence: {top_score:.2f}") | |
else: | |
print("No objects detected.") | |
timestamp = datetime.now() | |
insert_query = "INSERT INTO garbage_classification (timestamp, payload_value, predicted_class) VALUES (%s, %s, %s)" | |
try: | |
cursor.execute(insert_query, (timestamp, payload_value, predicted_class)) | |
db.commit() | |
print("Data saved to database.") | |
except mysql.connector.Error as err: | |
print(f"Error saving to database: {err}") | |
return predicted_class | |
def on_message(client, userdata, msg): | |
try: | |
value = int(msg.payload.decode()) | |
print(f"Received MQTT value: {value}") | |
folder_idx = get_folder_index(value) | |
if folder_idx != -1 and folder_idx < len(subfolders): | |
selected_folder = subfolders[folder_idx] | |
print(f"Mapped to folder: {selected_folder}") | |
process_random_image(selected_folder, value) | |
else: | |
timestamp = datetime.now() | |
insert_query = "INSERT INTO garbage_classification (timestamp, payload_value, predicted_class) VALUES (%s, %s, %s)" | |
cursor.execute(insert_query, (timestamp, value, "Out of Range")) | |
db.commit() | |
print("Value out of range. Saved to database.") | |
except ValueError: | |
print("Invalid payload - must be an integer") | |
# MQTT Setup | |
broker = "test.mosquitto.org" | |
topic = "garbage/index" | |
client = mqtt.Client() | |
client.on_message = on_message | |
client.connect(broker, 1883, 60) | |
client.subscribe(topic) | |
print(f"Subscribed to {topic}, waiting for messages...") | |
client.loop_forever() # Runs indefinitely |