ideityfy / module.py
Yegiiii's picture
Upload 5 files
c209d46 verified
import math
import torch
import torchvision.transforms as T
from os import path
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import CrossEntropyLoss
from torchmetrics.functional import accuracy
from timm import create_model, list_models
from timm.models.vision_transformer import VisionTransformer
from torchvision.datasets import ImageFolder
from utils import AverageMeter
from lightning import LightningDataModule, LightningModule
from huggingface_hub import PyTorchModelHubMixin, login
import torch.nn as nn
from lora import LoRA_qkv
PRE_SIZE = (256, 256)
IMG_SIZE = (224, 224)
STATS = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
DATASET_DIRECTORY = path.join(path.dirname(__file__), "datasets")
CHECKPOINT_DIRECTORY = path.join(path.dirname(__file__), "checkpoints")
TRANSFORMS = {
"train": T.Compose([
T.Resize(PRE_SIZE),
T.RandomCrop(IMG_SIZE),
T.ToTensor(),
T.Normalize(**STATS)
]),
"val": T.Compose([
T.Resize(PRE_SIZE),
T.CenterCrop(IMG_SIZE),
T.ToTensor(),
T.Normalize(**STATS)
])
}
class myDataModule(LightningDataModule):
"""
Lightning DataModule for loading and preparing the image dataset.
Args:
ds_name (str): Name of the dataset directory.
batch_size (int): Batch size for data loaders.
num_workers (int): Number of workers for data loaders.
"""
def __init__(self, ds_name: str = "deities", batch_size: int = 32, num_workers: int = 8):
super(myDataModule, self).__init__()
self.ds_path = path.join(DATASET_DIRECTORY, ds_name)
assert path.exists(self.ds_path), f"Dataset {ds_name} not found in {DATASET_DIRECTORY}."
self.ds_name = ds_name
self.batch_size = batch_size
self.num_workers = num_workers
def setup(self, stage=None):
if stage == "fit" or stage is None:
self.train_ds = ImageFolder(root=path.join(self.ds_path, 'train'), transform=TRANSFORMS['train'])
self.val_ds = ImageFolder(root=path.join(self.ds_path, 'val'), transform=TRANSFORMS['val'])
# Number of classes
self.num_classes = len(self.train_ds.classes)
def train_dataloader(self) -> DataLoader:
# Weighted Random sampler for imbalanced dataset
class_samples = [0] * self.num_classes
for _, (_, label) in enumerate(self.train_ds):
class_samples[label] += 1
weights = [1.0 / class_samples[label] for _, label in self.train_ds]
self.sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
return DataLoader(dataset=self.train_ds, batch_size=self.batch_size,
sampler=self.sampler, num_workers=self.num_workers, persistent_workers=True)
def val_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.val_ds, batch_size=self.batch_size,
shuffle=False, num_workers=self.num_workers, persistent_workers=True)
class myModule(LightningModule, PyTorchModelHubMixin):
"""
Lightning Module for training and evaluating the Image classification model.
Args:
model_name (str): Name of the Vision Transformer model.
num_classes (int): Number of classes in the dataset.
freeze_flag (bool): Flag to freeze the base model parameters.
use_lora (bool): Flag to use LoRA (Local Rank Adaptation) for fine-tuning.
rank (int): Rank for LoRA if use_lora is True.
learning_rate (float): Learning rate for the optimizer.
weight_decay (float): Weight decay for the optimizer.
push_to_hf (bool): Flag to push model to Huggingface Hub.
commit_message (str): Commit message
repo_id (str): Huggingface repo id
"""
def __init__(self,
model_name: str = "vit_tiny_patch16_224",
num_classes: int = 25,
freeze_flag: bool = True,
use_lora: bool = False,
rank: int = None,
learning_rate: float = 3e-4,
weight_decay: float = 2e-5,
push_to_hf: bool = True,
commit_message: str = "my model",
repo_id: str = "Yegiiii/ideityfy"
):
super(myModule, self).__init__()
self.save_hyperparameters()
self.model_name = model_name
self.num_classes = num_classes
self.freeze_flag = freeze_flag
self.rank = rank
self.use_lora = use_lora
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.push_to_hf = push_to_hf
self.commit_message = commit_message
self.repo_id = repo_id
assert model_name in list_models(), f"Timm model name {model_name} not available."
timm_model = create_model(model_name, pretrained=True)
assert isinstance(timm_model, VisionTransformer), f"{model_name} not a Vision Transformer."
self.model = timm_model
if freeze_flag:
# Freeze the Timm model parameters
self.freeze()
if use_lora:
# Add LoRA matrices to the Timm model
assert freeze_flag, "Set freeze_flag to True for using LoRA fine-tuning."
assert rank, "Rank can't be None."
# self.model = LoRA_VisionTransformer(self.model, rank)
self.add_lora()
self.model.reset_classifier(num_classes)
# Loss function
self.criterion = CrossEntropyLoss()
# Validation metrics
self.top1_acc = AverageMeter()
self.top3_acc = AverageMeter()
self.top5_acc = AverageMeter()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def on_fit_start(self) -> None:
num_classes = self.trainer.datamodule.num_classes
assert num_classes == self.num_classes, \
f"Number of classes provided in the argument ({self.num_classes}) is not matching \
the number of classes in the dataset ({num_classes})."
def on_fit_end(self) -> None:
if self.push_to_hf:
login()
self.push_to_hub(repo_id=self.repo_id, commit_message=self.commit_message)
def configure_optimizers(self):
optimizer = AdamW(params=filter(lambda param: param.requires_grad, self.model.parameters()),
lr=self.learning_rate, weight_decay=self.weight_decay)
scheduler = CosineAnnealingLR(optimizer, self.trainer.max_epochs, 1e-6)
return ([optimizer], [scheduler])
def shared_step(self, x: torch.Tensor, y: torch.Tensor):
logits = self(x)
loss = self.criterion(logits, y)
return logits, loss
def training_step(self, batch, batch_idx) -> torch.Tensor:
x, y = batch
_, loss = self.shared_step(x, y)
self.log("train_loss", loss, prog_bar=True, logger=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx) -> dict:
x, y = batch
logits, loss = self.shared_step(x, y)
self.top1_acc(
val=accuracy(logits, y, average="weighted", top_k=1, num_classes=self.num_classes))
self.top3_acc(
val=accuracy(logits, y, average="weighted", top_k=3, num_classes=self.num_classes))
self.top5_acc(
val=accuracy(logits, y, average="weighted", top_k=5, num_classes=self.num_classes))
metric_dict = {
"val_loss": loss,
"top1_acc": self.top1_acc.avg,
"top3_acc": self.top3_acc.avg,
"top5_acc": self.top5_acc.avg
}
self.log_dict(metric_dict, prog_bar=True, logger=True, on_epoch=True)
return metric_dict
def on_validation_epoch_end(self) -> None:
self.top1_acc.reset()
self.top3_acc.reset()
self.top5_acc.reset()
def add_lora(self):
self.w_As = []
self.w_Bs = []
for _, blk in enumerate(self.model.blocks):
w_qkv_linear = blk.attn.qkv
self.dim = w_qkv_linear.in_features
lora_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
lora_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
lora_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
lora_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
self.w_As.append(lora_a_linear_q)
self.w_Bs.append(lora_b_linear_q)
self.w_As.append(lora_a_linear_v)
self.w_Bs.append(lora_b_linear_v)
blk.attn.qkv = LoRA_qkv(w_qkv_linear, lora_a_linear_q,
lora_b_linear_q, lora_a_linear_v, lora_b_linear_v)
for w_A in self.w_As:
nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
for w_B in self.w_Bs:
nn.init.zeros_(w_B.weight)
if __name__ == "__main__":
# from torchinfo import summary
# module = myModule(freeze_flag=False)
# summary(module, (1, 3, 224, 224))
from datasets import load_dataset
dataset = load_dataset("Yegiiii/deities")
print(dataset)