Upload train_classifier.ipynb
Browse files- train_classifier.ipynb +20 -5
train_classifier.ipynb
CHANGED
@@ -339,7 +339,7 @@
|
|
339 |
" min_samples = counts.min()\n",
|
340 |
" # Calculate 2.0 times the minimum sample size, rounded down to the nearest integer\n",
|
341 |
" # target_samples = int(2.0 * min_samples)\n",
|
342 |
-
" target_samples =
|
343 |
" \n",
|
344 |
" indices_to_keep = np.hstack([\n",
|
345 |
" np.random.choice(\n",
|
@@ -521,7 +521,7 @@
|
|
521 |
"# Loss and optimizer\n",
|
522 |
"criterion = nn.CrossEntropyLoss()\n",
|
523 |
"optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5) \n",
|
524 |
-
"lambda_l1 = 1e-
|
525 |
]
|
526 |
},
|
527 |
{
|
@@ -539,7 +539,7 @@
|
|
539 |
"metadata": {},
|
540 |
"outputs": [],
|
541 |
"source": [
|
542 |
-
"epochs =
|
543 |
"train_losses, test_losses = [], []\n",
|
544 |
"\n",
|
545 |
"for epoch in range(epochs):\n",
|
@@ -577,7 +577,7 @@
|
|
577 |
" \n",
|
578 |
" precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='weighted', zero_division=0)\n",
|
579 |
" accuracy = accuracy_score(all_targets, all_preds) # Compute accuracy\n",
|
580 |
-
" if epoch %
|
581 |
" print(f'Epoch {epoch+1}: Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}')"
|
582 |
]
|
583 |
},
|
@@ -620,6 +620,9 @@
|
|
620 |
"metadata": {},
|
621 |
"outputs": [],
|
622 |
"source": [
|
|
|
|
|
|
|
623 |
"conf_matrix = confusion_matrix(all_targets, all_preds)\n",
|
624 |
"labels = [\"background\", \"tackle-live\", \"tackle-replay\",]\n",
|
625 |
" # \"tackle-live-incomplete\", \"tackle-replay-incomplete\"]\n",
|
@@ -627,7 +630,19 @@
|
|
627 |
"# plt.title('Confusion Matrix')\n",
|
628 |
"plt.xlabel('Predicted Label')\n",
|
629 |
"plt.ylabel('True Label')\n",
|
630 |
-
"plt.show()"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
631 |
]
|
632 |
},
|
633 |
{
|
|
|
339 |
" min_samples = counts.min()\n",
|
340 |
" # Calculate 2.0 times the minimum sample size, rounded down to the nearest integer\n",
|
341 |
" # target_samples = int(2.0 * min_samples)\n",
|
342 |
+
" target_samples = 7500\n",
|
343 |
" \n",
|
344 |
" indices_to_keep = np.hstack([\n",
|
345 |
" np.random.choice(\n",
|
|
|
521 |
"# Loss and optimizer\n",
|
522 |
"criterion = nn.CrossEntropyLoss()\n",
|
523 |
"optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5) \n",
|
524 |
+
"lambda_l1 = 1e-5 # L1 regularization strength"
|
525 |
]
|
526 |
},
|
527 |
{
|
|
|
539 |
"metadata": {},
|
540 |
"outputs": [],
|
541 |
"source": [
|
542 |
+
"epochs = 10\n",
|
543 |
"train_losses, test_losses = [], []\n",
|
544 |
"\n",
|
545 |
"for epoch in range(epochs):\n",
|
|
|
577 |
" \n",
|
578 |
" precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='weighted', zero_division=0)\n",
|
579 |
" accuracy = accuracy_score(all_targets, all_preds) # Compute accuracy\n",
|
580 |
+
" if epoch % 2==0:\n",
|
581 |
" print(f'Epoch {epoch+1}: Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}')"
|
582 |
]
|
583 |
},
|
|
|
620 |
"metadata": {},
|
621 |
"outputs": [],
|
622 |
"source": [
|
623 |
+
"print(np.unique(all_targets, return_counts=True))\n",
|
624 |
+
"print(np.unique(all_preds, return_counts=True))\n",
|
625 |
+
"\n",
|
626 |
"conf_matrix = confusion_matrix(all_targets, all_preds)\n",
|
627 |
"labels = [\"background\", \"tackle-live\", \"tackle-replay\",]\n",
|
628 |
" # \"tackle-live-incomplete\", \"tackle-replay-incomplete\"]\n",
|
|
|
630 |
"# plt.title('Confusion Matrix')\n",
|
631 |
"plt.xlabel('Predicted Label')\n",
|
632 |
"plt.ylabel('True Label')\n",
|
633 |
+
"plt.show()\n",
|
634 |
+
"\n",
|
635 |
+
"def showClassWiseAcc(conf_matrix):\n",
|
636 |
+
" # Calculate accuracy per class\n",
|
637 |
+
" class_accuracies = conf_matrix.diagonal() / conf_matrix.sum(axis=1)\n",
|
638 |
+
"\n",
|
639 |
+
" # Prepare accuracy data for writing to file\n",
|
640 |
+
" accuracy_data = \"\\n\".join([f\"Accuracy for class {i}: {class_accuracies[i]:.4f}\" for i in range(len(class_accuracies))])\n",
|
641 |
+
"\n",
|
642 |
+
" # Print accuracy per class and write to a file\n",
|
643 |
+
" print(accuracy_data) # Print to console\n",
|
644 |
+
"\n",
|
645 |
+
"showClassWiseAcc(conf_matrix)"
|
646 |
]
|
647 |
},
|
648 |
{
|