| The following PiT_MNIST_V1.0.ipynb is a direct implementationi of the PiT pixel transformer described in the 2024 paper titled | |
| An Image is Worth More Than 16 x 16 Patches: Exploring Transformers on Individual Pixels | |
| at https://arxiv.org/html/2406.09415v1 which describes "directly treating each individual pixel as a token and achieve highly performant results" | |
| This script simply applies this PiT model architecture without any modifications to the standard NMIST numeral-images-classification dataset that is provided in Google Colab sample_data folder. | |
| The script was ran for 25 epochs and obtained 92.30 Accuracy on the Validation set ( Train Loss: 0.2800 | Val Loss: 0.2441 | Val Acc: 92.30%) by epoch 15. | |
| Loss fell and Accuracy increased (almost) monontonically per each epoch until Epoch 18. (one minor dip in accuracy between Epoch 13 and 14, and again at Epoch 18-19, and 23-24 while Train Loss always continued to drop) | |
| Final Test Accuracy: 95.01% (25 Epochs) | |
| Final Test Loss: 0.1662 | |
| Ran on A100 | |
| PiT_MNIST_V1.0.ipynb | |
| Current session | |
| GPU | |
| 0 minutes ago | |
| 2.78 GB | |
| 6.51 GB | |
| Python 3 Google Compute Engine backend (GPU) | |
| Showing resources from 7:40 PM to 8:01 PM | |
| System RAM | |
| 2.8 / 83.5 GB | |
| GPU RAM | |
| 6.5 / 40.0 GB | |
| Disk | |
| 37.7 / 112.6 GB | |
| # ============================================================================== | |
| # PiT_MNIST_V1.0.py [in colab: PiT_MNIST_V1.0.ipynb] | |
| # | |
| # ML-Engineer LLM Agent Implementation | |
| # | |
| # Description: | |
| # This script implements a Pixel Transformer (PiT) for MNIST classification, | |
| # based on the paper "An Image is Worth More Than 16x16 Patches" | |
| # (arXiv:2406.09415). It treats each pixel as an individual token, forgoing | |
| # the patch-based approach of traditional Vision Transformers. | |
| # | |
| # Designed for Google Colab using the sample_data/mnist_*.csv files. | |
| # ============================================================================== | |
| import torch | |
| import torch.nn as nn | |
| import pandas as pd | |
| from torch.utils.data import Dataset, DataLoader | |
| from sklearn.model_selection import train_test_split | |
| from tqdm import tqdm | |
| import math | |
| # --- 1. Configuration & Hyperparameters --- | |
| # These parameters are chosen to be reasonable for the MNIST task and | |
| # inspired by the "Tiny" or "Small" variants in the paper. | |
| CONFIG = { | |
| "train_file": "/content/sample_data/mnist_train_small.csv", | |
| "test_file": "/content/sample_data/mnist_test.csv", | |
| "image_size": 28, | |
| "num_classes": 10, | |
| "embed_dim": 128, # 'd' in the paper. Dimension for each pixel embedding. | |
| "num_layers": 6, # Number of Transformer Encoder layers. | |
| "num_heads": 8, # Number of heads in Multi-Head Self-Attention. Must be a divisor of embed_dim. | |
| "mlp_dim": 512, # Hidden dimension of the MLP block inside the Transformer. (4 * embed_dim is common) | |
| "dropout": 0.1, | |
| "batch_size": 128, | |
| "epochs": 25, # Increased epochs for better convergence on the small dataset. | |
| "learning_rate": 1e-4, | |
| "device": "cuda" if torch.cuda.is_available() else "cpu", | |
| } | |
| CONFIG["sequence_length"] = CONFIG["image_size"] * CONFIG["image_size"] # 784 for MNIST | |
| print("--- Configuration ---") | |
| for key, value in CONFIG.items(): | |
| print(f"{key}: {value}") | |
| print("---------------------\n") | |
| # --- 2. Data Loading and Preprocessing --- | |
| class MNIST_CSV_Dataset(Dataset): | |
| """Custom PyTorch Dataset for loading MNIST data from CSV files.""" | |
| def __init__(self, file_path): | |
| df = pd.read_csv(file_path) | |
| self.labels = torch.tensor(df.iloc[:, 0].values, dtype=torch.long) | |
| # Normalize pixel values to [0, 1] and keep as float | |
| self.pixels = torch.tensor(df.iloc[:, 1:].values, dtype=torch.float32) / 255.0 | |
| def __len__(self): | |
| return len(self.labels) | |
| def __getitem__(self, idx): | |
| # The PiT's projection layer expects input of shape (in_features), | |
| # so for each pixel, we need a tensor of shape (1). | |
| # We reshape the 784 pixels to (784, 1). | |
| return self.pixels[idx].unsqueeze(-1), self.labels[idx] | |
| # --- 3. Pixel Transformer (PiT) Model Architecture --- | |
| class PixelTransformer(nn.Module): | |
| """ | |
| Pixel Transformer (PiT) model. | |
| Treats each pixel as a token and uses a Transformer Encoder for classification. | |
| """ | |
| def __init__(self, seq_len, num_classes, embed_dim, num_layers, num_heads, mlp_dim, dropout): | |
| super().__init__() | |
| # 1. Pixel Projection: Each pixel (a single value) is projected to embed_dim. | |
| # This is the core "pixels-as-tokens" step. | |
| self.pixel_projection = nn.Linear(1, embed_dim) | |
| # 2. CLS Token: A learnable parameter that is prepended to the sequence of | |
| # pixel embeddings. Its output state is used for classification. | |
| self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) | |
| # 3. Position Embedding: Learnable embeddings to encode spatial information. | |
| # Size is (seq_len + 1) to account for the CLS token. | |
| # This removes the inductive bias of fixed positional encodings. | |
| self.position_embedding = nn.Parameter(torch.randn(1, seq_len + 1, embed_dim)) | |
| self.dropout = nn.Dropout(dropout) | |
| # 4. Transformer Encoder: The main workhorse of the model. | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=embed_dim, | |
| nhead=num_heads, | |
| dim_feedforward=mlp_dim, | |
| dropout=dropout, | |
| activation="gelu", | |
| batch_first=True # Important for (batch, seq, feature) input format | |
| ) | |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
| # 5. Classification Head: A simple MLP head on top of the CLS token's output. | |
| self.mlp_head = nn.Sequential( | |
| nn.LayerNorm(embed_dim), | |
| nn.Linear(embed_dim, num_classes) | |
| ) | |
| def forward(self, x): | |
| # Input x shape: (batch_size, seq_len, 1) -> (B, 784, 1) | |
| # Project pixels to embedding dimension | |
| x = self.pixel_projection(x) # (B, 784, 1) -> (B, 784, embed_dim) | |
| # Prepend CLS token | |
| cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, embed_dim) | |
| x = torch.cat((cls_tokens, x), dim=1) # (B, 785, embed_dim) | |
| # Add position embedding | |
| x = x + self.position_embedding # (B, 785, embed_dim) | |
| x = self.dropout(x) | |
| # Pass through Transformer Encoder | |
| x = self.transformer_encoder(x) # (B, 785, embed_dim) | |
| # Extract the CLS token's output (at position 0) | |
| cls_output = x[:, 0] # (B, embed_dim) | |
| # Pass through MLP head to get logits | |
| logits = self.mlp_head(cls_output) # (B, num_classes) | |
| return logits | |
| # --- 4. Training and Evaluation Functions --- | |
| def train_one_epoch(model, dataloader, criterion, optimizer, device): | |
| model.train() | |
| total_loss = 0 | |
| progress_bar = tqdm(dataloader, desc="Training", leave=False) | |
| for pixels, labels in progress_bar: | |
| pixels, labels = pixels.to(device), labels.to(device) | |
| # Forward pass | |
| logits = model(pixels) | |
| loss = criterion(logits, labels) | |
| # Backward and optimize | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| progress_bar.set_postfix(loss=loss.item()) | |
| return total_loss / len(dataloader) | |
| def evaluate(model, dataloader, criterion, device): | |
| model.eval() | |
| total_loss = 0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| progress_bar = tqdm(dataloader, desc="Evaluating", leave=False) | |
| for pixels, labels in progress_bar: | |
| pixels, labels = pixels.to(device), labels.to(device) | |
| logits = model(pixels) | |
| loss = criterion(logits, labels) | |
| total_loss += loss.item() | |
| _, predicted = torch.max(logits.data, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| progress_bar.set_postfix(acc=100. * correct / total) | |
| avg_loss = total_loss / len(dataloader) | |
| accuracy = 100. * correct / total | |
| return avg_loss, accuracy | |
| # --- 5. Main Execution Block --- | |
| if __name__ == "__main__": | |
| device = CONFIG["device"] | |
| # Load full training data and split into train/validation sets | |
| # This helps monitor overfitting, as mnist_train_small is quite small. | |
| full_train_dataset = MNIST_CSV_Dataset(CONFIG["train_file"]) | |
| train_indices, val_indices = train_test_split( | |
| range(len(full_train_dataset)), | |
| test_size=0.1, # 10% for validation | |
| random_state=42 | |
| ) | |
| train_dataset = torch.utils.data.Subset(full_train_dataset, train_indices) | |
| val_dataset = torch.utils.data.Subset(full_train_dataset, val_indices) | |
| test_dataset = MNIST_CSV_Dataset(CONFIG["test_file"]) | |
| train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False) | |
| test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False) | |
| print(f"\nData loaded.") | |
| print(f" Training samples: {len(train_dataset)}") | |
| print(f" Validation samples: {len(val_dataset)}") | |
| print(f" Test samples: {len(test_dataset)}\n") | |
| # Initialize model, loss function, and optimizer | |
| model = PixelTransformer( | |
| seq_len=CONFIG["sequence_length"], | |
| num_classes=CONFIG["num_classes"], | |
| embed_dim=CONFIG["embed_dim"], | |
| num_layers=CONFIG["num_layers"], | |
| num_heads=CONFIG["num_heads"], | |
| mlp_dim=CONFIG["mlp_dim"], | |
| dropout=CONFIG["dropout"] | |
| ).to(device) | |
| total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Model initialized on {device}.") | |
| print(f"Total trainable parameters: {total_params:,}\n") | |
| criterion = nn.CrossEntropyLoss() | |
| # AdamW is often preferred for Transformers | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"]) | |
| # Training loop | |
| best_val_acc = 0 | |
| print("--- Starting Training ---") | |
| for epoch in range(CONFIG["epochs"]): | |
| train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) | |
| val_loss, val_acc = evaluate(model, val_loader, criterion, device) | |
| print( | |
| f"Epoch {epoch+1:02}/{CONFIG['epochs']} | " | |
| f"Train Loss: {train_loss:.4f} | " | |
| f"Val Loss: {val_loss:.4f} | " | |
| f"Val Acc: {val_acc:.2f}%" | |
| ) | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| print(f" -> New best validation accuracy! Saving model state.") | |
| torch.save(model.state_dict(), "PiT_MNIST_best.pth") | |
| print("--- Training Finished ---\n") | |
| # Final evaluation on the test set using the best model | |
| print("--- Evaluating on Test Set ---") | |
| model.load_state_dict(torch.load("PiT_MNIST_best.pth")) | |
| test_loss, test_acc = evaluate(model, test_loader, criterion, device) | |
| print(f"Final Test Loss: {test_loss:.4f}") | |
| print(f"Final Test Accuracy: {test_acc:.2f}%") | |
| print("----------------------------\n") | |
| [The PiT_MNIST_V1.0.ipynb script ran out of memory in CPUR, but was able to run and train fast in A100 GPU mode] | |
| --- Configuration --- | |
| train_file: /content/sample_data/mnist_train_small.csv | |
| test_file: /content/sample_data/mnist_test.csv | |
| image_size: 28 | |
| num_classes: 10 | |
| embed_dim: 128 | |
| num_layers: 6 | |
| num_heads: 8 | |
| mlp_dim: 512 | |
| dropout: 0.1 | |
| batch_size: 128 | |
| epochs: 25 | |
| learning_rate: 0.0001 | |
| device: cuda | |
| sequence_length: 784 | |
| --------------------- | |
| Data loaded. | |
| Training samples: 17999 | |
| Validation samples: 2000 | |
| Test samples: 9999 | |
| Model initialized on cuda. | |
| Total trainable parameters: 1,292,042 | |
| --- Starting Training --- | |
| Epoch 01/25 | Train Loss: 2.2063 | Val Loss: 2.0610 | Val Acc: 22.75% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 02/25 | Train Loss: 1.9907 | Val Loss: 1.7945 | Val Acc: 32.00% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 03/25 | Train Loss: 1.5767 | Val Loss: 1.1938 | Val Acc: 58.35% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 04/25 | Train Loss: 1.0441 | Val Loss: 0.7131 | Val Acc: 77.10% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 05/25 | Train Loss: 0.7299 | Val Loss: 0.5490 | Val Acc: 82.95% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 06/25 | Train Loss: 0.5935 | Val Loss: 0.4821 | Val Acc: 84.60% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 07/25 | Train Loss: 0.5311 | Val Loss: 0.4021 | Val Acc: 86.95% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 08/25 | Train Loss: 0.4682 | Val Loss: 0.3680 | Val Acc: 88.05% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 09/25 | Train Loss: 0.4264 | Val Loss: 0.3446 | Val Acc: 89.20% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 10/25 | Train Loss: 0.4038 | Val Loss: 0.3163 | Val Acc: 89.95% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 11/25 | Train Loss: 0.3641 | Val Loss: 0.2941 | Val Acc: 90.80% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 12/25 | Train Loss: 0.3447 | Val Loss: 0.2759 | Val Acc: 91.45% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 13/25 | Train Loss: 0.3181 | Val Loss: 0.2603 | Val Acc: 92.05% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 14/25 | Train Loss: 0.3023 | Val Loss: 0.2695 | Val Acc: 91.90% | |
| Epoch 15/25 | Train Loss: 0.2800 | Val Loss: 0.2441 | Val Acc: 92.30% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 16/25 | Train Loss: 0.2677 | Val Loss: 0.2377 | Val Acc: 92.65% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 17/25 | Train Loss: 0.2535 | Val Loss: 0.2143 | Val Acc: 93.80% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 18/25 | Train Loss: 0.2395 | Val Loss: 0.2059 | Val Acc: 94.05% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 19/25 | Train Loss: 0.2276 | Val Loss: 0.2126 | Val Acc: 93.75% | |
| Epoch 20/25 | Train Loss: 0.2189 | Val Loss: 0.1907 | Val Acc: 94.40% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 21/25 | Train Loss: 0.2113 | Val Loss: 0.1892 | Val Acc: 94.35% | |
| Epoch 22/25 | Train Loss: 0.2004 | Val Loss: 0.1775 | Val Acc: 94.50% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 23/25 | Train Loss: 0.1927 | Val Loss: 0.1912 | Val Acc: 94.15% | |
| Epoch 24/25 | Train Loss: 0.1836 | Val Loss: 0.1746 | Val Acc: 94.75% | |
| -> New best validation accuracy! Saving model state. | |
| Epoch 25/25 | Train Loss: 0.1804 | Val Loss: 0.1642 | Val Acc: 94.75% | |
| --- Training Finished --- | |
| --- Evaluating on Test Set --- | |
| Final Test Loss: 0.1662 | |
| Final Test Accuracy: 95.01% | |
| ---------------------------- | |
| --- | |
| license: apache-2.0 | |
| --- | |