prithivMLmods commited on
Commit
b0577e0
·
verified ·
1 Parent(s): d7f7353

Add files using upload-large-folder tool

Browse files
Builder Script/builder.script.trainner.ipynb ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "97b4efc3-1879-4441-af52-de470fbc3ae8",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "!pip install -q evaluate datasets accelerate\n",
11
+ "!pip install -q transformers\n",
12
+ "!pip install -q huggingface_hub"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "ae923886-86f3-431d-b701-1200110b429c",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "!pip install -q imbalanced-learn\n",
23
+ "#Skip the installation if your runtime is in Google Colab notebooks."
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "id": "126923c7-d53f-42d8-8f06-2ea05609ab0e",
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "!pip install -q numpy\n",
34
+ "#Skip the installation if your runtime is in Google Colab notebooks."
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "id": "9e628805-b90b-4b98-ae97-9f8a8142767f",
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "!pip install -q pillow==11.0.0\n",
45
+ "#Skip the installation if your runtime is in Google Colab notebooks."
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "id": "b58fab4c-211f-4b7b-b7c4-dd76e20c1beb",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "!pip install -q torchvision \n",
56
+ "#Skip the installation if your runtime is in Google Colab notebooks."
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "id": "d7454ffa-885e-44ba-8259-d8c45f8ec72b",
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "!pip install -q matplotlib\n",
67
+ "!pip install -q scikit-learn\n",
68
+ "#Skip the installation if your runtime is in Google Colab notebooks."
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "id": "4987ed31-c012-434b-9ea7-78da17061d5d",
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "import warnings\n",
79
+ "warnings.filterwarnings(\"ignore\")\n",
80
+ "\n",
81
+ "import gc\n",
82
+ "import numpy as np\n",
83
+ "import pandas as pd\n",
84
+ "import itertools\n",
85
+ "from collections import Counter\n",
86
+ "import matplotlib.pyplot as plt\n",
87
+ "from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, f1_score\n",
88
+ "from imblearn.over_sampling import RandomOverSampler\n",
89
+ "import evaluate\n",
90
+ "from datasets import Dataset, Image, ClassLabel\n",
91
+ "from transformers import (\n",
92
+ " TrainingArguments,\n",
93
+ " Trainer,\n",
94
+ " ViTImageProcessor,\n",
95
+ " ViTForImageClassification,\n",
96
+ " DefaultDataCollator\n",
97
+ ")\n",
98
+ "import torch\n",
99
+ "from torch.utils.data import DataLoader\n",
100
+ "from torchvision.transforms import (\n",
101
+ " CenterCrop,\n",
102
+ " Compose,\n",
103
+ " Normalize,\n",
104
+ " RandomRotation,\n",
105
+ " RandomResizedCrop,\n",
106
+ " RandomHorizontalFlip,\n",
107
+ " RandomAdjustSharpness,\n",
108
+ " Resize,\n",
109
+ " ToTensor\n",
110
+ ")\n",
111
+ "\n",
112
+ "#.......................................................................\n",
113
+ "\n",
114
+ "#Retain this part if you're working outside Google Colab notebooks.\n",
115
+ "from PIL import Image, ExifTags\n",
116
+ "\n",
117
+ "#.......................................................................\n",
118
+ "\n",
119
+ "from PIL import Image as PILImage\n",
120
+ "from PIL import ImageFile\n",
121
+ "# Enable loading truncated images\n",
122
+ "ImageFile.LOAD_TRUNCATED_IMAGES = True"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "id": "236bc802-54ba-44d1-b35b-62f548832935",
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "from datasets import load_dataset\n",
133
+ "dataset = load_dataset(\"--your--dataset--goes--here--\", split=\"train\")"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "id": "d57e17cc-72b2-4fde-9855-751cf3440624",
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "from pathlib import Path\n",
144
+ "\n",
145
+ "file_names = []\n",
146
+ "labels = []\n",
147
+ "\n",
148
+ "for example in dataset:\n",
149
+ " file_path = str(example['image']) \n",
150
+ " label = example['label'] \n",
151
+ "\n",
152
+ " file_names.append(file_path) \n",
153
+ " labels.append(label) \n",
154
+ "\n",
155
+ "print(len(file_names), len(labels))"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "id": "e52c85d2-a245-47c5-9403-5a9cf4e4269d",
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "df = pd.DataFrame.from_dict({\"image\": file_names, \"label\": labels})\n",
166
+ "print(df.shape)"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "id": "beba86dd-0605-4ebf-8ebb-97d6ad9e5edd",
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "df.head()\n",
177
+ "df['label'].unique()"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "id": "6defc1e9-4f46-49b6-addc-f422c38fe7e8",
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": [
187
+ "y = df[['label']]\n",
188
+ "df = df.drop(['label'], axis=1)\n",
189
+ "ros = RandomOverSampler(random_state=83)\n",
190
+ "df, y_resampled = ros.fit_resample(df, y)\n",
191
+ "del y\n",
192
+ "df['label'] = y_resampled\n",
193
+ "del y_resampled\n",
194
+ "gc.collect()"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "id": "129d278c-3899-49d2-b06f-a0b2f22f4c4e",
201
+ "metadata": {},
202
+ "outputs": [],
203
+ "source": [
204
+ "dataset[0][\"image\"]\n",
205
+ "dataset[99][\"image\"]"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": null,
211
+ "id": "bffc8755-c4ac-41be-b8ab-f9a6e0dbcca3",
212
+ "metadata": {},
213
+ "outputs": [],
214
+ "source": [
215
+ "labels_subset = labels[:5]\n",
216
+ "print(labels_subset)"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": null,
222
+ "id": "d003f439-09d1-41e6-9f34-213c4ee38593",
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": [
226
+ "labels_list = ['Issue In Deepfake', 'High Quality Deepfake']\n",
227
+ "\n",
228
+ "label2id, id2label = {}, {}\n",
229
+ "for i, label in enumerate(labels_list):\n",
230
+ " label2id[label] = i\n",
231
+ " id2label[i] = label\n",
232
+ "\n",
233
+ "ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)\n",
234
+ "\n",
235
+ "print(\"Mapping of IDs to Labels:\", id2label, '\\n')\n",
236
+ "print(\"Mapping of Labels to IDs:\", label2id)"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "id": "2fbf1f1b-5936-48be-bc99-6897fea94794",
243
+ "metadata": {},
244
+ "outputs": [],
245
+ "source": [
246
+ "def map_label2id(example):\n",
247
+ " example['label'] = ClassLabels.str2int(example['label'])\n",
248
+ " return example\n",
249
+ "\n",
250
+ "dataset = dataset.map(map_label2id, batched=True)\n",
251
+ "\n",
252
+ "dataset = dataset.cast_column('label', ClassLabels)\n",
253
+ "\n",
254
+ "dataset = dataset.train_test_split(test_size=0.4, shuffle=True, stratify_by_column=\"label\")\n",
255
+ "\n",
256
+ "train_data = dataset['train']\n",
257
+ "\n",
258
+ "test_data = dataset['test']"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": null,
264
+ "id": "d8a4f7ca-4dff-4446-acaf-f3e7630b678d",
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": [
268
+ "model_str = \"google/vit-base-patch16-224-in21k\"\n",
269
+ "processor = ViTImageProcessor.from_pretrained(model_str)\n",
270
+ "\n",
271
+ "image_mean, image_std = processor.image_mean, processor.image_std\n",
272
+ "size = processor.size[\"height\"]\n",
273
+ "\n",
274
+ "_train_transforms = Compose(\n",
275
+ " [\n",
276
+ " Resize((size, size)),\n",
277
+ " RandomRotation(90),\n",
278
+ " RandomAdjustSharpness(2),\n",
279
+ " ToTensor(),\n",
280
+ " Normalize(mean=image_mean, std=image_std)\n",
281
+ " ]\n",
282
+ ")\n",
283
+ "\n",
284
+ "_val_transforms = Compose(\n",
285
+ " [\n",
286
+ " Resize((size, size)),\n",
287
+ " ToTensor(),\n",
288
+ " Normalize(mean=image_mean, std=image_std)\n",
289
+ " ]\n",
290
+ ")\n",
291
+ "\n",
292
+ "def train_transforms(examples):\n",
293
+ " examples['pixel_values'] = [_train_transforms(image.convert(\"RGB\")) for image in examples['image']]\n",
294
+ " return examples\n",
295
+ "\n",
296
+ "def val_transforms(examples):\n",
297
+ " examples['pixel_values'] = [_val_transforms(image.convert(\"RGB\")) for image in examples['image']]\n",
298
+ " return examples\n",
299
+ "\n",
300
+ "train_data.set_transform(train_transforms)\n",
301
+ "test_data.set_transform(val_transforms)"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": null,
307
+ "id": "0c8a93ca-e4ff-42e2-b58d-445afa0cfee0",
308
+ "metadata": {},
309
+ "outputs": [],
310
+ "source": [
311
+ "def collate_fn(examples):\n",
312
+ " pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n",
313
+ " labels = torch.tensor([example['label'] for example in examples])\n",
314
+ " return {\"pixel_values\": pixel_values, \"labels\": labels}"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "id": "11e0c254-ebb1-4100-a389-9e661d0810ff",
321
+ "metadata": {},
322
+ "outputs": [],
323
+ "source": [
324
+ "model = ViTForImageClassification.from_pretrained(model_str, num_labels=len(labels_list))\n",
325
+ "model.config.id2label = id2label\n",
326
+ "model.config.label2id = label2id\n",
327
+ "\n",
328
+ "print(model.num_parameters(only_trainable=True) / 1e6)"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": null,
334
+ "id": "bea51959-9abc-4afc-aee6-0e774f8db9c2",
335
+ "metadata": {},
336
+ "outputs": [],
337
+ "source": [
338
+ "accuracy = evaluate.load(\"accuracy\")\n",
339
+ "\n",
340
+ "def compute_metrics(eval_pred):\n",
341
+ " predictions = eval_pred.predictions\n",
342
+ " label_ids = eval_pred.label_ids\n",
343
+ "\n",
344
+ " predicted_labels = predictions.argmax(axis=1)\n",
345
+ " acc_score = accuracy.compute(predictions=predicted_labels, references=label_ids)['accuracy']\n",
346
+ " \n",
347
+ " return {\n",
348
+ " \"accuracy\": acc_score\n",
349
+ " }"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": null,
355
+ "id": "d5ea0bbc-51a3-4b98-823e-10819ffda292",
356
+ "metadata": {},
357
+ "outputs": [],
358
+ "source": [
359
+ "args = TrainingArguments(\n",
360
+ " output_dir=\"deepfake_vit\",\n",
361
+ " logging_dir='./logs',\n",
362
+ " evaluation_strategy=\"epoch\",\n",
363
+ " learning_rate=2e-5,\n",
364
+ " per_device_train_batch_size=32,\n",
365
+ " per_device_eval_batch_size=8,\n",
366
+ " num_train_epochs=4,\n",
367
+ " weight_decay=0.02,\n",
368
+ " warmup_steps=50,\n",
369
+ " remove_unused_columns=False,\n",
370
+ " save_strategy='epoch',\n",
371
+ " load_best_model_at_end=True,\n",
372
+ " save_total_limit=1,\n",
373
+ " report_to=\"none\"\n",
374
+ ")"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": null,
380
+ "id": "0a965131-c670-43b1-a153-c1a4df611189",
381
+ "metadata": {},
382
+ "outputs": [],
383
+ "source": [
384
+ "trainer = Trainer(\n",
385
+ " model,\n",
386
+ " args,\n",
387
+ " train_dataset=train_data,\n",
388
+ " eval_dataset=test_data,\n",
389
+ " data_collator=collate_fn,\n",
390
+ " compute_metrics=compute_metrics,\n",
391
+ " tokenizer=processor,\n",
392
+ ")"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "execution_count": null,
398
+ "id": "ad42ea98-86d6-420e-befe-2ef77eadd76d",
399
+ "metadata": {},
400
+ "outputs": [],
401
+ "source": [
402
+ "trainer.evaluate()"
403
+ ]
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "id": "df43c341-0e55-41ef-a274-731c88b9b5d5",
409
+ "metadata": {},
410
+ "outputs": [],
411
+ "source": [
412
+ "trainer.train()"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": null,
418
+ "id": "28866dda",
419
+ "metadata": {},
420
+ "outputs": [],
421
+ "source": [
422
+ "trainer.evaluate()"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "id": "0ec258d9",
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "outputs = trainer.predict(test_data)\n",
433
+ "print(outputs.metrics)"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": null,
439
+ "id": "c12a6b10",
440
+ "metadata": {},
441
+ "outputs": [],
442
+ "source": [
443
+ "y_true = outputs.label_ids\n",
444
+ "y_pred = outputs.predictions.argmax(1)\n",
445
+ "\n",
446
+ "def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Blues, figsize=(10, 8)):\n",
447
+ " \n",
448
+ " plt.figure(figsize=figsize)\n",
449
+ "\n",
450
+ " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
451
+ " plt.title(title)\n",
452
+ " plt.colorbar()\n",
453
+ "\n",
454
+ " tick_marks = np.arange(len(classes))\n",
455
+ " plt.xticks(tick_marks, classes, rotation=90)\n",
456
+ " plt.yticks(tick_marks, classes)\n",
457
+ "\n",
458
+ " fmt = '.0f'\n",
459
+ " thresh = cm.max() / 2.0\n",
460
+ " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
461
+ " plt.text(j, i, format(cm[i, j], fmt), horizontalalignment=\"center\", color=\"white\" if cm[i, j] > thresh else \"black\")\n",
462
+ "\n",
463
+ " plt.ylabel('True label')\n",
464
+ " plt.xlabel('Predicted label')\n",
465
+ " plt.tight_layout()\n",
466
+ " plt.show()\n",
467
+ "\n",
468
+ "accuracy = accuracy_score(y_true, y_pred)\n",
469
+ "f1 = f1_score(y_true, y_pred, average='macro')\n",
470
+ "\n",
471
+ "print(f\"Accuracy: {accuracy:.4f}\")\n",
472
+ "print(f\"F1 Score: {f1:.4f}\")\n",
473
+ "\n",
474
+ "if len(labels_list) <= 150:\n",
475
+ " cm = confusion_matrix(y_true, y_pred)\n",
476
+ " plot_confusion_matrix(cm, labels_list, figsize=(8, 6))\n",
477
+ "\n",
478
+ "print()\n",
479
+ "print(\"Classification report:\")\n",
480
+ "print()\n",
481
+ "print(classification_report(y_true, y_pred, target_names=labels_list, digits=4))"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": null,
487
+ "id": "9889438c",
488
+ "metadata": {},
489
+ "outputs": [],
490
+ "source": [
491
+ "trainer.save_model()"
492
+ ]
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "execution_count": null,
497
+ "id": "688e3d62",
498
+ "metadata": {},
499
+ "outputs": [],
500
+ "source": [
501
+ "#upload to hub\n",
502
+ "from huggingface_hub import notebook_login\n",
503
+ "notebook_login()"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "code",
508
+ "execution_count": null,
509
+ "id": "fad56df2",
510
+ "metadata": {},
511
+ "outputs": [],
512
+ "source": [
513
+ "from huggingface_hub import HfApi\n",
514
+ "\n",
515
+ "api = HfApi()\n",
516
+ "repo_id = f\"prithivMLmods/deepfake_vit\"\n",
517
+ "\n",
518
+ "try:\n",
519
+ " api.create_repo(repo_id)\n",
520
+ " print(f\"Repo {repo_id} created\")\n",
521
+ "\n",
522
+ "except:\n",
523
+ " \n",
524
+ " print(f\"Repo {repo_id} already exists\")"
525
+ ]
526
+ },
527
+ {
528
+ "cell_type": "code",
529
+ "execution_count": null,
530
+ "id": "f5e1559f",
531
+ "metadata": {},
532
+ "outputs": [],
533
+ "source": [
534
+ "api.upload_folder(\n",
535
+ " folder_path=\"deepfake_vit\", \n",
536
+ " path_in_repo=\".\", \n",
537
+ " repo_id=repo_id, \n",
538
+ " repo_type=\"model\", \n",
539
+ " revision=\"main\"\n",
540
+ ")"
541
+ ]
542
+ }
543
+ ],
544
+ "metadata": {
545
+ "kernelspec": {
546
+ "display_name": "Python 3",
547
+ "language": "python",
548
+ "name": "python3"
549
+ },
550
+ "language_info": {
551
+ "codemirror_mode": {
552
+ "name": "ipython",
553
+ "version": 3
554
+ },
555
+ "file_extension": ".py",
556
+ "mimetype": "text/x-python",
557
+ "name": "python",
558
+ "nbconvert_exporter": "python",
559
+ "pygments_lexer": "ipython3",
560
+ "version": "3.12.7"
561
+ }
562
+ },
563
+ "nbformat": 4,
564
+ "nbformat_minor": 5
565
+ }