#!/usr/bin/env python3 """ Pentachoron Constellation with Greyscale PentaFreq Encoder Optimized with Batched Operations and Complete Loss Functions Apache License 2.0 Author: AbstractPhil Assistance: GPT 4o, GPT 5, Claude Opus 4.1, Claude Sonnet 4.0, Gemini """ import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoader import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm import time import torch import torchvision from torchvision import datasets, transforms from torch.utils.data import DataLoader import numpy as np import random # ============================================================ # CONFIGURATION # ============================================================ # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Hyperparameters config = { 'input_dim': 64, 'base_dim': 64, 'batch_size': 2048, 'epochs': 50, 'lr': 1e-1, 'num_heads': 8, 'num_pentachoron_pairs': 1, 'loss_weight_scalar': 0.1, 'lambda_separation': 0.29514, 'temp': 0.70486, "weight_decay": 1e-5, } print("\n" + "="*60) print("PENTACHORON CONSTELLATION CONFIGURATION") print("="*60) for key, value in config.items(): print(f"{key:20}: {value}") # ============================================================ # DATASET # ============================================================ transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1)) ]) # ============================================================ # SELECT YOUR DATASET HERE! # ============================================================ DATASET_NAME = "OCTMNIST" # Change this to any dataset below! # Available datasets (all 28x28): AVAILABLE_DATASETS = { "MNIST": "Classic handwritten digits (10 classes)", "FashionMNIST": "Fashion items (10 classes) - The tough one!", "KMNIST": "Kuzushiji-MNIST - Japanese characters (10 classes)", "EMNIST": "Extended MNIST - Letters & digits (47 classes)", "QMNIST": "MNIST with better test set (10 classes)", "USPS": "US Postal Service digits (10 classes)", # MedMNIST variants (medical images) "BloodMNIST": "Blood cell types (8 classes)", "PathMNIST": "Pathology images (9 classes)", "OCTMNIST": "Retinal OCT (4 classes)", "PneumoniaMNIST": "Chest X-Ray (2 classes)", "DermaMNIST": "Dermatoscope images (7 classes)", "RetinaMNIST": "Retina fundus (5 classes)", "BreastMNIST": "Breast ultrasound (2 classes)", "OrganAMNIST": "Abdominal CT - Axial (11 classes)", "OrganCMNIST": "Abdominal CT - Coronal (11 classes)", "OrganSMNIST": "Abdominal CT - Sagittal (11 classes)", "TissueMNIST": "Tissue cells (8 classes)", } # ---------- MedMNIST INFO + helpers ---------- try: import medmnist from medmnist import INFO as MED_INFO # official dict except Exception: medmnist = None MED_INFO = None # Fallback labels/tasks/channels for the 2D sets you listed. # Source: MedMNIST v2 dataset card / builder (labels) and project docs (tasks/channels). FALLBACK_INFO = { "bloodmnist": { "python_class": "BloodMNIST", "task": "multi-class", "n_channels": 3, "label": { "0": "basophil", "1": "eosinophil", "2": "erythroblast", "3": "immature granulocytes(myelocytes, metamyelocytes and promyelocytes)", "4": "lymphocyte", "5": "monocyte", "6": "neutrophil", "7": "platelet", }, }, "pathmnist": { "python_class": "PathMNIST", "task": "multi-class", "n_channels": 3, "label": { "0": "adipose", "1": "background", "2": "debris", "3": "lymphocytes", "4": "mucus", "5": "smooth muscle", "6": "normal colon mucosa", "7": "cancer-associated stroma", "8": "colorectal adenocarcinoma epithelium", }, }, "octmnist": { "python_class": "OCTMNIST", "task": "multi-class", "n_channels": 1, "label": { "0": "choroidal neovascularization", "1": "diabetic macular edema", "2": "drusen", "3": "normal", }, }, "pneumoniamnist": { "python_class": "PneumoniaMNIST", "task": "binary-class", "n_channels": 1, "label": { "0": "normal", "1": "pneumonia", }, }, "dermamnist": { "python_class": "DermaMNIST", "task": "multi-class", "n_channels": 3, "label": { "0": "actinic keratoses and intraepithelial carcinoma", "1": "basal cell carcinoma", "2": "benign keratosis-like lesions", "3": "dermatofibroma", "4": "melanoma", "5": "melanocytic nevi", "6": "vascular lesions", }, }, "retinamnist": { "python_class": "RetinaMNIST", "task": "ordinal-regression", "n_channels": 3, "label": { # ordinal 0..4 "0": "0", "1": "1", "2": "2", "3": "3", "4": "4", }, }, "breastmnist": { "python_class": "BreastMNIST", "task": "binary-class", "n_channels": 1, "label": { "0": "malignant", "1": "normal, benign", }, }, "tissuemnist": { "python_class": "TissueMNIST", "task": "multi-class", "n_channels": 1, "label": { "0": "Collecting Duct, Connecting Tubule", "1": "Distal Convoluted Tubule", "2": "Glomerular endothelial cells", "3": "Interstitial endothelial cells", "4": "Leukocytes", "5": "Podocytes", "6": "Proximal Tubule Segments", "7": "Thick Ascending Limb", }, }, # The Organ* 2D sets share the same 11 organ names; channels are grayscale. "organamnist": { "python_class": "OrganAMNIST", "task": "multi-class", "n_channels": 1, "label": { "0": "liver", "1": "kidney-right", "2": "kidney-left", "3": "femur-right", "4": "femur-left", "5": "bladder", "6": "heart", "7": "lung-right", "8": "lung-left", "9": "spleen", "10": "pancreas", }, }, "organcmnist": { "python_class": "OrganCMNIST", "task": "multi-class", "n_channels": 1, "label": { "0": "liver", "1": "kidney-right", "2": "kidney-left", "3": "femur-right", "4": "femur-left", "5": "bladder", "6": "heart", "7": "lung-right", "8": "lung-left", "9": "spleen", "10": "pancreas", }, }, "organsmnist": { "python_class": "OrganSMNIST", "task": "multi-class", "n_channels": 1, "label": { "0": "liver", "1": "kidney-right", "2": "kidney-left", "3": "femur-right", "4": "femur-left", "5": "bladder", "6": "heart", "7": "lung-right", "8": "lung-left", "9": "spleen", "10": "pancreas", }, }, } def as_class_indices(t: torch.Tensor) -> torch.Tensor: """ Normalize MedMNIST-style labels to 1D Long class indices for CE loss. - Accepts shapes: [], [B], [B,1], or one-hot [B,C] - Returns shape [B], dtype torch.long """ if t.ndim == 0: # scalar return t.long().view(1) if t.ndim == 1: return t.long() # ndims >= 2 if t.size(-1) == 1: t = t.squeeze(-1) return t.long() # likely one-hot [B,C] return t.argmax(dim=-1).long() def get_med_info(flag: str) -> dict: """Return official medmnist.INFO[flag] if available, else fallback.""" if MED_INFO is not None and flag in MED_INFO: return MED_INFO[flag] if flag in FALLBACK_INFO: return FALLBACK_INFO[flag] raise KeyError(f"Unknown MedMNIST flag: {flag}") def make_med_transform(n_channels: int): """ ToTensor -> ensure single gray channel -> flatten to 784 for your pipeline. We keep your 28x28 target and collapse channels deterministically. """ return transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda t: t[:1, :, :] if t.shape[0] > 1 else t), # pick first channel if RGB transforms.Lambda(lambda t: t.view(-1)), ]) def med_class_names_from_info(info: dict): """Convert label dict -> ordered list by index: ['name0','name1',...]""" label_dict = info["label"] return [label_dict[str(i)] for i in range(len(label_dict))] # ============================================================ # DATASET LOADER # ============================================================ def get_dataset(name=DATASET_NAME, batch_size=128, num_workers=2): """ Universal loader for all MNIST-like datasets. Returns train_loader, test_loader, num_classes, class_names """ print(f"\n{'='*60}") print(f"Loading {name}") print(f"Description: {AVAILABLE_DATASETS.get(name, 'Unknown dataset')}") print(f"{'='*60}") # Standard transform for all datasets transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1)) # Flatten to 784 ]) # Special transform for grayscale conversion if needed transform_gray = transforms.Compose([ transforms.Grayscale(num_output_channels=config.get("n_channels", 1)), transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1)) ]) # STANDARD TORCHVISION DATASETS if name == "MNIST": train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True) test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True) num_classes = 10 class_names = [str(i) for i in range(10)] elif name == "FashionMNIST": train_dataset = datasets.FashionMNIST(root="./data", train=True, transform=transform, download=True) test_dataset = datasets.FashionMNIST(root="./data", train=False, transform=transform, download=True) num_classes = 10 class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] elif name == "KMNIST": train_dataset = datasets.KMNIST(root="./data", train=True, transform=transform, download=True) test_dataset = datasets.KMNIST(root="./data", train=False, transform=transform, download=True) num_classes = 10 class_names = ['お', 'き', 'す', 'つ', 'な', 'は', 'ま', 'や', 'れ', 'を'] elif name == "EMNIST": # Using 'balanced' split - 47 classes (digits + letters) train_dataset = datasets.EMNIST(root="./data", split='balanced', train=True, transform=transform, download=True) test_dataset = datasets.EMNIST(root="./data", split='balanced', train=False, transform=transform, download=True) num_classes = 47 # class_names = [str(i) for i in range(47)] # Mix of digits and letters class_names = [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'd', 'e', 'f', 'g', 'h', 'n', 'q', 'r', 't' ] elif name == "QMNIST": train_dataset = datasets.QMNIST(root="./data", what='train', transform=transform, download=True) test_dataset = datasets.QMNIST(root="./data", what='test', transform=transform, download=True) num_classes = 10 class_names = [str(i) for i in range(10)] elif name == "USPS": # USPS is 16x16, need to resize transform_usps = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1)) ]) train_dataset = datasets.USPS(root="./data", train=True, transform=transform_usps, download=True) test_dataset = datasets.USPS(root="./data", train=False, transform=transform_usps, download=True) num_classes = 10 class_names = [str(i) for i in range(10)] # MEDMNIST DATASETS elif name in ["BloodMNIST", "PathMNIST", "OCTMNIST", "PneumoniaMNIST", "DermaMNIST", "RetinaMNIST", "BreastMNIST", "OrganAMNIST", "OrganCMNIST", "OrganSMNIST", "TissueMNIST"]: # Map UI names to medmnist flags medmnist_map = { "BloodMNIST": "bloodmnist", "PathMNIST": "pathmnist", "OCTMNIST": "octmnist", "PneumoniaMNIST": "pneumoniamnist", "DermaMNIST": "dermamnist", "RetinaMNIST": "retinamnist", "BreastMNIST": "breastmnist", "OrganAMNIST": "organamnist", "OrganCMNIST": "organcmnist", "OrganSMNIST": "organsmnist", "TissueMNIST": "tissuemnist", } dataset_flag = medmnist_map[name] info = get_med_info(dataset_flag) # Require the package to actually load data if medmnist is None: raise ImportError( "medmnist is not installed. Run: pip install medmnist\n" f"(INFO fallback is provided; DataClass={info['python_class']} needs the package.)" ) DataClass = getattr(medmnist, info["python_class"]) # Transform: force 1-channel grayscale then flatten to 784 transform_med = make_med_transform(info["n_channels"]) # 28x28 size (default); you can bump to 64/128/224 by size=... train_dataset = DataClass(split='train', transform=transform_med, download=True, size=28) test_dataset = DataClass(split='test', transform=transform_med, download=True, size=28) num_classes = len(info["label"]) class_names = med_class_names_from_info(info) print(f" MedMNIST Dataset: {dataset_flag}") print(f" Task: {info['task']}") print(f" Classes: {num_classes} | Channels: {info['n_channels']}") else: raise ValueError(f"Unknown dataset: {name}. Choose from: {list(AVAILABLE_DATASETS.keys())}") # Create data loaders train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) print(f"\nDataset loaded successfully!") print(f" Train samples: {len(train_dataset):,}") print(f" Test samples: {len(test_dataset):,}") print(f" Number of classes: {num_classes}") print(f" Input shape: 28x28 = 784 dimensions") return train_loader, test_loader, num_classes, class_names #train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2) #test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2) train_loader, test_loader, num_classes, class_names = get_dataset(DATASET_NAME, config['batch_size']) config['num_classes'] = num_classes FASHION_CLASSES = class_names #[ # '0', '1', '2', '3', '4', '5', '6', '7', '8', '9' #'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', #'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot' #] print(f"\nDataset loaded:") #print(f" Train: {len(train_dataset):,} samples") #print(f" Test: {len(test_dataset):,} samples") # ============================ # ADDITIONS: saving & hub push # ============================ import os, json, math, platform, sys, shutil, zipfile from pathlib import Path from datetime import datetime # Auto-install per Phil’s preference def _ensure(pkg, pip_name=None): pip_name = pip_name or pkg try: __import__(pkg) except Exception: print(f"[setup] Installing {pip_name} ...") os.system(f"{sys.executable} -m pip install -q {pip_name}") _ensure("safetensors") _ensure("huggingface_hub") _ensure("psutil") _ensure("pandas") from safetensors.torch import save_file as save_safetensors from huggingface_hub import HfApi, create_repo, whoami, login from torch.utils.tensorboard import SummaryWriter import psutil import pandas as pd def _param_count(model: torch.nn.Module) -> int: return sum(p.numel() for p in model.parameters()) def _timestamp(): return datetime.now().strftime("%Y%m%d-%H%M%S") def _resolve_repo_id(config: dict) -> str: rid = os.getenv("PENTACHORA_HF_REPO") or config.get("hf_repo_id") if not rid: raise RuntimeError( "Hugging Face repo id is not set. Set config['hf_repo_id'] or PENTACHORA_HF_REPO env var." ) return rid def _hf_login_if_needed(): # Use existing login if available; otherwise try HF_TOKEN try: _ = whoami() return except Exception: token = os.getenv("HF_TOKEN") if not token: print("[huggingface] No active login and HF_TOKEN not set; if push fails, run huggingface-cli login.") return login(token=token, add_to_git_credential=True) def _ensure_repo(repo_id: str): api = HfApi() create_repo(repo_id=repo_id, private=False, exist_ok=True, repo_type="model") return api def _zip_dir(src_dir: Path, zip_path: Path): with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as z: for p in src_dir.rglob("*"): z.write(p, arcname=p.relative_to(src_dir)) def save_and_push_artifacts( encoder: nn.Module, constellation: nn.Module, diagnostic_head: nn.Module, config: dict, class_names: list, history: dict, best_acc: float, tb_log_dir: Path, last_confusion_png: Path | None, repo_subdir_root: str = "pentachora-adaptive-encoded", ): """ Saves safetensors + metadata locally and pushes to HF Hub under: /// """ ts = _timestamp() repo_id = _resolve_repo_id(config) _hf_login_if_needed() api = _ensure_repo(repo_id) # ---------- local layout ---------- base_out = Path("artifacts") / repo_subdir_root / ts base_out.mkdir(parents=True, exist_ok=True) # 1) Weights weights_dir = base_out / "weights" weights_dir.mkdir(parents=True, exist_ok=True) # Save each module separately to keep them composable save_safetensors({k: v.cpu() for k, v in encoder.state_dict().items()}, str(weights_dir / "encoder.safetensors")) save_safetensors({k: v.cpu() for k, v in constellation.state_dict().items()}, str(weights_dir / "constellation.safetensors")) save_safetensors({k: v.cpu() for k, v in diagnostic_head.state_dict().items()}, str(weights_dir / "diagnostic_head.safetensors")) # 2) Config conf_path = base_out / "config.json" with conf_path.open("w", encoding="utf-8") as f: json.dump(config, f, indent=2, sort_keys=True) # 3) History (per-epoch metrics) and CSV hist_json = base_out / "history.json" with hist_json.open("w", encoding="utf-8") as f: json.dump(history, f, indent=2, sort_keys=True) # CSV max_len = max(len(history.get("train_loss", [])), len(history.get("train_acc", [])), len(history.get("test_acc", []))) df = pd.DataFrame({ "epoch": list(range(1, max_len + 1)), "train_loss": history.get("train_loss", [math.nan]*max_len), "train_acc": history.get("train_acc", [math.nan]*max_len), "test_acc": history.get("test_acc", [math.nan]*max_len), }) df.to_csv(base_out / "history.csv", index=False) # 4) Manifest manifest = { "timestamp": ts, "repo_id": repo_id, "subdirectory": f"{repo_subdir_root}/{ts}", "dataset_name": DATASET_NAME, "class_names": class_names, "num_classes": len(class_names), "models": { "encoder": {"params": _param_count(encoder)}, "constellation": {"params": _param_count(constellation)}, "diagnostic_head": {"params": _param_count(diagnostic_head)}, }, "results": { "best_test_accuracy": best_acc, }, "environment": { "python": sys.version, "platform": platform.platform(), "torch": torch.__version__, "cuda_available": torch.cuda.is_available(), "cuda_device": (torch.cuda.get_device_name(0) if torch.cuda.is_available() else None), "cpu_count": psutil.cpu_count(logical=True), "memory_gb": round(psutil.virtual_memory().total / (1024**3), 2), }, } manifest_path = base_out / "manifest.json" with manifest_path.open("w", encoding="utf-8") as f: json.dump(manifest, f, indent=2, sort_keys=True) # 5) Debug info debug_txt = base_out / "debug.txt" with debug_txt.open("w", encoding="utf-8") as f: f.write("==== DEBUG INFO ====\n") f.write(f"Timestamp: {ts}\n") f.write(f"Repo: {repo_id}\n") f.write(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}\n") f.write(f"Encoder params: {_param_count(encoder)}\n") f.write(f"Constellation params: {_param_count(constellation)}\n") f.write(f"Diagnostic head params: {_param_count(diagnostic_head)}\n") f.write(f"Best test accuracy: {best_acc:.6f}\n") # 6) Plots (confusion matrix already saved during training; accuracy_plot.png at CWD) # Copy if present acc_plot = Path("accuracy_plot.png") if acc_plot.exists(): shutil.copy2(acc_plot, base_out / "accuracy_plot.png") if last_confusion_png and Path(last_confusion_png).exists(): shutil.copy2(last_confusion_png, base_out / Path(last_confusion_png).name) # 7) TensorBoard ("the tensorflow") logs # We copy the event files into artifacts, and zip them for convenience tb_out = base_out / "tensorboard" tb_out.mkdir(parents=True, exist_ok=True) if tb_log_dir and Path(tb_log_dir).exists(): for p in Path(tb_log_dir).glob("*"): shutil.copy2(p, tb_out / p.name) _zip_dir(tb_out, base_out / "tensorboard_events.zip") # 8) Also save a small README readme = base_out / "README.md" readme.write_text( f"""# Pentachora Adaptive Encoded — {ts} This folder is an immutable snapshot of training artifacts. **Contents** - `weights/*.safetensors` — encoder, constellation, diagnostic head - `config.json` — full run configuration - `manifest.json` — environment + model sizes + dataset - `history.json` / `history.csv` — per-epoch metrics - `tensorboard/` + `tensorboard_events.zip` — raw TB event files ("the tensorflow") - `accuracy_plot.png` (if available) - `best_confusion_matrix_epoch_*.png` (if available) - `debug.txt` — quick human-readable summary """, encoding="utf-8" ) # ---------- push to HF Hub ---------- print(f"[push] Uploading to hf://{repo_id}/{repo_subdir_root}/{ts}") api.upload_folder( repo_id=repo_id, folder_path=str(base_out), path_in_repo=f"{repo_subdir_root}/{ts}", repo_type="model", ) print("[push] ✅ Upload complete.") return base_out, f"{repo_subdir_root}/{ts}" # ============================================================ # PENTAFREQ ENCODER (Original 93% Version) # ============================================================ class PentaFreqEncoder(nn.Module): """ 5-Frequency Band Encoder designed to perfectly align with pentachoron vertices. Each frequency band corresponds to one vertex of the pentachoron. The adjacency relationships between frequency bands naturally form the edge structure of the pentachoron! """ def __init__(self, input_dim=784, base_dim=64): super().__init__() self.input_dim = input_dim self.base_dim = base_dim self.img_size = 28 self.unflatten = nn.Unflatten(1, (1, 28, 28)) # ========== 5 FREQUENCY EXTRACTORS ========== # Vertex 0: Ultra-High Frequency (finest details, noise, texture grain) self.v0_ultrahigh = nn.Sequential( nn.Conv2d(1, 12, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(12), nn.ReLU(), # Edge enhancement nn.Conv2d(12, 12, kernel_size=3, padding=1, groups=12), # Depthwise nn.BatchNorm2d(12), nn.ReLU(), nn.AdaptiveAvgPool2d(7), nn.Flatten() ) self.v0_encode = nn.Linear(12 * 49, base_dim) # Vertex 1: High Frequency (edges, sharp transitions) self.v1_high = nn.Sequential( nn.Conv2d(1, 12, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(12), nn.Tanh(), nn.MaxPool2d(2), # 14x14 nn.Conv2d(12, 12, kernel_size=3, padding=1), nn.BatchNorm2d(12), nn.Tanh(), nn.AdaptiveAvgPool2d(7), nn.Flatten() ) self.v1_encode = nn.Linear(12 * 49, base_dim) # Vertex 2: Mid Frequency (local patterns, textures) self.v2_mid = nn.Sequential( nn.Conv2d(1, 12, kernel_size=5, padding=2, stride=2), # 14x14 nn.BatchNorm2d(12), nn.GELU(), nn.Conv2d(12, 12, kernel_size=3, padding=1), nn.BatchNorm2d(12), nn.GELU(), nn.AdaptiveAvgPool2d(7), nn.Flatten() ) self.v2_encode = nn.Linear(12 * 49, base_dim) # Vertex 3: Low-Mid Frequency (shapes, regional features) self.v3_lowmid = nn.Sequential( nn.AvgPool2d(2), # Start with 14x14 nn.Conv2d(1, 12, kernel_size=7, padding=3), nn.BatchNorm2d(12), nn.SiLU(), nn.AvgPool2d(2), # 7x7 nn.Flatten() ) self.v3_encode = nn.Linear(12 * 49, base_dim) # Vertex 4: Low Frequency (global structure, overall form) self.v4_low = nn.Sequential( nn.AvgPool2d(4), # Start with 7x7 nn.Conv2d(1, 12, kernel_size=7, padding=3), nn.BatchNorm2d(12), nn.Sigmoid(), # Smooth activation for global features nn.AdaptiveAvgPool2d(7), nn.Flatten() ) self.v4_encode = nn.Linear(12 * 49, base_dim) # ========== PENTACHORON ADJACENCY MATRIX ========== # Defines which frequency bands are "adjacent" (connected by edges) # This follows the edge structure of a perfect pentachoron self.register_buffer('adjacency_matrix', self._create_pentachoron_adjacency()) # ========== FUSION NETWORK ========== # Learns to combine all 5 frequency bands self.fusion = nn.Sequential( nn.Linear(base_dim * 5, base_dim * 3), nn.BatchNorm1d(base_dim * 3), nn.ReLU(), nn.Dropout(0.2), nn.Linear(base_dim * 3, base_dim * 2), nn.BatchNorm1d(base_dim * 2), nn.ReLU(), nn.Linear(base_dim * 2, base_dim) ) # Initialize edge detection kernels for ultra-high frequency self._init_edge_kernels() def _create_pentachoron_adjacency(self): """ Create adjacency matrix for a complete graph (pentachoron). In a 4-simplex, every vertex connects to every other vertex. """ adj = torch.ones(5, 5) - torch.eye(5) return adj def _init_edge_kernels(self): """Initialize V0 with various edge detection kernels.""" with torch.no_grad(): if hasattr(self.v0_ultrahigh[0], 'weight'): kernels = self.v0_ultrahigh[0].weight # Sobel X kernels[0, 0] = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) / 4.0 # Sobel Y kernels[1, 0] = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) / 4.0 # Laplacian kernels[2, 0] = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]]) / 4.0 # Roberts Cross kernels[3, 0] = torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 0]]) / 2.0 # Prewitt X kernels[4, 0] = torch.tensor([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]]) / 3.0 def forward(self, x): batch_size = x.size(0) # Reshape to image x_img = self.unflatten(x) # ========== EXTRACT 5 FREQUENCY BANDS ========== # Each vertex processes a different frequency range # V0: Ultra-high frequency v0_features = self.v0_ultrahigh(x_img) v0 = self.v0_encode(v0_features) # V1: High frequency v1_features = self.v1_high(x_img) v1 = self.v1_encode(v1_features) # V2: Mid frequency v2_features = self.v2_mid(x_img) v2 = self.v2_encode(v2_features) # V3: Low-mid frequency v3_features = self.v3_lowmid(x_img) v3 = self.v3_encode(v3_features) # V4: Low frequency v4_features = self.v4_low(x_img) v4 = self.v4_encode(v4_features) # Stack all vertex features vertices = torch.stack([v0, v1, v2, v3, v4], dim=1) # [B, 5, base_dim] # ========== COMPUTE PENTACHORON EDGE WEIGHTS ========== # Normalize each vertex vertices_norm = F.normalize(vertices, dim=2) # Compute pairwise similarities (edge strengths) - BATCHED # Use bmm for efficiency instead of loops similarities = torch.bmm(vertices_norm, vertices_norm.transpose(1, 2)) # [B, 5, 5] # Apply pentachoron adjacency mask edge_strengths = similarities * self.adjacency_matrix.unsqueeze(0) # ========== WEIGHTED COMBINATION BASED ON EDGE STRUCTURE ========== # Each vertex is weighted by its edge connections edge_weights = edge_strengths.sum(dim=2) # [B, 5] edge_weights = F.softmax(edge_weights, dim=1) # Weight each frequency band - BATCHED weighted_vertices = vertices * edge_weights.unsqueeze(2) # [B, 5, base_dim] # ========== FUSION ========== # Flatten all weighted frequency bands combined = weighted_vertices.flatten(1) # [B, base_dim * 5] # Fuse through network fused = self.fusion(combined) # Final normalization to unit sphere output = F.normalize(fused, dim=1) return output def get_frequency_contributions(self, x): """ Utility function to visualize how much each frequency band contributes. Returns the weights for each vertex/frequency band. """ with torch.no_grad(): # Run forward pass to get edge weights x_img = self.unflatten(x) # Extract all frequencies v0 = self.v0_encode(self.v0_ultrahigh(x_img)) v1 = self.v1_encode(self.v1_high(x_img)) v2 = self.v2_encode(self.v2_mid(x_img)) v3 = self.v3_encode(self.v3_lowmid(x_img)) v4 = self.v4_encode(self.v4_low(x_img)) vertices = torch.stack([v0, v1, v2, v3, v4], dim=1) vertices_norm = F.normalize(vertices, dim=2) # Compute edge strengths - BATCHED similarities = torch.bmm(vertices_norm, vertices_norm.transpose(1, 2)) edge_strengths = similarities * self.adjacency_matrix.unsqueeze(0) edge_weights = edge_strengths.sum(dim=2) edge_weights = F.softmax(edge_weights, dim=1) return edge_weights # ============================================================ # BATCHED PENTACHORON CONSTELLATION # ============================================================ class BatchedPentachoronConstellation(nn.Module): """Optimized constellation with a permanent, integrated Coherence Head.""" def __init__(self, num_classes, dim, num_pairs=5, device='cuda', lambda_sep=0.5): super().__init__() self.num_classes = num_classes self.dim = dim self.num_pairs = num_pairs self.device = device self.lambda_separation = lambda_sep # Initialize all pentachora as single tensors for batched ops self.dispatchers = nn.Parameter(self._init_batched_pentachora()) self.specialists = nn.Parameter(self._init_batched_pentachora()) # Batched weights self.dispatcher_weights = nn.Parameter(torch.randn(num_pairs, 5) * 0.1) self.specialist_weights = nn.Parameter(torch.randn(num_pairs, 5) * 0.1) # Temperature per pair self.temps = nn.Parameter(0.3 * torch.ones(num_pairs)) # Vertex assignments self.register_buffer('vertex_map', self._create_vertex_mapping()) # Group classification heads for each vertex self.group_heads = nn.ModuleList([ nn.Linear(dim, (self.vertex_map == i).sum().item()) if (self.vertex_map == i).sum().item() > 0 else None for i in range(5) ]) # Cross-pair attention mechanism self.cross_attention = nn.MultiheadAttention( embed_dim=dim, num_heads=config.get('num_heads', 4), dropout=0.1, batch_first=True ) # Aggregation weights for combining scores from different pairs self.aggregation_weights = nn.Parameter(torch.ones(num_pairs) / num_pairs) # Final fusion network self.fusion = nn.Sequential( nn.Linear(num_classes * num_pairs, num_classes * 2), nn.BatchNorm1d(num_classes * 2), nn.ReLU(), nn.Dropout(0.2), nn.Linear(num_classes * 2, num_classes) ) ### ADDED: Integrated Coherence Head ### # This small MLP acts as the permanent "rose_head". It learns to assess # the quality/coherence of the input latent vector `x`. self.coherence_head = nn.Sequential( nn.Linear(dim, dim // 2), nn.GELU(), nn.Linear(dim // 2, 1) ) def _init_batched_pentachora(self): """Initializes all pentachora for the constellation.""" sqrt15, sqrt10, sqrt5 = np.sqrt(15), np.sqrt(10), np.sqrt(5) base_simplex = torch.tensor([ [ 1.0, 0.0, 0.0, 0.0], [-0.25, sqrt15/4, 0.0, 0.0], [-0.25, -sqrt15/12, sqrt10/3, 0.0], [-0.25, -sqrt15/12, -sqrt10/6, sqrt5/2], [-0.25, -sqrt15/12, -sqrt10/6, -sqrt5/2] ], device=self.device) base_simplex = F.normalize(base_simplex, dim=1) pentachora = torch.zeros(self.num_pairs, 5, self.dim, device=self.device) for i in range(self.num_pairs): pentachora[i, :, :4] = base_simplex * (1 + 0.1 * i) if self.dim > 4: pentachora[i, :, 4:] = torch.randn(5, self.dim - 4, device=self.device) * (random.random() * 0.25) return pentachora * 2.0 def _create_vertex_mapping(self): """Creates a mapping from classes to the 5 pentachoron vertices.""" mapping = torch.zeros(self.num_classes, dtype=torch.long) for i in range(self.num_classes): mapping[i] = i % 5 return mapping def forward(self, x): batch_size = x.size(0) ### MODIFIED: Coherence Gating ### # 1. Calculate the coherence score for the latent vector `x`. coherence_gate = torch.sigmoid(self.coherence_head(x)) # Shape: [batch_size, 1] # Distance calculations x_expanded = x.unsqueeze(1).unsqueeze(2) disp_expanded = self.dispatchers.unsqueeze(0) spec_expanded = self.specialists.unsqueeze(0) disp_dists = torch.norm(x_expanded - disp_expanded, dim=3) spec_dists = torch.norm(x_expanded - spec_expanded, dim=3) disp_weights = F.softmax(self.dispatcher_weights, dim=1).unsqueeze(0) spec_weights = F.softmax(self.specialist_weights, dim=1).unsqueeze(0) weighted_disp = disp_dists * disp_weights weighted_spec = spec_dists * spec_weights temps_clamped = torch.clamp(self.temps, 0.1, 2.0).view(1, -1, 1) ### MODIFIED: Apply Coherence to Vertex Logits ### # 2. Calculate pre-softmax "logits" and modulate with the coherence score. disp_logits = -weighted_disp / temps_clamped spec_logits = -weighted_spec / temps_clamped modulated_disp_logits = disp_logits * coherence_gate.unsqueeze(-1) modulated_spec_logits = spec_logits * coherence_gate.unsqueeze(-1) # 3. Calculate probabilities from the *modulated* logits. vertex_probs = F.softmax(modulated_disp_logits, dim=2) spec_probs = F.softmax(modulated_spec_logits, dim=2) combined_probs = 0.5 * vertex_probs + 0.5 * spec_probs # Score calculation using group heads all_scores = [] for p in range(self.num_pairs): pair_scores = torch.zeros(batch_size, self.num_classes, device=self.device) for v_idx in range(5): classes_in_vertex = (self.vertex_map == v_idx).nonzero(as_tuple=True)[0] if len(classes_in_vertex) == 0: continue v_prob = combined_probs[:, p, v_idx:v_idx+1] if self.group_heads[v_idx] is not None: group_logits = self.group_heads[v_idx](x) gated_logits = group_logits * v_prob for i, cls in enumerate(classes_in_vertex): if i < gated_logits.size(1): pair_scores[:, cls] = gated_logits[:, i] all_scores.append(pair_scores) all_scores_tensor = torch.stack(all_scores, dim=1) # Cross-attention and aggregation avg_dispatcher_centers = self.dispatchers.mean(dim=1).unsqueeze(0).expand(batch_size, -1, -1) attended_features, _ = self.cross_attention( avg_dispatcher_centers, avg_dispatcher_centers, avg_dispatcher_centers ) agg_weights = F.softmax(self.aggregation_weights, dim=0).view(1, -1, 1) weighted_scores = (all_scores_tensor * agg_weights).sum(dim=1) # Final fusion concat_scores = all_scores_tensor.flatten(1) fused_scores = self.fusion(concat_scores) final_scores = 0.6 * weighted_scores + 0.4 * fused_scores return final_scores, (disp_dists, spec_dists, vertex_probs) def regularization_loss(self, vertex_weights=None): """BATCHED regularization with optional per-vertex weighting.""" # Original Geometric Regularization disp_cm = self._batched_cayley_menger(self.dispatchers) spec_cm = self._batched_cayley_menger(self.specialists) cm_loss = torch.relu(1.0 - torch.abs(disp_cm)).sum() + torch.relu(1.0 - torch.abs(spec_cm)).sum() edge_loss = self._batched_edge_variance(self.dispatchers) + self._batched_edge_variance(self.specialists) disp_centers = self.dispatchers.mean(dim=1) spec_centers = self.specialists.mean(dim=1) cos_sims = F.cosine_similarity(disp_centers, spec_centers, dim=1) ortho_loss = torch.abs(cos_sims).sum() * self.lambda_separation separations = torch.norm(disp_centers - spec_centers, dim=1) sep_loss = torch.relu(2.0 - separations).sum() * self.lambda_separation # Dynamic Vertex Regularization dynamic_reg_loss = 0.0 if vertex_weights is not None: vertex_weights = vertex_weights.to(self.dispatchers.device) disp_norms = torch.norm(self.dispatchers, p=2, dim=2) spec_norms = torch.norm(self.specialists, p=2, dim=2) weighted_disp_loss = (disp_norms * vertex_weights.unsqueeze(0)).mean() weighted_spec_loss = (spec_norms * vertex_weights.unsqueeze(0)).mean() dynamic_reg_loss = 0.1 * (weighted_disp_loss + weighted_spec_loss) total_loss = (cm_loss + edge_loss + ortho_loss + sep_loss) / self.num_pairs return total_loss + dynamic_reg_loss def _batched_cayley_menger(self, pentachora): """Compute Cayley-Menger determinant for all pairs at once.""" num_pairs = pentachora.shape[0] dists_sq = torch.cdist(pentachora, pentachora) ** 2 cm_matrices = torch.zeros(num_pairs, 6, 6, device=self.device) cm_matrices[:, 0, 1:] = 1 cm_matrices[:, 1:, 0] = 1 cm_matrices[:, 1:, 1:] = dists_sq return torch.det(cm_matrices) def _batched_edge_variance(self, pentachora): """Compute edge variance for all pairs at once.""" dists = torch.cdist(pentachora, pentachora) mask = torch.triu(torch.ones(5, 5, device=self.device), diagonal=1).bool() edges_list = [dists[p][mask] for p in range(self.num_pairs)] edges_all = torch.stack(edges_list) variances = edges_all.var(dim=1) mins = edges_all.min(dim=1)[0] return variances.sum() + torch.relu(0.5 - mins).sum() def _cayley_menger_determinant(self, vertices): """Compute Cayley-Menger determinant for pentachoron validity.""" n = vertices.shape[0] # Distance matrix dists_sq = torch.cdist(vertices.unsqueeze(0), vertices.unsqueeze(0))[0] ** 2 # Build Cayley-Menger matrix cm_matrix = torch.zeros(n+1, n+1, device=self.device) cm_matrix[0, 1:] = 1 cm_matrix[1:, 0] = 1 cm_matrix[1:, 1:] = dists_sq return torch.det(cm_matrix) # ============================================================ # COMPLETE LOSS FUNCTIONS # ============================================================ def dual_contrastive_loss(latents, targets, constellation, config): """ Computes a dual contrastive loss for pulling samples to the correct pentachoron vertex and pushing them away from all incorrect vertices. Args: latents (torch.Tensor): The encoded feature vectors from the encoder [B, dim]. targets (torch.Tensor): The ground truth class labels [B]. constellation (nn.Module): The PentachoronConstellation model. config (dict): The configuration dictionary containing 'temp'. Returns: torch.Tensor: The total contrastive loss. """ batch_size = latents.size(0) device = latents.device temp = config['temp'] # Get the target vertex for each sample in the batch target_vertices = constellation.vertex_map[targets] # [B] # Normalize latents to be on the unit sphere for a clean cosine similarity latents = F.normalize(latents, dim=1) # --- DISPATCHER LOSS --- # Shape: [num_pairs, 5, dim] disp_pentachora_norm = F.normalize(constellation.dispatchers, dim=2) # The fix: Repeat the dispatcher tensor for each item in the batch disp_pentachora_expanded = disp_pentachora_norm.unsqueeze(0).expand(batch_size, -1, -1, -1) # [B, num_pairs, 5, dim] # Compute cosine similarity between each latent and all dispatcher vertices # latents: [B, 1, dim], disp_pentachora_expanded: [B, num_pairs, 5, dim] # Resulting shape: [B, num_pairs, 5] disp_sims = torch.einsum('bd,bpvd->bpv', latents, F.normalize(disp_pentachora_expanded, dim=3)) # Gather the similarities for the correct vertices for each sample # disp_sims[i, p, target_vertices[i]] disp_positive_sims = disp_sims[torch.arange(batch_size), :, target_vertices] # [B, num_pairs] # Calculate negative logits by taking similarities of all vertices disp_all_logits = disp_sims / temp # [B, num_pairs, 5] # Calculate InfoNCE loss for dispatchers disp_loss = -torch.log(torch.exp(disp_positive_sims / temp) / torch.exp(disp_all_logits).sum(dim=2)).mean() # --- SPECIALIST LOSS --- # Same process for the specialists spec_pentachora_norm = F.normalize(constellation.specialists, dim=2) spec_pentachora_expanded = spec_pentachora_norm.unsqueeze(0).expand(batch_size, -1, -1, -1) spec_sims = torch.einsum('bd,bpvd->bpv', latents, F.normalize(spec_pentachora_expanded, dim=3)) spec_positive_sims = spec_sims[torch.arange(batch_size), :, target_vertices] spec_all_logits = spec_sims / temp spec_loss = -torch.log(torch.exp(spec_positive_sims / temp) / torch.exp(spec_all_logits).sum(dim=2)).mean() # Combine losses total_loss = disp_loss + spec_loss return total_loss # Helper functions meant to solidify the new scheduler def get_class_similarity(constellation_model, num_classes): """ Calculates pairwise class similarity based on the final layer's weights. Returns a [num_classes, num_classes] similarity matrix. """ # Use the final fusion layer as the class representation final_layer = constellation_model.fusion[-1] weights = final_layer.weight.data.detach() # Shape: [num_classes, feature_dim] # Normalize each class vector to get cosine similarity norm_weights = F.normalize(weights, p=2, dim=1) # Cosine similarity is the dot product of normalized vectors similarity_matrix = torch.matmul(norm_weights, norm_weights.T) return torch.clamp(similarity_matrix, 0.0, 1.0) # Ensure values are [0, 1] def get_vertex_weights_from_confusion(conf_matrix, class_similarity, vertex_map, device): """ Calculates per-vertex regularization weights based on class confusion and similarity. """ num_classes = conf_matrix.shape[0] # 1. Calculate a "confusion score" for each class (1 - accuracy) class_totals = conf_matrix.sum(axis=1) class_correct = conf_matrix.diagonal() class_acc = np.divide(class_correct, class_totals, out=np.zeros_like(class_correct, dtype=float), where=class_totals!=0) confusion_scores = 1.0 - torch.tensor(class_acc, device=device, dtype=torch.float32) # 2. Spread the confusion using the similarity matrix (the "bell curve") sigma = 0.5 # Controls the width of the bell curve gaussian_similarity = torch.exp(-((1 - class_similarity)**2) / (2 * sigma**2)) propagated_scores = torch.matmul(gaussian_similarity, confusion_scores) # 3. Map per-class scores to per-vertex scores vertex_problem_scores_sum = torch.zeros(5, device=device) vertex_counts = torch.zeros(5, device=device) for class_idx, vertex_idx in enumerate(vertex_map): vertex_problem_scores_sum[vertex_idx] += propagated_scores[class_idx] vertex_counts[vertex_idx] += 1 # --- CORRECTED LINE --- # Perform safe division to average the scores for vertices with multiple classes vertex_problem_scores = torch.zeros_like(vertex_problem_scores_sum) mask = vertex_counts > 0 vertex_problem_scores[mask] = vertex_problem_scores_sum[mask] / vertex_counts[mask] # 4. Convert "problem score" to "regularization weight" vertex_weights = 1.0 - torch.tanh(vertex_problem_scores) # Maps scores to a (0, 1) range return F.normalize(vertex_weights, p=1, dim=0) * 5.0 # Normalize sum to 5, so avg is 1 # ============================================================ # TRAINING FUNCTIONS # ============================================================ # In the TRAINING FUNCTIONS section # ============================================================ # TRAINING FUNCTION # ============================================================ def train_epoch(encoder, constellation, optimizer, train_loader, epoch, config, vertex_weights, device): """ Performs one full training epoch using the provided dynamic regularization weights. """ # Set models to training mode encoder.train() constellation.train() # Initialize trackers for loss and accuracy total_loss = 0.0 correct_predictions = 0 total_samples = 0 # Create a progress bar for the training loader pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']} [Training]") for inputs, targets in pbar: # Move data to the configured device (GPU or CPU) inputs, targets = inputs.to(device), as_class_indices(targets.to(device)) # Reset gradients from the previous iteration optimizer.zero_grad() # --- Forward Pass --- # 1. Get latent representations from the encoder z = encoder(inputs) # 2. Get classification scores from the constellation scores, _ = constellation(z) # --- Loss Calculation --- # 1. Standard cross-entropy loss for classification ce_loss = F.cross_entropy(scores, targets) # 2. Regularization loss, now modulated by our dynamic per-vertex weights reg_loss = constellation.regularization_loss(vertex_weights=vertex_weights) # 3. Combine the losses loss = ce_loss + config['loss_weight_scalar'] * reg_loss # --- Backward Pass and Optimization --- # 1. Compute gradients loss.backward() # 2. Clip gradients to prevent exploding gradients torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(constellation.parameters(), 1.0) # 3. Update model weights optimizer.step() # --- Update Statistics --- total_loss += loss.item() * inputs.size(0) preds = scores.argmax(dim=1) correct_predictions += (preds == targets).sum().item() total_samples += inputs.size(0) # Update the progress bar with live metrics pbar.set_postfix({ 'loss': f"{loss.item():.4f}", 'acc': f"{correct_predictions/total_samples:.4f}", 'reg': f"{reg_loss.item():.4f}" }) # Return the average loss and accuracy for the epoch return total_loss / total_samples, correct_predictions / total_samples from sklearn.metrics import confusion_matrix import seaborn as sns @torch.no_grad() def evaluate(encoder, constellation, test_loader, num_classes): # Added num_classes encoder.eval() constellation.eval() all_preds = [] all_targets = [] for inputs, targets in tqdm(test_loader, desc="Evaluating"): inputs, targets = inputs.to(device), as_class_indices(targets.to(device)) z = encoder(inputs) scores, _ = constellation(z) preds = scores.argmax(dim=1) all_preds.extend(preds.cpu().numpy()) all_targets.extend(targets.cpu().numpy()) correct = (np.array(all_preds) == np.array(all_targets)).sum() total = len(all_targets) # Calculate confusion matrix conf_matrix = confusion_matrix(all_targets, all_preds, labels=np.arange(num_classes)) # Calculate per-class accuracies from the confusion matrix class_correct = conf_matrix.diagonal() class_total = conf_matrix.sum(axis=1) # Avoid division by zero for classes not present in the test set class_accs = np.divide(class_correct, class_total, out=np.zeros_like(class_correct, dtype=float), where=class_total!=0) return correct/total, list(class_accs), conf_matrix # ============================================================ # DYNAMIC SCHEDULER # ============================================================ class DynamicScheduler: """ A custom learning rate scheduler with warmup and reduce-on-plateau logic. - Warmup Phase: Linearly increases LR from a small value to the initial LR. - Main Phase: Monitors a metric (e.g., test accuracy) and reduces the LR when the metric stops improving for a 'patience' number of epochs. """ def __init__(self, optimizer, initial_lr, warmup_epochs, patience, factor=0.5, min_lr=1e-6, cooldown_epochs=2): self.optimizer = optimizer self.initial_lr = initial_lr self.warmup_epochs = warmup_epochs self.patience = patience self.factor = factor self.min_lr = min_lr self.cooldown_epochs = cooldown_epochs # State tracking self.current_epoch = 0 self.phase = 'warmup' if warmup_epochs > 0 else 'main' self.best_metric = -1.0 self.epochs_since_improvement = 0 self.cooldown_counter = 0 print("\n" + "="*60) print("INITIALIZING DYNAMIC SCHEDULER") print("="*60) print(f"{'Initial LR':<25}: {self.initial_lr}") print(f"{'Warmup Epochs':<25}: {self.warmup_epochs}") print(f"{'Patience (for plateau)':<25}: {self.patience}") print(f"{'Reduction Factor':<25}: {self.factor}") print(f"{'Cooldown Epochs':<25}: {self.cooldown_epochs}") print(f"{'Minimum LR':<25}: {self.min_lr}") def _set_lr(self, lr_value): """Sets the learning rate for all parameter groups in the optimizer.""" for param_group in self.optimizer.param_groups: param_group['lr'] = lr_value def step(self, metric): """ Update the learning rate based on the provided metric (e.g., test accuracy). This should be called once per epoch AFTER evaluation. """ self.current_epoch += 1 current_lr = self.optimizer.param_groups[0]['lr'] if self.phase == 'warmup': # Calculate the learning rate for the current warmup step lr = self.initial_lr * (self.current_epoch / self.warmup_epochs) self._set_lr(lr) print(f" Scheduler (Warmup): Epoch {self.current_epoch}/{self.warmup_epochs}, LR set to {lr:.6f}") # Check if warmup phase is complete if self.current_epoch >= self.warmup_epochs: self.phase = 'main' self.best_metric = metric # Initialize best metric after warmup print(" Scheduler: Warmup complete. Switched to main (plateau) phase.") elif self.phase == 'main': # Handle cooldown period if self.cooldown_counter > 0: self.cooldown_counter -= 1 print(f" Scheduler (Cooldown): {self.cooldown_counter+1} epochs remaining.") return # Check for improvement if metric > self.best_metric: self.best_metric = metric self.epochs_since_improvement = 0 else: self.epochs_since_improvement += 1 print(f" Scheduler: No improvement for {self.epochs_since_improvement} epoch(s). Best Acc: {self.best_metric:.4f}") # If patience is exceeded, reduce learning rate if self.epochs_since_improvement >= self.patience: new_lr = max(current_lr * self.factor, self.min_lr) if new_lr < current_lr: self._set_lr(new_lr) print(f" 🔥 Scheduler: Metric plateaued. Reducing LR to {new_lr:.6f}") self.epochs_since_improvement = 0 self.cooldown_counter = self.cooldown_epochs # Start cooldown else: print(" Scheduler: Already at minimum LR. No change.") # ============================================================ # MAIN TRAINING LOOP # ============================================================ class RoseDiagnosticHead(nn.Module): """ A simple MLP to predict the rose_score_magnitude from a latent vector. This is a "throwaway" module used for diagnostics, not for the final model's task. """ def __init__(self, latent_dim, hidden_dim=128): super().__init__() self.net = nn.Sequential( nn.Linear(latent_dim, hidden_dim), nn.GELU(), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, 1) # Output a single scalar value ) def forward(self, x): return self.net(x) def rose_score_magnitude(x: torch.Tensor, need: torch.Tensor, relation: torch.Tensor, purpose: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: """ Computes a magnitude-only Rose similarity score between `x` and `need`, modulated by triadic reference vectors `relation` and `purpose`. """ x_n = F.normalize(x, dim=-1, eps=eps) n_n = F.normalize(need, dim=-1, eps=eps) r_n = F.normalize(relation, dim=-1, eps=eps) p_n = F.normalize(purpose, dim=-1, eps=eps) # Core directional cosine components a_n = torch.einsum('bd,bd->b', x_n, n_n) # Batch dot product a_r = torch.einsum('bd,bd->b', x_n, r_n) a_p = torch.einsum('bd,bd->b', x_n, p_n) # Triadic magnitude score r7 = (a_n + a_r + a_p) / 3.0 r8 = x.norm(dim=-1) return r7 * r8 def RoseCrossContrastiveLoss(latents, targets, constellation, temp=0.5): """ Computes a contrastive loss where each sample's contribution is weighted by the inverse of its `rose_score_magnitude`. Returns the final loss and the calculated rose scores for diagnostics. """ batch_size = latents.size(0) device = latents.device # --- 1. Define the Symbolic Basis for ROSE Score --- target_vertex_indices = constellation.vertex_map[targets] # Need: Target vertices from the specialist pentachora (the ideal goal) # [B, D] need_vectors = constellation.specialists[:, target_vertex_indices, :].mean(dim=0) # Relation: Target vertices from the dispatcher pentachora (the context) # [B, D] relation_vectors = constellation.dispatchers[:, target_vertex_indices, :].mean(dim=0) # Purpose: The centroid of the specialist pentachora (the overall structure) # [D] -> [B, D] purpose_vectors = constellation.specialists.mean(dim=(0, 1)).unsqueeze(0).expand(batch_size, -1) # --- 2. Calculate the ROSE Score for each sample in the batch --- # rose_scores will have shape [B] rose_scores = rose_score_magnitude(latents, need_vectors, relation_vectors, purpose_vectors) # --- 3. Calculate Per-Sample Inverse Weights --- # We use (1 - tanh(x)) to create a stable, bounded weight between (0, 2). # High rose_score -> low loss weight. Low rose_score -> high loss weight. loss_weights = 1.0 - torch.tanh(rose_scores) # --- 4. Calculate Base Contrastive Loss (InfoNCE) --- all_vertices_specialist = constellation.specialists.mean(dim=0) # [5, D] all_vertices_dispatcher = constellation.dispatchers.mean(dim=0) # [5, D] # Similarities to all specialist and dispatcher vertices sim_specialist = F.normalize(latents) @ F.normalize(all_vertices_specialist).T # [B, 5] sim_dispatcher = F.normalize(latents) @ F.normalize(all_vertices_dispatcher).T # [B, 5] # Get the similarity to the positive (correct) vertex for each sample pos_sim_specialist = sim_specialist[torch.arange(batch_size), target_vertex_indices] pos_sim_dispatcher = sim_dispatcher[torch.arange(batch_size), target_vertex_indices] # Calculate the per-sample InfoNCE loss for both pentachora logits_specialist = -torch.log(torch.exp(pos_sim_specialist / temp) / torch.exp(sim_specialist / temp).sum(dim=1)) logits_dispatcher = -torch.log(torch.exp(pos_sim_dispatcher / temp) / torch.exp(sim_dispatcher / temp).sum(dim=1)) per_sample_loss = (logits_specialist + logits_dispatcher) / 2.0 # --- 5. Apply the ROSE Weights and return the mean loss --- final_loss = (per_sample_loss * loss_weights).mean() return final_loss, rose_scores.detach() # Detach scores for diagnostic use # ============================================================ # MAIN FUNCTION # ============================================================ def main(): print("\n" + "="*60) print("PENTACHORON CONSTELLATION FINAL CONFIGURATION") print("="*60) for key, value in config.items(): print(f"{key:25}: {value}") # Models encoder = PentaFreqEncoder(config['input_dim'], config['base_dim']).to(device) constellation = BatchedPentachoronConstellation( config['num_classes'], config['base_dim'], config['num_pentachoron_pairs'], device, config['lambda_separation'] ).to(device) diagnostic_head = RoseDiagnosticHead(config['base_dim']).to(device) # Optimizer & scheduler optimizer = torch.optim.AdamW( list(encoder.parameters()) + list(constellation.parameters()) + list(diagnostic_head.parameters()), lr=config['lr'], weight_decay=config["weight_decay"] ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs']) # TensorBoard ("the tensorflow") tb_dir = Path("tb_logs") / _timestamp() tb_dir.mkdir(parents=True, exist_ok=True) writer = SummaryWriter(log_dir=str(tb_dir)) history = {'train_loss': [], 'train_acc': [], 'test_acc': []} best_acc = 0.0 last_conf_png = None start_time = time.time() print("\n" + "="*60) print("STARTING TRAINING WITH ROSE-MODULATED LOSS") print("="*60 + "\n") for epoch in range(config['epochs']): encoder.train(); constellation.train(); diagnostic_head.train() total_loss = total_correct = total_samples = 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}") for inputs, targets in pbar: inputs, targets = inputs.to(device), as_class_indices(targets.to(device)) optimizer.zero_grad() latents = encoder(inputs) scores, _ = constellation(latents) loss_ce = F.cross_entropy(scores, targets) loss_contrastive, true_rose_scores = RoseCrossContrastiveLoss( latents, targets, constellation, temp=config['temp'] ) pred_rose = diagnostic_head(latents.detach()) loss_diag = F.mse_loss(pred_rose.squeeze(), true_rose_scores) loss_reg = constellation.regularization_loss() loss = loss_ce + (1.0 * loss_contrastive) + (0.1 * loss_diag) + (config['loss_weight_scalar'] * loss_reg) loss.backward() torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(constellation.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(diagnostic_head.parameters(), 1.0) optimizer.step() total_loss += loss.item() * inputs.size(0) preds = scores.argmax(dim=1) total_correct += (preds == targets).sum().item() total_samples += inputs.size(0) pbar.set_postfix({ 'loss': f"{loss.item():.4f}", 'acc': f"{total_correct/total_samples:.4f}", 'rose_loss': f"{loss_contrastive.item():.4f}", 'diag_loss': f"{loss_diag.item():.4f}" }) train_loss = total_loss / total_samples train_acc = total_correct / total_samples # Evaluation test_acc, class_accs, conf_matrix = evaluate( encoder, constellation, test_loader, config['num_classes'] ) # Log to TensorBoard writer.add_scalar("Loss/train", train_loss, epoch+1) writer.add_scalar("Acc/train", train_acc, epoch+1) writer.add_scalar("Acc/test", test_acc, epoch+1) writer.add_scalar("LR", optimizer.param_groups[0]['lr'], epoch+1) # Scheduler scheduler.step() # History history['train_loss'].append(train_loss) history['train_acc'].append(train_acc) history['test_acc'].append(test_acc) print(f"\n[Epoch {epoch+1}/{config['epochs']}]") print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}") if test_acc > best_acc: best_acc = test_acc print(f" 🎯 NEW BEST ACCURACY: {best_acc:.4f}") print(" Saving new best confusion matrix heatmap...") import seaborn as sns plt.figure(figsize=(12, 10)) sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.title(f'Confusion Matrix - Epoch {epoch+1} - Accuracy: {best_acc:.4f}', fontsize=16) plt.xlabel('Predicted Label', fontsize=12) plt.ylabel('True Label', fontsize=12) plt.tight_layout() last_conf_png = f'best_confusion_matrix_epoch_{epoch+1}.png' plt.savefig(last_conf_png, dpi=150) plt.close() # Final plots elapsed_time = time.time() - start_time print("\n" + "="*60) print("TRAINING COMPLETE") print("="*60) print(f" Best Test Accuracy: {best_acc*100:.2f}%") print(f" Total Training Time: {elapsed_time/60:.2f} minutes") plt.figure(figsize=(12, 5)) plt.plot(history['train_acc'], label='Train Accuracy') plt.plot(history['test_acc'], label='Test Accuracy', linewidth=2) plt.title('Model Accuracy Over Epochs', fontsize=16) plt.xlabel('Epoch', fontsize=12) plt.ylabel('Accuracy', fontsize=12) plt.legend() plt.grid(True, linestyle='--', alpha=0.6) plt.tight_layout() plt.savefig('accuracy_plot.png', dpi=150) plt.show() # Save and push bundle local_dir, hub_path = save_and_push_artifacts( encoder=encoder, constellation=constellation, diagnostic_head=diagnostic_head, config=config, class_names=class_names, history=history, best_acc=best_acc, tb_log_dir=tb_dir, last_confusion_png=last_conf_png, repo_subdir_root="pentachora-adaptive-encoded/" + DATASET_NAME, ) print(f"[done] Local artifacts at: {local_dir}") print(f"[done] HuggingFace path: {hub_path}") return encoder, constellation, history # ============================ # OPTIONAL: set your repo here # ============================ # Example: config['hf_repo_id'] = "AbstractPhil/pentachora-frequency-encoded" if __name__ == "__main__": encoder, constellation, history = main() print("\n✨ Optimized Pentachoron Constellation Training Complete!")