Luecke commited on
Commit
7b4e127
·
1 Parent(s): 6b1d0c8

classification model

Browse files
Files changed (1) hide show
  1. classification/classification.py +396 -0
classification/classification.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """classification.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1JuZNV3fqC5XQ0L-jhIyVRbIDPfWWGkVI
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ from torchvision import datasets, models, transforms
14
+ from torch.utils.data import DataLoader
15
+ from torch.utils.data import DataLoader, random_split
16
+ import os
17
+ import matplotlib.pyplot as plt
18
+ import random
19
+ from PIL import Image
20
+ import numpy as np
21
+ import pandas as pd
22
+
23
+ # Define the data directories
24
+ data_dir = 'drive/MyDrive/Ai_Hackathon_2024/plant_data/data_for_training'
25
+ augmented_data_dir = 'drive/MyDrive/Ai_Hackathon_2024/plant_data/augmented_data'
26
+
27
+ # Define the desired number of images per class
28
+ N = 50
29
+
30
+ # Define the augmentation transforms
31
+ augmentation_transforms = transforms.Compose([
32
+ transforms.RandomHorizontalFlip(),
33
+ transforms.RandomRotation(30),
34
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
35
+ transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)),
36
+ transforms.Pad(padding=10, padding_mode='reflect'), # Add padding with reflection
37
+ transforms.ToTensor(),
38
+ ])
39
+
40
+ # Load the dataset
41
+ print('loading dataset...')
42
+ dataset = datasets.ImageFolder(data_dir)
43
+ class_names = dataset.classes
44
+
45
+ print('loaded dataset.')
46
+
47
+ # Function to save augmented images
48
+ def save_image(img, path, idx):
49
+ img.save(os.path.join(path, f'{idx}.png'))
50
+
51
+ # Augment the dataset
52
+ if not os.path.exists(augmented_data_dir):
53
+ os.makedirs(augmented_data_dir)
54
+
55
+ print('starting augmentation process...')
56
+ for class_idx in range(len(dataset.classes)):
57
+ print(f"class_idx = {class_idx}")
58
+ class_dir = os.path.join(augmented_data_dir, dataset.classes[class_idx])
59
+ if not os.path.exists(class_dir):
60
+ os.makedirs(class_dir)
61
+
62
+ class_images = [img_path for img_path, label in dataset.samples if label == class_idx]
63
+ current_count = 0
64
+
65
+ # Save original images first
66
+ for img_path in class_images:
67
+ img = Image.open(img_path)
68
+ save_image(img, class_dir, current_count)
69
+ current_count += 1
70
+
71
+ # If there are fewer than N images, augment the dataset
72
+ while current_count < N:
73
+ img_path = random.choice(class_images)
74
+ img = Image.open(img_path)
75
+ img = augmentation_transforms(img)
76
+ img = transforms.ToPILImage()(img) # Convert back to PIL Image
77
+ save_image(img, class_dir, current_count)
78
+ current_count += 1
79
+
80
+ print('Data augmentation completed.')
81
+
82
+ # Define the data directory
83
+ data_dir = augmented_data_dir #'drive/MyDrive/Ai_Hackathon_2024/plant_data/data_for_training'
84
+
85
+
86
+ # Set the random seed for reproducibility
87
+ seed = 42
88
+ torch.manual_seed(seed)
89
+
90
+ # Define transforms
91
+ data_transforms = transforms.Compose([
92
+ transforms.Resize((224, 224)),
93
+ transforms.RandomHorizontalFlip(),
94
+ transforms.RandomRotation(30),
95
+ transforms.ToTensor(),
96
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
97
+ ])
98
+
99
+ # Create the dataset
100
+ full_dataset = datasets.ImageFolder(data_dir, transform=data_transforms)
101
+
102
+ # Define the train-validation split ratio
103
+ train_size = int(0.8 * len(full_dataset))
104
+ val_size = len(full_dataset) - train_size
105
+
106
+ # Split the dataset
107
+ train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(seed))
108
+
109
+ # Create data loaders
110
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
111
+ val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
112
+
113
+ # Load the pre-trained ResNet50 model
114
+ resnet50 = models.resnet50(weights='ResNet50_Weights.DEFAULT')
115
+
116
+ # Freeze the parameters of the pre-trained model
117
+ for param in resnet50.parameters():
118
+ param.requires_grad = False
119
+
120
+ # Remove the final fully connected layer
121
+ num_ftrs = resnet50.fc.in_features
122
+ resnet50.fc = nn.Identity() # Replace the final layer with an identity function to get the feature vectors
123
+
124
+ # Define a custom neural network with one hidden layer and an output layer
125
+ class CustomNet(nn.Module):
126
+ def __init__(self, num_ftrs, num_classes):
127
+ super(CustomNet, self).__init__()
128
+ self.resnet50 = resnet50
129
+ self.hidden = nn.Linear(num_ftrs, 512)
130
+ self.relu = nn.ReLU()
131
+ self.output = nn.Linear(512, num_classes)
132
+
133
+ def forward(self, x):
134
+ x = self.resnet50(x) # Extract features using the pre-trained model
135
+ x = self.hidden(x) # Pass through the hidden layer
136
+ x = self.relu(x) # Apply ReLU activation
137
+ x = self.output(x) # Output layer
138
+ return x
139
+
140
+ # Instantiate the custom network
141
+ num_classes = len(full_dataset.classes)
142
+ model = CustomNet(num_ftrs, num_classes)
143
+
144
+ # Move the model to the appropriate device
145
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
146
+ model = model.to(device)
147
+
148
+ # Define criterion and optimizer
149
+ criterion = nn.CrossEntropyLoss()
150
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
151
+
152
+ def train_model(model, dataloaders, criterion, optimizer, num_epochs=10):
153
+ best_model_wts = model.state_dict()
154
+ best_acc = 0.0
155
+
156
+ train_losses = []
157
+ val_losses = []
158
+
159
+ for epoch in range(num_epochs):
160
+ print(f'Epoch {epoch}/{num_epochs - 1}')
161
+ print('-' * 10)
162
+
163
+ # Each epoch has a training and validation phase
164
+ for phase in ['train', 'val']:
165
+ if phase == 'train':
166
+ model.train()
167
+ else:
168
+ model.eval()
169
+
170
+ running_loss = 0.0
171
+ running_corrects = 0
172
+
173
+ for inputs, labels in dataloaders[phase]:
174
+ inputs, labels = inputs.to(device), labels.to(device)
175
+
176
+ # Zero the parameter gradients
177
+ optimizer.zero_grad()
178
+
179
+ # Forward
180
+ with torch.set_grad_enabled(phase == 'train'):
181
+ outputs = model(inputs)
182
+ _, preds = torch.max(outputs, 1)
183
+ loss = criterion(outputs, labels)
184
+
185
+ # Backward + optimize only if in training phase
186
+ if phase == 'train':
187
+ loss.backward()
188
+ optimizer.step()
189
+
190
+ running_loss += loss.item() * inputs.size(0)
191
+ running_corrects += torch.sum(preds == labels.data)
192
+
193
+ epoch_loss = running_loss / len(dataloaders[phase].dataset)
194
+ epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
195
+
196
+ print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
197
+
198
+ if phase == 'train':
199
+ train_losses.append(epoch_loss)
200
+ else:
201
+ val_losses.append(epoch_loss)
202
+
203
+ # Deep copy the model
204
+ if phase == 'val' and epoch_acc > best_acc:
205
+ best_acc = epoch_acc
206
+ best_model_wts = model.state_dict()
207
+
208
+ print('Best val Acc: {:4f}'.format(best_acc))
209
+
210
+ # Load best model weights
211
+ model.load_state_dict(best_model_wts)
212
+
213
+ # Plot the training and validation loss
214
+ plt.figure(figsize=(10, 5))
215
+ plt.plot(train_losses, label='Training Loss')
216
+ plt.plot(val_losses, label='Validation Loss')
217
+ plt.xlabel('Epochs')
218
+ plt.ylabel('Loss')
219
+ plt.legend()
220
+ plt.show()
221
+
222
+ return model
223
+
224
+ # Create a dictionary to hold the dataloaders
225
+ dataloaders = {'train': train_loader, 'val': val_loader}
226
+
227
+ # Train and evaluate the model
228
+ model = train_model(model, dataloaders, criterion, optimizer, num_epochs=10)
229
+
230
+ # Save the model
231
+ torch.save(model.state_dict(), 'drive/MyDrive/Ai_Hackathon_2024/plant_data/fine_tuned_plant_classifier.pth')
232
+
233
+ # Function to evaluate the model
234
+ def evaluate_model(model, dataloader):
235
+ model.eval()
236
+ correct = 0
237
+ total = 0
238
+
239
+ all_preds = []
240
+ all_labels = []
241
+
242
+ with torch.no_grad():
243
+ for inputs, labels in dataloader:
244
+ inputs, labels = inputs.to(device), labels.to(device)
245
+ outputs = model(inputs)
246
+ _, preds = torch.max(outputs, 1)
247
+
248
+ all_preds.extend(preds.cpu().numpy())
249
+ all_labels.extend(labels.cpu().numpy())
250
+
251
+ correct += (preds == labels).sum().item()
252
+ total += labels.size(0)
253
+
254
+ accuracy = correct / total
255
+ return accuracy, all_preds, all_labels
256
+
257
+ # Evaluate the model
258
+ dataloader = DataLoader(full_dataset, batch_size=32, shuffle=True)
259
+ accuracy, all_preds, all_labels = evaluate_model(model, dataloader)
260
+
261
+ # Calculate the number of correct and incorrect predictions
262
+ correct_preds = sum(np.array(all_preds) == np.array(all_labels))
263
+ incorrect_preds = len(all_labels) - correct_preds
264
+
265
+ print(f'Total images: {len(all_labels)}')
266
+ print(f'Correct predictions: {correct_preds}')
267
+ print(f'Incorrect predictions: {incorrect_preds}')
268
+ print(f'Accuracy: {accuracy:.4f}')
269
+
270
+ ##-----------------------------------------------------------##
271
+ real_dataset = datasets.ImageFolder('drive/MyDrive/Ai_Hackathon_2024/plant_data/data_for_training', transform=data_transforms)
272
+
273
+ # Evaluate the model
274
+ dataloader = DataLoader(real_dataset, batch_size=32, shuffle=True)
275
+ accuracy, all_preds, all_labels = evaluate_model(model, dataloader)
276
+
277
+ # Calculate the number of correct and incorrect predictions
278
+ correct_preds = sum(np.array(all_preds) == np.array(all_labels))
279
+ incorrect_preds = len(all_labels) - correct_preds
280
+ print('-'*10)
281
+ print(f'Total images: {len(all_labels)}')
282
+ print(f'Correct predictions: {correct_preds}')
283
+ print(f'Incorrect predictions: {incorrect_preds}')
284
+ print(f'Accuracy: {accuracy:.4f}')
285
+
286
+ # Function to load and preprocess the image
287
+ def process_image(image_path):
288
+ data_transform = transforms.Compose([
289
+ transforms.Resize((224, 224)),
290
+ transforms.ToTensor(),
291
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
292
+ ])
293
+ image = Image.open(image_path).convert('RGB')
294
+ image = data_transform(image)# data_transforms(image) # <-- data transforms uses all the random cropping as well
295
+ image = image.unsqueeze(0) # Add batch dimension
296
+ return image
297
+
298
+ #----------------------------INFERENCE PART----------------------------
299
+
300
+ # Function to predict the class of a single image
301
+ def predict_single_image(image_path, model):
302
+ # Load the image and preprocess it
303
+ image = process_image(image_path)
304
+
305
+ # Load the model
306
+ model.eval()
307
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
308
+ model = model.to(device)
309
+
310
+ # Pass the image through the model
311
+ with torch.no_grad():
312
+ image = image.to(device)
313
+ outputs = model(image)
314
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
315
+
316
+ # Return the class names and their probabilities as a Pandas Series
317
+ return pd.Series(probabilities.cpu().numpy(), index=class_names).sort_values(ascending=False)
318
+
319
+ def classify(img_path):
320
+ # Path to the single image
321
+ image_path = img_path
322
+
323
+ # Initialize your custom model
324
+ model = CustomNet(num_ftrs, num_classes)
325
+ # Load the trained model weights
326
+ model.load_state_dict(torch.load('./fine_tuned:plant_classifier.pth'))
327
+
328
+ # Predict the class probabilities
329
+ class_probabilities = predict_single_image(image_path, model)
330
+ return class_probabilities
331
+
332
+
333
+ #----------------------------INFERENCE PART----------------------------
334
+
335
+
336
+ ## script to automatically include larger drone images
337
+
338
+ import os
339
+ import shutil
340
+ from PIL import Image
341
+
342
+ # Define the paths
343
+ source_dir = 'path/to/source_images' # The directory with new images
344
+ target_base_dir = 'path/to/training_images' # The base directory containing original class folders
345
+ new_base_dir = 'path/to/training_images_2' # The base directory for the new substructure
346
+
347
+ # Extract the class folders
348
+ class_folders = [d for d in os.listdir(target_base_dir) if os.path.isdir(os.path.join(target_base_dir, d))]
349
+
350
+ # Function to extract ID from a filename
351
+ def extract_id(filename):
352
+ return filename.split('_')[0] # Assumes ID is the first part of the filename separated by '_'
353
+
354
+ # Function to crop the middle section of an image
355
+ def crop_middle_section(image):
356
+ width, height = image.size
357
+ new_width = width // 3
358
+ new_height = height // 3
359
+ left = (width - new_width) // 2
360
+ top = (height - new_height) // 2
361
+ right = left + new_width
362
+ bottom = top + new_height
363
+ return image.crop((left, top, right, bottom))
364
+
365
+ # Create the new base directory if it does not exist
366
+ os.makedirs(new_base_dir, exist_ok=True)
367
+
368
+ # Create a dictionary to map IDs to their respective class folders
369
+ id_to_class_folder = {}
370
+ for class_folder in class_folders:
371
+ class_folder_path = os.path.join(target_base_dir, class_folder)
372
+ for filename in os.listdir(class_folder_path):
373
+ if os.path.isfile(os.path.join(class_folder_path, filename)):
374
+ file_id = extract_id(filename)
375
+ id_to_class_folder[file_id] = class_folder
376
+
377
+ # Copy and manipulate the matching images
378
+ for filename in os.listdir(source_dir):
379
+ if os.path.isfile(os.path.join(source_dir, filename)):
380
+ file_id = extract_id(filename)
381
+ if file_id in id_to_class_folder:
382
+ target_class_folder = id_to_class_folder[file_id]
383
+ new_class_folder_path = os.path.join(new_base_dir, target_class_folder)
384
+ os.makedirs(new_class_folder_path, exist_ok=True) # Create the class folder if it doesn't exist
385
+
386
+ target_path = os.path.join(new_class_folder_path, filename)
387
+
388
+ # Open and manipulate the image
389
+ image_path = os.path.join(source_dir, filename)
390
+ with Image.open(image_path) as img:
391
+ cropped_img = crop_middle_section(img)
392
+ cropped_img.save(target_path)
393
+
394
+ print(f'Copied and cropped {filename} to {new_class_folder_path}')
395
+
396
+ print('Image processing and copying completed.')