Adityak204 commited on
Commit
905e42f
·
1 Parent(s): 6d04701
Files changed (6) hide show
  1. .gitignore +74 -0
  2. __init__.py +0 -0
  3. app.py +148 -0
  4. data/imagenet-simple-labels.json +1000 -0
  5. pl_train.py +361 -0
  6. requirements.txt +7 -0
.gitignore ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ env/
26
+ ENV/
27
+ .env
28
+ .venv
29
+
30
+ # IDE
31
+ .idea/
32
+ .vscode/
33
+ *.swp
34
+ *.swo
35
+ .project
36
+ .pydevproject
37
+ .settings
38
+
39
+ # Jupyter Notebook
40
+ .ipynb_checkpoints
41
+ *.ipynb_checkpoints/
42
+
43
+ # PyTorch
44
+ *.pth
45
+ *.pt
46
+ *.pkl
47
+
48
+ # Logs and databases
49
+ *.log
50
+ *.sqlite
51
+ *.db
52
+
53
+ # OS generated files
54
+ .DS_Store
55
+ .DS_Store?
56
+ ._*
57
+ .Spotlight-V100
58
+ .Trashes
59
+ ehthumbs.db
60
+ Thumbs.db
61
+
62
+ # Project specific
63
+ runs/
64
+ checkpoints/
65
+ outputs/
66
+ logs/
67
+ lightning_logs/
68
+ gradio_cached_examples/
69
+
70
+ flagged/
71
+ *.pt
72
+ *.jpeg
73
+ *.png
74
+ *.jpg
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import json
6
+ from pathlib import Path
7
+ import os
8
+ from huggingface_hub import hf_hub_download
9
+
10
+
11
+ class ModelPredictor:
12
+ def __init__(
13
+ self,
14
+ model_repo: str,
15
+ model_filename: str,
16
+ device: str = None,
17
+ ):
18
+ self.device = (
19
+ device if device else ("cuda" if torch.cuda.is_available() else "cpu")
20
+ )
21
+
22
+ # Load the model
23
+ checkpoint_path = hf_hub_download(repo_id=model_repo, filename=model_filename)
24
+ self.model = self.load_model(checkpoint_path)
25
+ self.model.to(self.device)
26
+ self.model.eval()
27
+
28
+ # Setup transforms
29
+ self.transform = transforms.Compose(
30
+ [
31
+ transforms.Resize(256),
32
+ transforms.CenterCrop(224),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(
35
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
36
+ ),
37
+ ]
38
+ )
39
+
40
+ # Load ImageNet class labels
41
+ self.class_labels = self.load_imagenet_labels()
42
+
43
+ def load_model(self, checkpoint_path: str):
44
+ """Load the trained model from checkpoint"""
45
+ from pl_train import ImageNetModule
46
+
47
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
48
+ model = ImageNetModule(
49
+ learning_rate=0.156,
50
+ batch_size=1,
51
+ num_workers=0, # Set to 0 for Gradio
52
+ max_epochs=40,
53
+ train_path="",
54
+ val_path="",
55
+ checkpoint_dir="",
56
+ )
57
+ model.load_state_dict(checkpoint["state_dict"])
58
+ return model
59
+
60
+ def load_imagenet_labels(self):
61
+ """Load ImageNet class labels"""
62
+ # For HuggingFace Spaces, we'll look for the labels file in the same directory
63
+ labels_path = Path("data/imagenet-simple-labels.json")
64
+
65
+ if labels_path.exists():
66
+ with open(labels_path) as f:
67
+ data = json.load(f)
68
+ return {str(i + 1): name for i, name in enumerate(data)}
69
+ return {str(i): f"class_{i}" for i in range(1000)} # Fallback
70
+
71
+ def predict(self, image):
72
+ """
73
+ Make prediction for a single image
74
+ Args:
75
+ image: PIL Image or path to image
76
+ Returns:
77
+ Dictionary of class labels and probabilities
78
+ """
79
+ if isinstance(image, str):
80
+ image = Image.open(image).convert("RGB")
81
+ else:
82
+ image = Image.fromarray(image).convert("RGB")
83
+
84
+ image_tensor = self.transform(image).unsqueeze(0)
85
+ image_tensor = image_tensor.to(self.device)
86
+
87
+ with torch.no_grad():
88
+ outputs = self.model(image_tensor)
89
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
90
+
91
+ # Get top 5 predictions
92
+ top_probs, top_indices = torch.topk(probabilities, 5)
93
+
94
+ # Create results dictionary
95
+ results = {}
96
+ for prob, idx in zip(top_probs[0], top_indices[0]):
97
+ class_name = self.class_labels[str(idx.item())]
98
+ results[class_name] = float(prob)
99
+
100
+ return results
101
+
102
+
103
+ # Initialize the predictor
104
+ predictor = ModelPredictor(
105
+ model_repo="Adityak204/ResNetVision-1K", # Replace with your repo
106
+ model_filename="resnet50-epoch36-acc60.3506.ckpt", # Replace with your model filename
107
+ )
108
+
109
+
110
+ def predict_image(image):
111
+ """
112
+ Gradio interface function
113
+ """
114
+ predictions = predictor.predict(image)
115
+
116
+ # Format results for display
117
+ return {k: f"{v:.2%}" for k, v in predictions.items()}
118
+
119
+
120
+ # Create Gradio interface
121
+ iface = gr.Interface(
122
+ fn=predict_image,
123
+ inputs=gr.Image(type="pil"),
124
+ outputs=gr.Label(num_top_classes=5),
125
+ title="ImageNet-1K Classification",
126
+ description="Upload an image to classify it into one of 1000 ImageNet categories",
127
+ examples=(
128
+ [
129
+ ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"],
130
+ ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"],
131
+ ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"],
132
+ ]
133
+ if all(
134
+ Path(f).exists()
135
+ for f in [
136
+ ["ResNetVision-1K/data/ILSVRC2012_val_00000048.JPEG"],
137
+ ["ResNetVision-1K/data/ILSVRC2012_val_00000090.JPEG"],
138
+ ["ResNetVision-1K/data/ILSVRC2012_val_00000.JPEG"],
139
+ ]
140
+ )
141
+ else None
142
+ ),
143
+ analytics_enabled=False,
144
+ )
145
+
146
+ # Launch the app
147
+ if __name__ == "__main__":
148
+ iface.launch()
data/imagenet-simple-labels.json ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ["tench",
2
+ "goldfish",
3
+ "great white shark",
4
+ "tiger shark",
5
+ "hammerhead shark",
6
+ "electric ray",
7
+ "stingray",
8
+ "cock",
9
+ "hen",
10
+ "ostrich",
11
+ "brambling",
12
+ "goldfinch",
13
+ "house finch",
14
+ "junco",
15
+ "indigo bunting",
16
+ "American robin",
17
+ "bulbul",
18
+ "jay",
19
+ "magpie",
20
+ "chickadee",
21
+ "American dipper",
22
+ "kite",
23
+ "bald eagle",
24
+ "vulture",
25
+ "great grey owl",
26
+ "fire salamander",
27
+ "smooth newt",
28
+ "newt",
29
+ "spotted salamander",
30
+ "axolotl",
31
+ "American bullfrog",
32
+ "tree frog",
33
+ "tailed frog",
34
+ "loggerhead sea turtle",
35
+ "leatherback sea turtle",
36
+ "mud turtle",
37
+ "terrapin",
38
+ "box turtle",
39
+ "banded gecko",
40
+ "green iguana",
41
+ "Carolina anole",
42
+ "desert grassland whiptail lizard",
43
+ "agama",
44
+ "frilled-necked lizard",
45
+ "alligator lizard",
46
+ "Gila monster",
47
+ "European green lizard",
48
+ "chameleon",
49
+ "Komodo dragon",
50
+ "Nile crocodile",
51
+ "American alligator",
52
+ "triceratops",
53
+ "worm snake",
54
+ "ring-necked snake",
55
+ "eastern hog-nosed snake",
56
+ "smooth green snake",
57
+ "kingsnake",
58
+ "garter snake",
59
+ "water snake",
60
+ "vine snake",
61
+ "night snake",
62
+ "boa constrictor",
63
+ "African rock python",
64
+ "Indian cobra",
65
+ "green mamba",
66
+ "sea snake",
67
+ "Saharan horned viper",
68
+ "eastern diamondback rattlesnake",
69
+ "sidewinder",
70
+ "trilobite",
71
+ "harvestman",
72
+ "scorpion",
73
+ "yellow garden spider",
74
+ "barn spider",
75
+ "European garden spider",
76
+ "southern black widow",
77
+ "tarantula",
78
+ "wolf spider",
79
+ "tick",
80
+ "centipede",
81
+ "black grouse",
82
+ "ptarmigan",
83
+ "ruffed grouse",
84
+ "prairie grouse",
85
+ "peacock",
86
+ "quail",
87
+ "partridge",
88
+ "grey parrot",
89
+ "macaw",
90
+ "sulphur-crested cockatoo",
91
+ "lorikeet",
92
+ "coucal",
93
+ "bee eater",
94
+ "hornbill",
95
+ "hummingbird",
96
+ "jacamar",
97
+ "toucan",
98
+ "duck",
99
+ "red-breasted merganser",
100
+ "goose",
101
+ "black swan",
102
+ "tusker",
103
+ "echidna",
104
+ "platypus",
105
+ "wallaby",
106
+ "koala",
107
+ "wombat",
108
+ "jellyfish",
109
+ "sea anemone",
110
+ "brain coral",
111
+ "flatworm",
112
+ "nematode",
113
+ "conch",
114
+ "snail",
115
+ "slug",
116
+ "sea slug",
117
+ "chiton",
118
+ "chambered nautilus",
119
+ "Dungeness crab",
120
+ "rock crab",
121
+ "fiddler crab",
122
+ "red king crab",
123
+ "American lobster",
124
+ "spiny lobster",
125
+ "crayfish",
126
+ "hermit crab",
127
+ "isopod",
128
+ "white stork",
129
+ "black stork",
130
+ "spoonbill",
131
+ "flamingo",
132
+ "little blue heron",
133
+ "great egret",
134
+ "bittern",
135
+ "crane (bird)",
136
+ "limpkin",
137
+ "common gallinule",
138
+ "American coot",
139
+ "bustard",
140
+ "ruddy turnstone",
141
+ "dunlin",
142
+ "common redshank",
143
+ "dowitcher",
144
+ "oystercatcher",
145
+ "pelican",
146
+ "king penguin",
147
+ "albatross",
148
+ "grey whale",
149
+ "killer whale",
150
+ "dugong",
151
+ "sea lion",
152
+ "Chihuahua",
153
+ "Japanese Chin",
154
+ "Maltese",
155
+ "Pekingese",
156
+ "Shih Tzu",
157
+ "King Charles Spaniel",
158
+ "Papillon",
159
+ "toy terrier",
160
+ "Rhodesian Ridgeback",
161
+ "Afghan Hound",
162
+ "Basset Hound",
163
+ "Beagle",
164
+ "Bloodhound",
165
+ "Bluetick Coonhound",
166
+ "Black and Tan Coonhound",
167
+ "Treeing Walker Coonhound",
168
+ "English foxhound",
169
+ "Redbone Coonhound",
170
+ "borzoi",
171
+ "Irish Wolfhound",
172
+ "Italian Greyhound",
173
+ "Whippet",
174
+ "Ibizan Hound",
175
+ "Norwegian Elkhound",
176
+ "Otterhound",
177
+ "Saluki",
178
+ "Scottish Deerhound",
179
+ "Weimaraner",
180
+ "Staffordshire Bull Terrier",
181
+ "American Staffordshire Terrier",
182
+ "Bedlington Terrier",
183
+ "Border Terrier",
184
+ "Kerry Blue Terrier",
185
+ "Irish Terrier",
186
+ "Norfolk Terrier",
187
+ "Norwich Terrier",
188
+ "Yorkshire Terrier",
189
+ "Wire Fox Terrier",
190
+ "Lakeland Terrier",
191
+ "Sealyham Terrier",
192
+ "Airedale Terrier",
193
+ "Cairn Terrier",
194
+ "Australian Terrier",
195
+ "Dandie Dinmont Terrier",
196
+ "Boston Terrier",
197
+ "Miniature Schnauzer",
198
+ "Giant Schnauzer",
199
+ "Standard Schnauzer",
200
+ "Scottish Terrier",
201
+ "Tibetan Terrier",
202
+ "Australian Silky Terrier",
203
+ "Soft-coated Wheaten Terrier",
204
+ "West Highland White Terrier",
205
+ "Lhasa Apso",
206
+ "Flat-Coated Retriever",
207
+ "Curly-coated Retriever",
208
+ "Golden Retriever",
209
+ "Labrador Retriever",
210
+ "Chesapeake Bay Retriever",
211
+ "German Shorthaired Pointer",
212
+ "Vizsla",
213
+ "English Setter",
214
+ "Irish Setter",
215
+ "Gordon Setter",
216
+ "Brittany Spaniel",
217
+ "Clumber Spaniel",
218
+ "English Springer Spaniel",
219
+ "Welsh Springer Spaniel",
220
+ "Cocker Spaniels",
221
+ "Sussex Spaniel",
222
+ "Irish Water Spaniel",
223
+ "Kuvasz",
224
+ "Schipperke",
225
+ "Groenendael",
226
+ "Malinois",
227
+ "Briard",
228
+ "Australian Kelpie",
229
+ "Komondor",
230
+ "Old English Sheepdog",
231
+ "Shetland Sheepdog",
232
+ "collie",
233
+ "Border Collie",
234
+ "Bouvier des Flandres",
235
+ "Rottweiler",
236
+ "German Shepherd Dog",
237
+ "Dobermann",
238
+ "Miniature Pinscher",
239
+ "Greater Swiss Mountain Dog",
240
+ "Bernese Mountain Dog",
241
+ "Appenzeller Sennenhund",
242
+ "Entlebucher Sennenhund",
243
+ "Boxer",
244
+ "Bullmastiff",
245
+ "Tibetan Mastiff",
246
+ "French Bulldog",
247
+ "Great Dane",
248
+ "St. Bernard",
249
+ "husky",
250
+ "Alaskan Malamute",
251
+ "Siberian Husky",
252
+ "Dalmatian",
253
+ "Affenpinscher",
254
+ "Basenji",
255
+ "pug",
256
+ "Leonberger",
257
+ "Newfoundland",
258
+ "Pyrenean Mountain Dog",
259
+ "Samoyed",
260
+ "Pomeranian",
261
+ "Chow Chow",
262
+ "Keeshond",
263
+ "Griffon Bruxellois",
264
+ "Pembroke Welsh Corgi",
265
+ "Cardigan Welsh Corgi",
266
+ "Toy Poodle",
267
+ "Miniature Poodle",
268
+ "Standard Poodle",
269
+ "Mexican hairless dog",
270
+ "grey wolf",
271
+ "Alaskan tundra wolf",
272
+ "red wolf",
273
+ "coyote",
274
+ "dingo",
275
+ "dhole",
276
+ "African wild dog",
277
+ "hyena",
278
+ "red fox",
279
+ "kit fox",
280
+ "Arctic fox",
281
+ "grey fox",
282
+ "tabby cat",
283
+ "tiger cat",
284
+ "Persian cat",
285
+ "Siamese cat",
286
+ "Egyptian Mau",
287
+ "cougar",
288
+ "lynx",
289
+ "leopard",
290
+ "snow leopard",
291
+ "jaguar",
292
+ "lion",
293
+ "tiger",
294
+ "cheetah",
295
+ "brown bear",
296
+ "American black bear",
297
+ "polar bear",
298
+ "sloth bear",
299
+ "mongoose",
300
+ "meerkat",
301
+ "tiger beetle",
302
+ "ladybug",
303
+ "ground beetle",
304
+ "longhorn beetle",
305
+ "leaf beetle",
306
+ "dung beetle",
307
+ "rhinoceros beetle",
308
+ "weevil",
309
+ "fly",
310
+ "bee",
311
+ "ant",
312
+ "grasshopper",
313
+ "cricket",
314
+ "stick insect",
315
+ "cockroach",
316
+ "mantis",
317
+ "cicada",
318
+ "leafhopper",
319
+ "lacewing",
320
+ "dragonfly",
321
+ "damselfly",
322
+ "red admiral",
323
+ "ringlet",
324
+ "monarch butterfly",
325
+ "small white",
326
+ "sulphur butterfly",
327
+ "gossamer-winged butterfly",
328
+ "starfish",
329
+ "sea urchin",
330
+ "sea cucumber",
331
+ "cottontail rabbit",
332
+ "hare",
333
+ "Angora rabbit",
334
+ "hamster",
335
+ "porcupine",
336
+ "fox squirrel",
337
+ "marmot",
338
+ "beaver",
339
+ "guinea pig",
340
+ "common sorrel",
341
+ "zebra",
342
+ "pig",
343
+ "wild boar",
344
+ "warthog",
345
+ "hippopotamus",
346
+ "ox",
347
+ "water buffalo",
348
+ "bison",
349
+ "ram",
350
+ "bighorn sheep",
351
+ "Alpine ibex",
352
+ "hartebeest",
353
+ "impala",
354
+ "gazelle",
355
+ "dromedary",
356
+ "llama",
357
+ "weasel",
358
+ "mink",
359
+ "European polecat",
360
+ "black-footed ferret",
361
+ "otter",
362
+ "skunk",
363
+ "badger",
364
+ "armadillo",
365
+ "three-toed sloth",
366
+ "orangutan",
367
+ "gorilla",
368
+ "chimpanzee",
369
+ "gibbon",
370
+ "siamang",
371
+ "guenon",
372
+ "patas monkey",
373
+ "baboon",
374
+ "macaque",
375
+ "langur",
376
+ "black-and-white colobus",
377
+ "proboscis monkey",
378
+ "marmoset",
379
+ "white-headed capuchin",
380
+ "howler monkey",
381
+ "titi",
382
+ "Geoffroy's spider monkey",
383
+ "common squirrel monkey",
384
+ "ring-tailed lemur",
385
+ "indri",
386
+ "Asian elephant",
387
+ "African bush elephant",
388
+ "red panda",
389
+ "giant panda",
390
+ "snoek",
391
+ "eel",
392
+ "coho salmon",
393
+ "rock beauty",
394
+ "clownfish",
395
+ "sturgeon",
396
+ "garfish",
397
+ "lionfish",
398
+ "pufferfish",
399
+ "abacus",
400
+ "abaya",
401
+ "academic gown",
402
+ "accordion",
403
+ "acoustic guitar",
404
+ "aircraft carrier",
405
+ "airliner",
406
+ "airship",
407
+ "altar",
408
+ "ambulance",
409
+ "amphibious vehicle",
410
+ "analog clock",
411
+ "apiary",
412
+ "apron",
413
+ "waste container",
414
+ "assault rifle",
415
+ "backpack",
416
+ "bakery",
417
+ "balance beam",
418
+ "balloon",
419
+ "ballpoint pen",
420
+ "Band-Aid",
421
+ "banjo",
422
+ "baluster",
423
+ "barbell",
424
+ "barber chair",
425
+ "barbershop",
426
+ "barn",
427
+ "barometer",
428
+ "barrel",
429
+ "wheelbarrow",
430
+ "baseball",
431
+ "basketball",
432
+ "bassinet",
433
+ "bassoon",
434
+ "swimming cap",
435
+ "bath towel",
436
+ "bathtub",
437
+ "station wagon",
438
+ "lighthouse",
439
+ "beaker",
440
+ "military cap",
441
+ "beer bottle",
442
+ "beer glass",
443
+ "bell-cot",
444
+ "bib",
445
+ "tandem bicycle",
446
+ "bikini",
447
+ "ring binder",
448
+ "binoculars",
449
+ "birdhouse",
450
+ "boathouse",
451
+ "bobsleigh",
452
+ "bolo tie",
453
+ "poke bonnet",
454
+ "bookcase",
455
+ "bookstore",
456
+ "bottle cap",
457
+ "bow",
458
+ "bow tie",
459
+ "brass",
460
+ "bra",
461
+ "breakwater",
462
+ "breastplate",
463
+ "broom",
464
+ "bucket",
465
+ "buckle",
466
+ "bulletproof vest",
467
+ "high-speed train",
468
+ "butcher shop",
469
+ "taxicab",
470
+ "cauldron",
471
+ "candle",
472
+ "cannon",
473
+ "canoe",
474
+ "can opener",
475
+ "cardigan",
476
+ "car mirror",
477
+ "carousel",
478
+ "tool kit",
479
+ "carton",
480
+ "car wheel",
481
+ "automated teller machine",
482
+ "cassette",
483
+ "cassette player",
484
+ "castle",
485
+ "catamaran",
486
+ "CD player",
487
+ "cello",
488
+ "mobile phone",
489
+ "chain",
490
+ "chain-link fence",
491
+ "chain mail",
492
+ "chainsaw",
493
+ "chest",
494
+ "chiffonier",
495
+ "chime",
496
+ "china cabinet",
497
+ "Christmas stocking",
498
+ "church",
499
+ "movie theater",
500
+ "cleaver",
501
+ "cliff dwelling",
502
+ "cloak",
503
+ "clogs",
504
+ "cocktail shaker",
505
+ "coffee mug",
506
+ "coffeemaker",
507
+ "coil",
508
+ "combination lock",
509
+ "computer keyboard",
510
+ "confectionery store",
511
+ "container ship",
512
+ "convertible",
513
+ "corkscrew",
514
+ "cornet",
515
+ "cowboy boot",
516
+ "cowboy hat",
517
+ "cradle",
518
+ "crane (machine)",
519
+ "crash helmet",
520
+ "crate",
521
+ "infant bed",
522
+ "Crock Pot",
523
+ "croquet ball",
524
+ "crutch",
525
+ "cuirass",
526
+ "dam",
527
+ "desk",
528
+ "desktop computer",
529
+ "rotary dial telephone",
530
+ "diaper",
531
+ "digital clock",
532
+ "digital watch",
533
+ "dining table",
534
+ "dishcloth",
535
+ "dishwasher",
536
+ "disc brake",
537
+ "dock",
538
+ "dog sled",
539
+ "dome",
540
+ "doormat",
541
+ "drilling rig",
542
+ "drum",
543
+ "drumstick",
544
+ "dumbbell",
545
+ "Dutch oven",
546
+ "electric fan",
547
+ "electric guitar",
548
+ "electric locomotive",
549
+ "entertainment center",
550
+ "envelope",
551
+ "espresso machine",
552
+ "face powder",
553
+ "feather boa",
554
+ "filing cabinet",
555
+ "fireboat",
556
+ "fire engine",
557
+ "fire screen sheet",
558
+ "flagpole",
559
+ "flute",
560
+ "folding chair",
561
+ "football helmet",
562
+ "forklift",
563
+ "fountain",
564
+ "fountain pen",
565
+ "four-poster bed",
566
+ "freight car",
567
+ "French horn",
568
+ "frying pan",
569
+ "fur coat",
570
+ "garbage truck",
571
+ "gas mask",
572
+ "gas pump",
573
+ "goblet",
574
+ "go-kart",
575
+ "golf ball",
576
+ "golf cart",
577
+ "gondola",
578
+ "gong",
579
+ "gown",
580
+ "grand piano",
581
+ "greenhouse",
582
+ "grille",
583
+ "grocery store",
584
+ "guillotine",
585
+ "barrette",
586
+ "hair spray",
587
+ "half-track",
588
+ "hammer",
589
+ "hamper",
590
+ "hair dryer",
591
+ "hand-held computer",
592
+ "handkerchief",
593
+ "hard disk drive",
594
+ "harmonica",
595
+ "harp",
596
+ "harvester",
597
+ "hatchet",
598
+ "holster",
599
+ "home theater",
600
+ "honeycomb",
601
+ "hook",
602
+ "hoop skirt",
603
+ "horizontal bar",
604
+ "horse-drawn vehicle",
605
+ "hourglass",
606
+ "iPod",
607
+ "clothes iron",
608
+ "jack-o'-lantern",
609
+ "jeans",
610
+ "jeep",
611
+ "T-shirt",
612
+ "jigsaw puzzle",
613
+ "pulled rickshaw",
614
+ "joystick",
615
+ "kimono",
616
+ "knee pad",
617
+ "knot",
618
+ "lab coat",
619
+ "ladle",
620
+ "lampshade",
621
+ "laptop computer",
622
+ "lawn mower",
623
+ "lens cap",
624
+ "paper knife",
625
+ "library",
626
+ "lifeboat",
627
+ "lighter",
628
+ "limousine",
629
+ "ocean liner",
630
+ "lipstick",
631
+ "slip-on shoe",
632
+ "lotion",
633
+ "speaker",
634
+ "loupe",
635
+ "sawmill",
636
+ "magnetic compass",
637
+ "mail bag",
638
+ "mailbox",
639
+ "tights",
640
+ "tank suit",
641
+ "manhole cover",
642
+ "maraca",
643
+ "marimba",
644
+ "mask",
645
+ "match",
646
+ "maypole",
647
+ "maze",
648
+ "measuring cup",
649
+ "medicine chest",
650
+ "megalith",
651
+ "microphone",
652
+ "microwave oven",
653
+ "military uniform",
654
+ "milk can",
655
+ "minibus",
656
+ "miniskirt",
657
+ "minivan",
658
+ "missile",
659
+ "mitten",
660
+ "mixing bowl",
661
+ "mobile home",
662
+ "Model T",
663
+ "modem",
664
+ "monastery",
665
+ "monitor",
666
+ "moped",
667
+ "mortar",
668
+ "square academic cap",
669
+ "mosque",
670
+ "mosquito net",
671
+ "scooter",
672
+ "mountain bike",
673
+ "tent",
674
+ "computer mouse",
675
+ "mousetrap",
676
+ "moving van",
677
+ "muzzle",
678
+ "nail",
679
+ "neck brace",
680
+ "necklace",
681
+ "nipple",
682
+ "notebook computer",
683
+ "obelisk",
684
+ "oboe",
685
+ "ocarina",
686
+ "odometer",
687
+ "oil filter",
688
+ "organ",
689
+ "oscilloscope",
690
+ "overskirt",
691
+ "bullock cart",
692
+ "oxygen mask",
693
+ "packet",
694
+ "paddle",
695
+ "paddle wheel",
696
+ "padlock",
697
+ "paintbrush",
698
+ "pajamas",
699
+ "palace",
700
+ "pan flute",
701
+ "paper towel",
702
+ "parachute",
703
+ "parallel bars",
704
+ "park bench",
705
+ "parking meter",
706
+ "passenger car",
707
+ "patio",
708
+ "payphone",
709
+ "pedestal",
710
+ "pencil case",
711
+ "pencil sharpener",
712
+ "perfume",
713
+ "Petri dish",
714
+ "photocopier",
715
+ "plectrum",
716
+ "Pickelhaube",
717
+ "picket fence",
718
+ "pickup truck",
719
+ "pier",
720
+ "piggy bank",
721
+ "pill bottle",
722
+ "pillow",
723
+ "ping-pong ball",
724
+ "pinwheel",
725
+ "pirate ship",
726
+ "pitcher",
727
+ "hand plane",
728
+ "planetarium",
729
+ "plastic bag",
730
+ "plate rack",
731
+ "plow",
732
+ "plunger",
733
+ "Polaroid camera",
734
+ "pole",
735
+ "police van",
736
+ "poncho",
737
+ "billiard table",
738
+ "soda bottle",
739
+ "pot",
740
+ "potter's wheel",
741
+ "power drill",
742
+ "prayer rug",
743
+ "printer",
744
+ "prison",
745
+ "projectile",
746
+ "projector",
747
+ "hockey puck",
748
+ "punching bag",
749
+ "purse",
750
+ "quill",
751
+ "quilt",
752
+ "race car",
753
+ "racket",
754
+ "radiator",
755
+ "radio",
756
+ "radio telescope",
757
+ "rain barrel",
758
+ "recreational vehicle",
759
+ "reel",
760
+ "reflex camera",
761
+ "refrigerator",
762
+ "remote control",
763
+ "restaurant",
764
+ "revolver",
765
+ "rifle",
766
+ "rocking chair",
767
+ "rotisserie",
768
+ "eraser",
769
+ "rugby ball",
770
+ "ruler",
771
+ "running shoe",
772
+ "safe",
773
+ "safety pin",
774
+ "salt shaker",
775
+ "sandal",
776
+ "sarong",
777
+ "saxophone",
778
+ "scabbard",
779
+ "weighing scale",
780
+ "school bus",
781
+ "schooner",
782
+ "scoreboard",
783
+ "CRT screen",
784
+ "screw",
785
+ "screwdriver",
786
+ "seat belt",
787
+ "sewing machine",
788
+ "shield",
789
+ "shoe store",
790
+ "shoji",
791
+ "shopping basket",
792
+ "shopping cart",
793
+ "shovel",
794
+ "shower cap",
795
+ "shower curtain",
796
+ "ski",
797
+ "ski mask",
798
+ "sleeping bag",
799
+ "slide rule",
800
+ "sliding door",
801
+ "slot machine",
802
+ "snorkel",
803
+ "snowmobile",
804
+ "snowplow",
805
+ "soap dispenser",
806
+ "soccer ball",
807
+ "sock",
808
+ "solar thermal collector",
809
+ "sombrero",
810
+ "soup bowl",
811
+ "space bar",
812
+ "space heater",
813
+ "space shuttle",
814
+ "spatula",
815
+ "motorboat",
816
+ "spider web",
817
+ "spindle",
818
+ "sports car",
819
+ "spotlight",
820
+ "stage",
821
+ "steam locomotive",
822
+ "through arch bridge",
823
+ "steel drum",
824
+ "stethoscope",
825
+ "scarf",
826
+ "stone wall",
827
+ "stopwatch",
828
+ "stove",
829
+ "strainer",
830
+ "tram",
831
+ "stretcher",
832
+ "couch",
833
+ "stupa",
834
+ "submarine",
835
+ "suit",
836
+ "sundial",
837
+ "sunglass",
838
+ "sunglasses",
839
+ "sunscreen",
840
+ "suspension bridge",
841
+ "mop",
842
+ "sweatshirt",
843
+ "swimsuit",
844
+ "swing",
845
+ "switch",
846
+ "syringe",
847
+ "table lamp",
848
+ "tank",
849
+ "tape player",
850
+ "teapot",
851
+ "teddy bear",
852
+ "television",
853
+ "tennis ball",
854
+ "thatched roof",
855
+ "front curtain",
856
+ "thimble",
857
+ "threshing machine",
858
+ "throne",
859
+ "tile roof",
860
+ "toaster",
861
+ "tobacco shop",
862
+ "toilet seat",
863
+ "torch",
864
+ "totem pole",
865
+ "tow truck",
866
+ "toy store",
867
+ "tractor",
868
+ "semi-trailer truck",
869
+ "tray",
870
+ "trench coat",
871
+ "tricycle",
872
+ "trimaran",
873
+ "tripod",
874
+ "triumphal arch",
875
+ "trolleybus",
876
+ "trombone",
877
+ "tub",
878
+ "turnstile",
879
+ "typewriter keyboard",
880
+ "umbrella",
881
+ "unicycle",
882
+ "upright piano",
883
+ "vacuum cleaner",
884
+ "vase",
885
+ "vault",
886
+ "velvet",
887
+ "vending machine",
888
+ "vestment",
889
+ "viaduct",
890
+ "violin",
891
+ "volleyball",
892
+ "waffle iron",
893
+ "wall clock",
894
+ "wallet",
895
+ "wardrobe",
896
+ "military aircraft",
897
+ "sink",
898
+ "washing machine",
899
+ "water bottle",
900
+ "water jug",
901
+ "water tower",
902
+ "whiskey jug",
903
+ "whistle",
904
+ "wig",
905
+ "window screen",
906
+ "window shade",
907
+ "Windsor tie",
908
+ "wine bottle",
909
+ "wing",
910
+ "wok",
911
+ "wooden spoon",
912
+ "wool",
913
+ "split-rail fence",
914
+ "shipwreck",
915
+ "yawl",
916
+ "yurt",
917
+ "website",
918
+ "comic book",
919
+ "crossword",
920
+ "traffic sign",
921
+ "traffic light",
922
+ "dust jacket",
923
+ "menu",
924
+ "plate",
925
+ "guacamole",
926
+ "consomme",
927
+ "hot pot",
928
+ "trifle",
929
+ "ice cream",
930
+ "ice pop",
931
+ "baguette",
932
+ "bagel",
933
+ "pretzel",
934
+ "cheeseburger",
935
+ "hot dog",
936
+ "mashed potato",
937
+ "cabbage",
938
+ "broccoli",
939
+ "cauliflower",
940
+ "zucchini",
941
+ "spaghetti squash",
942
+ "acorn squash",
943
+ "butternut squash",
944
+ "cucumber",
945
+ "artichoke",
946
+ "bell pepper",
947
+ "cardoon",
948
+ "mushroom",
949
+ "Granny Smith",
950
+ "strawberry",
951
+ "orange",
952
+ "lemon",
953
+ "fig",
954
+ "pineapple",
955
+ "banana",
956
+ "jackfruit",
957
+ "custard apple",
958
+ "pomegranate",
959
+ "hay",
960
+ "carbonara",
961
+ "chocolate syrup",
962
+ "dough",
963
+ "meatloaf",
964
+ "pizza",
965
+ "pot pie",
966
+ "burrito",
967
+ "red wine",
968
+ "espresso",
969
+ "cup",
970
+ "eggnog",
971
+ "alp",
972
+ "bubble",
973
+ "cliff",
974
+ "coral reef",
975
+ "geyser",
976
+ "lakeshore",
977
+ "promontory",
978
+ "shoal",
979
+ "seashore",
980
+ "valley",
981
+ "volcano",
982
+ "baseball player",
983
+ "bridegroom",
984
+ "scuba diver",
985
+ "rapeseed",
986
+ "daisy",
987
+ "yellow lady's slipper",
988
+ "corn",
989
+ "acorn",
990
+ "rose hip",
991
+ "horse chestnut seed",
992
+ "coral fungus",
993
+ "agaric",
994
+ "gyromitra",
995
+ "stinkhorn mushroom",
996
+ "earth star",
997
+ "hen-of-the-woods",
998
+ "bolete",
999
+ "ear of corn",
1000
+ "toilet paper"]
pl_train.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ from typing import Optional, Tuple
4
+ import glob
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import DataLoader
10
+ from torchvision import transforms, models, datasets
11
+ from pytorch_lightning import LightningModule, Trainer
12
+ from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
13
+ from loguru import logger
14
+
15
+ class CustomProgressBar(TQDMProgressBar):
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.enable = True
19
+
20
+ def on_train_epoch_start(self, trainer, pl_module):
21
+ super().on_train_epoch_start(trainer, pl_module)
22
+ logger.info(f"\n{'='*20} Epoch {trainer.current_epoch} {'='*20}")
23
+
24
+ class ImageNetModule(LightningModule):
25
+ def __init__(
26
+ self,
27
+ learning_rate: float = 0.1,
28
+ momentum: float = 0.9,
29
+ weight_decay: float = 1e-4,
30
+ batch_size: int = 256,
31
+ num_workers: int = 16,
32
+ max_epochs: int = 90,
33
+ train_path: str = "path/to/imagenet",
34
+ val_path: str = "path/to/imagenet",
35
+ checkpoint_dir: str = "checkpoints"
36
+ ):
37
+ super().__init__()
38
+ # self.save_hyperparameters()
39
+
40
+ # Model
41
+ self.model = models.resnet50(weights=None)
42
+
43
+ # Training parameters
44
+ self.learning_rate = learning_rate
45
+ self.momentum = momentum
46
+ self.weight_decay = weight_decay
47
+ self.batch_size = batch_size
48
+ self.num_workers = num_workers
49
+ self.max_epochs = max_epochs
50
+ self.train_path = train_path
51
+ self.val_path = val_path
52
+ self.checkpoint_dir = checkpoint_dir
53
+
54
+ # Metrics tracking
55
+ self.training_step_outputs = []
56
+ self.validation_step_outputs = []
57
+ self.best_val_acc = 0.0
58
+
59
+ # Set up transforms
60
+ self.train_transforms = transforms.Compose([
61
+ transforms.RandomResizedCrop(224),
62
+ transforms.RandomHorizontalFlip(),
63
+ transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
66
+ std=[0.229, 0.224, 0.225])
67
+ ])
68
+
69
+ self.val_transforms = transforms.Compose([
70
+ transforms.Resize(256),
71
+ transforms.CenterCrop(224),
72
+ transforms.ToTensor(),
73
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
74
+ std=[0.229, 0.224, 0.225])
75
+ ])
76
+
77
+ def forward(self, x):
78
+ return self.model(x)
79
+
80
+ def training_step(self, batch, batch_idx):
81
+ images, labels = batch
82
+ outputs = self(images)
83
+ loss = F.cross_entropy(outputs, labels)
84
+
85
+ # Calculate accuracy
86
+ _, predicted = torch.max(outputs.data, 1)
87
+ correct = (predicted == labels).sum().item()
88
+ accuracy = (correct / labels.size(0))*100
89
+
90
+ # Log metrics for this step
91
+ self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
92
+ self.log('train_acc', accuracy, on_step=False, on_epoch=True, prog_bar=True)
93
+
94
+ self.training_step_outputs.append({
95
+ 'loss': loss.detach(),
96
+ 'acc': torch.tensor(accuracy)
97
+ })
98
+
99
+ return loss
100
+
101
+ def on_train_epoch_end(self):
102
+ if not self.training_step_outputs:
103
+ print("Warning: No training outputs available for this epoch")
104
+ return
105
+ avg_loss = torch.stack([x['loss'] for x in self.training_step_outputs]).mean()
106
+ avg_acc = torch.stack([x['acc'] for x in self.training_step_outputs]).mean()
107
+
108
+ # Get current learning rate
109
+ current_lr = self.optimizers().param_groups[0]['lr']
110
+
111
+ logger.info(f"Training metrics - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}, LR: {current_lr:.6f}")
112
+
113
+ self.training_step_outputs.clear()
114
+
115
+ def validation_step(self, batch, batch_idx):
116
+ images, labels = batch
117
+ outputs = self(images)
118
+ loss = F.cross_entropy(outputs, labels)
119
+
120
+ # Calculate accuracy
121
+ _, predicted = torch.max(outputs.data, 1)
122
+ correct = (predicted == labels).sum().item()
123
+ accuracy = (correct / labels.size(0))*100
124
+
125
+ # Log metrics for this step
126
+ self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
127
+ self.log('val_acc', accuracy, on_step=False, on_epoch=True, prog_bar=True)
128
+
129
+ self.validation_step_outputs.append({
130
+ 'val_loss': loss.detach(),
131
+ 'val_acc': torch.tensor(accuracy)
132
+ })
133
+
134
+ return {'val_loss': loss, 'val_acc': accuracy}
135
+
136
+ def on_validation_epoch_end(self):
137
+ avg_loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean()
138
+ avg_acc = torch.stack([x['val_acc'] for x in self.validation_step_outputs]).mean()
139
+
140
+ # Log final validation metrics
141
+ self.log('val_loss_epoch', avg_loss)
142
+ self.log('val_acc_epoch', avg_acc)
143
+
144
+ # Save checkpoint if validation accuracy improves
145
+ if avg_acc > self.best_val_acc:
146
+ self.best_val_acc = avg_acc
147
+ checkpoint_path = os.path.join(
148
+ self.checkpoint_dir,
149
+ f"resnet50-epoch{self.current_epoch:02d}-acc{avg_acc:.4f}.ckpt"
150
+ )
151
+ self.trainer.save_checkpoint(checkpoint_path)
152
+ logger.info(f"New best validation accuracy: {avg_acc:.4f}. Saved checkpoint to {checkpoint_path}")
153
+
154
+ logger.info(f"Validation metrics - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}")
155
+
156
+ self.validation_step_outputs.clear()
157
+
158
+ def train_dataloader(self):
159
+ train_dataset = datasets.ImageFolder(
160
+ self.train_path,
161
+ transform=self.train_transforms
162
+ )
163
+ return DataLoader(
164
+ train_dataset,
165
+ batch_size=self.batch_size,
166
+ shuffle=True,
167
+ num_workers=self.num_workers,
168
+ pin_memory=True
169
+ )
170
+
171
+ def val_dataloader(self):
172
+ val_dataset = datasets.ImageFolder(
173
+ self.val_path,
174
+ transform=self.val_transforms
175
+ )
176
+ return DataLoader(
177
+ val_dataset,
178
+ batch_size=self.batch_size,
179
+ shuffle=False,
180
+ num_workers=self.num_workers,
181
+ pin_memory=True
182
+ )
183
+
184
+ def configure_optimizers(self):
185
+ optimizer = torch.optim.SGD(
186
+ self.parameters(),
187
+ lr=self.learning_rate,
188
+ momentum=self.momentum,
189
+ weight_decay=self.weight_decay
190
+ )
191
+
192
+ # OneCycleLR scheduler
193
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
194
+ optimizer,
195
+ max_lr=self.learning_rate,
196
+ epochs=self.max_epochs,
197
+ steps_per_epoch=len(self.train_dataloader()),
198
+ pct_start=0.3,
199
+ anneal_strategy='cos',
200
+ div_factor=25.0,
201
+ cycle_momentum=True,
202
+ base_momentum=0.85,
203
+ max_momentum=0.95,
204
+ )
205
+
206
+ return {
207
+ "optimizer": optimizer,
208
+ "lr_scheduler": {
209
+ "scheduler": scheduler,
210
+ "interval": "step"
211
+ }
212
+ }
213
+
214
+ def setup_logging(log_dir="logs"):
215
+ os.makedirs(log_dir, exist_ok=True)
216
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
217
+ log_file = os.path.join(log_dir, f"training_{timestamp}.log")
218
+
219
+ logger.remove()
220
+ logger.add(
221
+ lambda msg: print(msg),
222
+ format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | {message}",
223
+ colorize=True,
224
+ level="INFO"
225
+ )
226
+
227
+ logger.add(
228
+ log_file,
229
+ format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
230
+ level="INFO",
231
+ rotation="100 MB",
232
+ retention="30 days"
233
+ )
234
+
235
+ logger.info(f"Logging setup complete. Logs will be saved to: {log_file}")
236
+ return log_file
237
+
238
+ def find_latest_checkpoint(checkpoint_dir: str) -> Optional[str]:
239
+ """Find the latest checkpoint file using various possible naming patterns."""
240
+ # Look for checkpoint files with different possible patterns
241
+ patterns = [
242
+ "*.ckpt", # Generic checkpoint files
243
+ "resnet50-epoch*.ckpt", # Our custom format
244
+ "*epoch=*.ckpt", # PyTorch Lightning default format
245
+ "checkpoint_epoch*.ckpt" # Another common format
246
+ ]
247
+
248
+ all_checkpoints = []
249
+ for pattern in patterns:
250
+ checkpoint_pattern = os.path.join(checkpoint_dir, pattern)
251
+ all_checkpoints.extend(glob.glob(checkpoint_pattern))
252
+
253
+ if not all_checkpoints:
254
+ logger.info("No existing checkpoints found.")
255
+ return None
256
+
257
+ def extract_info(checkpoint_path: str) -> Tuple[int, float]:
258
+ """Extract epoch and optional accuracy from checkpoint filename."""
259
+ filename = os.path.basename(checkpoint_path)
260
+
261
+ # Try different patterns to extract epoch number
262
+ epoch_patterns = [
263
+ r'epoch=(\d+)', # matches epoch=X
264
+ r'epoch(\d+)', # matches epochX
265
+ r'epoch[_-](\d+)', # matches epoch_X or epoch-X
266
+ ]
267
+
268
+ epoch = None
269
+ for pattern in epoch_patterns:
270
+ match = re.search(pattern, filename)
271
+ if match:
272
+ epoch = int(match.group(1))
273
+ break
274
+
275
+ # If no epoch found, try to get from file modification time
276
+ if epoch is None:
277
+ epoch = int(os.path.getmtime(checkpoint_path))
278
+
279
+ # Try to extract accuracy if present
280
+ acc_match = re.search(r'acc[_-]?([\d.]+)', filename)
281
+ acc = float(acc_match.group(1)) if acc_match else 0.0
282
+
283
+ return epoch, acc
284
+
285
+ try:
286
+ latest_checkpoint = max(all_checkpoints, key=lambda x: extract_info(x)[0])
287
+ epoch, acc = extract_info(latest_checkpoint)
288
+ logger.info(f"Found latest checkpoint: {latest_checkpoint}")
289
+ logger.info(f"Epoch: {epoch}" + (f", Accuracy: {acc:.4f}" if acc > 0 else ""))
290
+ return latest_checkpoint
291
+ except Exception as e:
292
+ logger.error(f"Error processing checkpoints: {str(e)}")
293
+ # If there's any error in parsing, return the most recently modified file
294
+ latest_checkpoint = max(all_checkpoints, key=os.path.getmtime)
295
+ logger.info(f"Falling back to most recently modified checkpoint: {latest_checkpoint}")
296
+ return latest_checkpoint
297
+
298
+
299
+ def main():
300
+ checkpoint_dir = "/home/ec2-user/ebs/volumes/era_session9"
301
+ log_file = setup_logging(log_dir=checkpoint_dir)
302
+
303
+ logger.info("Starting training with configuration:")
304
+ logger.info(f"PyTorch version: {torch.__version__}")
305
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
306
+ if torch.cuda.is_available():
307
+ logger.info(f"CUDA device count: {torch.cuda.device_count()}")
308
+ logger.info(f"CUDA devices: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}")
309
+
310
+ # Find latest checkpoint
311
+ # latest_checkpoint = find_latest_checkpoint(checkpoint_dir)
312
+ latest_checkpoint = "/home/ec2-user/ebs/volumes/era_session9/resnet50-epoch18-acc53.7369.ckpt"
313
+
314
+ model = ImageNetModule(
315
+ learning_rate=0.156,
316
+ batch_size=256,
317
+ num_workers=16,
318
+ max_epochs=60,
319
+ train_path="/home/ec2-user/ebs/volumes/imagenet/ILSVRC/Data/CLS-LOC/train",
320
+ val_path="/home/ec2-user/ebs/volumes/imagenet/imagenet_validation",
321
+ checkpoint_dir=checkpoint_dir
322
+ )
323
+
324
+ logger.info(f"Model configuration:")
325
+ logger.info(f"Learning rate: {model.learning_rate}")
326
+ logger.info(f"Batch size: {model.batch_size}")
327
+ logger.info(f"Number of workers: {model.num_workers}")
328
+ logger.info(f"Max epochs: {model.max_epochs}")
329
+
330
+ progress_bar = CustomProgressBar()
331
+
332
+ trainer = Trainer(
333
+ max_epochs=60,
334
+ accelerator="gpu",
335
+ devices=4,
336
+ strategy="ddp",
337
+ precision=16,
338
+ callbacks=[progress_bar],
339
+ enable_progress_bar=True,
340
+ )
341
+
342
+ logger.info("Starting training")
343
+
344
+ try:
345
+ if latest_checkpoint:
346
+ logger.info(f"Resuming training from checkpoint: {latest_checkpoint}")
347
+ trainer.fit(model, ckpt_path=latest_checkpoint)
348
+ else:
349
+ logger.info("Starting training from scratch")
350
+ trainer.fit(model)
351
+
352
+ logger.info("Training completed successfully")
353
+ except Exception as e:
354
+ logger.error(f"Training failed with error: {str(e)}")
355
+ raise
356
+ finally:
357
+ logger.info(f"Training session ended. Log file: {log_file}")
358
+
359
+ if __name__ == "__main__":
360
+ main()
361
+ # pass
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ typing
2
+ torch
3
+ torchvision
4
+ pytorch-lightning
5
+ loguru
6
+ Pillow
7
+ huggingface_hub