Upload 2 files
Browse files- feature_extraction.ipynb +433 -0
- train_classifier.ipynb +745 -0
feature_extraction.ipynb
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Start to finish - DINOv2 feature extraction"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "markdown",
|
12 |
+
"metadata": {
|
13 |
+
"jp-MarkdownHeadingCollapsed": true
|
14 |
+
},
|
15 |
+
"source": [
|
16 |
+
"## Imports"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"cell_type": "code",
|
21 |
+
"execution_count": null,
|
22 |
+
"metadata": {
|
23 |
+
"id": "3AdjGBwjnr-5"
|
24 |
+
},
|
25 |
+
"outputs": [],
|
26 |
+
"source": [
|
27 |
+
"from transformers import AutoImageProcessor, AutoModel\n",
|
28 |
+
"from PIL import Image\n",
|
29 |
+
"\n",
|
30 |
+
"\n",
|
31 |
+
"import matplotlib.pyplot as plt\n",
|
32 |
+
"import numpy as np\n",
|
33 |
+
"import requests\n",
|
34 |
+
"import torch\n",
|
35 |
+
"import cv2\n",
|
36 |
+
"import os"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "markdown",
|
41 |
+
"metadata": {
|
42 |
+
"id": "qvTYvSVOkLLL"
|
43 |
+
},
|
44 |
+
"source": [
|
45 |
+
"## Initialize pre-trained image processor and model"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": null,
|
51 |
+
"metadata": {
|
52 |
+
"colab": {
|
53 |
+
"base_uri": "https://localhost:8080/"
|
54 |
+
},
|
55 |
+
"id": "aRlCk-Tlj8Iv",
|
56 |
+
"outputId": "fb51843c-598f-48ad-a1c0-cf8d9bab53f4",
|
57 |
+
"scrolled": true
|
58 |
+
},
|
59 |
+
"outputs": [],
|
60 |
+
"source": [
|
61 |
+
"# Adjust for cuda - takes up 2193 MiB on device\n",
|
62 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
63 |
+
"\n",
|
64 |
+
"processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large')\n",
|
65 |
+
"model = AutoModel.from_pretrained('facebook/dinov2-large').to(device)"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "markdown",
|
70 |
+
"metadata": {
|
71 |
+
"jp-MarkdownHeadingCollapsed": true
|
72 |
+
},
|
73 |
+
"source": [
|
74 |
+
"## DINOv2 Feature Extraction"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"cell_type": "code",
|
79 |
+
"execution_count": null,
|
80 |
+
"metadata": {},
|
81 |
+
"outputs": [],
|
82 |
+
"source": [
|
83 |
+
"from tqdm import tqdm\n",
|
84 |
+
"import gc\n",
|
85 |
+
"\n",
|
86 |
+
"torch.cuda.empty_cache() \n",
|
87 |
+
"gc.collect()"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": null,
|
93 |
+
"metadata": {
|
94 |
+
"id": "Crq7KD84qz5d"
|
95 |
+
},
|
96 |
+
"outputs": [],
|
97 |
+
"source": [
|
98 |
+
"# Path to your videos\n",
|
99 |
+
"path_to_videos = './dataset-tacdec/videos'\n",
|
100 |
+
"\n",
|
101 |
+
"# Directory paths\n",
|
102 |
+
"processed_features_dir = './processed_features'\n",
|
103 |
+
"last_hidden_states_dir = os.path.join(processed_features_dir, 'last_hidden_states/')\n",
|
104 |
+
"pooler_outputs_dir = os.path.join(processed_features_dir, 'pooler_outputs/')\n",
|
105 |
+
"\n",
|
106 |
+
"# Create directories if they don't exist\n",
|
107 |
+
"os.makedirs(last_hidden_states_dir, exist_ok=True)\n",
|
108 |
+
"os.makedirs(pooler_outputs_dir, exist_ok=True)\n",
|
109 |
+
"\n",
|
110 |
+
"# Dictonary with filename as key, all feature extracted frames as values\n",
|
111 |
+
"feature_extracted_videos = {}\n",
|
112 |
+
"\n",
|
113 |
+
"# Define batch size\n",
|
114 |
+
"batch_size = 32\n",
|
115 |
+
"\n",
|
116 |
+
"# Process each video\n",
|
117 |
+
"for video_file in tqdm(os.listdir(path_to_videos)):\n",
|
118 |
+
" full_path = os.path.join(path_to_videos, video_file)\n",
|
119 |
+
"\n",
|
120 |
+
" if not os.path.isfile(full_path):\n",
|
121 |
+
" continue\n",
|
122 |
+
"\n",
|
123 |
+
" cap = cv2.VideoCapture(full_path)\n",
|
124 |
+
"\n",
|
125 |
+
" # List to hold all batch outputs, clear for each video\n",
|
126 |
+
" batch_last_hidden_states = []\n",
|
127 |
+
" batch_pooler_outputs = []\n",
|
128 |
+
" \n",
|
129 |
+
" batch_frames = []\n",
|
130 |
+
"\n",
|
131 |
+
" while True:\n",
|
132 |
+
" ret, frame = cap.read()\n",
|
133 |
+
" if not ret:\n",
|
134 |
+
" \n",
|
135 |
+
" # Process the last batch\n",
|
136 |
+
" if len(batch_frames) > 0:\n",
|
137 |
+
" inputs = processor(images=batch_frames, return_tensors=\"pt\").to(device)\n",
|
138 |
+
" \n",
|
139 |
+
" with torch.no_grad():\n",
|
140 |
+
" outputs = model(**inputs)\n",
|
141 |
+
" \n",
|
142 |
+
" for key, value in outputs.items():\n",
|
143 |
+
" if key == 'last_hidden_state':\n",
|
144 |
+
" # batch_last_hidden_states.append(value.cpu().numpy())\n",
|
145 |
+
" batch_last_hidden_states.append(value)\n",
|
146 |
+
" elif key == 'pooler_output':\n",
|
147 |
+
" # batch_pooler_outputs.append(value.cpu().numpy())\n",
|
148 |
+
" batch_pooler_outputs.append(value)\n",
|
149 |
+
" else:\n",
|
150 |
+
" print('Error in key, expected last_hidden_state or pooler_output, got: ', key)\n",
|
151 |
+
" break\n",
|
152 |
+
"\n",
|
153 |
+
" # cv2 comes in BGR, but transformer takes RGB\n",
|
154 |
+
" frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
|
155 |
+
" batch_frames.append(frame_rgb)\n",
|
156 |
+
"\n",
|
157 |
+
" # Check if batch is full\n",
|
158 |
+
" if len(batch_frames) == batch_size:\n",
|
159 |
+
" inputs = processor(images=batch_frames, return_tensors=\"pt\").to(device)\n",
|
160 |
+
" # outputs = model(**inputs)\n",
|
161 |
+
" with torch.no_grad():\n",
|
162 |
+
" outputs = model(**inputs)\n",
|
163 |
+
" for key, value in outputs.items():\n",
|
164 |
+
" if key == 'last_hidden_state':\n",
|
165 |
+
" batch_last_hidden_states.append(value)\n",
|
166 |
+
" elif key == 'pooler_output':\n",
|
167 |
+
" batch_pooler_outputs.append(value)\n",
|
168 |
+
" else:\n",
|
169 |
+
" print('Error in key, expected last_hidden_state or pooler_output, got: ', key)\n",
|
170 |
+
"\n",
|
171 |
+
" # Clear batch\n",
|
172 |
+
" batch_frames = []\n",
|
173 |
+
"\n",
|
174 |
+
" \n",
|
175 |
+
" all_last_hidden_states = torch.cat(batch_last_hidden_states, dim=0)\n",
|
176 |
+
" all_pooler_outputs = torch.cat(batch_pooler_outputs, dim=0)\n",
|
177 |
+
"\n",
|
178 |
+
" # Save the tensors with the video name as filename\n",
|
179 |
+
" pt_filename = video_file.replace('.mp4', '.pt')\n",
|
180 |
+
" torch.save(all_last_hidden_states, os.path.join(last_hidden_states_dir, f'{pt_filename}'))\n",
|
181 |
+
" torch.save(all_pooler_outputs, os.path.join(pooler_outputs_dir, f'{pt_filename}'))\n",
|
182 |
+
" \n",
|
183 |
+
"print('Features extracted')"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "markdown",
|
188 |
+
"metadata": {
|
189 |
+
"jp-MarkdownHeadingCollapsed": true
|
190 |
+
},
|
191 |
+
"source": [
|
192 |
+
"## Reload features to verify "
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"cell_type": "code",
|
197 |
+
"execution_count": null,
|
198 |
+
"metadata": {},
|
199 |
+
"outputs": [],
|
200 |
+
"source": [
|
201 |
+
"lhs_torch = torch.load('./processed_features/last_hidden_states/1738_avxeiaxxw6ocr.pt')\n",
|
202 |
+
"po_torch = torch.load('./processed_features/pooler_outputs/1738_avxeiaxxw6ocr.pt')\n",
|
203 |
+
"\n",
|
204 |
+
"print('LHS Torch size: ', lhs_torch.size())\n",
|
205 |
+
"print('PO Torch size: ', po_torch.size())\n",
|
206 |
+
"\n",
|
207 |
+
"for i in range(all_last_hidden_states.size(0)):\n",
|
208 |
+
" print(f\"Frame {i}:\")\n",
|
209 |
+
" print(all_last_hidden_states[i])\n",
|
210 |
+
" print() \n",
|
211 |
+
" break\n",
|
212 |
+
"\n",
|
213 |
+
"for i in range(lhs_torch.size(0)):\n",
|
214 |
+
" print(f\"Frame {i}:\")\n",
|
215 |
+
" print(all_last_hidden_states[i])\n",
|
216 |
+
" print() \n",
|
217 |
+
" break\n"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"cell_type": "markdown",
|
222 |
+
"metadata": {},
|
223 |
+
"source": [
|
224 |
+
"# Different sorts of plots"
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "markdown",
|
229 |
+
"metadata": {},
|
230 |
+
"source": [
|
231 |
+
"## Histogram of video length in seconds"
|
232 |
+
]
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"cell_type": "code",
|
236 |
+
"execution_count": null,
|
237 |
+
"metadata": {},
|
238 |
+
"outputs": [],
|
239 |
+
"source": [
|
240 |
+
"import os\n",
|
241 |
+
"import cv2\n",
|
242 |
+
"import numpy as np\n",
|
243 |
+
"\n",
|
244 |
+
"path_to_videos = './dataset-tacdec/videos'\n",
|
245 |
+
"video_lengths = []\n",
|
246 |
+
"frame_counts = []\n",
|
247 |
+
"\n",
|
248 |
+
"# Iterate through each file in the directory\n",
|
249 |
+
"for video_file in os.listdir(path_to_videos):\n",
|
250 |
+
" full_path = os.path.join(path_to_videos, video_file)\n",
|
251 |
+
"\n",
|
252 |
+
" if not os.path.isfile(full_path):\n",
|
253 |
+
" continue\n",
|
254 |
+
"\n",
|
255 |
+
" cap = cv2.VideoCapture(full_path)\n",
|
256 |
+
"\n",
|
257 |
+
" # Calculate the length of the video\n",
|
258 |
+
" # Note: Assuming the frame rate information is accurate\n",
|
259 |
+
" if cap.isOpened():\n",
|
260 |
+
" fps = cap.get(cv2.CAP_PROP_FPS) # Frame rate\n",
|
261 |
+
" frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
|
262 |
+
" duration = frame_count / fps if fps > 0 else 0\n",
|
263 |
+
" video_lengths.append(duration)\n",
|
264 |
+
" frame_counts.append(frame_count)\n",
|
265 |
+
"\n",
|
266 |
+
" cap.release()\n",
|
267 |
+
"\n",
|
268 |
+
"np.save('./video_durations', video_lengths)\n",
|
269 |
+
"np.save('./frame_counts', frame_counts)\n"
|
270 |
+
]
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "code",
|
274 |
+
"execution_count": null,
|
275 |
+
"metadata": {},
|
276 |
+
"outputs": [],
|
277 |
+
"source": [
|
278 |
+
"import seaborn as sns\n",
|
279 |
+
"\n",
|
280 |
+
"# Set the aesthetic style of the plots\n",
|
281 |
+
"sns.set(style=\"darkgrid\")\n",
|
282 |
+
"\n",
|
283 |
+
"# Plotting the histogram for video lengths\n",
|
284 |
+
"plt.figure(figsize=(12, 6))\n",
|
285 |
+
"sns.histplot(video_lengths, kde=True, color=\"blue\")\n",
|
286 |
+
"plt.title('Histogram - Video Lengths')\n",
|
287 |
+
"plt.xlabel('Length of Videos (seconds)')\n",
|
288 |
+
"plt.ylabel('Number of Videos')\n",
|
289 |
+
"\n",
|
290 |
+
"# Plotting the histogram for frame counts\n",
|
291 |
+
"plt.figure(figsize=(12, 6))\n",
|
292 |
+
"sns.histplot(frame_counts, kde=True, color=\"green\")\n",
|
293 |
+
"plt.title('Histogram - Number of Frames')\n",
|
294 |
+
"plt.xlabel('Frame Count')\n",
|
295 |
+
"plt.ylabel('Number of Videos')\n",
|
296 |
+
"\n",
|
297 |
+
"plt.show()"
|
298 |
+
]
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"cell_type": "markdown",
|
302 |
+
"metadata": {
|
303 |
+
"jp-MarkdownHeadingCollapsed": true
|
304 |
+
},
|
305 |
+
"source": [
|
306 |
+
"## Frame count and vid lengths"
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": null,
|
312 |
+
"metadata": {},
|
313 |
+
"outputs": [],
|
314 |
+
"source": [
|
315 |
+
"sns.boxplot(x=video_lengths)\n",
|
316 |
+
"plt.title('Box Plot of Video Lengths')\n",
|
317 |
+
"plt.xlabel('Video Length (seconds)')\n",
|
318 |
+
"plt.show()\n",
|
319 |
+
"\n",
|
320 |
+
"sns.boxplot(x=frame_counts, color=\"r\")\n",
|
321 |
+
"plt.title('Box Plot of Frame Counts')\n",
|
322 |
+
"plt.xlabel('Frame Count')\n",
|
323 |
+
"plt.show()\n"
|
324 |
+
]
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "markdown",
|
328 |
+
"metadata": {},
|
329 |
+
"source": [
|
330 |
+
"## Class distributions"
|
331 |
+
]
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"cell_type": "code",
|
335 |
+
"execution_count": null,
|
336 |
+
"metadata": {
|
337 |
+
"scrolled": true
|
338 |
+
},
|
339 |
+
"outputs": [],
|
340 |
+
"source": [
|
341 |
+
"import os\n",
|
342 |
+
"import json\n",
|
343 |
+
"import pandas as pd\n",
|
344 |
+
"import matplotlib.pyplot as plt\n",
|
345 |
+
"import seaborn as sns\n",
|
346 |
+
"\n",
|
347 |
+
"path_to_labels = './dataset-tacdec/full_labels'\n",
|
348 |
+
"class_counts = {'background': 0, 'tackle-live': 0, 'tackle-replay': 0, 'tackle-live-incomplete': 0, 'tackle-replay-incomplete': 0, 'dummy_class': 0}\n",
|
349 |
+
"\n",
|
350 |
+
"# Iterate through each JSON file in the labels directory\n",
|
351 |
+
"for label_file in os.listdir(path_to_labels):\n",
|
352 |
+
" full_path = os.path.join(path_to_labels, label_file)\n",
|
353 |
+
"\n",
|
354 |
+
" if not os.path.isfile(full_path):\n",
|
355 |
+
" continue\n",
|
356 |
+
"\n",
|
357 |
+
" with open(full_path, 'r') as file:\n",
|
358 |
+
" data = json.load(file)\n",
|
359 |
+
" frame_sections = data['frames_sections']\n",
|
360 |
+
"\n",
|
361 |
+
" # Extract annotations\n",
|
362 |
+
" for section in frame_sections:\n",
|
363 |
+
" for frame_number, frame_data in section.items():\n",
|
364 |
+
" class_label = frame_data['radio_answer']\n",
|
365 |
+
" if class_label in class_counts:\n",
|
366 |
+
" class_counts[class_label] += 1\n",
|
367 |
+
"\n",
|
368 |
+
"# Convert the dictionary to a DataFrame for Seaborn\n",
|
369 |
+
"df_class_counts = pd.DataFrame(list(class_counts.items()), columns=['Class', 'Occurrences'])\n",
|
370 |
+
"\n",
|
371 |
+
"# Save the DataFrame to a CSV file\n",
|
372 |
+
"df_class_counts.to_csv('class_distribution.csv', sep=',', index=False, encoding='utf-8')\n",
|
373 |
+
"\n",
|
374 |
+
"# Plotting the distribution using Seaborn\n",
|
375 |
+
"plt.figure(figsize=(10, 6))\n",
|
376 |
+
"sns.barplot(x='Class', y='Occurrences', data=df_class_counts, palette='viridis', alpha=0.75)\n",
|
377 |
+
"plt.title('Distribution of Frame Classes')\n",
|
378 |
+
"plt.xlabel('Class')\n",
|
379 |
+
"plt.ylabel('Number of Occurrences')\n",
|
380 |
+
"plt.xticks(rotation=45) # Rotate class names for better readability\n",
|
381 |
+
"plt.tight_layout() # Adjust layout to make room for the rotated x-axis labels\n",
|
382 |
+
"plt.show()\n"
|
383 |
+
]
|
384 |
+
},
|
385 |
+
{
|
386 |
+
"cell_type": "code",
|
387 |
+
"execution_count": null,
|
388 |
+
"metadata": {},
|
389 |
+
"outputs": [],
|
390 |
+
"source": [
|
391 |
+
"import pandas as pd\n",
|
392 |
+
"import matplotlib.pyplot as plt\n",
|
393 |
+
"\n",
|
394 |
+
"# Ensure df_class_counts is already created as in the previous script\n",
|
395 |
+
"\n",
|
396 |
+
"# Create a pie chart\n",
|
397 |
+
"plt.figure(figsize=(8, 8))\n",
|
398 |
+
"plt.pie(df_class_counts['Occurrences'], labels=df_class_counts['Class'], \n",
|
399 |
+
" autopct=lambda p: '{:.1f}%'.format(p), startangle=140, \n",
|
400 |
+
" colors=sns.color_palette('bright', len(df_class_counts)))\n",
|
401 |
+
"plt.title('Distribution of Frame Classes', fontweight='bold')\n",
|
402 |
+
"plt.show()"
|
403 |
+
]
|
404 |
+
}
|
405 |
+
],
|
406 |
+
"metadata": {
|
407 |
+
"colab": {
|
408 |
+
"collapsed_sections": [
|
409 |
+
"uzdIsbuEpF2w"
|
410 |
+
],
|
411 |
+
"provenance": []
|
412 |
+
},
|
413 |
+
"kernelspec": {
|
414 |
+
"display_name": "Python (evan31818)",
|
415 |
+
"language": "python",
|
416 |
+
"name": "evan31818"
|
417 |
+
},
|
418 |
+
"language_info": {
|
419 |
+
"codemirror_mode": {
|
420 |
+
"name": "ipython",
|
421 |
+
"version": 3
|
422 |
+
},
|
423 |
+
"file_extension": ".py",
|
424 |
+
"mimetype": "text/x-python",
|
425 |
+
"name": "python",
|
426 |
+
"nbconvert_exporter": "python",
|
427 |
+
"pygments_lexer": "ipython3",
|
428 |
+
"version": "3.8.18"
|
429 |
+
}
|
430 |
+
},
|
431 |
+
"nbformat": 4,
|
432 |
+
"nbformat_minor": 0
|
433 |
+
}
|
train_classifier.ipynb
ADDED
@@ -0,0 +1,745 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "27933625-f946-4fce-a622-e92ea518fad1",
|
6 |
+
"metadata": {
|
7 |
+
"jp-MarkdownHeadingCollapsed": true
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"## 1. Mandatory"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": null,
|
16 |
+
"id": "8674dce1-4885-4bc9-8b90-1d847c38e6f1",
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, accuracy_score\n",
|
21 |
+
"from torch.utils.data import TensorDataset, DataLoader\n",
|
22 |
+
"from sklearn.model_selection import train_test_split\n",
|
23 |
+
"\n",
|
24 |
+
"import matplotlib.pyplot as plt\n",
|
25 |
+
"import torch.optim as optim\n",
|
26 |
+
"import torch.nn as nn\n",
|
27 |
+
"import seaborn as sns\n",
|
28 |
+
"import numpy as np\n",
|
29 |
+
"import torch\n",
|
30 |
+
"import json\n",
|
31 |
+
"import os"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"id": "46a4597f",
|
37 |
+
"metadata": {},
|
38 |
+
"source": [
|
39 |
+
"# 2. Complete below - if you did not download DINOv2 cls-tokens together with the labels - Skip to step 3 if done."
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "markdown",
|
44 |
+
"id": "1f1bd72b-ed98-4669-908c-2b103bcacda5",
|
45 |
+
"metadata": {},
|
46 |
+
"source": [
|
47 |
+
"## Load labels"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "code",
|
52 |
+
"execution_count": null,
|
53 |
+
"id": "98e09803-9862-4e29-aaff-3bdcd4e0fe53",
|
54 |
+
"metadata": {},
|
55 |
+
"outputs": [],
|
56 |
+
"source": [
|
57 |
+
"# Paths to labels\n",
|
58 |
+
"path_to_labels = '/home/evan/D1/project/code/start_end_labels'"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"cell_type": "code",
|
63 |
+
"execution_count": null,
|
64 |
+
"id": "b41d5fd2-ee4a-4f02-98b9-887e48115c47",
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [],
|
67 |
+
"source": [
|
68 |
+
"# Should be 425 files, code just to verify\n",
|
69 |
+
"num_of_labels = 0\n",
|
70 |
+
"for ind, label in enumerate(os.listdir(path_to_labels)):\n",
|
71 |
+
" num_of_labels = ind+1\n",
|
72 |
+
"\n",
|
73 |
+
"num_of_labels"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": null,
|
79 |
+
"id": "1ef791d8-a268-4436-ad18-150d645bef73",
|
80 |
+
"metadata": {},
|
81 |
+
"outputs": [],
|
82 |
+
"source": [
|
83 |
+
"list_of_labels = []\n",
|
84 |
+
"\n",
|
85 |
+
"categorical_mapping = {'background': 0, 'tackle-live': 1, 'tackle-replay': 2, 'tackle-live-incomplete': 3, 'tackle-replay-incomplete': 4}\n",
|
86 |
+
"\n",
|
87 |
+
"# Sort to make sure order is maintained\n",
|
88 |
+
"for ind, label in enumerate(sorted(os.listdir(path_to_labels))):\n",
|
89 |
+
" full_path = os.path.join(path_to_labels, label)\n",
|
90 |
+
"\n",
|
91 |
+
" with open(full_path, 'r') as file:\n",
|
92 |
+
" data = json.load(file)\n",
|
93 |
+
" \n",
|
94 |
+
" # Extract frame count\n",
|
95 |
+
" frame_count = data['media_attributes']['frame_count']\n",
|
96 |
+
"\n",
|
97 |
+
" # Extract tackles\n",
|
98 |
+
" tackles = data['events']\n",
|
99 |
+
" \n",
|
100 |
+
" labels_of_current_file = np.zeros(frame_count)\n",
|
101 |
+
" \n",
|
102 |
+
" for tackle in tackles:\n",
|
103 |
+
" # Extract variables\n",
|
104 |
+
" tackle_class = tackle['type']\n",
|
105 |
+
" start_frame = tackle['frame_start']\n",
|
106 |
+
" end_frame = tackle['frame_end']\n",
|
107 |
+
"\n",
|
108 |
+
" # Need to shift start_frame with -1 as array-indexing starts at 0, while \n",
|
109 |
+
" # frame count starts at 1\n",
|
110 |
+
" for i in range(start_frame-1, end_frame, 1):\n",
|
111 |
+
" labels_of_current_file[i] = categorical_mapping[tackle_class]\n",
|
112 |
+
"\n",
|
113 |
+
" list_of_labels.append(labels_of_current_file)\n"
|
114 |
+
]
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"cell_type": "markdown",
|
118 |
+
"id": "b302d94a-d18c-4e41-929b-3c8f4d547afa",
|
119 |
+
"metadata": {},
|
120 |
+
"source": [
|
121 |
+
"## Verify that change is correct"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "code",
|
126 |
+
"execution_count": null,
|
127 |
+
"id": "286b27a8-1c9a-4ba9-9996-deeef7927195",
|
128 |
+
"metadata": {},
|
129 |
+
"outputs": [],
|
130 |
+
"source": [
|
131 |
+
"test = list_of_labels[0]\n",
|
132 |
+
"\n",
|
133 |
+
"for i in range(len(test)):\n",
|
134 |
+
" # Should give [0,1,1,0] as 181-107 is the actual sequence, but its moved to 180-206 with array indexing\n",
|
135 |
+
" # starting from 0 instead of 1 like the frame counting.\n",
|
136 |
+
" if i == 179 or i == 180 or i == 206 or i == 207:\n",
|
137 |
+
" print(test[i])"
|
138 |
+
]
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"cell_type": "markdown",
|
142 |
+
"id": "88650952-a098-4ae3-ba3b-d67f5d17c41b",
|
143 |
+
"metadata": {},
|
144 |
+
"source": [
|
145 |
+
"## Map incomplete class-labels to instances of their respective 'full-class'"
|
146 |
+
]
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"cell_type": "code",
|
150 |
+
"execution_count": null,
|
151 |
+
"id": "2c48db00-b367-4f38-aa59-de5164d11fe9",
|
152 |
+
"metadata": {},
|
153 |
+
"outputs": [],
|
154 |
+
"source": [
|
155 |
+
"class_mapping = {0:0, 1: 1, 2: 2, 3: 1, 4: 2}\n",
|
156 |
+
"prev_list_of_labels = list_of_labels\n",
|
157 |
+
"\n",
|
158 |
+
"for i, label in enumerate(list_of_labels):\n",
|
159 |
+
" list_of_labels[i] = np.array([class_mapping[frame_class] for frame_class in label])"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "markdown",
|
164 |
+
"id": "ee69c1f0-db9d-4848-9b3c-2556e09d1991",
|
165 |
+
"metadata": {},
|
166 |
+
"source": [
|
167 |
+
"## Load DINOv2-features and extract CLS-tokens"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "code",
|
172 |
+
"execution_count": null,
|
173 |
+
"id": "20b2ee27-5d94-4301-9229-aa9486360a73",
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [],
|
176 |
+
"source": [
|
177 |
+
"# Define path to DINOv2-features\n",
|
178 |
+
"path_to_tensors = '/home/evan/D1/project/code/processed_features/last_hidden_states'\n",
|
179 |
+
"path_to_first_tensor = '/home/evan/D1/project/code/processed_features/last_hidden_states/1738_avxeiaxxw6ocr.pt'\n",
|
180 |
+
"\n",
|
181 |
+
"all_cls_tokens = torch.load(path_to_first_tensor)[:,0,:]\n",
|
182 |
+
"\n",
|
183 |
+
"for index, tensor_file in enumerate(sorted(os.listdir(path_to_tensors))[1:]): # Start from the second item\n",
|
184 |
+
" full_path = os.path.join(path_to_tensors, tensor_file)\n",
|
185 |
+
" cls_token = torch.load(full_path)[:,0,:]\n",
|
186 |
+
" all_cls_tokens = torch.cat((all_cls_tokens, cls_token), dim=0)\n",
|
187 |
+
"\n",
|
188 |
+
"\n",
|
189 |
+
"# Should have shape: total_frames, feature_vector (1024)\n",
|
190 |
+
"print('CLS tokens shape: ', all_cls_tokens.shape)"
|
191 |
+
]
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"cell_type": "markdown",
|
195 |
+
"id": "03c8f5ed-5b04-456d-a9fd-8d493878ea18",
|
196 |
+
"metadata": {},
|
197 |
+
"source": [
|
198 |
+
"### Reshape labels list"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": null,
|
204 |
+
"id": "c9bc68a4-5c33-43b6-a9e1-febb035ea2fb",
|
205 |
+
"metadata": {},
|
206 |
+
"outputs": [],
|
207 |
+
"source": [
|
208 |
+
"all_labels_concatenated = np.concatenate(list_of_labels, axis=0)\n",
|
209 |
+
"\n",
|
210 |
+
"# Length should be total number of frames\n",
|
211 |
+
"print('Length of all labels concatenated: ', len(all_labels_concatenated))\n",
|
212 |
+
"\n",
|
213 |
+
"\n",
|
214 |
+
"\n",
|
215 |
+
"# Map imcomplete instances to complete ones. As this approach only looks at 'background', 'tackle-live' and 'tackle-replay',\n",
|
216 |
+
"# the incomplete classes can be mapped to their respective others due to a single frame being part of the tackle whatsoever.\n",
|
217 |
+
"class_mapping = {0:0, 1: 1, 2: 2, 3: 1, 4: 2}\n",
|
218 |
+
"\n",
|
219 |
+
"for i, label in enumerate(all_labels_concatenated):\n",
|
220 |
+
" all_labels_concatenated[i] = class_mapping[label]"
|
221 |
+
]
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"cell_type": "markdown",
|
225 |
+
"id": "f644964d",
|
226 |
+
"metadata": {},
|
227 |
+
"source": [
|
228 |
+
"# 3. If you downloaded the DINOv2 cls-tokens together with the labels, follow below:"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"cell_type": "markdown",
|
233 |
+
"id": "ab5f971c",
|
234 |
+
"metadata": {},
|
235 |
+
"source": [
|
236 |
+
"The next cell can be skipped if you completed step 1."
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "code",
|
241 |
+
"execution_count": null,
|
242 |
+
"id": "5e2600aa",
|
243 |
+
"metadata": {},
|
244 |
+
"outputs": [],
|
245 |
+
"source": [
|
246 |
+
"\n",
|
247 |
+
"# Place the path to your cls tokens and labels downloaded below:\n",
|
248 |
+
"cls_path = '/home/evan/D1/project/code/full_concat_dino_features.pt'\n",
|
249 |
+
"labels_path = '/home/evan/D1/project/code/all_labels_concatenated.npy'\n",
|
250 |
+
"\n",
|
251 |
+
"all_cls_tokens = torch.load(cls_path)\n",
|
252 |
+
"all_labels_concatenated = np.load(labels_path)\n",
|
253 |
+
"\n",
|
254 |
+
"# Map imcomplete instances to complete ones. As this approach only looks at 'background', 'tackle-live' and 'tackle-replay',\n",
|
255 |
+
"# the incomplete classes can be mapped to their respective others due to a single frame being part of the tackle whatsoever.\n",
|
256 |
+
"class_mapping = {0:0, 1: 1, 2: 2, 3: 1, 4: 2}\n",
|
257 |
+
"\n",
|
258 |
+
"for i, label in enumerate(all_labels_concatenated):\n",
|
259 |
+
" all_labels_concatenated[i] = class_mapping[label]"
|
260 |
+
]
|
261 |
+
},
|
262 |
+
{
|
263 |
+
"cell_type": "markdown",
|
264 |
+
"id": "01b360a4",
|
265 |
+
"metadata": {},
|
266 |
+
"source": [
|
267 |
+
"# 4. Follow below "
|
268 |
+
]
|
269 |
+
},
|
270 |
+
{
|
271 |
+
"cell_type": "markdown",
|
272 |
+
"id": "e4561d68-a149-4a00-9a7d-e0e69bbcfa53",
|
273 |
+
"metadata": {},
|
274 |
+
"source": [
|
275 |
+
"## Balance classes"
|
276 |
+
]
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"cell_type": "markdown",
|
280 |
+
"id": "68e2e245-36d3-464e-85ae-6d5f30ebe164",
|
281 |
+
"metadata": {},
|
282 |
+
"source": [
|
283 |
+
"### Move cls-tokens to CPU"
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"execution_count": null,
|
289 |
+
"id": "61b8a9fe-d3ac-4d6c-b0a9-5c32a2593495",
|
290 |
+
"metadata": {},
|
291 |
+
"outputs": [],
|
292 |
+
"source": [
|
293 |
+
"all_cls_tokens = np.array([e.cpu().numpy() for e in all_cls_tokens])\n",
|
294 |
+
"print('Tensor shape after reshaping: ', all_cls_tokens.shape)"
|
295 |
+
]
|
296 |
+
},
|
297 |
+
{
|
298 |
+
"cell_type": "markdown",
|
299 |
+
"id": "b6074527-9ddc-4b9e-b933-a6c5af9cd134",
|
300 |
+
"metadata": {},
|
301 |
+
"source": [
|
302 |
+
"### Verify that order is correct"
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"cell_type": "code",
|
307 |
+
"execution_count": null,
|
308 |
+
"id": "ea1425ae-6588-4c71-8a08-7f9c0adc7422",
|
309 |
+
"metadata": {},
|
310 |
+
"outputs": [],
|
311 |
+
"source": [
|
312 |
+
"for i in range(len(all_labels_concatenated)):\n",
|
313 |
+
" # Should give [0,1,1,0] as 181-107 is the actual sequence, but its moved to 180-206 with array indexing\n",
|
314 |
+
" # starting from 0 instead of 1 like the frame counting.\n",
|
315 |
+
" if i == 179 or i == 180 or i == 206 or i == 207:\n",
|
316 |
+
" print(all_labels_concatenated[i])\n",
|
317 |
+
"\n",
|
318 |
+
" if i > 210:\n",
|
319 |
+
" break"
|
320 |
+
]
|
321 |
+
},
|
322 |
+
{
|
323 |
+
"cell_type": "markdown",
|
324 |
+
"id": "6e851954-e2d7-41fd-956f-92df09a79e8b",
|
325 |
+
"metadata": {},
|
326 |
+
"source": [
|
327 |
+
"### Class for balancing distribution of classes"
|
328 |
+
]
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"cell_type": "code",
|
332 |
+
"execution_count": null,
|
333 |
+
"id": "479daf78-11c0-4ded-9bb3-8fa34d12c6d7",
|
334 |
+
"metadata": {},
|
335 |
+
"outputs": [],
|
336 |
+
"source": [
|
337 |
+
"def balance_classes(X, y):\n",
|
338 |
+
" unique, counts = np.unique(y, return_counts=True)\n",
|
339 |
+
" min_samples = counts.min()\n",
|
340 |
+
" # Calculate 2.0 times the minimum sample size, rounded down to the nearest integer\n",
|
341 |
+
" # target_samples = int(2.0 * min_samples)\n",
|
342 |
+
" target_samples = 5000\n",
|
343 |
+
" \n",
|
344 |
+
" indices_to_keep = np.hstack([\n",
|
345 |
+
" np.random.choice(\n",
|
346 |
+
" np.where(y == label)[0], \n",
|
347 |
+
" min(target_samples, counts[unique.tolist().index(label)]), # Ensure not to exceed the actual count\n",
|
348 |
+
" replace=False\n",
|
349 |
+
" ) for label in unique\n",
|
350 |
+
" ])\n",
|
351 |
+
" \n",
|
352 |
+
" return X[indices_to_keep], y[indices_to_keep]"
|
353 |
+
]
|
354 |
+
},
|
355 |
+
{
|
356 |
+
"cell_type": "markdown",
|
357 |
+
"id": "6cf24d79-27d7-499e-b856-e58938cef5e7",
|
358 |
+
"metadata": {},
|
359 |
+
"source": [
|
360 |
+
"### Split into train and test, without shuffle to remain order"
|
361 |
+
]
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"cell_type": "code",
|
365 |
+
"execution_count": null,
|
366 |
+
"id": "9c9fbaec-2849-48d0-867d-e0ad39682135",
|
367 |
+
"metadata": {},
|
368 |
+
"outputs": [],
|
369 |
+
"source": [
|
370 |
+
"X_train, X_test, y_train, y_test = train_test_split(all_cls_tokens, all_labels_concatenated, test_size=0.2, shuffle=False, stratify=None)"
|
371 |
+
]
|
372 |
+
},
|
373 |
+
{
|
374 |
+
"cell_type": "code",
|
375 |
+
"execution_count": null,
|
376 |
+
"id": "35fa46bb-258a-4b6e-a8c0-56c47c791d55",
|
377 |
+
"metadata": {},
|
378 |
+
"outputs": [],
|
379 |
+
"source": [
|
380 |
+
"X_train_balanced, y_train_balanced = balance_classes(X_train, y_train)\n",
|
381 |
+
"X_test_balanced, y_test_balanced = balance_classes(X_test, y_test)\n",
|
382 |
+
"print(\"Total number of samples:\", len(all_labels_concatenated))\n",
|
383 |
+
"print(\"\")\n",
|
384 |
+
"\n",
|
385 |
+
"print('Total distribution of labels: \\n', np.unique(all_labels_concatenated, return_counts=True))\n",
|
386 |
+
"print(\"\")\n",
|
387 |
+
"\n",
|
388 |
+
"\n",
|
389 |
+
"print('Distribution within training set: \\n', np.unique(y_train_balanced, return_counts=True))\n",
|
390 |
+
"print(\"\")\n",
|
391 |
+
"\n",
|
392 |
+
"print('Distribution within test set: \\n', np.unique(y_test_balanced, return_counts=True))\n",
|
393 |
+
"print(\"\")\n",
|
394 |
+
"\n",
|
395 |
+
"\n",
|
396 |
+
"print('Training shape: ', X_train_balanced.shape, y_train_balanced.shape)\n",
|
397 |
+
"print(\"\")\n",
|
398 |
+
"\n",
|
399 |
+
"print('Test shape: ', X_test_balanced.shape, y_test_balanced.shape)\n",
|
400 |
+
"print(\"\")"
|
401 |
+
]
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"cell_type": "code",
|
405 |
+
"execution_count": null,
|
406 |
+
"id": "5b6bf3b4-5d67-41b4-9c6b-8d02d3923366",
|
407 |
+
"metadata": {},
|
408 |
+
"outputs": [],
|
409 |
+
"source": [
|
410 |
+
"# Convert data to torch tensors\n",
|
411 |
+
"X_train = torch.tensor(X_train_balanced, dtype=torch.float32)\n",
|
412 |
+
"y_train = torch.tensor(y_train_balanced, dtype=torch.long)\n",
|
413 |
+
"X_test = torch.tensor(X_test_balanced, dtype=torch.float32)\n",
|
414 |
+
"y_test = torch.tensor(y_test_balanced, dtype=torch.long)"
|
415 |
+
]
|
416 |
+
},
|
417 |
+
{
|
418 |
+
"cell_type": "markdown",
|
419 |
+
"id": "7d7250f4-c820-4c00-9bde-77bdc3cdd2e2",
|
420 |
+
"metadata": {},
|
421 |
+
"source": [
|
422 |
+
"## Create dataset and Dataloaders"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"cell_type": "code",
|
427 |
+
"execution_count": null,
|
428 |
+
"id": "532583ed-65e9-4339-b94d-6cdb704c0ed7",
|
429 |
+
"metadata": {},
|
430 |
+
"outputs": [],
|
431 |
+
"source": [
|
432 |
+
"# Create data loaders\n",
|
433 |
+
"batch_size = 64\n",
|
434 |
+
"train_dataset = TensorDataset(X_train, y_train)\n",
|
435 |
+
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
|
436 |
+
"\n",
|
437 |
+
"test_dataset = TensorDataset(X_test, y_test)\n",
|
438 |
+
"test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n"
|
439 |
+
]
|
440 |
+
},
|
441 |
+
{
|
442 |
+
"cell_type": "markdown",
|
443 |
+
"id": "5ef7b5d4-04e1-4c2e-9476-2537a6785893",
|
444 |
+
"metadata": {},
|
445 |
+
"source": [
|
446 |
+
"## Model class"
|
447 |
+
]
|
448 |
+
},
|
449 |
+
{
|
450 |
+
"cell_type": "code",
|
451 |
+
"execution_count": null,
|
452 |
+
"id": "d7120ab9-c016-4eba-9588-77afde98a639",
|
453 |
+
"metadata": {},
|
454 |
+
"outputs": [],
|
455 |
+
"source": [
|
456 |
+
"import torch.nn as nn\n",
|
457 |
+
"import torch.nn.functional as F\n",
|
458 |
+
"\n",
|
459 |
+
"class MultiLayerClassifier(nn.Module):\n",
|
460 |
+
" def __init__(self, input_size, num_classes):\n",
|
461 |
+
" super(MultiLayerClassifier, self).__init__()\n",
|
462 |
+
" \n",
|
463 |
+
" self.fc1 = nn.Linear(input_size, 128, bias=True)\n",
|
464 |
+
" self.dropout1 = nn.Dropout(0.5) \n",
|
465 |
+
" \n",
|
466 |
+
" # self.fc2 = nn.Linear(512, 128)\n",
|
467 |
+
" # self.dropout2 = nn.Dropout(0.5)\n",
|
468 |
+
" \n",
|
469 |
+
" self.fc3 = nn.Linear(128, num_classes, bias=True)\n",
|
470 |
+
" \n",
|
471 |
+
" def forward(self, x):\n",
|
472 |
+
" x = F.relu(self.fc1(x))\n",
|
473 |
+
" x = self.dropout1(x)\n",
|
474 |
+
" # x = F.relu(self.fc2(x))\n",
|
475 |
+
" # x = self.dropout2(x)\n",
|
476 |
+
" x = self.fc3(x)\n",
|
477 |
+
" \n",
|
478 |
+
" return x\n",
|
479 |
+
"\n",
|
480 |
+
"model = MultiLayerClassifier(1024, 3)\n",
|
481 |
+
"model"
|
482 |
+
]
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"cell_type": "markdown",
|
486 |
+
"id": "5b0ba056-0a73-466f-b65e-a3261e1a69f1",
|
487 |
+
"metadata": {},
|
488 |
+
"source": [
|
489 |
+
"## L1-regularization class"
|
490 |
+
]
|
491 |
+
},
|
492 |
+
{
|
493 |
+
"cell_type": "code",
|
494 |
+
"execution_count": null,
|
495 |
+
"id": "ebd6211c-fc94-4557-947b-5a3fac89c1ba",
|
496 |
+
"metadata": {},
|
497 |
+
"outputs": [],
|
498 |
+
"source": [
|
499 |
+
"def l1_regularization(model, lambda_l1):\n",
|
500 |
+
" l1_penalty = torch.tensor(0.) # Ensure the penalty is on the same device as model parameters\n",
|
501 |
+
" for param in model.parameters():\n",
|
502 |
+
" l1_penalty += torch.norm(param, 1)\n",
|
503 |
+
" return lambda_l1 * l1_penalty"
|
504 |
+
]
|
505 |
+
},
|
506 |
+
{
|
507 |
+
"cell_type": "markdown",
|
508 |
+
"id": "00735f1f-2bf9-4aae-90c2-61e44973f699",
|
509 |
+
"metadata": {},
|
510 |
+
"source": [
|
511 |
+
"## Loss, optimizer and L1-strength initialization"
|
512 |
+
]
|
513 |
+
},
|
514 |
+
{
|
515 |
+
"cell_type": "code",
|
516 |
+
"execution_count": null,
|
517 |
+
"id": "c4efe9d8-fc72-4701-a1a9-d463c6b33dfa",
|
518 |
+
"metadata": {},
|
519 |
+
"outputs": [],
|
520 |
+
"source": [
|
521 |
+
"# Loss and optimizer\n",
|
522 |
+
"criterion = nn.CrossEntropyLoss()\n",
|
523 |
+
"optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5) \n",
|
524 |
+
"lambda_l1 = 1e-3 # L1 regularization strength"
|
525 |
+
]
|
526 |
+
},
|
527 |
+
{
|
528 |
+
"cell_type": "markdown",
|
529 |
+
"id": "e87f7513-47d0-491e-9073-9289eda1b484",
|
530 |
+
"metadata": {},
|
531 |
+
"source": [
|
532 |
+
"## Training loop"
|
533 |
+
]
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"cell_type": "code",
|
537 |
+
"execution_count": null,
|
538 |
+
"id": "4260c3bc-25c2-48f0-b79c-b6d7cc0c14eb",
|
539 |
+
"metadata": {},
|
540 |
+
"outputs": [],
|
541 |
+
"source": [
|
542 |
+
"epochs = 50\n",
|
543 |
+
"train_losses, test_losses = [], []\n",
|
544 |
+
"\n",
|
545 |
+
"for epoch in range(epochs):\n",
|
546 |
+
" model.train()\n",
|
547 |
+
" train_loss = 0\n",
|
548 |
+
" for X_batch, y_batch in train_loader:\n",
|
549 |
+
" optimizer.zero_grad()\n",
|
550 |
+
" outputs = model(X_batch)\n",
|
551 |
+
" loss = criterion(outputs, y_batch)\n",
|
552 |
+
"\n",
|
553 |
+
" # Calculate L1 regularization penalty\n",
|
554 |
+
" l1_penalty = l1_regularization(model, lambda_l1)\n",
|
555 |
+
" \n",
|
556 |
+
" # Add L1 penalty to the loss\n",
|
557 |
+
" loss += l1_penalty\n",
|
558 |
+
" \n",
|
559 |
+
" loss.backward()\n",
|
560 |
+
" optimizer.step()\n",
|
561 |
+
" train_loss += loss.item()\n",
|
562 |
+
" train_losses.append(train_loss / len(train_loader))\n",
|
563 |
+
"\n",
|
564 |
+
" model.eval()\n",
|
565 |
+
" test_loss = 0\n",
|
566 |
+
" all_preds, all_targets, all_outputs = [], [], []\n",
|
567 |
+
" with torch.no_grad():\n",
|
568 |
+
" for X_batch, y_batch in test_loader:\n",
|
569 |
+
" outputs = model(X_batch)\n",
|
570 |
+
" loss = criterion(outputs, y_batch)\n",
|
571 |
+
" test_loss += loss.item()\n",
|
572 |
+
" _, predicted = torch.max(outputs.data, 1)\n",
|
573 |
+
" all_preds.extend(predicted.numpy())\n",
|
574 |
+
" all_targets.extend(y_batch.numpy())\n",
|
575 |
+
" all_outputs.extend(outputs.numpy())\n",
|
576 |
+
" test_losses.append(test_loss / len(test_loader))\n",
|
577 |
+
" \n",
|
578 |
+
" precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='weighted', zero_division=0)\n",
|
579 |
+
" accuracy = accuracy_score(all_targets, all_preds) # Compute accuracy\n",
|
580 |
+
" if epoch % 10==0:\n",
|
581 |
+
" print(f'Epoch {epoch+1}: Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}')"
|
582 |
+
]
|
583 |
+
},
|
584 |
+
{
|
585 |
+
"cell_type": "markdown",
|
586 |
+
"id": "615f685e-fb19-46f8-afba-b76fb730ed49",
|
587 |
+
"metadata": {},
|
588 |
+
"source": [
|
589 |
+
"## Train- vs Test-loss graph"
|
590 |
+
]
|
591 |
+
},
|
592 |
+
{
|
593 |
+
"cell_type": "code",
|
594 |
+
"execution_count": null,
|
595 |
+
"id": "597b4570-1579-470e-8f11-f72b7b04b816",
|
596 |
+
"metadata": {},
|
597 |
+
"outputs": [],
|
598 |
+
"source": [
|
599 |
+
"plt.plot(train_losses, label='Train Loss')\n",
|
600 |
+
"plt.plot(test_losses, label='Test Loss')\n",
|
601 |
+
"plt.legend()\n",
|
602 |
+
"plt.title('Train vs Test Loss')\n",
|
603 |
+
"plt.xlabel('Epoch')\n",
|
604 |
+
"plt.ylabel('Loss')\n",
|
605 |
+
"plt.show()"
|
606 |
+
]
|
607 |
+
},
|
608 |
+
{
|
609 |
+
"cell_type": "markdown",
|
610 |
+
"id": "1babe3bd-da5b-4f0d-9d83-9ca4d73922c5",
|
611 |
+
"metadata": {},
|
612 |
+
"source": [
|
613 |
+
"## Confusion matrix"
|
614 |
+
]
|
615 |
+
},
|
616 |
+
{
|
617 |
+
"cell_type": "code",
|
618 |
+
"execution_count": null,
|
619 |
+
"id": "2c0b0fa3-814e-474c-bbe1-31152305e17b",
|
620 |
+
"metadata": {},
|
621 |
+
"outputs": [],
|
622 |
+
"source": [
|
623 |
+
"conf_matrix = confusion_matrix(all_targets, all_preds)\n",
|
624 |
+
"labels = [\"background\", \"tackle-live\", \"tackle-replay\",]\n",
|
625 |
+
" # \"tackle-live-incomplete\", \"tackle-replay-incomplete\"]\n",
|
626 |
+
"sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)\n",
|
627 |
+
"# plt.title('Confusion Matrix')\n",
|
628 |
+
"plt.xlabel('Predicted Label')\n",
|
629 |
+
"plt.ylabel('True Label')\n",
|
630 |
+
"plt.show()"
|
631 |
+
]
|
632 |
+
},
|
633 |
+
{
|
634 |
+
"cell_type": "markdown",
|
635 |
+
"id": "480ddfd5-6ac4-46ed-92db-b556c8bfbd7d",
|
636 |
+
"metadata": {},
|
637 |
+
"source": [
|
638 |
+
"## ROC Curve"
|
639 |
+
]
|
640 |
+
},
|
641 |
+
{
|
642 |
+
"cell_type": "code",
|
643 |
+
"execution_count": null,
|
644 |
+
"id": "ddc52d39-7612-43ad-ae44-345119122112",
|
645 |
+
"metadata": {},
|
646 |
+
"outputs": [],
|
647 |
+
"source": [
|
648 |
+
"from sklearn.metrics import roc_curve, auc\n",
|
649 |
+
"import matplotlib.pyplot as plt\n",
|
650 |
+
"\n",
|
651 |
+
"y_score= np.array(all_outputs)\n",
|
652 |
+
"fpr = dict()\n",
|
653 |
+
"tpr = dict()\n",
|
654 |
+
"roc_auc = dict()\n",
|
655 |
+
"n_classes = len(labels) \n",
|
656 |
+
"\n",
|
657 |
+
"y_test_one_hot = np.eye(n_classes)[y_test]\n",
|
658 |
+
"\n",
|
659 |
+
"for i in range(n_classes):\n",
|
660 |
+
" fpr[i], tpr[i], _ = roc_curve(y_test_one_hot[:, i], y_score[:, i])\n",
|
661 |
+
" roc_auc[i] = auc(fpr[i], tpr[i])\n",
|
662 |
+
"\n",
|
663 |
+
"# Plot all ROC curves\n",
|
664 |
+
"plt.figure()\n",
|
665 |
+
"colors = ['blue', 'red', 'green', 'darkorange', 'purple']\n",
|
666 |
+
"for i, color in zip(range(n_classes), colors):\n",
|
667 |
+
" plt.plot(fpr[i], tpr[i], color=color, lw=2,\n",
|
668 |
+
" label='ROC curve of class {0} (area = {1:0.2f})'\n",
|
669 |
+
" ''.format(labels[i], roc_auc[i]))\n",
|
670 |
+
"\n",
|
671 |
+
"plt.plot([0, 1], [0, 1], 'k--', lw=2)\n",
|
672 |
+
"plt.xlim([0.0, 1.0])\n",
|
673 |
+
"plt.ylim([0.0, 1.05])\n",
|
674 |
+
"plt.xlabel('False Positive Rate')\n",
|
675 |
+
"plt.ylabel('True Positive Rate')\n",
|
676 |
+
"print('Receiver operating characteristic for multi-class')\n",
|
677 |
+
"plt.legend(loc=\"lower right\")\n",
|
678 |
+
"plt.show()\n"
|
679 |
+
]
|
680 |
+
},
|
681 |
+
{
|
682 |
+
"cell_type": "markdown",
|
683 |
+
"id": "45c05c14-99d8-49e6-ad64-7e6ad565c0ca",
|
684 |
+
"metadata": {},
|
685 |
+
"source": [
|
686 |
+
"## Multi-Class Precision-Recall Cruve"
|
687 |
+
]
|
688 |
+
},
|
689 |
+
{
|
690 |
+
"cell_type": "code",
|
691 |
+
"execution_count": null,
|
692 |
+
"id": "3c779274-252f-4248-bf57-a07c665c618c",
|
693 |
+
"metadata": {},
|
694 |
+
"outputs": [],
|
695 |
+
"source": [
|
696 |
+
"from sklearn.metrics import precision_recall_curve\n",
|
697 |
+
"from sklearn.preprocessing import label_binarize\n",
|
698 |
+
"from itertools import cycle\n",
|
699 |
+
"\n",
|
700 |
+
"y_test_bin = label_binarize(y_test, classes=range(n_classes))\n",
|
701 |
+
"\n",
|
702 |
+
"precision_recall = {}\n",
|
703 |
+
"\n",
|
704 |
+
"for i in range(n_classes):\n",
|
705 |
+
" precision, recall, _ = precision_recall_curve(y_test_bin[:, i], y_score[:, i])\n",
|
706 |
+
" precision_recall[i] = (precision, recall)\n",
|
707 |
+
"\n",
|
708 |
+
"colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal'])\n",
|
709 |
+
"\n",
|
710 |
+
"plt.figure(figsize=(6, 4))\n",
|
711 |
+
"\n",
|
712 |
+
"for i, color in zip(range(n_classes), colors):\n",
|
713 |
+
" precision, recall = precision_recall[i]\n",
|
714 |
+
" plt.plot(recall, precision, color=color, lw=2, label=f'{labels[i]}')\n",
|
715 |
+
"\n",
|
716 |
+
"plt.xlabel('Recall')\n",
|
717 |
+
"plt.ylabel('Precision')\n",
|
718 |
+
"print('Multi-Class Precision-Recall Curve')\n",
|
719 |
+
"plt.legend(loc='best')\n",
|
720 |
+
"plt.show()"
|
721 |
+
]
|
722 |
+
}
|
723 |
+
],
|
724 |
+
"metadata": {
|
725 |
+
"kernelspec": {
|
726 |
+
"display_name": "Python (evan31818)",
|
727 |
+
"language": "python",
|
728 |
+
"name": "evan31818"
|
729 |
+
},
|
730 |
+
"language_info": {
|
731 |
+
"codemirror_mode": {
|
732 |
+
"name": "ipython",
|
733 |
+
"version": 3
|
734 |
+
},
|
735 |
+
"file_extension": ".py",
|
736 |
+
"mimetype": "text/x-python",
|
737 |
+
"name": "python",
|
738 |
+
"nbconvert_exporter": "python",
|
739 |
+
"pygments_lexer": "ipython3",
|
740 |
+
"version": "3.8.18"
|
741 |
+
}
|
742 |
+
},
|
743 |
+
"nbformat": 4,
|
744 |
+
"nbformat_minor": 5
|
745 |
+
}
|