{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "97b4efc3-1879-4441-af52-de470fbc3ae8", "metadata": {}, "outputs": [], "source": [ "!pip install -q evaluate datasets accelerate\n", "!pip install -q transformers\n", "!pip install -q huggingface_hub" ] }, { "cell_type": "code", "execution_count": null, "id": "ae923886-86f3-431d-b701-1200110b429c", "metadata": {}, "outputs": [], "source": [ "!pip install -q imbalanced-learn\n", "#Skip the installation if your runtime is in Google Colab notebooks." ] }, { "cell_type": "code", "execution_count": null, "id": "126923c7-d53f-42d8-8f06-2ea05609ab0e", "metadata": {}, "outputs": [], "source": [ "!pip install -q numpy\n", "#Skip the installation if your runtime is in Google Colab notebooks." ] }, { "cell_type": "code", "execution_count": null, "id": "9e628805-b90b-4b98-ae97-9f8a8142767f", "metadata": {}, "outputs": [], "source": [ "!pip install -q pillow==11.0.0\n", "#Skip the installation if your runtime is in Google Colab notebooks." ] }, { "cell_type": "code", "execution_count": null, "id": "b58fab4c-211f-4b7b-b7c4-dd76e20c1beb", "metadata": {}, "outputs": [], "source": [ "!pip install -q torchvision \n", "#Skip the installation if your runtime is in Google Colab notebooks." ] }, { "cell_type": "code", "execution_count": null, "id": "d7454ffa-885e-44ba-8259-d8c45f8ec72b", "metadata": {}, "outputs": [], "source": [ "!pip install -q matplotlib\n", "!pip install -q scikit-learn\n", "#Skip the installation if your runtime is in Google Colab notebooks." ] }, { "cell_type": "code", "execution_count": null, "id": "4987ed31-c012-434b-9ea7-78da17061d5d", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "import gc\n", "import numpy as np\n", "import pandas as pd\n", "import itertools\n", "from collections import Counter\n", "import matplotlib.pyplot as plt\n", "from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report, f1_score\n", "from imblearn.over_sampling import RandomOverSampler\n", "import evaluate\n", "from datasets import Dataset, Image, ClassLabel\n", "from transformers import (\n", " TrainingArguments,\n", " Trainer,\n", " ViTImageProcessor,\n", " ViTForImageClassification,\n", " DefaultDataCollator\n", ")\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from torchvision.transforms import (\n", " CenterCrop,\n", " Compose,\n", " Normalize,\n", " RandomRotation,\n", " RandomResizedCrop,\n", " RandomHorizontalFlip,\n", " RandomAdjustSharpness,\n", " Resize,\n", " ToTensor\n", ")\n", "\n", "#.......................................................................\n", "\n", "#Retain this part if you're working outside Google Colab notebooks.\n", "from PIL import Image, ExifTags\n", "\n", "#.......................................................................\n", "\n", "from PIL import Image as PILImage\n", "from PIL import ImageFile\n", "# Enable loading truncated images\n", "ImageFile.LOAD_TRUNCATED_IMAGES = True" ] }, { "cell_type": "code", "execution_count": null, "id": "236bc802-54ba-44d1-b35b-62f548832935", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "dataset = load_dataset(\"--your--dataset--goes--here--\", split=\"train\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d57e17cc-72b2-4fde-9855-751cf3440624", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "file_names = []\n", "labels = []\n", "\n", "for example in dataset:\n", " file_path = str(example['image']) \n", " label = example['label'] \n", "\n", " file_names.append(file_path) \n", " labels.append(label) \n", "\n", "print(len(file_names), len(labels))" ] }, { "cell_type": "code", "execution_count": null, "id": "e52c85d2-a245-47c5-9403-5a9cf4e4269d", "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame.from_dict({\"image\": file_names, \"label\": labels})\n", "print(df.shape)" ] }, { "cell_type": "code", "execution_count": null, "id": "beba86dd-0605-4ebf-8ebb-97d6ad9e5edd", "metadata": {}, "outputs": [], "source": [ "df.head()\n", "df['label'].unique()" ] }, { "cell_type": "code", "execution_count": null, "id": "6defc1e9-4f46-49b6-addc-f422c38fe7e8", "metadata": {}, "outputs": [], "source": [ "y = df[['label']]\n", "df = df.drop(['label'], axis=1)\n", "ros = RandomOverSampler(random_state=83)\n", "df, y_resampled = ros.fit_resample(df, y)\n", "del y\n", "df['label'] = y_resampled\n", "del y_resampled\n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": null, "id": "129d278c-3899-49d2-b06f-a0b2f22f4c4e", "metadata": {}, "outputs": [], "source": [ "dataset[0][\"image\"]\n", "dataset[99][\"image\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "bffc8755-c4ac-41be-b8ab-f9a6e0dbcca3", "metadata": {}, "outputs": [], "source": [ "labels_subset = labels[:5]\n", "print(labels_subset)" ] }, { "cell_type": "code", "execution_count": null, "id": "d003f439-09d1-41e6-9f34-213c4ee38593", "metadata": {}, "outputs": [], "source": [ "labels_list = ['Issue In Deepfake', 'High Quality Deepfake']\n", "\n", "label2id, id2label = {}, {}\n", "for i, label in enumerate(labels_list):\n", " label2id[label] = i\n", " id2label[i] = label\n", "\n", "ClassLabels = ClassLabel(num_classes=len(labels_list), names=labels_list)\n", "\n", "print(\"Mapping of IDs to Labels:\", id2label, '\\n')\n", "print(\"Mapping of Labels to IDs:\", label2id)" ] }, { "cell_type": "code", "execution_count": null, "id": "2fbf1f1b-5936-48be-bc99-6897fea94794", "metadata": {}, "outputs": [], "source": [ "def map_label2id(example):\n", " example['label'] = ClassLabels.str2int(example['label'])\n", " return example\n", "\n", "dataset = dataset.map(map_label2id, batched=True)\n", "\n", "dataset = dataset.cast_column('label', ClassLabels)\n", "\n", "dataset = dataset.train_test_split(test_size=0.4, shuffle=True, stratify_by_column=\"label\")\n", "\n", "train_data = dataset['train']\n", "\n", "test_data = dataset['test']" ] }, { "cell_type": "code", "execution_count": null, "id": "d8a4f7ca-4dff-4446-acaf-f3e7630b678d", "metadata": {}, "outputs": [], "source": [ "model_str = \"google/vit-base-patch16-224-in21k\"\n", "processor = ViTImageProcessor.from_pretrained(model_str)\n", "\n", "image_mean, image_std = processor.image_mean, processor.image_std\n", "size = processor.size[\"height\"]\n", "\n", "_train_transforms = Compose(\n", " [\n", " Resize((size, size)),\n", " RandomRotation(90),\n", " RandomAdjustSharpness(2),\n", " ToTensor(),\n", " Normalize(mean=image_mean, std=image_std)\n", " ]\n", ")\n", "\n", "_val_transforms = Compose(\n", " [\n", " Resize((size, size)),\n", " ToTensor(),\n", " Normalize(mean=image_mean, std=image_std)\n", " ]\n", ")\n", "\n", "def train_transforms(examples):\n", " examples['pixel_values'] = [_train_transforms(image.convert(\"RGB\")) for image in examples['image']]\n", " return examples\n", "\n", "def val_transforms(examples):\n", " examples['pixel_values'] = [_val_transforms(image.convert(\"RGB\")) for image in examples['image']]\n", " return examples\n", "\n", "train_data.set_transform(train_transforms)\n", "test_data.set_transform(val_transforms)" ] }, { "cell_type": "code", "execution_count": null, "id": "0c8a93ca-e4ff-42e2-b58d-445afa0cfee0", "metadata": {}, "outputs": [], "source": [ "def collate_fn(examples):\n", " pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n", " labels = torch.tensor([example['label'] for example in examples])\n", " return {\"pixel_values\": pixel_values, \"labels\": labels}" ] }, { "cell_type": "code", "execution_count": null, "id": "11e0c254-ebb1-4100-a389-9e661d0810ff", "metadata": {}, "outputs": [], "source": [ "model = ViTForImageClassification.from_pretrained(model_str, num_labels=len(labels_list))\n", "model.config.id2label = id2label\n", "model.config.label2id = label2id\n", "\n", "print(model.num_parameters(only_trainable=True) / 1e6)" ] }, { "cell_type": "code", "execution_count": null, "id": "bea51959-9abc-4afc-aee6-0e774f8db9c2", "metadata": {}, "outputs": [], "source": [ "accuracy = evaluate.load(\"accuracy\")\n", "\n", "def compute_metrics(eval_pred):\n", " predictions = eval_pred.predictions\n", " label_ids = eval_pred.label_ids\n", "\n", " predicted_labels = predictions.argmax(axis=1)\n", " acc_score = accuracy.compute(predictions=predicted_labels, references=label_ids)['accuracy']\n", " \n", " return {\n", " \"accuracy\": acc_score\n", " }" ] }, { "cell_type": "code", "execution_count": null, "id": "d5ea0bbc-51a3-4b98-823e-10819ffda292", "metadata": {}, "outputs": [], "source": [ "args = TrainingArguments(\n", " output_dir=\"deepfake_vit\",\n", " logging_dir='./logs',\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=32,\n", " per_device_eval_batch_size=8,\n", " num_train_epochs=4,\n", " weight_decay=0.02,\n", " warmup_steps=50,\n", " remove_unused_columns=False,\n", " save_strategy='epoch',\n", " load_best_model_at_end=True,\n", " save_total_limit=1,\n", " report_to=\"none\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "0a965131-c670-43b1-a153-c1a4df611189", "metadata": {}, "outputs": [], "source": [ "trainer = Trainer(\n", " model,\n", " args,\n", " train_dataset=train_data,\n", " eval_dataset=test_data,\n", " data_collator=collate_fn,\n", " compute_metrics=compute_metrics,\n", " tokenizer=processor,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "ad42ea98-86d6-420e-befe-2ef77eadd76d", "metadata": {}, "outputs": [], "source": [ "trainer.evaluate()" ] }, { "cell_type": "code", "execution_count": null, "id": "df43c341-0e55-41ef-a274-731c88b9b5d5", "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "28866dda", "metadata": {}, "outputs": [], "source": [ "trainer.evaluate()" ] }, { "cell_type": "code", "execution_count": null, "id": "0ec258d9", "metadata": {}, "outputs": [], "source": [ "outputs = trainer.predict(test_data)\n", "print(outputs.metrics)" ] }, { "cell_type": "code", "execution_count": null, "id": "c12a6b10", "metadata": {}, "outputs": [], "source": [ "y_true = outputs.label_ids\n", "y_pred = outputs.predictions.argmax(1)\n", "\n", "def plot_confusion_matrix(cm, classes, title='Confusion Matrix', cmap=plt.cm.Blues, figsize=(10, 8)):\n", " \n", " plt.figure(figsize=figsize)\n", "\n", " plt.imshow(cm, interpolation='nearest', cmap=cmap)\n", " plt.title(title)\n", " plt.colorbar()\n", "\n", " tick_marks = np.arange(len(classes))\n", " plt.xticks(tick_marks, classes, rotation=90)\n", " plt.yticks(tick_marks, classes)\n", "\n", " fmt = '.0f'\n", " thresh = cm.max() / 2.0\n", " for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n", " plt.text(j, i, format(cm[i, j], fmt), horizontalalignment=\"center\", color=\"white\" if cm[i, j] > thresh else \"black\")\n", "\n", " plt.ylabel('True label')\n", " plt.xlabel('Predicted label')\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "accuracy = accuracy_score(y_true, y_pred)\n", "f1 = f1_score(y_true, y_pred, average='macro')\n", "\n", "print(f\"Accuracy: {accuracy:.4f}\")\n", "print(f\"F1 Score: {f1:.4f}\")\n", "\n", "if len(labels_list) <= 150:\n", " cm = confusion_matrix(y_true, y_pred)\n", " plot_confusion_matrix(cm, labels_list, figsize=(8, 6))\n", "\n", "print()\n", "print(\"Classification report:\")\n", "print()\n", "print(classification_report(y_true, y_pred, target_names=labels_list, digits=4))" ] }, { "cell_type": "code", "execution_count": null, "id": "9889438c", "metadata": {}, "outputs": [], "source": [ "trainer.save_model()" ] }, { "cell_type": "code", "execution_count": null, "id": "688e3d62", "metadata": {}, "outputs": [], "source": [ "#upload to hub\n", "from huggingface_hub import notebook_login\n", "notebook_login()" ] }, { "cell_type": "code", "execution_count": null, "id": "fad56df2", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import HfApi\n", "\n", "api = HfApi()\n", "repo_id = f\"prithivMLmods/deepfake_vit\"\n", "\n", "try:\n", " api.create_repo(repo_id)\n", " print(f\"Repo {repo_id} created\")\n", "\n", "except:\n", " \n", " print(f\"Repo {repo_id} already exists\")" ] }, { "cell_type": "code", "execution_count": null, "id": "f5e1559f", "metadata": {}, "outputs": [], "source": [ "api.upload_folder(\n", " folder_path=\"deepfake_vit\", \n", " path_in_repo=\".\", \n", " repo_id=repo_id, \n", " repo_type=\"model\", \n", " revision=\"main\"\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }