AbstractPhil commited on
Commit
c6749b5
·
verified ·
1 Parent(s): cd84b58

Create trainer_pentachora_greyscale_frequency_encoded.ipynb

Browse files
trainer_pentachora_greyscale_frequency_encoded.ipynb ADDED
@@ -0,0 +1,1679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Pentachoron Constellation with Greyscale PentaFreq Encoder
4
+ Optimized with Batched Operations and Complete Loss Functions
5
+ Apache License 2.0
6
+ Author: AbstractPhil
7
+ Assistance: GPT 4o, GPT 5, Claude Opus 4.1, Claude Sonnet 4.0, Gemini
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torchvision import datasets, transforms
14
+ from torch.utils.data import DataLoader
15
+ import numpy as np
16
+ import matplotlib.pyplot as plt
17
+ from tqdm import tqdm
18
+ import time
19
+ import torch
20
+ import torchvision
21
+ from torchvision import datasets, transforms
22
+ from torch.utils.data import DataLoader
23
+ import numpy as np
24
+ import random
25
+
26
+
27
+ # ============================================================
28
+ # CONFIGURATION
29
+ # ============================================================
30
+
31
+ # Clear CUDA cache
32
+ if torch.cuda.is_available():
33
+ torch.cuda.empty_cache()
34
+
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ print(f"Using device: {device}")
37
+
38
+ # Hyperparameters
39
+ config = {
40
+ 'input_dim': 64,
41
+ 'base_dim': 64,
42
+ 'batch_size': 2048,
43
+ 'epochs': 50,
44
+ 'lr': 1e-1,
45
+ 'num_heads': 8,
46
+ 'num_pentachoron_pairs': 1,
47
+ 'loss_weight_scalar': 0.1,
48
+ 'lambda_separation': 0.29514,
49
+ 'temp': 0.70486,
50
+ "weight_decay": 1e-5,
51
+ }
52
+
53
+ print("\n" + "="*60)
54
+ print("PENTACHORON CONSTELLATION CONFIGURATION")
55
+ print("="*60)
56
+ for key, value in config.items():
57
+ print(f"{key:20}: {value}")
58
+
59
+ # ============================================================
60
+ # DATASET
61
+ # ============================================================
62
+
63
+ transform = transforms.Compose([
64
+ transforms.ToTensor(),
65
+ transforms.Lambda(lambda x: x.view(-1))
66
+ ])
67
+
68
+ # ============================================================
69
+ # SELECT YOUR DATASET HERE!
70
+ # ============================================================
71
+
72
+ DATASET_NAME = "OCTMNIST" # Change this to any dataset below!
73
+
74
+ # Available datasets (all 28x28):
75
+ AVAILABLE_DATASETS = {
76
+ "MNIST": "Classic handwritten digits (10 classes)",
77
+ "FashionMNIST": "Fashion items (10 classes) - The tough one!",
78
+ "KMNIST": "Kuzushiji-MNIST - Japanese characters (10 classes)",
79
+ "EMNIST": "Extended MNIST - Letters & digits (47 classes)",
80
+ "QMNIST": "MNIST with better test set (10 classes)",
81
+ "USPS": "US Postal Service digits (10 classes)",
82
+
83
+ # MedMNIST variants (medical images)
84
+ "BloodMNIST": "Blood cell types (8 classes)",
85
+ "PathMNIST": "Pathology images (9 classes)",
86
+ "OCTMNIST": "Retinal OCT (4 classes)",
87
+ "PneumoniaMNIST": "Chest X-Ray (2 classes)",
88
+ "DermaMNIST": "Dermatoscope images (7 classes)",
89
+ "RetinaMNIST": "Retina fundus (5 classes)",
90
+ "BreastMNIST": "Breast ultrasound (2 classes)",
91
+ "OrganAMNIST": "Abdominal CT - Axial (11 classes)",
92
+ "OrganCMNIST": "Abdominal CT - Coronal (11 classes)",
93
+ "OrganSMNIST": "Abdominal CT - Sagittal (11 classes)",
94
+ "TissueMNIST": "Tissue cells (8 classes)",
95
+ }
96
+ # ---------- MedMNIST INFO + helpers ----------
97
+ try:
98
+ import medmnist
99
+ from medmnist import INFO as MED_INFO # official dict
100
+ except Exception:
101
+ medmnist = None
102
+ MED_INFO = None
103
+
104
+ # Fallback labels/tasks/channels for the 2D sets you listed.
105
+ # Source: MedMNIST v2 dataset card / builder (labels) and project docs (tasks/channels).
106
+ FALLBACK_INFO = {
107
+ "bloodmnist": {
108
+ "python_class": "BloodMNIST",
109
+ "task": "multi-class",
110
+ "n_channels": 3,
111
+ "label": {
112
+ "0": "basophil",
113
+ "1": "eosinophil",
114
+ "2": "erythroblast",
115
+ "3": "immature granulocytes(myelocytes, metamyelocytes and promyelocytes)",
116
+ "4": "lymphocyte",
117
+ "5": "monocyte",
118
+ "6": "neutrophil",
119
+ "7": "platelet",
120
+ },
121
+ },
122
+ "pathmnist": {
123
+ "python_class": "PathMNIST",
124
+ "task": "multi-class",
125
+ "n_channels": 3,
126
+ "label": {
127
+ "0": "adipose",
128
+ "1": "background",
129
+ "2": "debris",
130
+ "3": "lymphocytes",
131
+ "4": "mucus",
132
+ "5": "smooth muscle",
133
+ "6": "normal colon mucosa",
134
+ "7": "cancer-associated stroma",
135
+ "8": "colorectal adenocarcinoma epithelium",
136
+ },
137
+ },
138
+ "octmnist": {
139
+ "python_class": "OCTMNIST",
140
+ "task": "multi-class",
141
+ "n_channels": 1,
142
+ "label": {
143
+ "0": "choroidal neovascularization",
144
+ "1": "diabetic macular edema",
145
+ "2": "drusen",
146
+ "3": "normal",
147
+ },
148
+ },
149
+ "pneumoniamnist": {
150
+ "python_class": "PneumoniaMNIST",
151
+ "task": "binary-class",
152
+ "n_channels": 1,
153
+ "label": {
154
+ "0": "normal",
155
+ "1": "pneumonia",
156
+ },
157
+ },
158
+ "dermamnist": {
159
+ "python_class": "DermaMNIST",
160
+ "task": "multi-class",
161
+ "n_channels": 3,
162
+ "label": {
163
+ "0": "actinic keratoses and intraepithelial carcinoma",
164
+ "1": "basal cell carcinoma",
165
+ "2": "benign keratosis-like lesions",
166
+ "3": "dermatofibroma",
167
+ "4": "melanoma",
168
+ "5": "melanocytic nevi",
169
+ "6": "vascular lesions",
170
+ },
171
+ },
172
+ "retinamnist": {
173
+ "python_class": "RetinaMNIST",
174
+ "task": "ordinal-regression",
175
+ "n_channels": 3,
176
+ "label": { # ordinal 0..4
177
+ "0": "0",
178
+ "1": "1",
179
+ "2": "2",
180
+ "3": "3",
181
+ "4": "4",
182
+ },
183
+ },
184
+ "breastmnist": {
185
+ "python_class": "BreastMNIST",
186
+ "task": "binary-class",
187
+ "n_channels": 1,
188
+ "label": {
189
+ "0": "malignant",
190
+ "1": "normal, benign",
191
+ },
192
+ },
193
+ "tissuemnist": {
194
+ "python_class": "TissueMNIST",
195
+ "task": "multi-class",
196
+ "n_channels": 1,
197
+ "label": {
198
+ "0": "Collecting Duct, Connecting Tubule",
199
+ "1": "Distal Convoluted Tubule",
200
+ "2": "Glomerular endothelial cells",
201
+ "3": "Interstitial endothelial cells",
202
+ "4": "Leukocytes",
203
+ "5": "Podocytes",
204
+ "6": "Proximal Tubule Segments",
205
+ "7": "Thick Ascending Limb",
206
+ },
207
+ },
208
+ # The Organ* 2D sets share the same 11 organ names; channels are grayscale.
209
+ "organamnist": {
210
+ "python_class": "OrganAMNIST",
211
+ "task": "multi-class",
212
+ "n_channels": 1,
213
+ "label": {
214
+ "0": "liver", "1": "kidney-right", "2": "kidney-left",
215
+ "3": "femur-right", "4": "femur-left", "5": "bladder",
216
+ "6": "heart", "7": "lung-right", "8": "lung-left",
217
+ "9": "spleen", "10": "pancreas",
218
+ },
219
+ },
220
+ "organcmnist": {
221
+ "python_class": "OrganCMNIST",
222
+ "task": "multi-class",
223
+ "n_channels": 1,
224
+ "label": {
225
+ "0": "liver", "1": "kidney-right", "2": "kidney-left",
226
+ "3": "femur-right", "4": "femur-left", "5": "bladder",
227
+ "6": "heart", "7": "lung-right", "8": "lung-left",
228
+ "9": "spleen", "10": "pancreas",
229
+ },
230
+ },
231
+ "organsmnist": {
232
+ "python_class": "OrganSMNIST",
233
+ "task": "multi-class",
234
+ "n_channels": 1,
235
+ "label": {
236
+ "0": "liver", "1": "kidney-right", "2": "kidney-left",
237
+ "3": "femur-right", "4": "femur-left", "5": "bladder",
238
+ "6": "heart", "7": "lung-right", "8": "lung-left",
239
+ "9": "spleen", "10": "pancreas",
240
+ },
241
+ },
242
+ }
243
+
244
+ def as_class_indices(t: torch.Tensor) -> torch.Tensor:
245
+ """
246
+ Normalize MedMNIST-style labels to 1D Long class indices for CE loss.
247
+ - Accepts shapes: [], [B], [B,1], or one-hot [B,C]
248
+ - Returns shape [B], dtype torch.long
249
+ """
250
+ if t.ndim == 0: # scalar
251
+ return t.long().view(1)
252
+ if t.ndim == 1:
253
+ return t.long()
254
+ # ndims >= 2
255
+ if t.size(-1) == 1:
256
+ t = t.squeeze(-1)
257
+ return t.long()
258
+ # likely one-hot [B,C]
259
+ return t.argmax(dim=-1).long()
260
+
261
+ def get_med_info(flag: str) -> dict:
262
+ """Return official medmnist.INFO[flag] if available, else fallback."""
263
+ if MED_INFO is not None and flag in MED_INFO:
264
+ return MED_INFO[flag]
265
+ if flag in FALLBACK_INFO:
266
+ return FALLBACK_INFO[flag]
267
+ raise KeyError(f"Unknown MedMNIST flag: {flag}")
268
+
269
+ def make_med_transform(n_channels: int):
270
+ """
271
+ ToTensor -> ensure single gray channel -> flatten to 784 for your pipeline.
272
+ We keep your 28x28 target and collapse channels deterministically.
273
+ """
274
+ return transforms.Compose([
275
+ transforms.ToTensor(),
276
+ transforms.Lambda(lambda t: t[:1, :, :] if t.shape[0] > 1 else t), # pick first channel if RGB
277
+ transforms.Lambda(lambda t: t.view(-1)),
278
+ ])
279
+
280
+ def med_class_names_from_info(info: dict):
281
+ """Convert label dict -> ordered list by index: ['name0','name1',...]"""
282
+ label_dict = info["label"]
283
+ return [label_dict[str(i)] for i in range(len(label_dict))]
284
+
285
+ # ============================================================
286
+ # DATASET LOADER
287
+ # ============================================================
288
+
289
+ def get_dataset(name=DATASET_NAME, batch_size=128, num_workers=2):
290
+ """
291
+ Universal loader for all MNIST-like datasets.
292
+ Returns train_loader, test_loader, num_classes, class_names
293
+ """
294
+
295
+ print(f"\n{'='*60}")
296
+ print(f"Loading {name}")
297
+ print(f"Description: {AVAILABLE_DATASETS.get(name, 'Unknown dataset')}")
298
+ print(f"{'='*60}")
299
+
300
+ # Standard transform for all datasets
301
+ transform = transforms.Compose([
302
+ transforms.ToTensor(),
303
+ transforms.Lambda(lambda x: x.view(-1)) # Flatten to 784
304
+ ])
305
+
306
+ # Special transform for grayscale conversion if needed
307
+ transform_gray = transforms.Compose([
308
+ transforms.Grayscale(num_output_channels=config.get("n_channels", 1)),
309
+ transforms.ToTensor(),
310
+ transforms.Lambda(lambda x: x.view(-1))
311
+ ])
312
+
313
+ # STANDARD TORCHVISION DATASETS
314
+ if name == "MNIST":
315
+ train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
316
+ test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
317
+ num_classes = 10
318
+ class_names = [str(i) for i in range(10)]
319
+
320
+ elif name == "FashionMNIST":
321
+ train_dataset = datasets.FashionMNIST(root="./data", train=True, transform=transform, download=True)
322
+ test_dataset = datasets.FashionMNIST(root="./data", train=False, transform=transform, download=True)
323
+ num_classes = 10
324
+ class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
325
+ 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
326
+
327
+ elif name == "KMNIST":
328
+ train_dataset = datasets.KMNIST(root="./data", train=True, transform=transform, download=True)
329
+ test_dataset = datasets.KMNIST(root="./data", train=False, transform=transform, download=True)
330
+ num_classes = 10
331
+ class_names = ['お', 'き', 'す', 'つ', 'な', 'は', 'ま', 'や', 'れ', 'を']
332
+
333
+ elif name == "EMNIST":
334
+ # Using 'balanced' split - 47 classes (digits + letters)
335
+ train_dataset = datasets.EMNIST(root="./data", split='balanced', train=True, transform=transform, download=True)
336
+ test_dataset = datasets.EMNIST(root="./data", split='balanced', train=False, transform=transform, download=True)
337
+ num_classes = 47
338
+ # class_names = [str(i) for i in range(47)] # Mix of digits and letters
339
+ class_names = [
340
+ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
341
+ 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
342
+ 'a', 'b', 'd', 'e', 'f', 'g', 'h', 'n', 'q', 'r', 't'
343
+ ]
344
+
345
+ elif name == "QMNIST":
346
+ train_dataset = datasets.QMNIST(root="./data", what='train', transform=transform, download=True)
347
+ test_dataset = datasets.QMNIST(root="./data", what='test', transform=transform, download=True)
348
+ num_classes = 10
349
+ class_names = [str(i) for i in range(10)]
350
+
351
+ elif name == "USPS":
352
+ # USPS is 16x16, need to resize
353
+ transform_usps = transforms.Compose([
354
+ transforms.Resize((28, 28)),
355
+ transforms.ToTensor(),
356
+ transforms.Lambda(lambda x: x.view(-1))
357
+ ])
358
+ train_dataset = datasets.USPS(root="./data", train=True, transform=transform_usps, download=True)
359
+ test_dataset = datasets.USPS(root="./data", train=False, transform=transform_usps, download=True)
360
+ num_classes = 10
361
+ class_names = [str(i) for i in range(10)]
362
+
363
+ # MEDMNIST DATASETS
364
+ elif name in ["BloodMNIST", "PathMNIST", "OCTMNIST", "PneumoniaMNIST",
365
+ "DermaMNIST", "RetinaMNIST", "BreastMNIST",
366
+ "OrganAMNIST", "OrganCMNIST", "OrganSMNIST", "TissueMNIST"]:
367
+
368
+ # Map UI names to medmnist flags
369
+ medmnist_map = {
370
+ "BloodMNIST": "bloodmnist",
371
+ "PathMNIST": "pathmnist",
372
+ "OCTMNIST": "octmnist",
373
+ "PneumoniaMNIST": "pneumoniamnist",
374
+ "DermaMNIST": "dermamnist",
375
+ "RetinaMNIST": "retinamnist",
376
+ "BreastMNIST": "breastmnist",
377
+ "OrganAMNIST": "organamnist",
378
+ "OrganCMNIST": "organcmnist",
379
+ "OrganSMNIST": "organsmnist",
380
+ "TissueMNIST": "tissuemnist",
381
+ }
382
+
383
+ dataset_flag = medmnist_map[name]
384
+ info = get_med_info(dataset_flag)
385
+
386
+ # Require the package to actually load data
387
+ if medmnist is None:
388
+ raise ImportError(
389
+ "medmnist is not installed. Run: pip install medmnist\n"
390
+ f"(INFO fallback is provided; DataClass={info['python_class']} needs the package.)"
391
+ )
392
+
393
+ DataClass = getattr(medmnist, info["python_class"])
394
+
395
+ # Transform: force 1-channel grayscale then flatten to 784
396
+ transform_med = make_med_transform(info["n_channels"])
397
+
398
+ # 28x28 size (default); you can bump to 64/128/224 by size=...
399
+ train_dataset = DataClass(split='train', transform=transform_med, download=True, size=28)
400
+ test_dataset = DataClass(split='test', transform=transform_med, download=True, size=28)
401
+
402
+ num_classes = len(info["label"])
403
+ class_names = med_class_names_from_info(info)
404
+
405
+ print(f" MedMNIST Dataset: {dataset_flag}")
406
+ print(f" Task: {info['task']}")
407
+ print(f" Classes: {num_classes} | Channels: {info['n_channels']}")
408
+
409
+ else:
410
+ raise ValueError(f"Unknown dataset: {name}. Choose from: {list(AVAILABLE_DATASETS.keys())}")
411
+
412
+ # Create data loaders
413
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
414
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
415
+
416
+ print(f"\nDataset loaded successfully!")
417
+ print(f" Train samples: {len(train_dataset):,}")
418
+ print(f" Test samples: {len(test_dataset):,}")
419
+ print(f" Number of classes: {num_classes}")
420
+ print(f" Input shape: 28x28 = 784 dimensions")
421
+
422
+ return train_loader, test_loader, num_classes, class_names
423
+
424
+ #train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2)
425
+ #test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2)
426
+
427
+ train_loader, test_loader, num_classes, class_names = get_dataset(DATASET_NAME, config['batch_size'])
428
+
429
+ config['num_classes'] = num_classes
430
+
431
+ FASHION_CLASSES = class_names #[
432
+ # '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'
433
+ #'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
434
+ #'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
435
+ #]
436
+
437
+ print(f"\nDataset loaded:")
438
+ #print(f" Train: {len(train_dataset):,} samples")
439
+ #print(f" Test: {len(test_dataset):,} samples")
440
+
441
+
442
+
443
+
444
+ # ============================
445
+ # ADDITIONS: saving & hub push
446
+ # ============================
447
+ import os, json, math, platform, sys, shutil, zipfile
448
+ from pathlib import Path
449
+ from datetime import datetime
450
+
451
+ # Auto-install per Phil’s preference
452
+ def _ensure(pkg, pip_name=None):
453
+ pip_name = pip_name or pkg
454
+ try:
455
+ __import__(pkg)
456
+ except Exception:
457
+ print(f"[setup] Installing {pip_name} ...")
458
+ os.system(f"{sys.executable} -m pip install -q {pip_name}")
459
+
460
+ _ensure("safetensors")
461
+ _ensure("huggingface_hub")
462
+ _ensure("psutil")
463
+ _ensure("pandas")
464
+
465
+ from safetensors.torch import save_file as save_safetensors
466
+ from huggingface_hub import HfApi, create_repo, whoami, login
467
+ from torch.utils.tensorboard import SummaryWriter
468
+ import psutil
469
+ import pandas as pd
470
+
471
+ def _param_count(model: torch.nn.Module) -> int:
472
+ return sum(p.numel() for p in model.parameters())
473
+
474
+ def _timestamp():
475
+ return datetime.now().strftime("%Y%m%d-%H%M%S")
476
+
477
+ def _resolve_repo_id(config: dict) -> str:
478
+ rid = os.getenv("PENTACHORA_HF_REPO") or config.get("hf_repo_id")
479
+ if not rid:
480
+ raise RuntimeError(
481
+ "Hugging Face repo id is not set. Set config['hf_repo_id'] or PENTACHORA_HF_REPO env var."
482
+ )
483
+ return rid
484
+
485
+ def _hf_login_if_needed():
486
+ # Use existing login if available; otherwise try HF_TOKEN
487
+ try:
488
+ _ = whoami()
489
+ return
490
+ except Exception:
491
+ token = os.getenv("HF_TOKEN")
492
+ if not token:
493
+ print("[huggingface] No active login and HF_TOKEN not set; if push fails, run huggingface-cli login.")
494
+ return
495
+ login(token=token, add_to_git_credential=True)
496
+
497
+ def _ensure_repo(repo_id: str):
498
+ api = HfApi()
499
+ create_repo(repo_id=repo_id, private=False, exist_ok=True, repo_type="model")
500
+ return api
501
+
502
+ def _zip_dir(src_dir: Path, zip_path: Path):
503
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z:
504
+ for p in src_dir.rglob("*"):
505
+ z.write(p, arcname=p.relative_to(src_dir))
506
+
507
+ def save_and_push_artifacts(
508
+ encoder: nn.Module,
509
+ constellation: nn.Module,
510
+ diagnostic_head: nn.Module,
511
+ config: dict,
512
+ class_names: list,
513
+ history: dict,
514
+ best_acc: float,
515
+ tb_log_dir: Path,
516
+ last_confusion_png: Path | None,
517
+ repo_subdir_root: str = "pentachora-adaptive-encoded",
518
+ ):
519
+ """
520
+ Saves safetensors + metadata locally and pushes to HF Hub under:
521
+ <repo>/<repo_subdir_root>/<timestamp>/
522
+ """
523
+ ts = _timestamp()
524
+ repo_id = _resolve_repo_id(config)
525
+ _hf_login_if_needed()
526
+ api = _ensure_repo(repo_id)
527
+
528
+ # ---------- local layout ----------
529
+ base_out = Path("artifacts") / repo_subdir_root / ts
530
+ base_out.mkdir(parents=True, exist_ok=True)
531
+
532
+ # 1) Weights
533
+ weights_dir = base_out / "weights"
534
+ weights_dir.mkdir(parents=True, exist_ok=True)
535
+ # Save each module separately to keep them composable
536
+ save_safetensors({k: v.cpu() for k, v in encoder.state_dict().items()}, str(weights_dir / "encoder.safetensors"))
537
+ save_safetensors({k: v.cpu() for k, v in constellation.state_dict().items()}, str(weights_dir / "constellation.safetensors"))
538
+ save_safetensors({k: v.cpu() for k, v in diagnostic_head.state_dict().items()}, str(weights_dir / "diagnostic_head.safetensors"))
539
+
540
+ # 2) Config
541
+ conf_path = base_out / "config.json"
542
+ with conf_path.open("w", encoding="utf-8") as f:
543
+ json.dump(config, f, indent=2, sort_keys=True)
544
+
545
+ # 3) History (per-epoch metrics) and CSV
546
+ hist_json = base_out / "history.json"
547
+ with hist_json.open("w", encoding="utf-8") as f:
548
+ json.dump(history, f, indent=2, sort_keys=True)
549
+ # CSV
550
+ max_len = max(len(history.get("train_loss", [])),
551
+ len(history.get("train_acc", [])),
552
+ len(history.get("test_acc", [])))
553
+ df = pd.DataFrame({
554
+ "epoch": list(range(1, max_len + 1)),
555
+ "train_loss": history.get("train_loss", [math.nan]*max_len),
556
+ "train_acc": history.get("train_acc", [math.nan]*max_len),
557
+ "test_acc": history.get("test_acc", [math.nan]*max_len),
558
+ })
559
+ df.to_csv(base_out / "history.csv", index=False)
560
+
561
+ # 4) Manifest
562
+ manifest = {
563
+ "timestamp": ts,
564
+ "repo_id": repo_id,
565
+ "subdirectory": f"{repo_subdir_root}/{ts}",
566
+ "dataset_name": DATASET_NAME,
567
+ "class_names": class_names,
568
+ "num_classes": len(class_names),
569
+ "models": {
570
+ "encoder": {"params": _param_count(encoder)},
571
+ "constellation": {"params": _param_count(constellation)},
572
+ "diagnostic_head": {"params": _param_count(diagnostic_head)},
573
+ },
574
+ "results": {
575
+ "best_test_accuracy": best_acc,
576
+ },
577
+ "environment": {
578
+ "python": sys.version,
579
+ "platform": platform.platform(),
580
+ "torch": torch.__version__,
581
+ "cuda_available": torch.cuda.is_available(),
582
+ "cuda_device": (torch.cuda.get_device_name(0) if torch.cuda.is_available() else None),
583
+ "cpu_count": psutil.cpu_count(logical=True),
584
+ "memory_gb": round(psutil.virtual_memory().total / (1024**3), 2),
585
+ },
586
+ }
587
+ manifest_path = base_out / "manifest.json"
588
+ with manifest_path.open("w", encoding="utf-8") as f:
589
+ json.dump(manifest, f, indent=2, sort_keys=True)
590
+
591
+ # 5) Debug info
592
+ debug_txt = base_out / "debug.txt"
593
+ with debug_txt.open("w", encoding="utf-8") as f:
594
+ f.write("==== DEBUG INFO ====\n")
595
+ f.write(f"Timestamp: {ts}\n")
596
+ f.write(f"Repo: {repo_id}\n")
597
+ f.write(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}\n")
598
+ f.write(f"Encoder params: {_param_count(encoder)}\n")
599
+ f.write(f"Constellation params: {_param_count(constellation)}\n")
600
+ f.write(f"Diagnostic head params: {_param_count(diagnostic_head)}\n")
601
+ f.write(f"Best test accuracy: {best_acc:.6f}\n")
602
+
603
+ # 6) Plots (confusion matrix already saved during training; accuracy_plot.png at CWD)
604
+ # Copy if present
605
+ acc_plot = Path("accuracy_plot.png")
606
+ if acc_plot.exists():
607
+ shutil.copy2(acc_plot, base_out / "accuracy_plot.png")
608
+ if last_confusion_png and Path(last_confusion_png).exists():
609
+ shutil.copy2(last_confusion_png, base_out / Path(last_confusion_png).name)
610
+
611
+ # 7) TensorBoard ("the tensorflow") logs
612
+ # We copy the event files into artifacts, and zip them for convenience
613
+ tb_out = base_out / "tensorboard"
614
+ tb_out.mkdir(parents=True, exist_ok=True)
615
+ if tb_log_dir and Path(tb_log_dir).exists():
616
+ for p in Path(tb_log_dir).glob("*"):
617
+ shutil.copy2(p, tb_out / p.name)
618
+ _zip_dir(tb_out, base_out / "tensorboard_events.zip")
619
+
620
+ # 8) Also save a small README
621
+ readme = base_out / "README.md"
622
+ readme.write_text(
623
+ f"""# Pentachora Adaptive Encoded — {ts}
624
+
625
+ This folder is an immutable snapshot of training artifacts.
626
+
627
+ **Contents**
628
+ - `weights/*.safetensors` — encoder, constellation, diagnostic head
629
+ - `config.json` — full run configuration
630
+ - `manifest.json` — environment + model sizes + dataset
631
+ - `history.json` / `history.csv` — per-epoch metrics
632
+ - `tensorboard/` + `tensorboard_events.zip` — raw TB event files ("the tensorflow")
633
+ - `accuracy_plot.png` (if available)
634
+ - `best_confusion_matrix_epoch_*.png` (if available)
635
+ - `debug.txt` — quick human-readable summary
636
+ """,
637
+ encoding="utf-8"
638
+ )
639
+
640
+ # ---------- push to HF Hub ----------
641
+ print(f"[push] Uploading to hf://{repo_id}/{repo_subdir_root}/{ts}")
642
+ api.upload_folder(
643
+ repo_id=repo_id,
644
+ folder_path=str(base_out),
645
+ path_in_repo=f"{repo_subdir_root}/{ts}",
646
+ repo_type="model",
647
+ )
648
+ print("[push] ✅ Upload complete.")
649
+
650
+ return base_out, f"{repo_subdir_root}/{ts}"
651
+
652
+
653
+
654
+ # ============================================================
655
+ # PENTAFREQ ENCODER (Original 93% Version)
656
+ # ============================================================
657
+
658
+ class PentaFreqEncoder(nn.Module):
659
+ """
660
+ 5-Frequency Band Encoder designed to perfectly align with pentachoron vertices.
661
+ Each frequency band corresponds to one vertex of the pentachoron.
662
+
663
+ The adjacency relationships between frequency bands naturally form
664
+ the edge structure of the pentachoron!
665
+ """
666
+ def __init__(self, input_dim=784, base_dim=64):
667
+ super().__init__()
668
+ self.input_dim = input_dim
669
+ self.base_dim = base_dim
670
+ self.img_size = 28
671
+
672
+ self.unflatten = nn.Unflatten(1, (1, 28, 28))
673
+
674
+ # ========== 5 FREQUENCY EXTRACTORS ==========
675
+
676
+ # Vertex 0: Ultra-High Frequency (finest details, noise, texture grain)
677
+ self.v0_ultrahigh = nn.Sequential(
678
+ nn.Conv2d(1, 12, kernel_size=3, padding=1, stride=1),
679
+ nn.BatchNorm2d(12),
680
+ nn.ReLU(),
681
+ # Edge enhancement
682
+ nn.Conv2d(12, 12, kernel_size=3, padding=1, groups=12), # Depthwise
683
+ nn.BatchNorm2d(12),
684
+ nn.ReLU(),
685
+ nn.AdaptiveAvgPool2d(7),
686
+ nn.Flatten()
687
+ )
688
+ self.v0_encode = nn.Linear(12 * 49, base_dim)
689
+
690
+ # Vertex 1: High Frequency (edges, sharp transitions)
691
+ self.v1_high = nn.Sequential(
692
+ nn.Conv2d(1, 12, kernel_size=3, padding=1, stride=1),
693
+ nn.BatchNorm2d(12),
694
+ nn.Tanh(),
695
+ nn.MaxPool2d(2), # 14x14
696
+ nn.Conv2d(12, 12, kernel_size=3, padding=1),
697
+ nn.BatchNorm2d(12),
698
+ nn.Tanh(),
699
+ nn.AdaptiveAvgPool2d(7),
700
+ nn.Flatten()
701
+ )
702
+ self.v1_encode = nn.Linear(12 * 49, base_dim)
703
+
704
+ # Vertex 2: Mid Frequency (local patterns, textures)
705
+ self.v2_mid = nn.Sequential(
706
+ nn.Conv2d(1, 12, kernel_size=5, padding=2, stride=2), # 14x14
707
+ nn.BatchNorm2d(12),
708
+ nn.GELU(),
709
+ nn.Conv2d(12, 12, kernel_size=3, padding=1),
710
+ nn.BatchNorm2d(12),
711
+ nn.GELU(),
712
+ nn.AdaptiveAvgPool2d(7),
713
+ nn.Flatten()
714
+ )
715
+ self.v2_encode = nn.Linear(12 * 49, base_dim)
716
+
717
+ # Vertex 3: Low-Mid Frequency (shapes, regional features)
718
+ self.v3_lowmid = nn.Sequential(
719
+ nn.AvgPool2d(2), # Start with 14x14
720
+ nn.Conv2d(1, 12, kernel_size=7, padding=3),
721
+ nn.BatchNorm2d(12),
722
+ nn.SiLU(),
723
+ nn.AvgPool2d(2), # 7x7
724
+ nn.Flatten()
725
+ )
726
+ self.v3_encode = nn.Linear(12 * 49, base_dim)
727
+
728
+ # Vertex 4: Low Frequency (global structure, overall form)
729
+ self.v4_low = nn.Sequential(
730
+ nn.AvgPool2d(4), # Start with 7x7
731
+ nn.Conv2d(1, 12, kernel_size=7, padding=3),
732
+ nn.BatchNorm2d(12),
733
+ nn.Sigmoid(), # Smooth activation for global features
734
+ nn.AdaptiveAvgPool2d(7),
735
+ nn.Flatten()
736
+ )
737
+ self.v4_encode = nn.Linear(12 * 49, base_dim)
738
+
739
+ # ========== PENTACHORON ADJACENCY MATRIX ==========
740
+ # Defines which frequency bands are "adjacent" (connected by edges)
741
+ # This follows the edge structure of a perfect pentachoron
742
+ self.register_buffer('adjacency_matrix', self._create_pentachoron_adjacency())
743
+
744
+ # ========== FUSION NETWORK ==========
745
+ # Learns to combine all 5 frequency bands
746
+ self.fusion = nn.Sequential(
747
+ nn.Linear(base_dim * 5, base_dim * 3),
748
+ nn.BatchNorm1d(base_dim * 3),
749
+ nn.ReLU(),
750
+ nn.Dropout(0.2),
751
+ nn.Linear(base_dim * 3, base_dim * 2),
752
+ nn.BatchNorm1d(base_dim * 2),
753
+ nn.ReLU(),
754
+ nn.Linear(base_dim * 2, base_dim)
755
+ )
756
+
757
+ # Initialize edge detection kernels for ultra-high frequency
758
+ self._init_edge_kernels()
759
+
760
+ def _create_pentachoron_adjacency(self):
761
+ """
762
+ Create adjacency matrix for a complete graph (pentachoron).
763
+ In a 4-simplex, every vertex connects to every other vertex.
764
+ """
765
+ adj = torch.ones(5, 5) - torch.eye(5)
766
+ return adj
767
+
768
+ def _init_edge_kernels(self):
769
+ """Initialize V0 with various edge detection kernels."""
770
+ with torch.no_grad():
771
+ if hasattr(self.v0_ultrahigh[0], 'weight'):
772
+ kernels = self.v0_ultrahigh[0].weight
773
+ # Sobel X
774
+ kernels[0, 0] = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) / 4.0
775
+ # Sobel Y
776
+ kernels[1, 0] = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) / 4.0
777
+ # Laplacian
778
+ kernels[2, 0] = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]]) / 4.0
779
+ # Roberts Cross
780
+ kernels[3, 0] = torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 0]]) / 2.0
781
+ # Prewitt X
782
+ kernels[4, 0] = torch.tensor([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]) / 3.0
783
+
784
+ def forward(self, x):
785
+ batch_size = x.size(0)
786
+
787
+ # Reshape to image
788
+ x_img = self.unflatten(x)
789
+
790
+ # ========== EXTRACT 5 FREQUENCY BANDS ==========
791
+ # Each vertex processes a different frequency range
792
+
793
+ # V0: Ultra-high frequency
794
+ v0_features = self.v0_ultrahigh(x_img)
795
+ v0 = self.v0_encode(v0_features)
796
+
797
+ # V1: High frequency
798
+ v1_features = self.v1_high(x_img)
799
+ v1 = self.v1_encode(v1_features)
800
+
801
+ # V2: Mid frequency
802
+ v2_features = self.v2_mid(x_img)
803
+ v2 = self.v2_encode(v2_features)
804
+
805
+ # V3: Low-mid frequency
806
+ v3_features = self.v3_lowmid(x_img)
807
+ v3 = self.v3_encode(v3_features)
808
+
809
+ # V4: Low frequency
810
+ v4_features = self.v4_low(x_img)
811
+ v4 = self.v4_encode(v4_features)
812
+
813
+ # Stack all vertex features
814
+ vertices = torch.stack([v0, v1, v2, v3, v4], dim=1) # [B, 5, base_dim]
815
+
816
+ # ========== COMPUTE PENTACHORON EDGE WEIGHTS ==========
817
+ # Normalize each vertex
818
+ vertices_norm = F.normalize(vertices, dim=2)
819
+
820
+ # Compute pairwise similarities (edge strengths) - BATCHED
821
+ # Use bmm for efficiency instead of loops
822
+ similarities = torch.bmm(vertices_norm, vertices_norm.transpose(1, 2)) # [B, 5, 5]
823
+
824
+ # Apply pentachoron adjacency mask
825
+ edge_strengths = similarities * self.adjacency_matrix.unsqueeze(0)
826
+
827
+ # ========== WEIGHTED COMBINATION BASED ON EDGE STRUCTURE ==========
828
+ # Each vertex is weighted by its edge connections
829
+ edge_weights = edge_strengths.sum(dim=2) # [B, 5]
830
+ edge_weights = F.softmax(edge_weights, dim=1)
831
+
832
+ # Weight each frequency band - BATCHED
833
+ weighted_vertices = vertices * edge_weights.unsqueeze(2) # [B, 5, base_dim]
834
+
835
+ # ========== FUSION ==========
836
+ # Flatten all weighted frequency bands
837
+ combined = weighted_vertices.flatten(1) # [B, base_dim * 5]
838
+
839
+ # Fuse through network
840
+ fused = self.fusion(combined)
841
+
842
+ # Final normalization to unit sphere
843
+ output = F.normalize(fused, dim=1)
844
+
845
+ return output
846
+
847
+ def get_frequency_contributions(self, x):
848
+ """
849
+ Utility function to visualize how much each frequency band contributes.
850
+ Returns the weights for each vertex/frequency band.
851
+ """
852
+ with torch.no_grad():
853
+ # Run forward pass to get edge weights
854
+ x_img = self.unflatten(x)
855
+
856
+ # Extract all frequencies
857
+ v0 = self.v0_encode(self.v0_ultrahigh(x_img))
858
+ v1 = self.v1_encode(self.v1_high(x_img))
859
+ v2 = self.v2_encode(self.v2_mid(x_img))
860
+ v3 = self.v3_encode(self.v3_lowmid(x_img))
861
+ v4 = self.v4_encode(self.v4_low(x_img))
862
+
863
+ vertices = torch.stack([v0, v1, v2, v3, v4], dim=1)
864
+ vertices_norm = F.normalize(vertices, dim=2)
865
+
866
+ # Compute edge strengths - BATCHED
867
+ similarities = torch.bmm(vertices_norm, vertices_norm.transpose(1, 2))
868
+ edge_strengths = similarities * self.adjacency_matrix.unsqueeze(0)
869
+ edge_weights = edge_strengths.sum(dim=2)
870
+ edge_weights = F.softmax(edge_weights, dim=1)
871
+
872
+ return edge_weights
873
+
874
+ # ============================================================
875
+ # BATCHED PENTACHORON CONSTELLATION
876
+ # ============================================================
877
+
878
+ class BatchedPentachoronConstellation(nn.Module):
879
+ """Optimized constellation with a permanent, integrated Coherence Head."""
880
+ def __init__(self, num_classes, dim, num_pairs=5, device='cuda', lambda_sep=0.5):
881
+ super().__init__()
882
+ self.num_classes = num_classes
883
+ self.dim = dim
884
+ self.num_pairs = num_pairs
885
+ self.device = device
886
+ self.lambda_separation = lambda_sep
887
+
888
+ # Initialize all pentachora as single tensors for batched ops
889
+ self.dispatchers = nn.Parameter(self._init_batched_pentachora())
890
+ self.specialists = nn.Parameter(self._init_batched_pentachora())
891
+
892
+ # Batched weights
893
+ self.dispatcher_weights = nn.Parameter(torch.randn(num_pairs, 5) * 0.1)
894
+ self.specialist_weights = nn.Parameter(torch.randn(num_pairs, 5) * 0.1)
895
+
896
+ # Temperature per pair
897
+ self.temps = nn.Parameter(0.3 * torch.ones(num_pairs))
898
+
899
+ # Vertex assignments
900
+ self.register_buffer('vertex_map', self._create_vertex_mapping())
901
+
902
+ # Group classification heads for each vertex
903
+ self.group_heads = nn.ModuleList([
904
+ nn.Linear(dim, (self.vertex_map == i).sum().item()) if (self.vertex_map == i).sum().item() > 0 else None
905
+ for i in range(5)
906
+ ])
907
+
908
+ # Cross-pair attention mechanism
909
+ self.cross_attention = nn.MultiheadAttention(
910
+ embed_dim=dim,
911
+ num_heads=config.get('num_heads', 4),
912
+ dropout=0.1,
913
+ batch_first=True
914
+ )
915
+
916
+ # Aggregation weights for combining scores from different pairs
917
+ self.aggregation_weights = nn.Parameter(torch.ones(num_pairs) / num_pairs)
918
+
919
+ # Final fusion network
920
+ self.fusion = nn.Sequential(
921
+ nn.Linear(num_classes * num_pairs, num_classes * 2),
922
+ nn.BatchNorm1d(num_classes * 2),
923
+ nn.ReLU(),
924
+ nn.Dropout(0.2),
925
+ nn.Linear(num_classes * 2, num_classes)
926
+ )
927
+
928
+ ### ADDED: Integrated Coherence Head ###
929
+ # This small MLP acts as the permanent "rose_head". It learns to assess
930
+ # the quality/coherence of the input latent vector `x`.
931
+ self.coherence_head = nn.Sequential(
932
+ nn.Linear(dim, dim // 2),
933
+ nn.GELU(),
934
+ nn.Linear(dim // 2, 1)
935
+ )
936
+
937
+ def _init_batched_pentachora(self):
938
+ """Initializes all pentachora for the constellation."""
939
+ sqrt15, sqrt10, sqrt5 = np.sqrt(15), np.sqrt(10), np.sqrt(5)
940
+
941
+ base_simplex = torch.tensor([
942
+ [ 1.0, 0.0, 0.0, 0.0],
943
+ [-0.25, sqrt15/4, 0.0, 0.0],
944
+ [-0.25, -sqrt15/12, sqrt10/3, 0.0],
945
+ [-0.25, -sqrt15/12, -sqrt10/6, sqrt5/2],
946
+ [-0.25, -sqrt15/12, -sqrt10/6, -sqrt5/2]
947
+ ], device=self.device)
948
+
949
+ base_simplex = F.normalize(base_simplex, dim=1)
950
+
951
+ pentachora = torch.zeros(self.num_pairs, 5, self.dim, device=self.device)
952
+ for i in range(self.num_pairs):
953
+ pentachora[i, :, :4] = base_simplex * (1 + 0.1 * i)
954
+ if self.dim > 4:
955
+ pentachora[i, :, 4:] = torch.randn(5, self.dim - 4, device=self.device) * (random.random() * 0.25)
956
+
957
+ return pentachora * 2.0
958
+
959
+ def _create_vertex_mapping(self):
960
+ """Creates a mapping from classes to the 5 pentachoron vertices."""
961
+ mapping = torch.zeros(self.num_classes, dtype=torch.long)
962
+ for i in range(self.num_classes):
963
+ mapping[i] = i % 5
964
+ return mapping
965
+
966
+ def forward(self, x):
967
+ batch_size = x.size(0)
968
+
969
+ ### MODIFIED: Coherence Gating ###
970
+ # 1. Calculate the coherence score for the latent vector `x`.
971
+ coherence_gate = torch.sigmoid(self.coherence_head(x)) # Shape: [batch_size, 1]
972
+
973
+ # Distance calculations
974
+ x_expanded = x.unsqueeze(1).unsqueeze(2)
975
+ disp_expanded = self.dispatchers.unsqueeze(0)
976
+ spec_expanded = self.specialists.unsqueeze(0)
977
+ disp_dists = torch.norm(x_expanded - disp_expanded, dim=3)
978
+ spec_dists = torch.norm(x_expanded - spec_expanded, dim=3)
979
+ disp_weights = F.softmax(self.dispatcher_weights, dim=1).unsqueeze(0)
980
+ spec_weights = F.softmax(self.specialist_weights, dim=1).unsqueeze(0)
981
+ weighted_disp = disp_dists * disp_weights
982
+ weighted_spec = spec_dists * spec_weights
983
+ temps_clamped = torch.clamp(self.temps, 0.1, 2.0).view(1, -1, 1)
984
+
985
+ ### MODIFIED: Apply Coherence to Vertex Logits ###
986
+ # 2. Calculate pre-softmax "logits" and modulate with the coherence score.
987
+ disp_logits = -weighted_disp / temps_clamped
988
+ spec_logits = -weighted_spec / temps_clamped
989
+
990
+ modulated_disp_logits = disp_logits * coherence_gate.unsqueeze(-1)
991
+ modulated_spec_logits = spec_logits * coherence_gate.unsqueeze(-1)
992
+
993
+ # 3. Calculate probabilities from the *modulated* logits.
994
+ vertex_probs = F.softmax(modulated_disp_logits, dim=2)
995
+ spec_probs = F.softmax(modulated_spec_logits, dim=2)
996
+
997
+ combined_probs = 0.5 * vertex_probs + 0.5 * spec_probs
998
+
999
+ # Score calculation using group heads
1000
+ all_scores = []
1001
+ for p in range(self.num_pairs):
1002
+ pair_scores = torch.zeros(batch_size, self.num_classes, device=self.device)
1003
+ for v_idx in range(5):
1004
+ classes_in_vertex = (self.vertex_map == v_idx).nonzero(as_tuple=True)[0]
1005
+ if len(classes_in_vertex) == 0: continue
1006
+ v_prob = combined_probs[:, p, v_idx:v_idx+1]
1007
+ if self.group_heads[v_idx] is not None:
1008
+ group_logits = self.group_heads[v_idx](x)
1009
+ gated_logits = group_logits * v_prob
1010
+ for i, cls in enumerate(classes_in_vertex):
1011
+ if i < gated_logits.size(1):
1012
+ pair_scores[:, cls] = gated_logits[:, i]
1013
+ all_scores.append(pair_scores)
1014
+
1015
+ all_scores_tensor = torch.stack(all_scores, dim=1)
1016
+
1017
+ # Cross-attention and aggregation
1018
+ avg_dispatcher_centers = self.dispatchers.mean(dim=1).unsqueeze(0).expand(batch_size, -1, -1)
1019
+ attended_features, _ = self.cross_attention(
1020
+ avg_dispatcher_centers, avg_dispatcher_centers, avg_dispatcher_centers
1021
+ )
1022
+
1023
+ agg_weights = F.softmax(self.aggregation_weights, dim=0).view(1, -1, 1)
1024
+ weighted_scores = (all_scores_tensor * agg_weights).sum(dim=1)
1025
+
1026
+ # Final fusion
1027
+ concat_scores = all_scores_tensor.flatten(1)
1028
+ fused_scores = self.fusion(concat_scores)
1029
+ final_scores = 0.6 * weighted_scores + 0.4 * fused_scores
1030
+
1031
+ return final_scores, (disp_dists, spec_dists, vertex_probs)
1032
+
1033
+ def regularization_loss(self, vertex_weights=None):
1034
+ """BATCHED regularization with optional per-vertex weighting."""
1035
+ # Original Geometric Regularization
1036
+ disp_cm = self._batched_cayley_menger(self.dispatchers)
1037
+ spec_cm = self._batched_cayley_menger(self.specialists)
1038
+ cm_loss = torch.relu(1.0 - torch.abs(disp_cm)).sum() + torch.relu(1.0 - torch.abs(spec_cm)).sum()
1039
+
1040
+ edge_loss = self._batched_edge_variance(self.dispatchers) + self._batched_edge_variance(self.specialists)
1041
+
1042
+ disp_centers = self.dispatchers.mean(dim=1)
1043
+ spec_centers = self.specialists.mean(dim=1)
1044
+ cos_sims = F.cosine_similarity(disp_centers, spec_centers, dim=1)
1045
+ ortho_loss = torch.abs(cos_sims).sum() * self.lambda_separation
1046
+
1047
+ separations = torch.norm(disp_centers - spec_centers, dim=1)
1048
+ sep_loss = torch.relu(2.0 - separations).sum() * self.lambda_separation
1049
+
1050
+ # Dynamic Vertex Regularization
1051
+ dynamic_reg_loss = 0.0
1052
+ if vertex_weights is not None:
1053
+ vertex_weights = vertex_weights.to(self.dispatchers.device)
1054
+ disp_norms = torch.norm(self.dispatchers, p=2, dim=2)
1055
+ spec_norms = torch.norm(self.specialists, p=2, dim=2)
1056
+ weighted_disp_loss = (disp_norms * vertex_weights.unsqueeze(0)).mean()
1057
+ weighted_spec_loss = (spec_norms * vertex_weights.unsqueeze(0)).mean()
1058
+ dynamic_reg_loss = 0.1 * (weighted_disp_loss + weighted_spec_loss)
1059
+
1060
+ total_loss = (cm_loss + edge_loss + ortho_loss + sep_loss) / self.num_pairs
1061
+ return total_loss + dynamic_reg_loss
1062
+
1063
+ def _batched_cayley_menger(self, pentachora):
1064
+ """Compute Cayley-Menger determinant for all pairs at once."""
1065
+ num_pairs = pentachora.shape[0]
1066
+ dists_sq = torch.cdist(pentachora, pentachora) ** 2
1067
+ cm_matrices = torch.zeros(num_pairs, 6, 6, device=self.device)
1068
+ cm_matrices[:, 0, 1:] = 1
1069
+ cm_matrices[:, 1:, 0] = 1
1070
+ cm_matrices[:, 1:, 1:] = dists_sq
1071
+ return torch.det(cm_matrices)
1072
+
1073
+ def _batched_edge_variance(self, pentachora):
1074
+ """Compute edge variance for all pairs at once."""
1075
+ dists = torch.cdist(pentachora, pentachora)
1076
+ mask = torch.triu(torch.ones(5, 5, device=self.device), diagonal=1).bool()
1077
+ edges_list = [dists[p][mask] for p in range(self.num_pairs)]
1078
+ edges_all = torch.stack(edges_list)
1079
+ variances = edges_all.var(dim=1)
1080
+ mins = edges_all.min(dim=1)[0]
1081
+ return variances.sum() + torch.relu(0.5 - mins).sum()
1082
+
1083
+ def _cayley_menger_determinant(self, vertices):
1084
+ """Compute Cayley-Menger determinant for pentachoron validity."""
1085
+ n = vertices.shape[0]
1086
+
1087
+ # Distance matrix
1088
+ dists_sq = torch.cdist(vertices.unsqueeze(0), vertices.unsqueeze(0))[0] ** 2
1089
+
1090
+ # Build Cayley-Menger matrix
1091
+ cm_matrix = torch.zeros(n+1, n+1, device=self.device)
1092
+ cm_matrix[0, 1:] = 1
1093
+ cm_matrix[1:, 0] = 1
1094
+ cm_matrix[1:, 1:] = dists_sq
1095
+
1096
+ return torch.det(cm_matrix)
1097
+
1098
+ # ============================================================
1099
+ # COMPLETE LOSS FUNCTIONS
1100
+ # ============================================================
1101
+
1102
+ def dual_contrastive_loss(latents, targets, constellation, config):
1103
+ """
1104
+ Computes a dual contrastive loss for pulling samples to the correct pentachoron vertex
1105
+ and pushing them away from all incorrect vertices.
1106
+
1107
+ Args:
1108
+ latents (torch.Tensor): The encoded feature vectors from the encoder [B, dim].
1109
+ targets (torch.Tensor): The ground truth class labels [B].
1110
+ constellation (nn.Module): The PentachoronConstellation model.
1111
+ config (dict): The configuration dictionary containing 'temp'.
1112
+
1113
+ Returns:
1114
+ torch.Tensor: The total contrastive loss.
1115
+ """
1116
+ batch_size = latents.size(0)
1117
+ device = latents.device
1118
+ temp = config['temp']
1119
+
1120
+ # Get the target vertex for each sample in the batch
1121
+ target_vertices = constellation.vertex_map[targets] # [B]
1122
+
1123
+ # Normalize latents to be on the unit sphere for a clean cosine similarity
1124
+ latents = F.normalize(latents, dim=1)
1125
+
1126
+ # --- DISPATCHER LOSS ---
1127
+ # Shape: [num_pairs, 5, dim]
1128
+ disp_pentachora_norm = F.normalize(constellation.dispatchers, dim=2)
1129
+ # The fix: Repeat the dispatcher tensor for each item in the batch
1130
+ disp_pentachora_expanded = disp_pentachora_norm.unsqueeze(0).expand(batch_size, -1, -1, -1) # [B, num_pairs, 5, dim]
1131
+
1132
+ # Compute cosine similarity between each latent and all dispatcher vertices
1133
+ # latents: [B, 1, dim], disp_pentachora_expanded: [B, num_pairs, 5, dim]
1134
+ # Resulting shape: [B, num_pairs, 5]
1135
+ disp_sims = torch.einsum('bd,bpvd->bpv', latents, F.normalize(disp_pentachora_expanded, dim=3))
1136
+
1137
+ # Gather the similarities for the correct vertices for each sample
1138
+ # disp_sims[i, p, target_vertices[i]]
1139
+ disp_positive_sims = disp_sims[torch.arange(batch_size), :, target_vertices] # [B, num_pairs]
1140
+
1141
+ # Calculate negative logits by taking similarities of all vertices
1142
+ disp_all_logits = disp_sims / temp # [B, num_pairs, 5]
1143
+
1144
+ # Calculate InfoNCE loss for dispatchers
1145
+ disp_loss = -torch.log(torch.exp(disp_positive_sims / temp) / torch.exp(disp_all_logits).sum(dim=2)).mean()
1146
+
1147
+
1148
+ # --- SPECIALIST LOSS ---
1149
+ # Same process for the specialists
1150
+ spec_pentachora_norm = F.normalize(constellation.specialists, dim=2)
1151
+ spec_pentachora_expanded = spec_pentachora_norm.unsqueeze(0).expand(batch_size, -1, -1, -1)
1152
+ spec_sims = torch.einsum('bd,bpvd->bpv', latents, F.normalize(spec_pentachora_expanded, dim=3))
1153
+ spec_positive_sims = spec_sims[torch.arange(batch_size), :, target_vertices]
1154
+ spec_all_logits = spec_sims / temp
1155
+ spec_loss = -torch.log(torch.exp(spec_positive_sims / temp) / torch.exp(spec_all_logits).sum(dim=2)).mean()
1156
+
1157
+ # Combine losses
1158
+ total_loss = disp_loss + spec_loss
1159
+ return total_loss
1160
+
1161
+
1162
+ # Helper functions meant to solidify the new scheduler
1163
+ def get_class_similarity(constellation_model, num_classes):
1164
+ """
1165
+ Calculates pairwise class similarity based on the final layer's weights.
1166
+ Returns a [num_classes, num_classes] similarity matrix.
1167
+ """
1168
+ # Use the final fusion layer as the class representation
1169
+ final_layer = constellation_model.fusion[-1]
1170
+ weights = final_layer.weight.data.detach() # Shape: [num_classes, feature_dim]
1171
+
1172
+ # Normalize each class vector to get cosine similarity
1173
+ norm_weights = F.normalize(weights, p=2, dim=1)
1174
+
1175
+ # Cosine similarity is the dot product of normalized vectors
1176
+ similarity_matrix = torch.matmul(norm_weights, norm_weights.T)
1177
+
1178
+ return torch.clamp(similarity_matrix, 0.0, 1.0) # Ensure values are [0, 1]
1179
+
1180
+ def get_vertex_weights_from_confusion(conf_matrix, class_similarity, vertex_map, device):
1181
+ """
1182
+ Calculates per-vertex regularization weights based on class confusion
1183
+ and similarity.
1184
+ """
1185
+ num_classes = conf_matrix.shape[0]
1186
+
1187
+ # 1. Calculate a "confusion score" for each class (1 - accuracy)
1188
+ class_totals = conf_matrix.sum(axis=1)
1189
+ class_correct = conf_matrix.diagonal()
1190
+ class_acc = np.divide(class_correct, class_totals, out=np.zeros_like(class_correct, dtype=float), where=class_totals!=0)
1191
+ confusion_scores = 1.0 - torch.tensor(class_acc, device=device, dtype=torch.float32)
1192
+
1193
+ # 2. Spread the confusion using the similarity matrix (the "bell curve")
1194
+ sigma = 0.5 # Controls the width of the bell curve
1195
+ gaussian_similarity = torch.exp(-((1 - class_similarity)**2) / (2 * sigma**2))
1196
+ propagated_scores = torch.matmul(gaussian_similarity, confusion_scores)
1197
+
1198
+ # 3. Map per-class scores to per-vertex scores
1199
+ vertex_problem_scores_sum = torch.zeros(5, device=device)
1200
+ vertex_counts = torch.zeros(5, device=device)
1201
+ for class_idx, vertex_idx in enumerate(vertex_map):
1202
+ vertex_problem_scores_sum[vertex_idx] += propagated_scores[class_idx]
1203
+ vertex_counts[vertex_idx] += 1
1204
+
1205
+ # --- CORRECTED LINE ---
1206
+ # Perform safe division to average the scores for vertices with multiple classes
1207
+ vertex_problem_scores = torch.zeros_like(vertex_problem_scores_sum)
1208
+ mask = vertex_counts > 0
1209
+ vertex_problem_scores[mask] = vertex_problem_scores_sum[mask] / vertex_counts[mask]
1210
+
1211
+ # 4. Convert "problem score" to "regularization weight"
1212
+ vertex_weights = 1.0 - torch.tanh(vertex_problem_scores) # Maps scores to a (0, 1) range
1213
+
1214
+ return F.normalize(vertex_weights, p=1, dim=0) * 5.0 # Normalize sum to 5, so avg is 1
1215
+
1216
+ # ============================================================
1217
+ # TRAINING FUNCTIONS
1218
+ # ============================================================
1219
+
1220
+ # In the TRAINING FUNCTIONS section
1221
+
1222
+ # ============================================================
1223
+ # TRAINING FUNCTION
1224
+ # ============================================================
1225
+
1226
+ def train_epoch(encoder, constellation, optimizer, train_loader, epoch, config, vertex_weights, device):
1227
+ """
1228
+ Performs one full training epoch using the provided dynamic regularization weights.
1229
+ """
1230
+ # Set models to training mode
1231
+ encoder.train()
1232
+ constellation.train()
1233
+
1234
+ # Initialize trackers for loss and accuracy
1235
+ total_loss = 0.0
1236
+ correct_predictions = 0
1237
+ total_samples = 0
1238
+
1239
+ # Create a progress bar for the training loader
1240
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']} [Training]")
1241
+ for inputs, targets in pbar:
1242
+ # Move data to the configured device (GPU or CPU)
1243
+ inputs, targets = inputs.to(device), as_class_indices(targets.to(device))
1244
+
1245
+ # Reset gradients from the previous iteration
1246
+ optimizer.zero_grad()
1247
+
1248
+ # --- Forward Pass ---
1249
+ # 1. Get latent representations from the encoder
1250
+ z = encoder(inputs)
1251
+ # 2. Get classification scores from the constellation
1252
+ scores, _ = constellation(z)
1253
+
1254
+ # --- Loss Calculation ---
1255
+ # 1. Standard cross-entropy loss for classification
1256
+ ce_loss = F.cross_entropy(scores, targets)
1257
+ # 2. Regularization loss, now modulated by our dynamic per-vertex weights
1258
+ reg_loss = constellation.regularization_loss(vertex_weights=vertex_weights)
1259
+ # 3. Combine the losses
1260
+ loss = ce_loss + config['loss_weight_scalar'] * reg_loss
1261
+
1262
+ # --- Backward Pass and Optimization ---
1263
+ # 1. Compute gradients
1264
+ loss.backward()
1265
+ # 2. Clip gradients to prevent exploding gradients
1266
+ torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
1267
+ torch.nn.utils.clip_grad_norm_(constellation.parameters(), 1.0)
1268
+ # 3. Update model weights
1269
+ optimizer.step()
1270
+
1271
+ # --- Update Statistics ---
1272
+ total_loss += loss.item() * inputs.size(0)
1273
+ preds = scores.argmax(dim=1)
1274
+ correct_predictions += (preds == targets).sum().item()
1275
+ total_samples += inputs.size(0)
1276
+
1277
+ # Update the progress bar with live metrics
1278
+ pbar.set_postfix({
1279
+ 'loss': f"{loss.item():.4f}",
1280
+ 'acc': f"{correct_predictions/total_samples:.4f}",
1281
+ 'reg': f"{reg_loss.item():.4f}"
1282
+ })
1283
+
1284
+ # Return the average loss and accuracy for the epoch
1285
+ return total_loss / total_samples, correct_predictions / total_samples
1286
+
1287
+ from sklearn.metrics import confusion_matrix
1288
+ import seaborn as sns
1289
+
1290
+ @torch.no_grad()
1291
+ def evaluate(encoder, constellation, test_loader, num_classes): # Added num_classes
1292
+ encoder.eval()
1293
+ constellation.eval()
1294
+
1295
+ all_preds = []
1296
+ all_targets = []
1297
+
1298
+ for inputs, targets in tqdm(test_loader, desc="Evaluating"):
1299
+ inputs, targets = inputs.to(device), as_class_indices(targets.to(device))
1300
+
1301
+ z = encoder(inputs)
1302
+ scores, _ = constellation(z)
1303
+
1304
+ preds = scores.argmax(dim=1)
1305
+ all_preds.extend(preds.cpu().numpy())
1306
+ all_targets.extend(targets.cpu().numpy())
1307
+
1308
+ correct = (np.array(all_preds) == np.array(all_targets)).sum()
1309
+ total = len(all_targets)
1310
+
1311
+ # Calculate confusion matrix
1312
+ conf_matrix = confusion_matrix(all_targets, all_preds, labels=np.arange(num_classes))
1313
+
1314
+ # Calculate per-class accuracies from the confusion matrix
1315
+ class_correct = conf_matrix.diagonal()
1316
+ class_total = conf_matrix.sum(axis=1)
1317
+ # Avoid division by zero for classes not present in the test set
1318
+ class_accs = np.divide(class_correct, class_total, out=np.zeros_like(class_correct, dtype=float), where=class_total!=0)
1319
+
1320
+ return correct/total, list(class_accs), conf_matrix
1321
+
1322
+ # ============================================================
1323
+ # DYNAMIC SCHEDULER
1324
+ # ============================================================
1325
+
1326
+ class DynamicScheduler:
1327
+ """
1328
+ A custom learning rate scheduler with warmup and reduce-on-plateau logic.
1329
+ - Warmup Phase: Linearly increases LR from a small value to the initial LR.
1330
+ - Main Phase: Monitors a metric (e.g., test accuracy) and reduces the LR
1331
+ when the metric stops improving for a 'patience' number of epochs.
1332
+ """
1333
+ def __init__(self, optimizer, initial_lr, warmup_epochs, patience, factor=0.5, min_lr=1e-6, cooldown_epochs=2):
1334
+ self.optimizer = optimizer
1335
+ self.initial_lr = initial_lr
1336
+ self.warmup_epochs = warmup_epochs
1337
+ self.patience = patience
1338
+ self.factor = factor
1339
+ self.min_lr = min_lr
1340
+ self.cooldown_epochs = cooldown_epochs
1341
+
1342
+ # State tracking
1343
+ self.current_epoch = 0
1344
+ self.phase = 'warmup' if warmup_epochs > 0 else 'main'
1345
+ self.best_metric = -1.0
1346
+ self.epochs_since_improvement = 0
1347
+ self.cooldown_counter = 0
1348
+
1349
+ print("\n" + "="*60)
1350
+ print("INITIALIZING DYNAMIC SCHEDULER")
1351
+ print("="*60)
1352
+ print(f"{'Initial LR':<25}: {self.initial_lr}")
1353
+ print(f"{'Warmup Epochs':<25}: {self.warmup_epochs}")
1354
+ print(f"{'Patience (for plateau)':<25}: {self.patience}")
1355
+ print(f"{'Reduction Factor':<25}: {self.factor}")
1356
+ print(f"{'Cooldown Epochs':<25}: {self.cooldown_epochs}")
1357
+ print(f"{'Minimum LR':<25}: {self.min_lr}")
1358
+
1359
+
1360
+ def _set_lr(self, lr_value):
1361
+ """Sets the learning rate for all parameter groups in the optimizer."""
1362
+ for param_group in self.optimizer.param_groups:
1363
+ param_group['lr'] = lr_value
1364
+
1365
+ def step(self, metric):
1366
+ """
1367
+ Update the learning rate based on the provided metric (e.g., test accuracy).
1368
+ This should be called once per epoch AFTER evaluation.
1369
+ """
1370
+ self.current_epoch += 1
1371
+ current_lr = self.optimizer.param_groups[0]['lr']
1372
+
1373
+ if self.phase == 'warmup':
1374
+ # Calculate the learning rate for the current warmup step
1375
+ lr = self.initial_lr * (self.current_epoch / self.warmup_epochs)
1376
+ self._set_lr(lr)
1377
+ print(f" Scheduler (Warmup): Epoch {self.current_epoch}/{self.warmup_epochs}, LR set to {lr:.6f}")
1378
+
1379
+ # Check if warmup phase is complete
1380
+ if self.current_epoch >= self.warmup_epochs:
1381
+ self.phase = 'main'
1382
+ self.best_metric = metric # Initialize best metric after warmup
1383
+ print(" Scheduler: Warmup complete. Switched to main (plateau) phase.")
1384
+
1385
+ elif self.phase == 'main':
1386
+ # Handle cooldown period
1387
+ if self.cooldown_counter > 0:
1388
+ self.cooldown_counter -= 1
1389
+ print(f" Scheduler (Cooldown): {self.cooldown_counter+1} epochs remaining.")
1390
+ return
1391
+
1392
+ # Check for improvement
1393
+ if metric > self.best_metric:
1394
+ self.best_metric = metric
1395
+ self.epochs_since_improvement = 0
1396
+ else:
1397
+ self.epochs_since_improvement += 1
1398
+ print(f" Scheduler: No improvement for {self.epochs_since_improvement} epoch(s). Best Acc: {self.best_metric:.4f}")
1399
+
1400
+
1401
+ # If patience is exceeded, reduce learning rate
1402
+ if self.epochs_since_improvement >= self.patience:
1403
+ new_lr = max(current_lr * self.factor, self.min_lr)
1404
+ if new_lr < current_lr:
1405
+ self._set_lr(new_lr)
1406
+ print(f" 🔥 Scheduler: Metric plateaued. Reducing LR to {new_lr:.6f}")
1407
+ self.epochs_since_improvement = 0
1408
+ self.cooldown_counter = self.cooldown_epochs # Start cooldown
1409
+ else:
1410
+ print(" Scheduler: Already at minimum LR. No change.")
1411
+
1412
+ # ============================================================
1413
+ # MAIN TRAINING LOOP
1414
+ # ============================================================
1415
+ class RoseDiagnosticHead(nn.Module):
1416
+ """
1417
+ A simple MLP to predict the rose_score_magnitude from a latent vector.
1418
+ This is a "throwaway" module used for diagnostics, not for the final model's task.
1419
+ """
1420
+ def __init__(self, latent_dim, hidden_dim=128):
1421
+ super().__init__()
1422
+ self.net = nn.Sequential(
1423
+ nn.Linear(latent_dim, hidden_dim),
1424
+ nn.GELU(),
1425
+ nn.LayerNorm(hidden_dim),
1426
+ nn.Linear(hidden_dim, 1) # Output a single scalar value
1427
+ )
1428
+
1429
+ def forward(self, x):
1430
+ return self.net(x)
1431
+
1432
+ def rose_score_magnitude(x: torch.Tensor, need: torch.Tensor, relation: torch.Tensor, purpose: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
1433
+ """
1434
+ Computes a magnitude-only Rose similarity score between `x` and `need`,
1435
+ modulated by triadic reference vectors `relation` and `purpose`.
1436
+ """
1437
+ x_n = F.normalize(x, dim=-1, eps=eps)
1438
+ n_n = F.normalize(need, dim=-1, eps=eps)
1439
+ r_n = F.normalize(relation, dim=-1, eps=eps)
1440
+ p_n = F.normalize(purpose, dim=-1, eps=eps)
1441
+
1442
+ # Core directional cosine components
1443
+ a_n = torch.einsum('bd,bd->b', x_n, n_n) # Batch dot product
1444
+ a_r = torch.einsum('bd,bd->b', x_n, r_n)
1445
+ a_p = torch.einsum('bd,bd->b', x_n, p_n)
1446
+
1447
+ # Triadic magnitude score
1448
+ r7 = (a_n + a_r + a_p) / 3.0
1449
+ r8 = x.norm(dim=-1)
1450
+
1451
+ return r7 * r8
1452
+
1453
+ def RoseCrossContrastiveLoss(latents, targets, constellation, temp=0.5):
1454
+ """
1455
+ Computes a contrastive loss where each sample's contribution is weighted
1456
+ by the inverse of its `rose_score_magnitude`.
1457
+
1458
+ Returns the final loss and the calculated rose scores for diagnostics.
1459
+ """
1460
+ batch_size = latents.size(0)
1461
+ device = latents.device
1462
+
1463
+ # --- 1. Define the Symbolic Basis for ROSE Score ---
1464
+ target_vertex_indices = constellation.vertex_map[targets]
1465
+
1466
+ # Need: Target vertices from the specialist pentachora (the ideal goal)
1467
+ # [B, D]
1468
+ need_vectors = constellation.specialists[:, target_vertex_indices, :].mean(dim=0)
1469
+
1470
+ # Relation: Target vertices from the dispatcher pentachora (the context)
1471
+ # [B, D]
1472
+ relation_vectors = constellation.dispatchers[:, target_vertex_indices, :].mean(dim=0)
1473
+
1474
+ # Purpose: The centroid of the specialist pentachora (the overall structure)
1475
+ # [D] -> [B, D]
1476
+ purpose_vectors = constellation.specialists.mean(dim=(0, 1)).unsqueeze(0).expand(batch_size, -1)
1477
+
1478
+ # --- 2. Calculate the ROSE Score for each sample in the batch ---
1479
+ # rose_scores will have shape [B]
1480
+ rose_scores = rose_score_magnitude(latents, need_vectors, relation_vectors, purpose_vectors)
1481
+
1482
+ # --- 3. Calculate Per-Sample Inverse Weights ---
1483
+ # We use (1 - tanh(x)) to create a stable, bounded weight between (0, 2).
1484
+ # High rose_score -> low loss weight. Low rose_score -> high loss weight.
1485
+ loss_weights = 1.0 - torch.tanh(rose_scores)
1486
+
1487
+ # --- 4. Calculate Base Contrastive Loss (InfoNCE) ---
1488
+ all_vertices_specialist = constellation.specialists.mean(dim=0) # [5, D]
1489
+ all_vertices_dispatcher = constellation.dispatchers.mean(dim=0) # [5, D]
1490
+
1491
+ # Similarities to all specialist and dispatcher vertices
1492
+ sim_specialist = F.normalize(latents) @ F.normalize(all_vertices_specialist).T # [B, 5]
1493
+ sim_dispatcher = F.normalize(latents) @ F.normalize(all_vertices_dispatcher).T # [B, 5]
1494
+
1495
+ # Get the similarity to the positive (correct) vertex for each sample
1496
+ pos_sim_specialist = sim_specialist[torch.arange(batch_size), target_vertex_indices]
1497
+ pos_sim_dispatcher = sim_dispatcher[torch.arange(batch_size), target_vertex_indices]
1498
+
1499
+ # Calculate the per-sample InfoNCE loss for both pentachora
1500
+ logits_specialist = -torch.log(torch.exp(pos_sim_specialist / temp) / torch.exp(sim_specialist / temp).sum(dim=1))
1501
+ logits_dispatcher = -torch.log(torch.exp(pos_sim_dispatcher / temp) / torch.exp(sim_dispatcher / temp).sum(dim=1))
1502
+
1503
+ per_sample_loss = (logits_specialist + logits_dispatcher) / 2.0
1504
+
1505
+ # --- 5. Apply the ROSE Weights and return the mean loss ---
1506
+ final_loss = (per_sample_loss * loss_weights).mean()
1507
+
1508
+ return final_loss, rose_scores.detach() # Detach scores for diagnostic use
1509
+ # ============================================================
1510
+ # MAIN FUNCTION
1511
+ # ============================================================
1512
+ def main():
1513
+ print("\n" + "="*60)
1514
+ print("PENTACHORON CONSTELLATION FINAL CONFIGURATION")
1515
+ print("="*60)
1516
+ for key, value in config.items():
1517
+ print(f"{key:25}: {value}")
1518
+
1519
+ # Models
1520
+ encoder = PentaFreqEncoder(config['input_dim'], config['base_dim']).to(device)
1521
+ constellation = BatchedPentachoronConstellation(
1522
+ config['num_classes'],
1523
+ config['base_dim'],
1524
+ config['num_pentachoron_pairs'],
1525
+ device,
1526
+ config['lambda_separation']
1527
+ ).to(device)
1528
+ diagnostic_head = RoseDiagnosticHead(config['base_dim']).to(device)
1529
+
1530
+ # Optimizer & scheduler
1531
+ optimizer = torch.optim.AdamW(
1532
+ list(encoder.parameters()) + list(constellation.parameters()) + list(diagnostic_head.parameters()),
1533
+ lr=config['lr'],
1534
+ weight_decay=config["weight_decay"]
1535
+ )
1536
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])
1537
+
1538
+ # TensorBoard ("the tensorflow")
1539
+ tb_dir = Path("tb_logs") / _timestamp()
1540
+ tb_dir.mkdir(parents=True, exist_ok=True)
1541
+ writer = SummaryWriter(log_dir=str(tb_dir))
1542
+
1543
+ history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
1544
+ best_acc = 0.0
1545
+ last_conf_png = None
1546
+ start_time = time.time()
1547
+
1548
+ print("\n" + "="*60)
1549
+ print("STARTING TRAINING WITH ROSE-MODULATED LOSS")
1550
+ print("="*60 + "\n")
1551
+
1552
+ for epoch in range(config['epochs']):
1553
+ encoder.train(); constellation.train(); diagnostic_head.train()
1554
+ total_loss = total_correct = total_samples = 0
1555
+
1556
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
1557
+ for inputs, targets in pbar:
1558
+ inputs, targets = inputs.to(device), as_class_indices(targets.to(device))
1559
+ optimizer.zero_grad()
1560
+
1561
+ latents = encoder(inputs)
1562
+ scores, _ = constellation(latents)
1563
+
1564
+ loss_ce = F.cross_entropy(scores, targets)
1565
+ loss_contrastive, true_rose_scores = RoseCrossContrastiveLoss(
1566
+ latents, targets, constellation, temp=config['temp']
1567
+ )
1568
+ pred_rose = diagnostic_head(latents.detach())
1569
+ loss_diag = F.mse_loss(pred_rose.squeeze(), true_rose_scores)
1570
+ loss_reg = constellation.regularization_loss()
1571
+
1572
+ loss = loss_ce + (1.0 * loss_contrastive) + (0.1 * loss_diag) + (config['loss_weight_scalar'] * loss_reg)
1573
+ loss.backward()
1574
+ torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
1575
+ torch.nn.utils.clip_grad_norm_(constellation.parameters(), 1.0)
1576
+ torch.nn.utils.clip_grad_norm_(diagnostic_head.parameters(), 1.0)
1577
+ optimizer.step()
1578
+
1579
+ total_loss += loss.item() * inputs.size(0)
1580
+ preds = scores.argmax(dim=1)
1581
+ total_correct += (preds == targets).sum().item()
1582
+ total_samples += inputs.size(0)
1583
+
1584
+ pbar.set_postfix({
1585
+ 'loss': f"{loss.item():.4f}",
1586
+ 'acc': f"{total_correct/total_samples:.4f}",
1587
+ 'rose_loss': f"{loss_contrastive.item():.4f}",
1588
+ 'diag_loss': f"{loss_diag.item():.4f}"
1589
+ })
1590
+
1591
+ train_loss = total_loss / total_samples
1592
+ train_acc = total_correct / total_samples
1593
+
1594
+ # Evaluation
1595
+ test_acc, class_accs, conf_matrix = evaluate(
1596
+ encoder, constellation, test_loader, config['num_classes']
1597
+ )
1598
+
1599
+ # Log to TensorBoard
1600
+ writer.add_scalar("Loss/train", train_loss, epoch+1)
1601
+ writer.add_scalar("Acc/train", train_acc, epoch+1)
1602
+ writer.add_scalar("Acc/test", test_acc, epoch+1)
1603
+ writer.add_scalar("LR", optimizer.param_groups[0]['lr'], epoch+1)
1604
+
1605
+ # Scheduler
1606
+ scheduler.step()
1607
+
1608
+ # History
1609
+ history['train_loss'].append(train_loss)
1610
+ history['train_acc'].append(train_acc)
1611
+ history['test_acc'].append(test_acc)
1612
+
1613
+ print(f"\n[Epoch {epoch+1}/{config['epochs']}]")
1614
+ print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")
1615
+
1616
+ if test_acc > best_acc:
1617
+ best_acc = test_acc
1618
+ print(f" 🎯 NEW BEST ACCURACY: {best_acc:.4f}")
1619
+ print(" Saving new best confusion matrix heatmap...")
1620
+
1621
+ import seaborn as sns
1622
+ plt.figure(figsize=(12, 10))
1623
+ sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
1624
+ xticklabels=class_names, yticklabels=class_names)
1625
+ plt.title(f'Confusion Matrix - Epoch {epoch+1} - Accuracy: {best_acc:.4f}', fontsize=16)
1626
+ plt.xlabel('Predicted Label', fontsize=12)
1627
+ plt.ylabel('True Label', fontsize=12)
1628
+ plt.tight_layout()
1629
+ last_conf_png = f'best_confusion_matrix_epoch_{epoch+1}.png'
1630
+ plt.savefig(last_conf_png, dpi=150)
1631
+ plt.close()
1632
+
1633
+ # Final plots
1634
+ elapsed_time = time.time() - start_time
1635
+ print("\n" + "="*60)
1636
+ print("TRAINING COMPLETE")
1637
+ print("="*60)
1638
+ print(f" Best Test Accuracy: {best_acc*100:.2f}%")
1639
+ print(f" Total Training Time: {elapsed_time/60:.2f} minutes")
1640
+
1641
+ plt.figure(figsize=(12, 5))
1642
+ plt.plot(history['train_acc'], label='Train Accuracy')
1643
+ plt.plot(history['test_acc'], label='Test Accuracy', linewidth=2)
1644
+ plt.title('Model Accuracy Over Epochs', fontsize=16)
1645
+ plt.xlabel('Epoch', fontsize=12)
1646
+ plt.ylabel('Accuracy', fontsize=12)
1647
+ plt.legend()
1648
+ plt.grid(True, linestyle='--', alpha=0.6)
1649
+ plt.tight_layout()
1650
+ plt.savefig('accuracy_plot.png', dpi=150)
1651
+ plt.show()
1652
+
1653
+ # Save and push bundle
1654
+ local_dir, hub_path = save_and_push_artifacts(
1655
+ encoder=encoder,
1656
+ constellation=constellation,
1657
+ diagnostic_head=diagnostic_head,
1658
+ config=config,
1659
+ class_names=class_names,
1660
+ history=history,
1661
+ best_acc=best_acc,
1662
+ tb_log_dir=tb_dir,
1663
+ last_confusion_png=last_conf_png,
1664
+ repo_subdir_root="pentachora-adaptive-encoded/" + DATASET_NAME,
1665
+ )
1666
+ print(f"[done] Local artifacts at: {local_dir}")
1667
+ print(f"[done] HuggingFace path: {hub_path}")
1668
+
1669
+ return encoder, constellation, history
1670
+
1671
+ # ============================
1672
+ # OPTIONAL: set your repo here
1673
+ # ============================
1674
+ # Example:
1675
+ config['hf_repo_id'] = "AbstractPhil/pentachora-frequency-encoded"
1676
+
1677
+ if __name__ == "__main__":
1678
+ encoder, constellation, history = main()
1679
+ print("\n✨ Optimized Pentachoron Constellation Training Complete!")