Add files using upload-large-folder tool
Browse files
Builder Script/builder.script.trainner.ipynb
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "97b4efc3-1879-4441-af52-de470fbc3ae8",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"!pip install -q evaluate datasets accelerate\n",
|
11 |
+
"!pip install -q transformers\n",
|
12 |
+
"!pip install -q huggingface_hub"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": null,
|
18 |
+
"id": "ae923886-86f3-431d-b701-1200110b429c",
|
19 |
+
"metadata": {},
|
20 |
+
"outputs": [],
|
21 |
+
"source": [
|
22 |
+
"!pip install -q imbalanced-learn\n",
|
23 |
+
"#Skip the installation if your runtime is in Google Colab notebooks."
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": null,
|
29 |
+
"id": "126923c7-d53f-42d8-8f06-2ea05609ab0e",
|
30 |
+
"metadata": {},
|
31 |
+
"outputs": [],
|
32 |
+
"source": [
|
33 |
+
"!pip install -q numpy\n",
|
34 |
+
"#Skip the installation if your runtime is in Google Colab notebooks."
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": null,
|
40 |
+
"id": "9e628805-b90b-4b98-ae97-9f8a8142767f",
|
41 |
+
"metadata": {},
|
42 |
+
"outputs": [],
|
43 |
+
"source": [
|
44 |
+
"!pip install -q pillow==11.0.0\n",
|
45 |
+
"#Skip the installation if your runtime is in Google Colab notebooks."
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": null,
|
51 |
+
"id": "b58fab4c-211f-4b7b-b7c4-dd76e20c1beb",
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"!pip install -q torchvision \n",
|
56 |
+
"#Skip the installation if your runtime is in Google Colab notebooks."
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": null,
|
62 |
+
"id": "d7454ffa-885e-44ba-8259-d8c45f8ec72b",
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [],
|
65 |
+
"source": [
|
66 |
+
"!pip install -q matplotlib\n",
|
67 |
+
"!pip install -q scikit-learn\n",
|
68 |
+
"#Skip the installation if your runtime is in Google Colab notebooks."
|
69 |
+
]
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "code",
|
73 |
+
"execution_count": null,
|
74 |
+
"id": "4987ed31-c012-434b-9ea7-78da17061d5d",
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"import warnings\n",
|
79 |
+
"warnings.filterwarnings(\"ignore\")\n",
|
80 |
+
"\n",
|
81 |
+
"import gc\n",
|
82 |
+
"import numpy as np\n",
|
83 |
+
"import pandas as pd\n",
|
84 |
+
"import itertools\n",
|
85 |
+
"from collections import Counter\n",
|
86 |
+
"import matplotlib.pyplot as plt\n",
|
87 |
+
"from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, f1_score\n",
|
88 |
+
"from imblearn.over_sampling import RandomOverSampler\n",
|
89 |
+
"import evaluate\n",
|
90 |
+
"from datasets import Dataset, Image, ClassLabel\n",
|
91 |
+
"from transformers import (\n",
|
92 |
+
" TrainingArguments,\n",
|
93 |
+
" Trainer,\n",
|
94 |
+
" ViTImageProcessor,\n",
|
95 |
+
" ViTForImageClassification,\n",
|
96 |
+
" DefaultDataCollator\n",
|
97 |
+
")\n",
|
98 |
+
"import torch\n",
|
99 |
+
"from torch.utils.data import DataLoader\n",
|
100 |
+
"from torchvision.transforms import (\n",
|
101 |
+
" CenterCrop,\n",
|
102 |
+
" Compose,\n",
|
103 |
+
" Normalize,\n",
|
104 |
+
" RandomRotation,\n",
|
105 |
+
" RandomResizedCrop,\n",
|
106 |
+
" RandomHorizontalFlip,\n",
|
107 |
+
" RandomAdjustSharpness,\n",
|
108 |
+
" Resize,\n",
|
109 |
+
" ToTensor\n",
|
110 |
+
")\n",
|
111 |
+
"\n",
|
112 |
+
"#.......................................................................\n",
|
113 |
+
"\n",
|
114 |
+
"#Retain this part if you're working outside Google Colab notebooks.\n",
|
115 |
+
"from PIL import Image, ExifTags\n",
|
116 |
+
"\n",
|
117 |
+
"#.......................................................................\n",
|
118 |
+
"\n",
|
119 |
+
"from PIL import Image as PILImage\n",
|
120 |
+
"from PIL import ImageFile\n",
|
121 |
+
"# Enable loading truncated images\n",
|
122 |
+
"ImageFile.LOAD_TRUNCATED_IMAGES = True"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": null,
|
128 |
+
"id": "236bc802-54ba-44d1-b35b-62f548832935",
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [],
|
131 |
+
"source": [
|
132 |
+
"from datasets import load_dataset\n",
|
133 |
+
"dataset = load_dataset(\"--your--dataset--goes--here--\", split=\"train\")"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"cell_type": "code",
|
138 |
+
"execution_count": null,
|
139 |
+
"id": "d57e17cc-72b2-4fde-9855-751cf3440624",
|
140 |
+
"metadata": {},
|
141 |
+
"outputs": [],
|
142 |
+
"source": [
|
143 |
+
"from pathlib import Path\n",
|
144 |
+
"\n",
|
145 |
+
"file_names = []\n",
|
146 |
+
"labels = []\n",
|
147 |
+
"\n",
|
148 |
+
"for example in dataset:\n",
|
149 |
+
" file_path = str(example['image']) \n",
|
150 |
+
" label = example['label'] \n",
|
151 |
+
"\n",
|
152 |
+
" file_names.append(file_path) \n",
|
153 |
+
" labels.append(label) \n",
|
154 |
+
"\n",
|
155 |
+
"print(len(file_names), len(labels))"
|
156 |
+
]
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "code",
|
160 |
+
"execution_count": null,
|
161 |
+
"id": "e52c85d2-a245-47c5-9403-5a9cf4e4269d",
|
162 |
+
"metadata": {},
|
163 |
+
"outputs": [],
|
164 |
+
"source": [
|
165 |
+
"df = pd.DataFrame.from_dict({\"image\": file_names, \"label\": labels})\n",
|
166 |
+
"print(df.shape)"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"cell_type": "code",
|
171 |
+
"execution_count": null,
|
172 |
+
"id": "beba86dd-0605-4ebf-8ebb-97d6ad9e5edd",
|
173 |
+
"metadata": {},
|
174 |
+
"outputs": [],
|
175 |
+
"source": [
|
176 |
+
"df.head()\n",
|
177 |
+
"df['label'].unique()"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "code",
|
182 |
+
"execution_count": null,
|
183 |
+
"id": "6defc1e9-4f46-49b6-addc-f422c38fe7e8",
|
184 |
+
"metadata": {},
|
185 |
+
"outputs": [],
|
186 |
+
"source": [
|
187 |
+
"y = df[['label']]\n",
|
188 |
+
"df = df.drop(['label'], axis=1)\n",
|
189 |
+
"ros = RandomOverSampler(random_state=83)\n",
|
190 |
+
"df, y_resampled = ros.fit_resample(df, y)\n",
|
191 |
+
"del y\n",
|
192 |
+
"df['label'] = y_resampled\n",
|
193 |
+
"del y_resampled\n",
|
194 |
+
"gc.collect()"
|
195 |
+
]
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"cell_type": "code",
|
199 |
+
"execution_count": null,
|
200 |
+
"id": "129d278c-3899-49d2-b06f-a0b2f22f4c4e",
|
201 |
+
"metadata": {},
|
202 |
+
"outputs": [],
|
203 |
+
"source": [
|
204 |
+
"dataset[0][\"image\"]\n",
|
205 |
+
"dataset[99][\"image\"]"
|
206 |
+
]
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"cell_type": "code",
|
210 |
+
"execution_count": null,
|
211 |
+
"id": "bffc8755-c4ac-41be-b8ab-f9a6e0dbcca3",
|
212 |
+
"metadata": {},
|
213 |
+
"outputs": [],
|
214 |
+
"source": [
|
215 |
+
"labels_subset = labels[:5]\n",
|
216 |
+
"print(labels_subset)"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"execution_count": null,
|
222 |
+
"id": "d003f439-09d1-41e6-9f34-213c4ee38593",
|
223 |
+
"metadata": {},
|
224 |
+
"outputs": [],
|
225 |
+
"source": [
|
226 |
+
"labels_list = ['Issue In Deepfake', 'High Quality Deepfake']\n",
|
227 |
+
"\n",
|
228 |
+
"label2id, id2label = {}, {}\n",
|
229 |
+
"for i, label in enumerate(labels_list):\n",
|
230 |
+
" label2id[label] = i\n",
|
231 |
+
" id2label[i] = label\n",
|
232 |
+
"\n",
|
233 |
+
"ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)\n",
|
234 |
+
"\n",
|
235 |
+
"print(\"Mapping of IDs to Labels:\", id2label, '\\n')\n",
|
236 |
+
"print(\"Mapping of Labels to IDs:\", label2id)"
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "code",
|
241 |
+
"execution_count": null,
|
242 |
+
"id": "2fbf1f1b-5936-48be-bc99-6897fea94794",
|
243 |
+
"metadata": {},
|
244 |
+
"outputs": [],
|
245 |
+
"source": [
|
246 |
+
"def map_label2id(example):\n",
|
247 |
+
" example['label'] = ClassLabels.str2int(example['label'])\n",
|
248 |
+
" return example\n",
|
249 |
+
"\n",
|
250 |
+
"dataset = dataset.map(map_label2id, batched=True)\n",
|
251 |
+
"\n",
|
252 |
+
"dataset = dataset.cast_column('label', ClassLabels)\n",
|
253 |
+
"\n",
|
254 |
+
"dataset = dataset.train_test_split(test_size=0.4, shuffle=True, stratify_by_column=\"label\")\n",
|
255 |
+
"\n",
|
256 |
+
"train_data = dataset['train']\n",
|
257 |
+
"\n",
|
258 |
+
"test_data = dataset['test']"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"cell_type": "code",
|
263 |
+
"execution_count": null,
|
264 |
+
"id": "d8a4f7ca-4dff-4446-acaf-f3e7630b678d",
|
265 |
+
"metadata": {},
|
266 |
+
"outputs": [],
|
267 |
+
"source": [
|
268 |
+
"model_str = \"google/vit-base-patch16-224-in21k\"\n",
|
269 |
+
"processor = ViTImageProcessor.from_pretrained(model_str)\n",
|
270 |
+
"\n",
|
271 |
+
"image_mean, image_std = processor.image_mean, processor.image_std\n",
|
272 |
+
"size = processor.size[\"height\"]\n",
|
273 |
+
"\n",
|
274 |
+
"_train_transforms = Compose(\n",
|
275 |
+
" [\n",
|
276 |
+
" Resize((size, size)),\n",
|
277 |
+
" RandomRotation(90),\n",
|
278 |
+
" RandomAdjustSharpness(2),\n",
|
279 |
+
" ToTensor(),\n",
|
280 |
+
" Normalize(mean=image_mean, std=image_std)\n",
|
281 |
+
" ]\n",
|
282 |
+
")\n",
|
283 |
+
"\n",
|
284 |
+
"_val_transforms = Compose(\n",
|
285 |
+
" [\n",
|
286 |
+
" Resize((size, size)),\n",
|
287 |
+
" ToTensor(),\n",
|
288 |
+
" Normalize(mean=image_mean, std=image_std)\n",
|
289 |
+
" ]\n",
|
290 |
+
")\n",
|
291 |
+
"\n",
|
292 |
+
"def train_transforms(examples):\n",
|
293 |
+
" examples['pixel_values'] = [_train_transforms(image.convert(\"RGB\")) for image in examples['image']]\n",
|
294 |
+
" return examples\n",
|
295 |
+
"\n",
|
296 |
+
"def val_transforms(examples):\n",
|
297 |
+
" examples['pixel_values'] = [_val_transforms(image.convert(\"RGB\")) for image in examples['image']]\n",
|
298 |
+
" return examples\n",
|
299 |
+
"\n",
|
300 |
+
"train_data.set_transform(train_transforms)\n",
|
301 |
+
"test_data.set_transform(val_transforms)"
|
302 |
+
]
|
303 |
+
},
|
304 |
+
{
|
305 |
+
"cell_type": "code",
|
306 |
+
"execution_count": null,
|
307 |
+
"id": "0c8a93ca-e4ff-42e2-b58d-445afa0cfee0",
|
308 |
+
"metadata": {},
|
309 |
+
"outputs": [],
|
310 |
+
"source": [
|
311 |
+
"def collate_fn(examples):\n",
|
312 |
+
" pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n",
|
313 |
+
" labels = torch.tensor([example['label'] for example in examples])\n",
|
314 |
+
" return {\"pixel_values\": pixel_values, \"labels\": labels}"
|
315 |
+
]
|
316 |
+
},
|
317 |
+
{
|
318 |
+
"cell_type": "code",
|
319 |
+
"execution_count": null,
|
320 |
+
"id": "11e0c254-ebb1-4100-a389-9e661d0810ff",
|
321 |
+
"metadata": {},
|
322 |
+
"outputs": [],
|
323 |
+
"source": [
|
324 |
+
"model = ViTForImageClassification.from_pretrained(model_str, num_labels=len(labels_list))\n",
|
325 |
+
"model.config.id2label = id2label\n",
|
326 |
+
"model.config.label2id = label2id\n",
|
327 |
+
"\n",
|
328 |
+
"print(model.num_parameters(only_trainable=True) / 1e6)"
|
329 |
+
]
|
330 |
+
},
|
331 |
+
{
|
332 |
+
"cell_type": "code",
|
333 |
+
"execution_count": null,
|
334 |
+
"id": "bea51959-9abc-4afc-aee6-0e774f8db9c2",
|
335 |
+
"metadata": {},
|
336 |
+
"outputs": [],
|
337 |
+
"source": [
|
338 |
+
"accuracy = evaluate.load(\"accuracy\")\n",
|
339 |
+
"\n",
|
340 |
+
"def compute_metrics(eval_pred):\n",
|
341 |
+
" predictions = eval_pred.predictions\n",
|
342 |
+
" label_ids = eval_pred.label_ids\n",
|
343 |
+
"\n",
|
344 |
+
" predicted_labels = predictions.argmax(axis=1)\n",
|
345 |
+
" acc_score = accuracy.compute(predictions=predicted_labels, references=label_ids)['accuracy']\n",
|
346 |
+
" \n",
|
347 |
+
" return {\n",
|
348 |
+
" \"accuracy\": acc_score\n",
|
349 |
+
" }"
|
350 |
+
]
|
351 |
+
},
|
352 |
+
{
|
353 |
+
"cell_type": "code",
|
354 |
+
"execution_count": null,
|
355 |
+
"id": "d5ea0bbc-51a3-4b98-823e-10819ffda292",
|
356 |
+
"metadata": {},
|
357 |
+
"outputs": [],
|
358 |
+
"source": [
|
359 |
+
"args = TrainingArguments(\n",
|
360 |
+
" output_dir=\"deepfake_vit\",\n",
|
361 |
+
" logging_dir='./logs',\n",
|
362 |
+
" evaluation_strategy=\"epoch\",\n",
|
363 |
+
" learning_rate=2e-5,\n",
|
364 |
+
" per_device_train_batch_size=32,\n",
|
365 |
+
" per_device_eval_batch_size=8,\n",
|
366 |
+
" num_train_epochs=4,\n",
|
367 |
+
" weight_decay=0.02,\n",
|
368 |
+
" warmup_steps=50,\n",
|
369 |
+
" remove_unused_columns=False,\n",
|
370 |
+
" save_strategy='epoch',\n",
|
371 |
+
" load_best_model_at_end=True,\n",
|
372 |
+
" save_total_limit=1,\n",
|
373 |
+
" report_to=\"none\"\n",
|
374 |
+
")"
|
375 |
+
]
|
376 |
+
},
|
377 |
+
{
|
378 |
+
"cell_type": "code",
|
379 |
+
"execution_count": null,
|
380 |
+
"id": "0a965131-c670-43b1-a153-c1a4df611189",
|
381 |
+
"metadata": {},
|
382 |
+
"outputs": [],
|
383 |
+
"source": [
|
384 |
+
"trainer = Trainer(\n",
|
385 |
+
" model,\n",
|
386 |
+
" args,\n",
|
387 |
+
" train_dataset=train_data,\n",
|
388 |
+
" eval_dataset=test_data,\n",
|
389 |
+
" data_collator=collate_fn,\n",
|
390 |
+
" compute_metrics=compute_metrics,\n",
|
391 |
+
" tokenizer=processor,\n",
|
392 |
+
")"
|
393 |
+
]
|
394 |
+
},
|
395 |
+
{
|
396 |
+
"cell_type": "code",
|
397 |
+
"execution_count": null,
|
398 |
+
"id": "ad42ea98-86d6-420e-befe-2ef77eadd76d",
|
399 |
+
"metadata": {},
|
400 |
+
"outputs": [],
|
401 |
+
"source": [
|
402 |
+
"trainer.evaluate()"
|
403 |
+
]
|
404 |
+
},
|
405 |
+
{
|
406 |
+
"cell_type": "code",
|
407 |
+
"execution_count": null,
|
408 |
+
"id": "df43c341-0e55-41ef-a274-731c88b9b5d5",
|
409 |
+
"metadata": {},
|
410 |
+
"outputs": [],
|
411 |
+
"source": [
|
412 |
+
"trainer.train()"
|
413 |
+
]
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"cell_type": "code",
|
417 |
+
"execution_count": null,
|
418 |
+
"id": "28866dda",
|
419 |
+
"metadata": {},
|
420 |
+
"outputs": [],
|
421 |
+
"source": [
|
422 |
+
"trainer.evaluate()"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"cell_type": "code",
|
427 |
+
"execution_count": null,
|
428 |
+
"id": "0ec258d9",
|
429 |
+
"metadata": {},
|
430 |
+
"outputs": [],
|
431 |
+
"source": [
|
432 |
+
"outputs = trainer.predict(test_data)\n",
|
433 |
+
"print(outputs.metrics)"
|
434 |
+
]
|
435 |
+
},
|
436 |
+
{
|
437 |
+
"cell_type": "code",
|
438 |
+
"execution_count": null,
|
439 |
+
"id": "c12a6b10",
|
440 |
+
"metadata": {},
|
441 |
+
"outputs": [],
|
442 |
+
"source": [
|
443 |
+
"y_true = outputs.label_ids\n",
|
444 |
+
"y_pred = outputs.predictions.argmax(1)\n",
|
445 |
+
"\n",
|
446 |
+
"def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Blues, figsize=(10, 8)):\n",
|
447 |
+
" \n",
|
448 |
+
" plt.figure(figsize=figsize)\n",
|
449 |
+
"\n",
|
450 |
+
" plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
|
451 |
+
" plt.title(title)\n",
|
452 |
+
" plt.colorbar()\n",
|
453 |
+
"\n",
|
454 |
+
" tick_marks = np.arange(len(classes))\n",
|
455 |
+
" plt.xticks(tick_marks, classes, rotation=90)\n",
|
456 |
+
" plt.yticks(tick_marks, classes)\n",
|
457 |
+
"\n",
|
458 |
+
" fmt = '.0f'\n",
|
459 |
+
" thresh = cm.max() / 2.0\n",
|
460 |
+
" for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
|
461 |
+
" plt.text(j, i, format(cm[i, j], fmt), horizontalalignment=\"center\", color=\"white\" if cm[i, j] > thresh else \"black\")\n",
|
462 |
+
"\n",
|
463 |
+
" plt.ylabel('True label')\n",
|
464 |
+
" plt.xlabel('Predicted label')\n",
|
465 |
+
" plt.tight_layout()\n",
|
466 |
+
" plt.show()\n",
|
467 |
+
"\n",
|
468 |
+
"accuracy = accuracy_score(y_true, y_pred)\n",
|
469 |
+
"f1 = f1_score(y_true, y_pred, average='macro')\n",
|
470 |
+
"\n",
|
471 |
+
"print(f\"Accuracy: {accuracy:.4f}\")\n",
|
472 |
+
"print(f\"F1 Score: {f1:.4f}\")\n",
|
473 |
+
"\n",
|
474 |
+
"if len(labels_list) <= 150:\n",
|
475 |
+
" cm = confusion_matrix(y_true, y_pred)\n",
|
476 |
+
" plot_confusion_matrix(cm, labels_list, figsize=(8, 6))\n",
|
477 |
+
"\n",
|
478 |
+
"print()\n",
|
479 |
+
"print(\"Classification report:\")\n",
|
480 |
+
"print()\n",
|
481 |
+
"print(classification_report(y_true, y_pred, target_names=labels_list, digits=4))"
|
482 |
+
]
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"cell_type": "code",
|
486 |
+
"execution_count": null,
|
487 |
+
"id": "9889438c",
|
488 |
+
"metadata": {},
|
489 |
+
"outputs": [],
|
490 |
+
"source": [
|
491 |
+
"trainer.save_model()"
|
492 |
+
]
|
493 |
+
},
|
494 |
+
{
|
495 |
+
"cell_type": "code",
|
496 |
+
"execution_count": null,
|
497 |
+
"id": "688e3d62",
|
498 |
+
"metadata": {},
|
499 |
+
"outputs": [],
|
500 |
+
"source": [
|
501 |
+
"#upload to hub\n",
|
502 |
+
"from huggingface_hub import notebook_login\n",
|
503 |
+
"notebook_login()"
|
504 |
+
]
|
505 |
+
},
|
506 |
+
{
|
507 |
+
"cell_type": "code",
|
508 |
+
"execution_count": null,
|
509 |
+
"id": "fad56df2",
|
510 |
+
"metadata": {},
|
511 |
+
"outputs": [],
|
512 |
+
"source": [
|
513 |
+
"from huggingface_hub import HfApi\n",
|
514 |
+
"\n",
|
515 |
+
"api = HfApi()\n",
|
516 |
+
"repo_id = f\"prithivMLmods/deepfake_vit\"\n",
|
517 |
+
"\n",
|
518 |
+
"try:\n",
|
519 |
+
" api.create_repo(repo_id)\n",
|
520 |
+
" print(f\"Repo {repo_id} created\")\n",
|
521 |
+
"\n",
|
522 |
+
"except:\n",
|
523 |
+
" \n",
|
524 |
+
" print(f\"Repo {repo_id} already exists\")"
|
525 |
+
]
|
526 |
+
},
|
527 |
+
{
|
528 |
+
"cell_type": "code",
|
529 |
+
"execution_count": null,
|
530 |
+
"id": "f5e1559f",
|
531 |
+
"metadata": {},
|
532 |
+
"outputs": [],
|
533 |
+
"source": [
|
534 |
+
"api.upload_folder(\n",
|
535 |
+
" folder_path=\"deepfake_vit\", \n",
|
536 |
+
" path_in_repo=\".\", \n",
|
537 |
+
" repo_id=repo_id, \n",
|
538 |
+
" repo_type=\"model\", \n",
|
539 |
+
" revision=\"main\"\n",
|
540 |
+
")"
|
541 |
+
]
|
542 |
+
}
|
543 |
+
],
|
544 |
+
"metadata": {
|
545 |
+
"kernelspec": {
|
546 |
+
"display_name": "Python 3",
|
547 |
+
"language": "python",
|
548 |
+
"name": "python3"
|
549 |
+
},
|
550 |
+
"language_info": {
|
551 |
+
"codemirror_mode": {
|
552 |
+
"name": "ipython",
|
553 |
+
"version": 3
|
554 |
+
},
|
555 |
+
"file_extension": ".py",
|
556 |
+
"mimetype": "text/x-python",
|
557 |
+
"name": "python",
|
558 |
+
"nbconvert_exporter": "python",
|
559 |
+
"pygments_lexer": "ipython3",
|
560 |
+
"version": "3.12.7"
|
561 |
+
}
|
562 |
+
},
|
563 |
+
"nbformat": 4,
|
564 |
+
"nbformat_minor": 5
|
565 |
+
}
|