Spaces:
Runtime error
Runtime error
File size: 4,746 Bytes
29eec23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import yolov5
import random
from pathlib import Path
import paho.mqtt.client as mqtt
import mysql.connector
from datetime import datetime
# MySQL Configuration (TiDB Serverless)
db_config = {
'host': 'gateway01.ap-southeast-1.prod.aws.tidbcloud.com',
'user': '5ztcqT1EBcgYB5u.root',
'password': '92t7gF7zq3Oz9eKb',
'database': 'satwa',
'port': 4000,
'ssl_ca': '/etc/ssl/certs/ca-certificates.crt', # Works on Spaces
'ssl_verify_cert': True,
'ssl_verify_identity': True
}
# Connect to MySQL
try:
db = mysql.connector.connect(**db_config)
cursor = db.cursor()
print("Connected to MySQL database successfully.")
except mysql.connector.Error as err:
print(f"Failed to connect to MySQL: {err}")
exit()
# Create table
create_table_query = """
CREATE TABLE IF NOT EXISTS garbage_classification (
id INT AUTO_INCREMENT PRIMARY KEY,
timestamp DATETIME,
payload_value INT,
predicted_class VARCHAR(255)
)
"""
try:
cursor.execute(create_table_query)
db.commit()
print("Table checked/created.")
except mysql.connector.Error as err:
print(f"Error creating table: {err}")
exit()
# Force CPU (Spaces free tier is CPU-only)
device = 'cpu'
print(f"Using device: {device}")
# Load model
model = yolov5.load('keremberke/yolov5m-garbage')
model.to(device)
# Model parameters
model.conf = 0.25
model.iou = 0.45
model.agnostic = False
model.multi_label = False
model.max_det = 1000
class_names = model.names
# Base directory in Spaces
base_dir = "/data/garbage_classification"
base_path = Path(base_dir)
if not base_path.exists() or not base_path.is_dir():
print(f"Error: Directory '{base_dir}' not found.")
exit()
subfolders = sorted([f for f in base_path.rglob('*') if f.is_dir()])
if not subfolders:
print(f"Error: No subfolders found in '{base_dir}'.")
exit()
print(f"Found {len(subfolders)} subfolders.")
def get_folder_index(value):
if 1 <= value <= 11: return 0
elif 12 <= value <= 22: return 1
elif 23 <= value <= 33: return 2
elif 34 <= value <= 44: return 3
elif 45 <= value <= 55: return 4
elif 56 <= value <= 66: return 5
elif 67 <= value <= 77: return 6
elif 78 <= value <= 88: return 7
elif 89 <= value <= 99: return 8
else: return -1
def process_random_image(subfolder, payload_value):
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
images = [f for f in subfolder.iterdir() if f.is_file() and f.suffix.lower() in image_extensions]
if not images:
print(f"No images in {subfolder}")
return None
random_image = random.choice(images)
print(f"Selected image: {random_image}")
results = model(str(random_image), size=640)
predictions = results.pred[0]
predicted_class = "None"
if len(predictions) > 0:
scores = predictions[:, 4]
top_idx = scores.argmax()
top_score = scores[top_idx]
top_cat = int(predictions[top_idx, 5])
predicted_class = class_names[top_cat] if top_cat < len(class_names) else "Unknown"
print(f"Class: {predicted_class}, Confidence: {top_score:.2f}")
else:
print("No objects detected.")
timestamp = datetime.now()
insert_query = "INSERT INTO garbage_classification (timestamp, payload_value, predicted_class) VALUES (%s, %s, %s)"
try:
cursor.execute(insert_query, (timestamp, payload_value, predicted_class))
db.commit()
print("Data saved to database.")
except mysql.connector.Error as err:
print(f"Error saving to database: {err}")
return predicted_class
def on_message(client, userdata, msg):
try:
value = int(msg.payload.decode())
print(f"Received MQTT value: {value}")
folder_idx = get_folder_index(value)
if folder_idx != -1 and folder_idx < len(subfolders):
selected_folder = subfolders[folder_idx]
print(f"Mapped to folder: {selected_folder}")
process_random_image(selected_folder, value)
else:
timestamp = datetime.now()
insert_query = "INSERT INTO garbage_classification (timestamp, payload_value, predicted_class) VALUES (%s, %s, %s)"
cursor.execute(insert_query, (timestamp, value, "Out of Range"))
db.commit()
print("Value out of range. Saved to database.")
except ValueError:
print("Invalid payload - must be an integer")
# MQTT Setup
broker = "test.mosquitto.org"
topic = "garbage/index"
client = mqtt.Client()
client.on_message = on_message
client.connect(broker, 1883, 60)
client.subscribe(topic)
print(f"Subscribed to {topic}, waiting for messages...")
client.loop_forever() # Runs indefinitely |