Shiva Subhan S commited on
Commit
29eec23
·
1 Parent(s): 1b89ac9

Add garbage classifier app

Browse files
Files changed (2) hide show
  1. model2.py +157 -0
  2. requirements.txt +3 -0
model2.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yolov5
2
+ import random
3
+ from pathlib import Path
4
+ import paho.mqtt.client as mqtt
5
+ import mysql.connector
6
+ from datetime import datetime
7
+
8
+ # MySQL Configuration (TiDB Serverless)
9
+ db_config = {
10
+ 'host': 'gateway01.ap-southeast-1.prod.aws.tidbcloud.com',
11
+ 'user': '5ztcqT1EBcgYB5u.root',
12
+ 'password': '92t7gF7zq3Oz9eKb',
13
+ 'database': 'satwa',
14
+ 'port': 4000,
15
+ 'ssl_ca': '/etc/ssl/certs/ca-certificates.crt', # Works on Spaces
16
+ 'ssl_verify_cert': True,
17
+ 'ssl_verify_identity': True
18
+ }
19
+
20
+ # Connect to MySQL
21
+ try:
22
+ db = mysql.connector.connect(**db_config)
23
+ cursor = db.cursor()
24
+ print("Connected to MySQL database successfully.")
25
+ except mysql.connector.Error as err:
26
+ print(f"Failed to connect to MySQL: {err}")
27
+ exit()
28
+
29
+ # Create table
30
+ create_table_query = """
31
+ CREATE TABLE IF NOT EXISTS garbage_classification (
32
+ id INT AUTO_INCREMENT PRIMARY KEY,
33
+ timestamp DATETIME,
34
+ payload_value INT,
35
+ predicted_class VARCHAR(255)
36
+ )
37
+ """
38
+ try:
39
+ cursor.execute(create_table_query)
40
+ db.commit()
41
+ print("Table checked/created.")
42
+ except mysql.connector.Error as err:
43
+ print(f"Error creating table: {err}")
44
+ exit()
45
+
46
+ # Force CPU (Spaces free tier is CPU-only)
47
+ device = 'cpu'
48
+ print(f"Using device: {device}")
49
+
50
+ # Load model
51
+ model = yolov5.load('keremberke/yolov5m-garbage')
52
+ model.to(device)
53
+
54
+ # Model parameters
55
+ model.conf = 0.25
56
+ model.iou = 0.45
57
+ model.agnostic = False
58
+ model.multi_label = False
59
+ model.max_det = 1000
60
+
61
+ class_names = model.names
62
+
63
+ # Base directory in Spaces
64
+ base_dir = "/data/garbage_classification"
65
+ base_path = Path(base_dir)
66
+
67
+ if not base_path.exists() or not base_path.is_dir():
68
+ print(f"Error: Directory '{base_dir}' not found.")
69
+ exit()
70
+
71
+ subfolders = sorted([f for f in base_path.rglob('*') if f.is_dir()])
72
+ if not subfolders:
73
+ print(f"Error: No subfolders found in '{base_dir}'.")
74
+ exit()
75
+
76
+ print(f"Found {len(subfolders)} subfolders.")
77
+
78
+ def get_folder_index(value):
79
+ if 1 <= value <= 11: return 0
80
+ elif 12 <= value <= 22: return 1
81
+ elif 23 <= value <= 33: return 2
82
+ elif 34 <= value <= 44: return 3
83
+ elif 45 <= value <= 55: return 4
84
+ elif 56 <= value <= 66: return 5
85
+ elif 67 <= value <= 77: return 6
86
+ elif 78 <= value <= 88: return 7
87
+ elif 89 <= value <= 99: return 8
88
+ else: return -1
89
+
90
+ def process_random_image(subfolder, payload_value):
91
+ image_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
92
+ images = [f for f in subfolder.iterdir() if f.is_file() and f.suffix.lower() in image_extensions]
93
+
94
+ if not images:
95
+ print(f"No images in {subfolder}")
96
+ return None
97
+
98
+ random_image = random.choice(images)
99
+ print(f"Selected image: {random_image}")
100
+
101
+ results = model(str(random_image), size=640)
102
+ predictions = results.pred[0]
103
+
104
+ predicted_class = "None"
105
+ if len(predictions) > 0:
106
+ scores = predictions[:, 4]
107
+ top_idx = scores.argmax()
108
+ top_score = scores[top_idx]
109
+ top_cat = int(predictions[top_idx, 5])
110
+ predicted_class = class_names[top_cat] if top_cat < len(class_names) else "Unknown"
111
+ print(f"Class: {predicted_class}, Confidence: {top_score:.2f}")
112
+ else:
113
+ print("No objects detected.")
114
+
115
+ timestamp = datetime.now()
116
+ insert_query = "INSERT INTO garbage_classification (timestamp, payload_value, predicted_class) VALUES (%s, %s, %s)"
117
+ try:
118
+ cursor.execute(insert_query, (timestamp, payload_value, predicted_class))
119
+ db.commit()
120
+ print("Data saved to database.")
121
+ except mysql.connector.Error as err:
122
+ print(f"Error saving to database: {err}")
123
+
124
+ return predicted_class
125
+
126
+ def on_message(client, userdata, msg):
127
+ try:
128
+ value = int(msg.payload.decode())
129
+ print(f"Received MQTT value: {value}")
130
+
131
+ folder_idx = get_folder_index(value)
132
+
133
+ if folder_idx != -1 and folder_idx < len(subfolders):
134
+ selected_folder = subfolders[folder_idx]
135
+ print(f"Mapped to folder: {selected_folder}")
136
+ process_random_image(selected_folder, value)
137
+ else:
138
+ timestamp = datetime.now()
139
+ insert_query = "INSERT INTO garbage_classification (timestamp, payload_value, predicted_class) VALUES (%s, %s, %s)"
140
+ cursor.execute(insert_query, (timestamp, value, "Out of Range"))
141
+ db.commit()
142
+ print("Value out of range. Saved to database.")
143
+ except ValueError:
144
+ print("Invalid payload - must be an integer")
145
+
146
+ # MQTT Setup
147
+ broker = "test.mosquitto.org"
148
+ topic = "garbage/index"
149
+
150
+ client = mqtt.Client()
151
+ client.on_message = on_message
152
+
153
+ client.connect(broker, 1883, 60)
154
+ client.subscribe(topic)
155
+
156
+ print(f"Subscribed to {topic}, waiting for messages...")
157
+ client.loop_forever() # Runs indefinitely
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ yolov5
2
+ paho-mqtt
3
+ mysql-connector-python