The following PiT_MNIST_V1.0.ipynb is a direct implementationi of the PiT pixel transformer described in the 2024 paper titled An Image is Worth More Than 16 x 16 Patches: Exploring Transformers on Individual Pixels at https://arxiv.org/html/2406.09415v1 which describes "directly treating each individual pixel as a token and achieve highly performant results" This script simply applies this PiT model architecture without any modifications to the standard NMIST numeral-images-classification dataset that is provided in Google Colab sample_data folder. The script was ran for 25 epochs and obtained 92.30 Accuracy on the Validation set ( Train Loss: 0.2800 | Val Loss: 0.2441 | Val Acc: 92.30%) by epoch 15. Loss fell and Accuracy increased (almost) monontonically per each epoch until Epoch 18. (one minor dip in accuracy between Epoch 13 and 14, and again at Epoch 18-19, and 23-24 while Train Loss always continued to drop) Final Test Accuracy: 95.01% (25 Epochs) Final Test Loss: 0.1662
Ran on A100 PiT_MNIST_V1.0.ipynb Current session GPU 0 minutes ago 2.78 GB 6.51 GB
Python 3 Google Compute Engine backend (GPU) Showing resources from 7:40 PM to 8:01 PM System RAM 2.8 / 83.5 GB
GPU RAM 6.5 / 40.0 GB
Disk 37.7 / 112.6 GB
==============================================================================
PiT_MNIST_V1.0.py [in colab: PiT_MNIST_V1.0.ipynb]
ML-Engineer LLM Agent Implementation
Description:
This script implements a Pixel Transformer (PiT) for MNIST classification,
based on the paper "An Image is Worth More Than 16x16 Patches"
(arXiv:2406.09415). It treats each pixel as an individual token, forgoing
the patch-based approach of traditional Vision Transformers.
Designed for Google Colab using the sample_data/mnist_*.csv files.
==============================================================================
import torch import torch.nn as nn import pandas as pd from torch.utils.data import Dataset, DataLoader from sklearn.model_selection import train_test_split from tqdm import tqdm import math
--- 1. Configuration & Hyperparameters ---
These parameters are chosen to be reasonable for the MNIST task and
inspired by the "Tiny" or "Small" variants in the paper.
CONFIG = { "train_file": "/content/sample_data/mnist_train_small.csv", "test_file": "/content/sample_data/mnist_test.csv", "image_size": 28, "num_classes": 10, "embed_dim": 128, # 'd' in the paper. Dimension for each pixel embedding. "num_layers": 6, # Number of Transformer Encoder layers. "num_heads": 8, # Number of heads in Multi-Head Self-Attention. Must be a divisor of embed_dim. "mlp_dim": 512, # Hidden dimension of the MLP block inside the Transformer. (4 * embed_dim is common) "dropout": 0.1, "batch_size": 128, "epochs": 25, # Increased epochs for better convergence on the small dataset. "learning_rate": 1e-4, "device": "cuda" if torch.cuda.is_available() else "cpu", } CONFIG["sequence_length"] = CONFIG["image_size"] * CONFIG["image_size"] # 784 for MNIST
print("--- Configuration ---") for key, value in CONFIG.items(): print(f"{key}: {value}") print("---------------------\n")
--- 2. Data Loading and Preprocessing ---
class MNIST_CSV_Dataset(Dataset): """Custom PyTorch Dataset for loading MNIST data from CSV files.""" def init(self, file_path): df = pd.read_csv(file_path) self.labels = torch.tensor(df.iloc[:, 0].values, dtype=torch.long) # Normalize pixel values to [0, 1] and keep as float self.pixels = torch.tensor(df.iloc[:, 1:].values, dtype=torch.float32) / 255.0
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
# The PiT's projection layer expects input of shape (in_features),
# so for each pixel, we need a tensor of shape (1).
# We reshape the 784 pixels to (784, 1).
return self.pixels[idx].unsqueeze(-1), self.labels[idx]
--- 3. Pixel Transformer (PiT) Model Architecture ---
class PixelTransformer(nn.Module): """ Pixel Transformer (PiT) model. Treats each pixel as a token and uses a Transformer Encoder for classification. """ def init(self, seq_len, num_classes, embed_dim, num_layers, num_heads, mlp_dim, dropout): super().init()
# 1. Pixel Projection: Each pixel (a single value) is projected to embed_dim.
# This is the core "pixels-as-tokens" step.
self.pixel_projection = nn.Linear(1, embed_dim)
# 2. CLS Token: A learnable parameter that is prepended to the sequence of
# pixel embeddings. Its output state is used for classification.
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
# 3. Position Embedding: Learnable embeddings to encode spatial information.
# Size is (seq_len + 1) to account for the CLS token.
# This removes the inductive bias of fixed positional encodings.
self.position_embedding = nn.Parameter(torch.randn(1, seq_len + 1, embed_dim))
self.dropout = nn.Dropout(dropout)
# 4. Transformer Encoder: The main workhorse of the model.
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=mlp_dim,
dropout=dropout,
activation="gelu",
batch_first=True # Important for (batch, seq, feature) input format
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# 5. Classification Head: A simple MLP head on top of the CLS token's output.
self.mlp_head = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, num_classes)
)
def forward(self, x):
# Input x shape: (batch_size, seq_len, 1) -> (B, 784, 1)
# Project pixels to embedding dimension
x = self.pixel_projection(x) # (B, 784, 1) -> (B, 784, embed_dim)
# Prepend CLS token
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, embed_dim)
x = torch.cat((cls_tokens, x), dim=1) # (B, 785, embed_dim)
# Add position embedding
x = x + self.position_embedding # (B, 785, embed_dim)
x = self.dropout(x)
# Pass through Transformer Encoder
x = self.transformer_encoder(x) # (B, 785, embed_dim)
# Extract the CLS token's output (at position 0)
cls_output = x[:, 0] # (B, embed_dim)
# Pass through MLP head to get logits
logits = self.mlp_head(cls_output) # (B, num_classes)
return logits
--- 4. Training and Evaluation Functions ---
def train_one_epoch(model, dataloader, criterion, optimizer, device): model.train() total_loss = 0 progress_bar = tqdm(dataloader, desc="Training", leave=False) for pixels, labels in progress_bar: pixels, labels = pixels.to(device), labels.to(device)
# Forward pass
logits = model(pixels)
loss = criterion(logits, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix(loss=loss.item())
return total_loss / len(dataloader)
def evaluate(model, dataloader, criterion, device): model.eval() total_loss = 0 correct = 0 total = 0 with torch.no_grad(): progress_bar = tqdm(dataloader, desc="Evaluating", leave=False) for pixels, labels in progress_bar: pixels, labels = pixels.to(device), labels.to(device)
logits = model(pixels)
loss = criterion(logits, labels)
total_loss += loss.item()
_, predicted = torch.max(logits.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
progress_bar.set_postfix(acc=100. * correct / total)
avg_loss = total_loss / len(dataloader)
accuracy = 100. * correct / total
return avg_loss, accuracy
--- 5. Main Execution Block ---
if name == "main": device = CONFIG["device"]
# Load full training data and split into train/validation sets
# This helps monitor overfitting, as mnist_train_small is quite small.
full_train_dataset = MNIST_CSV_Dataset(CONFIG["train_file"])
train_indices, val_indices = train_test_split(
range(len(full_train_dataset)),
test_size=0.1, # 10% for validation
random_state=42
)
train_dataset = torch.utils.data.Subset(full_train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(full_train_dataset, val_indices)
test_dataset = MNIST_CSV_Dataset(CONFIG["test_file"])
train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False)
print(f"\nData loaded.")
print(f" Training samples: {len(train_dataset)}")
print(f" Validation samples: {len(val_dataset)}")
print(f" Test samples: {len(test_dataset)}\n")
# Initialize model, loss function, and optimizer
model = PixelTransformer(
seq_len=CONFIG["sequence_length"],
num_classes=CONFIG["num_classes"],
embed_dim=CONFIG["embed_dim"],
num_layers=CONFIG["num_layers"],
num_heads=CONFIG["num_heads"],
mlp_dim=CONFIG["mlp_dim"],
dropout=CONFIG["dropout"]
).to(device)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model initialized on {device}.")
print(f"Total trainable parameters: {total_params:,}\n")
criterion = nn.CrossEntropyLoss()
# AdamW is often preferred for Transformers
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])
# Training loop
best_val_acc = 0
print("--- Starting Training ---")
for epoch in range(CONFIG["epochs"]):
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_acc = evaluate(model, val_loader, criterion, device)
print(
f"Epoch {epoch+1:02}/{CONFIG['epochs']} | "
f"Train Loss: {train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} | "
f"Val Acc: {val_acc:.2f}%"
)
if val_acc > best_val_acc:
best_val_acc = val_acc
print(f" -> New best validation accuracy! Saving model state.")
torch.save(model.state_dict(), "PiT_MNIST_best.pth")
print("--- Training Finished ---\n")
# Final evaluation on the test set using the best model
print("--- Evaluating on Test Set ---")
model.load_state_dict(torch.load("PiT_MNIST_best.pth"))
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"Final Test Loss: {test_loss:.4f}")
print(f"Final Test Accuracy: {test_acc:.2f}%")
print("----------------------------\n")
[The PiT_MNIST_V1.0.ipynb script ran out of memory in CPUR, but was able to run and train fast in A100 GPU mode]
--- Configuration --- train_file: /content/sample_data/mnist_train_small.csv test_file: /content/sample_data/mnist_test.csv image_size: 28 num_classes: 10 embed_dim: 128 num_layers: 6 num_heads: 8 mlp_dim: 512 dropout: 0.1 batch_size: 128 epochs: 25 learning_rate: 0.0001 device: cuda sequence_length: 784
Data loaded. Training samples: 17999 Validation samples: 2000 Test samples: 9999
Model initialized on cuda. Total trainable parameters: 1,292,042
--- Starting Training --- Epoch 01/25 | Train Loss: 2.2063 | Val Loss: 2.0610 | Val Acc: 22.75% -> New best validation accuracy! Saving model state. Epoch 02/25 | Train Loss: 1.9907 | Val Loss: 1.7945 | Val Acc: 32.00% -> New best validation accuracy! Saving model state. Epoch 03/25 | Train Loss: 1.5767 | Val Loss: 1.1938 | Val Acc: 58.35% -> New best validation accuracy! Saving model state. Epoch 04/25 | Train Loss: 1.0441 | Val Loss: 0.7131 | Val Acc: 77.10% -> New best validation accuracy! Saving model state. Epoch 05/25 | Train Loss: 0.7299 | Val Loss: 0.5490 | Val Acc: 82.95% -> New best validation accuracy! Saving model state. Epoch 06/25 | Train Loss: 0.5935 | Val Loss: 0.4821 | Val Acc: 84.60% -> New best validation accuracy! Saving model state. Epoch 07/25 | Train Loss: 0.5311 | Val Loss: 0.4021 | Val Acc: 86.95% -> New best validation accuracy! Saving model state. Epoch 08/25 | Train Loss: 0.4682 | Val Loss: 0.3680 | Val Acc: 88.05% -> New best validation accuracy! Saving model state. Epoch 09/25 | Train Loss: 0.4264 | Val Loss: 0.3446 | Val Acc: 89.20% -> New best validation accuracy! Saving model state. Epoch 10/25 | Train Loss: 0.4038 | Val Loss: 0.3163 | Val Acc: 89.95% -> New best validation accuracy! Saving model state. Epoch 11/25 | Train Loss: 0.3641 | Val Loss: 0.2941 | Val Acc: 90.80% -> New best validation accuracy! Saving model state. Epoch 12/25 | Train Loss: 0.3447 | Val Loss: 0.2759 | Val Acc: 91.45% -> New best validation accuracy! Saving model state. Epoch 13/25 | Train Loss: 0.3181 | Val Loss: 0.2603 | Val Acc: 92.05% -> New best validation accuracy! Saving model state. Epoch 14/25 | Train Loss: 0.3023 | Val Loss: 0.2695 | Val Acc: 91.90% Epoch 15/25 | Train Loss: 0.2800 | Val Loss: 0.2441 | Val Acc: 92.30% -> New best validation accuracy! Saving model state. Epoch 16/25 | Train Loss: 0.2677 | Val Loss: 0.2377 | Val Acc: 92.65% -> New best validation accuracy! Saving model state. Epoch 17/25 | Train Loss: 0.2535 | Val Loss: 0.2143 | Val Acc: 93.80% -> New best validation accuracy! Saving model state. Epoch 18/25 | Train Loss: 0.2395 | Val Loss: 0.2059 | Val Acc: 94.05% -> New best validation accuracy! Saving model state. Epoch 19/25 | Train Loss: 0.2276 | Val Loss: 0.2126 | Val Acc: 93.75% Epoch 20/25 | Train Loss: 0.2189 | Val Loss: 0.1907 | Val Acc: 94.40% -> New best validation accuracy! Saving model state. Epoch 21/25 | Train Loss: 0.2113 | Val Loss: 0.1892 | Val Acc: 94.35% Epoch 22/25 | Train Loss: 0.2004 | Val Loss: 0.1775 | Val Acc: 94.50% -> New best validation accuracy! Saving model state. Epoch 23/25 | Train Loss: 0.1927 | Val Loss: 0.1912 | Val Acc: 94.15% Epoch 24/25 | Train Loss: 0.1836 | Val Loss: 0.1746 | Val Acc: 94.75% -> New best validation accuracy! Saving model state. Epoch 25/25 | Train Loss: 0.1804 | Val Loss: 0.1642 | Val Acc: 94.75% --- Training Finished ---