|
import math |
|
import torch |
|
import torchvision.transforms as T |
|
from os import path |
|
from torch.utils.data import DataLoader, WeightedRandomSampler |
|
from torch.optim import AdamW |
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
from torch.nn import CrossEntropyLoss |
|
from torchmetrics.functional import accuracy |
|
from timm import create_model, list_models |
|
from timm.models.vision_transformer import VisionTransformer |
|
from torchvision.datasets import ImageFolder |
|
|
|
from utils import AverageMeter |
|
from lightning import LightningDataModule, LightningModule |
|
from huggingface_hub import PyTorchModelHubMixin, login |
|
import torch.nn as nn |
|
from lora import LoRA_qkv |
|
|
|
|
|
PRE_SIZE = (256, 256) |
|
IMG_SIZE = (224, 224) |
|
|
|
STATS = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
DATASET_DIRECTORY = path.join(path.dirname(__file__), "datasets") |
|
CHECKPOINT_DIRECTORY = path.join(path.dirname(__file__), "checkpoints") |
|
|
|
TRANSFORMS = { |
|
"train": T.Compose([ |
|
T.Resize(PRE_SIZE), |
|
T.RandomCrop(IMG_SIZE), |
|
T.ToTensor(), |
|
T.Normalize(**STATS) |
|
]), |
|
"val": T.Compose([ |
|
T.Resize(PRE_SIZE), |
|
T.CenterCrop(IMG_SIZE), |
|
T.ToTensor(), |
|
T.Normalize(**STATS) |
|
]) |
|
} |
|
|
|
|
|
|
|
class myDataModule(LightningDataModule): |
|
""" |
|
Lightning DataModule for loading and preparing the image dataset. |
|
|
|
Args: |
|
ds_name (str): Name of the dataset directory. |
|
batch_size (int): Batch size for data loaders. |
|
num_workers (int): Number of workers for data loaders. |
|
""" |
|
def __init__(self, ds_name: str = "deities", batch_size: int = 32, num_workers: int = 8): |
|
super(myDataModule, self).__init__() |
|
|
|
self.ds_path = path.join(DATASET_DIRECTORY, ds_name) |
|
assert path.exists(self.ds_path), f"Dataset {ds_name} not found in {DATASET_DIRECTORY}." |
|
|
|
self.ds_name = ds_name |
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
|
|
|
|
def setup(self, stage=None): |
|
if stage == "fit" or stage is None: |
|
self.train_ds = ImageFolder(root=path.join(self.ds_path, 'train'), transform=TRANSFORMS['train']) |
|
self.val_ds = ImageFolder(root=path.join(self.ds_path, 'val'), transform=TRANSFORMS['val']) |
|
|
|
self.num_classes = len(self.train_ds.classes) |
|
|
|
|
|
def train_dataloader(self) -> DataLoader: |
|
|
|
class_samples = [0] * self.num_classes |
|
for _, (_, label) in enumerate(self.train_ds): |
|
class_samples[label] += 1 |
|
weights = [1.0 / class_samples[label] for _, label in self.train_ds] |
|
self.sampler = WeightedRandomSampler(weights, len(weights), replacement=True) |
|
return DataLoader(dataset=self.train_ds, batch_size=self.batch_size, |
|
sampler=self.sampler, num_workers=self.num_workers, persistent_workers=True) |
|
|
|
|
|
def val_dataloader(self) -> DataLoader: |
|
return DataLoader(dataset=self.val_ds, batch_size=self.batch_size, |
|
shuffle=False, num_workers=self.num_workers, persistent_workers=True) |
|
|
|
|
|
|
|
|
|
class myModule(LightningModule, PyTorchModelHubMixin): |
|
""" |
|
Lightning Module for training and evaluating the Image classification model. |
|
|
|
Args: |
|
model_name (str): Name of the Vision Transformer model. |
|
num_classes (int): Number of classes in the dataset. |
|
freeze_flag (bool): Flag to freeze the base model parameters. |
|
use_lora (bool): Flag to use LoRA (Local Rank Adaptation) for fine-tuning. |
|
rank (int): Rank for LoRA if use_lora is True. |
|
learning_rate (float): Learning rate for the optimizer. |
|
weight_decay (float): Weight decay for the optimizer. |
|
push_to_hf (bool): Flag to push model to Huggingface Hub. |
|
commit_message (str): Commit message |
|
repo_id (str): Huggingface repo id |
|
""" |
|
def __init__(self, |
|
model_name: str = "vit_tiny_patch16_224", |
|
num_classes: int = 25, |
|
freeze_flag: bool = True, |
|
use_lora: bool = False, |
|
rank: int = None, |
|
learning_rate: float = 3e-4, |
|
weight_decay: float = 2e-5, |
|
push_to_hf: bool = True, |
|
commit_message: str = "my model", |
|
repo_id: str = "Yegiiii/ideityfy" |
|
): |
|
|
|
super(myModule, self).__init__() |
|
self.save_hyperparameters() |
|
self.model_name = model_name |
|
self.num_classes = num_classes |
|
self.freeze_flag = freeze_flag |
|
self.rank = rank |
|
self.use_lora = use_lora |
|
self.learning_rate = learning_rate |
|
self.weight_decay = weight_decay |
|
self.push_to_hf = push_to_hf |
|
self.commit_message = commit_message |
|
self.repo_id = repo_id |
|
|
|
assert model_name in list_models(), f"Timm model name {model_name} not available." |
|
timm_model = create_model(model_name, pretrained=True) |
|
assert isinstance(timm_model, VisionTransformer), f"{model_name} not a Vision Transformer." |
|
self.model = timm_model |
|
|
|
if freeze_flag: |
|
|
|
self.freeze() |
|
|
|
if use_lora: |
|
|
|
assert freeze_flag, "Set freeze_flag to True for using LoRA fine-tuning." |
|
assert rank, "Rank can't be None." |
|
|
|
self.add_lora() |
|
|
|
self.model.reset_classifier(num_classes) |
|
|
|
|
|
self.criterion = CrossEntropyLoss() |
|
|
|
|
|
self.top1_acc = AverageMeter() |
|
self.top3_acc = AverageMeter() |
|
self.top5_acc = AverageMeter() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.model(x) |
|
|
|
|
|
def on_fit_start(self) -> None: |
|
num_classes = self.trainer.datamodule.num_classes |
|
assert num_classes == self.num_classes, \ |
|
f"Number of classes provided in the argument ({self.num_classes}) is not matching \ |
|
the number of classes in the dataset ({num_classes})." |
|
|
|
|
|
def on_fit_end(self) -> None: |
|
if self.push_to_hf: |
|
login() |
|
self.push_to_hub(repo_id=self.repo_id, commit_message=self.commit_message) |
|
|
|
|
|
def configure_optimizers(self): |
|
optimizer = AdamW(params=filter(lambda param: param.requires_grad, self.model.parameters()), |
|
lr=self.learning_rate, weight_decay=self.weight_decay) |
|
|
|
scheduler = CosineAnnealingLR(optimizer, self.trainer.max_epochs, 1e-6) |
|
return ([optimizer], [scheduler]) |
|
|
|
|
|
def shared_step(self, x: torch.Tensor, y: torch.Tensor): |
|
logits = self(x) |
|
loss = self.criterion(logits, y) |
|
return logits, loss |
|
|
|
|
|
def training_step(self, batch, batch_idx) -> torch.Tensor: |
|
x, y = batch |
|
_, loss = self.shared_step(x, y) |
|
|
|
self.log("train_loss", loss, prog_bar=True, logger=True, on_epoch=True) |
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx) -> dict: |
|
x, y = batch |
|
logits, loss = self.shared_step(x, y) |
|
|
|
self.top1_acc( |
|
val=accuracy(logits, y, average="weighted", top_k=1, num_classes=self.num_classes)) |
|
self.top3_acc( |
|
val=accuracy(logits, y, average="weighted", top_k=3, num_classes=self.num_classes)) |
|
self.top5_acc( |
|
val=accuracy(logits, y, average="weighted", top_k=5, num_classes=self.num_classes)) |
|
|
|
metric_dict = { |
|
"val_loss": loss, |
|
"top1_acc": self.top1_acc.avg, |
|
"top3_acc": self.top3_acc.avg, |
|
"top5_acc": self.top5_acc.avg |
|
} |
|
|
|
self.log_dict(metric_dict, prog_bar=True, logger=True, on_epoch=True) |
|
return metric_dict |
|
|
|
|
|
def on_validation_epoch_end(self) -> None: |
|
self.top1_acc.reset() |
|
self.top3_acc.reset() |
|
self.top5_acc.reset() |
|
|
|
|
|
def add_lora(self): |
|
self.w_As = [] |
|
self.w_Bs = [] |
|
|
|
for _, blk in enumerate(self.model.blocks): |
|
w_qkv_linear = blk.attn.qkv |
|
self.dim = w_qkv_linear.in_features |
|
lora_a_linear_q = nn.Linear(self.dim, self.rank, bias=False) |
|
lora_b_linear_q = nn.Linear(self.rank, self.dim, bias=False) |
|
lora_a_linear_v = nn.Linear(self.dim, self.rank, bias=False) |
|
lora_b_linear_v = nn.Linear(self.rank, self.dim, bias=False) |
|
self.w_As.append(lora_a_linear_q) |
|
self.w_Bs.append(lora_b_linear_q) |
|
self.w_As.append(lora_a_linear_v) |
|
self.w_Bs.append(lora_b_linear_v) |
|
blk.attn.qkv = LoRA_qkv(w_qkv_linear, lora_a_linear_q, |
|
lora_b_linear_q, lora_a_linear_v, lora_b_linear_v) |
|
|
|
for w_A in self.w_As: |
|
nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5)) |
|
for w_B in self.w_Bs: |
|
nn.init.zeros_(w_B.weight) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
from datasets import load_dataset |
|
|
|
dataset = load_dataset("Yegiiii/deities") |
|
print(dataset) |
|
|
|
|