Yegiiii commited on
Commit
c209d46
·
verified ·
1 Parent(s): 80a32e1

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +70 -0
  2. lora.py +27 -0
  3. module.py +262 -0
  4. utils.py +52 -0
  5. vit_base_clip_rank4.ckpt +3 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torchvision.transforms as T
4
+ from PIL import Image
5
+ from module import myModule
6
+
7
+ IMG_SIZE = (224, 224)
8
+ STATS = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
9
+ # Define the transformation for the input image
10
+ TTA_TRANSFORM = T.Compose([
11
+ T.Resize(IMG_SIZE),
12
+ T.AutoAugment(),
13
+ T.ToTensor(),
14
+ T.Normalize(**STATS)
15
+ ])
16
+
17
+
18
+ st.set_page_config(
19
+ page_title="Identify the deity using Computer Vision.",
20
+ layout="centered",
21
+ initial_sidebar_state="collapsed",
22
+ menu_items={
23
+ 'Get Help': 'https://www.extremelycoolapp.com/help',
24
+ 'Report a bug': "https://www.extremelycoolapp.com/bug",
25
+ 'About': "# This is an *extremely* cool app!"
26
+ }
27
+ )
28
+
29
+
30
+ st.title(":sparkles: I:orange[deity]fy")
31
+ st.header("Discover the deity with a snap.")
32
+
33
+ model = myModule.load_from_checkpoint("checkpoints/vit_base_clip_rank4.ckpt")
34
+ model.to("cpu")
35
+ model.eval()
36
+
37
+
38
+ # Function to make predictions
39
+ def predict(image):
40
+ # Load and preprocess the input image
41
+ with Image.open(image).convert('RGB') as img:
42
+ img_tensor = torch.stack([TTA_TRANSFORM(img) for img in [img for _ in range(10)]])
43
+ img_tensor = torch.mean(img_tensor, dim=0).unsqueeze(0)
44
+
45
+ # Make a prediction
46
+ with torch.no_grad():
47
+ logits = model(img_tensor)
48
+
49
+ # Get the top 3 predictions and their probabilities
50
+ probs = torch.softmax(logits, dim=1)
51
+ topk = torch.topk(probs, k=3)
52
+ values, indices = topk.values, topk.indices
53
+
54
+ values = values.squeeze().cpu().numpy().tolist()
55
+ indices = indices.cpu().squeeze().numpy().tolist()
56
+
57
+ return values, indices
58
+
59
+
60
+ # Upload image through Streamlit
61
+ img = st.file_uploader(label='choose a file', type=['png', 'jpg', 'jpeg'], label_visibility="hidden")
62
+
63
+
64
+ if img is not None:
65
+
66
+ # Make predictions when the user clicks the "Predict" button
67
+ if st.button("Predict"):
68
+ values, indices = predict(img)
69
+ # Display the top 3 predictions as a bar chart
70
+ st.bar_chart({label: prob for label, prob in zip(indices, values)}, color="#FFC101")
lora.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this is code adapted from https://github.com/JamesQFreeman/LoRA-ViT
2
+
3
+ import torch.nn as nn
4
+
5
+ class LoRA_qkv(nn.Module):
6
+ """ LoRA qkv module for Vision Transformer. """
7
+ def __init__(
8
+ self,
9
+ qkv: nn.Module,
10
+ linear_a_q: nn.Module,
11
+ linear_b_q: nn.Module,
12
+ linear_a_v: nn.Module,
13
+ linear_b_v: nn.Module,
14
+ ):
15
+ super().__init__()
16
+ self.qkv = qkv
17
+ self.dim = qkv.in_features
18
+ self.q_lora = nn.Sequential(linear_a_q, linear_b_q)
19
+ self.v_lora = nn.Sequential(linear_a_v, linear_b_v)
20
+
21
+ def forward(self, x):
22
+ qkv = self.qkv(x)
23
+ new_q = self.q_lora(x)
24
+ new_v = self.v_lora(x)
25
+ qkv[:, :, : self.dim] += new_q
26
+ qkv[:, :, -self.dim :] += new_v
27
+ return qkv
module.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torchvision.transforms as T
4
+ from os import path
5
+ from torch.utils.data import DataLoader, WeightedRandomSampler
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import CosineAnnealingLR
8
+ from torch.nn import CrossEntropyLoss
9
+ from torchmetrics.functional import accuracy
10
+ from timm import create_model, list_models
11
+ from timm.models.vision_transformer import VisionTransformer
12
+ from torchvision.datasets import ImageFolder
13
+
14
+ from utils import AverageMeter
15
+ from lightning import LightningDataModule, LightningModule
16
+ from huggingface_hub import PyTorchModelHubMixin, login
17
+ import torch.nn as nn
18
+ from lora import LoRA_qkv
19
+
20
+
21
+ PRE_SIZE = (256, 256)
22
+ IMG_SIZE = (224, 224)
23
+
24
+ STATS = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
25
+ DATASET_DIRECTORY = path.join(path.dirname(__file__), "datasets")
26
+ CHECKPOINT_DIRECTORY = path.join(path.dirname(__file__), "checkpoints")
27
+
28
+ TRANSFORMS = {
29
+ "train": T.Compose([
30
+ T.Resize(PRE_SIZE),
31
+ T.RandomCrop(IMG_SIZE),
32
+ T.ToTensor(),
33
+ T.Normalize(**STATS)
34
+ ]),
35
+ "val": T.Compose([
36
+ T.Resize(PRE_SIZE),
37
+ T.CenterCrop(IMG_SIZE),
38
+ T.ToTensor(),
39
+ T.Normalize(**STATS)
40
+ ])
41
+ }
42
+
43
+
44
+
45
+ class myDataModule(LightningDataModule):
46
+ """
47
+ Lightning DataModule for loading and preparing the image dataset.
48
+
49
+ Args:
50
+ ds_name (str): Name of the dataset directory.
51
+ batch_size (int): Batch size for data loaders.
52
+ num_workers (int): Number of workers for data loaders.
53
+ """
54
+ def __init__(self, ds_name: str = "deities", batch_size: int = 32, num_workers: int = 8):
55
+ super(myDataModule, self).__init__()
56
+
57
+ self.ds_path = path.join(DATASET_DIRECTORY, ds_name)
58
+ assert path.exists(self.ds_path), f"Dataset {ds_name} not found in {DATASET_DIRECTORY}."
59
+
60
+ self.ds_name = ds_name
61
+ self.batch_size = batch_size
62
+ self.num_workers = num_workers
63
+
64
+
65
+ def setup(self, stage=None):
66
+ if stage == "fit" or stage is None:
67
+ self.train_ds = ImageFolder(root=path.join(self.ds_path, 'train'), transform=TRANSFORMS['train'])
68
+ self.val_ds = ImageFolder(root=path.join(self.ds_path, 'val'), transform=TRANSFORMS['val'])
69
+ # Number of classes
70
+ self.num_classes = len(self.train_ds.classes)
71
+
72
+
73
+ def train_dataloader(self) -> DataLoader:
74
+ # Weighted Random sampler for imbalanced dataset
75
+ class_samples = [0] * self.num_classes
76
+ for _, (_, label) in enumerate(self.train_ds):
77
+ class_samples[label] += 1
78
+ weights = [1.0 / class_samples[label] for _, label in self.train_ds]
79
+ self.sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
80
+ return DataLoader(dataset=self.train_ds, batch_size=self.batch_size,
81
+ sampler=self.sampler, num_workers=self.num_workers, persistent_workers=True)
82
+
83
+
84
+ def val_dataloader(self) -> DataLoader:
85
+ return DataLoader(dataset=self.val_ds, batch_size=self.batch_size,
86
+ shuffle=False, num_workers=self.num_workers, persistent_workers=True)
87
+
88
+
89
+
90
+
91
+ class myModule(LightningModule, PyTorchModelHubMixin):
92
+ """
93
+ Lightning Module for training and evaluating the Image classification model.
94
+
95
+ Args:
96
+ model_name (str): Name of the Vision Transformer model.
97
+ num_classes (int): Number of classes in the dataset.
98
+ freeze_flag (bool): Flag to freeze the base model parameters.
99
+ use_lora (bool): Flag to use LoRA (Local Rank Adaptation) for fine-tuning.
100
+ rank (int): Rank for LoRA if use_lora is True.
101
+ learning_rate (float): Learning rate for the optimizer.
102
+ weight_decay (float): Weight decay for the optimizer.
103
+ push_to_hf (bool): Flag to push model to Huggingface Hub.
104
+ commit_message (str): Commit message
105
+ repo_id (str): Huggingface repo id
106
+ """
107
+ def __init__(self,
108
+ model_name: str = "vit_tiny_patch16_224",
109
+ num_classes: int = 25,
110
+ freeze_flag: bool = True,
111
+ use_lora: bool = False,
112
+ rank: int = None,
113
+ learning_rate: float = 3e-4,
114
+ weight_decay: float = 2e-5,
115
+ push_to_hf: bool = True,
116
+ commit_message: str = "my model",
117
+ repo_id: str = "Yegiiii/ideityfy"
118
+ ):
119
+
120
+ super(myModule, self).__init__()
121
+ self.save_hyperparameters()
122
+ self.model_name = model_name
123
+ self.num_classes = num_classes
124
+ self.freeze_flag = freeze_flag
125
+ self.rank = rank
126
+ self.use_lora = use_lora
127
+ self.learning_rate = learning_rate
128
+ self.weight_decay = weight_decay
129
+ self.push_to_hf = push_to_hf
130
+ self.commit_message = commit_message
131
+ self.repo_id = repo_id
132
+
133
+ assert model_name in list_models(), f"Timm model name {model_name} not available."
134
+ timm_model = create_model(model_name, pretrained=True)
135
+ assert isinstance(timm_model, VisionTransformer), f"{model_name} not a Vision Transformer."
136
+ self.model = timm_model
137
+
138
+ if freeze_flag:
139
+ # Freeze the Timm model parameters
140
+ self.freeze()
141
+
142
+ if use_lora:
143
+ # Add LoRA matrices to the Timm model
144
+ assert freeze_flag, "Set freeze_flag to True for using LoRA fine-tuning."
145
+ assert rank, "Rank can't be None."
146
+ # self.model = LoRA_VisionTransformer(self.model, rank)
147
+ self.add_lora()
148
+
149
+ self.model.reset_classifier(num_classes)
150
+
151
+ # Loss function
152
+ self.criterion = CrossEntropyLoss()
153
+
154
+ # Validation metrics
155
+ self.top1_acc = AverageMeter()
156
+ self.top3_acc = AverageMeter()
157
+ self.top5_acc = AverageMeter()
158
+
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ return self.model(x)
162
+
163
+
164
+ def on_fit_start(self) -> None:
165
+ num_classes = self.trainer.datamodule.num_classes
166
+ assert num_classes == self.num_classes, \
167
+ f"Number of classes provided in the argument ({self.num_classes}) is not matching \
168
+ the number of classes in the dataset ({num_classes})."
169
+
170
+
171
+ def on_fit_end(self) -> None:
172
+ if self.push_to_hf:
173
+ login()
174
+ self.push_to_hub(repo_id=self.repo_id, commit_message=self.commit_message)
175
+
176
+
177
+ def configure_optimizers(self):
178
+ optimizer = AdamW(params=filter(lambda param: param.requires_grad, self.model.parameters()),
179
+ lr=self.learning_rate, weight_decay=self.weight_decay)
180
+
181
+ scheduler = CosineAnnealingLR(optimizer, self.trainer.max_epochs, 1e-6)
182
+ return ([optimizer], [scheduler])
183
+
184
+
185
+ def shared_step(self, x: torch.Tensor, y: torch.Tensor):
186
+ logits = self(x)
187
+ loss = self.criterion(logits, y)
188
+ return logits, loss
189
+
190
+
191
+ def training_step(self, batch, batch_idx) -> torch.Tensor:
192
+ x, y = batch
193
+ _, loss = self.shared_step(x, y)
194
+
195
+ self.log("train_loss", loss, prog_bar=True, logger=True, on_epoch=True)
196
+ return loss
197
+
198
+
199
+ def validation_step(self, batch, batch_idx) -> dict:
200
+ x, y = batch
201
+ logits, loss = self.shared_step(x, y)
202
+
203
+ self.top1_acc(
204
+ val=accuracy(logits, y, average="weighted", top_k=1, num_classes=self.num_classes))
205
+ self.top3_acc(
206
+ val=accuracy(logits, y, average="weighted", top_k=3, num_classes=self.num_classes))
207
+ self.top5_acc(
208
+ val=accuracy(logits, y, average="weighted", top_k=5, num_classes=self.num_classes))
209
+
210
+ metric_dict = {
211
+ "val_loss": loss,
212
+ "top1_acc": self.top1_acc.avg,
213
+ "top3_acc": self.top3_acc.avg,
214
+ "top5_acc": self.top5_acc.avg
215
+ }
216
+
217
+ self.log_dict(metric_dict, prog_bar=True, logger=True, on_epoch=True)
218
+ return metric_dict
219
+
220
+
221
+ def on_validation_epoch_end(self) -> None:
222
+ self.top1_acc.reset()
223
+ self.top3_acc.reset()
224
+ self.top5_acc.reset()
225
+
226
+
227
+ def add_lora(self):
228
+ self.w_As = []
229
+ self.w_Bs = []
230
+
231
+ for _, blk in enumerate(self.model.blocks):
232
+ w_qkv_linear = blk.attn.qkv
233
+ self.dim = w_qkv_linear.in_features
234
+ lora_a_linear_q = nn.Linear(self.dim, self.rank, bias=False)
235
+ lora_b_linear_q = nn.Linear(self.rank, self.dim, bias=False)
236
+ lora_a_linear_v = nn.Linear(self.dim, self.rank, bias=False)
237
+ lora_b_linear_v = nn.Linear(self.rank, self.dim, bias=False)
238
+ self.w_As.append(lora_a_linear_q)
239
+ self.w_Bs.append(lora_b_linear_q)
240
+ self.w_As.append(lora_a_linear_v)
241
+ self.w_Bs.append(lora_b_linear_v)
242
+ blk.attn.qkv = LoRA_qkv(w_qkv_linear, lora_a_linear_q,
243
+ lora_b_linear_q, lora_a_linear_v, lora_b_linear_v)
244
+
245
+ for w_A in self.w_As:
246
+ nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
247
+ for w_B in self.w_Bs:
248
+ nn.init.zeros_(w_B.weight)
249
+
250
+
251
+
252
+ if __name__ == "__main__":
253
+ # from torchinfo import summary
254
+
255
+ # module = myModule(freeze_flag=False)
256
+ # summary(module, (1, 3, 224, 224))
257
+
258
+ from datasets import load_dataset
259
+
260
+ dataset = load_dataset("Yegiiii/deities")
261
+ print(dataset)
262
+
utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import platform
3
+
4
+
5
+ class AverageMeter(object):
6
+ """Computes and stores the average and current value"""
7
+
8
+ def __init__(self):
9
+ self.reset()
10
+
11
+ def reset(self):
12
+ self.val = 0
13
+ self.avg = 0
14
+ self.sum = 0
15
+ self.count = 0
16
+
17
+ def __call__(self, val, n=1):
18
+ self.val = val
19
+ self.sum += val * n
20
+ self.count += n
21
+ self.avg = self.sum / self.count
22
+
23
+
24
+
25
+ def getPlatform():
26
+ plt = platform.system()
27
+ if plt=='Darwin':
28
+ return 'mac'
29
+ return plt
30
+
31
+
32
+
33
+ def hasGPU(plt:str):
34
+ if plt == 'mac':
35
+ return torch.backends.mps.is_available()
36
+ return torch.cuda.is_available()
37
+
38
+
39
+
40
+ def getDevice(plt:str):
41
+ if plt == 'mac':
42
+ return torch.device('mps')
43
+ return torch.device('cuda')
44
+
45
+
46
+
47
+ def disableWarnings():
48
+ import warnings
49
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers.utils.generic")
50
+ warnings.filterwarnings("ignore", category=UserWarning, module="trl.trainer.ppo_config")
51
+ warnings.filterwarnings("ignore", message="torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly")
52
+
vit_base_clip_rank4.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2028c9135feda494935c91ad8b33089df6e6df3ef13b73d4f8a187af41feb5f9
3
+ size 345320306