PiT_MNIST_Colab / PiT_MNIST_Colab_README.md
MartialTerran's picture
Update PiT_MNIST_Colab_README.md
85e8225 verified
The following PiT_MNIST_V1.0.ipynb is a direct implementationi of the PiT pixel transformer described in the 2024 paper titled
An Image is Worth More Than 16 x 16 Patches: Exploring Transformers on Individual Pixels
at https://arxiv.org/html/2406.09415v1 which describes "directly treating each individual pixel as a token and achieve highly performant results"
This script simply applies this PiT model architecture without any modifications to the standard NMIST numeral-images-classification dataset that is provided in Google Colab sample_data folder.
The script was ran for 25 epochs and obtained 92.30 Accuracy on the Validation set ( Train Loss: 0.2800 | Val Loss: 0.2441 | Val Acc: 92.30%) by epoch 15.
Loss fell and Accuracy increased (almost) monontonically per each epoch until Epoch 18. (one minor dip in accuracy between Epoch 13 and 14, and again at Epoch 18-19, and 23-24 while Train Loss always continued to drop)
Final Test Accuracy: 95.01% (25 Epochs)
Final Test Loss: 0.1662
Ran on A100
PiT_MNIST_V1.0.ipynb
Current session
GPU
0 minutes ago
2.78 GB
6.51 GB
Python 3 Google Compute Engine backend (GPU)
Showing resources from 7:40 PM to 8:01 PM
System RAM
2.8 / 83.5 GB
GPU RAM
6.5 / 40.0 GB
Disk
37.7 / 112.6 GB
# ==============================================================================
# PiT_MNIST_V1.0.py [in colab: PiT_MNIST_V1.0.ipynb]
#
# ML-Engineer LLM Agent Implementation
#
# Description:
# This script implements a Pixel Transformer (PiT) for MNIST classification,
# based on the paper "An Image is Worth More Than 16x16 Patches"
# (arXiv:2406.09415). It treats each pixel as an individual token, forgoing
# the patch-based approach of traditional Vision Transformers.
#
# Designed for Google Colab using the sample_data/mnist_*.csv files.
# ==============================================================================
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import math
# --- 1. Configuration & Hyperparameters ---
# These parameters are chosen to be reasonable for the MNIST task and
# inspired by the "Tiny" or "Small" variants in the paper.
CONFIG = {
"train_file": "/content/sample_data/mnist_train_small.csv",
"test_file": "/content/sample_data/mnist_test.csv",
"image_size": 28,
"num_classes": 10,
"embed_dim": 128, # 'd' in the paper. Dimension for each pixel embedding.
"num_layers": 6, # Number of Transformer Encoder layers.
"num_heads": 8, # Number of heads in Multi-Head Self-Attention. Must be a divisor of embed_dim.
"mlp_dim": 512, # Hidden dimension of the MLP block inside the Transformer. (4 * embed_dim is common)
"dropout": 0.1,
"batch_size": 128,
"epochs": 25, # Increased epochs for better convergence on the small dataset.
"learning_rate": 1e-4,
"device": "cuda" if torch.cuda.is_available() else "cpu",
}
CONFIG["sequence_length"] = CONFIG["image_size"] * CONFIG["image_size"] # 784 for MNIST
print("--- Configuration ---")
for key, value in CONFIG.items():
print(f"{key}: {value}")
print("---------------------\n")
# --- 2. Data Loading and Preprocessing ---
class MNIST_CSV_Dataset(Dataset):
"""Custom PyTorch Dataset for loading MNIST data from CSV files."""
def __init__(self, file_path):
df = pd.read_csv(file_path)
self.labels = torch.tensor(df.iloc[:, 0].values, dtype=torch.long)
# Normalize pixel values to [0, 1] and keep as float
self.pixels = torch.tensor(df.iloc[:, 1:].values, dtype=torch.float32) / 255.0
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
# The PiT's projection layer expects input of shape (in_features),
# so for each pixel, we need a tensor of shape (1).
# We reshape the 784 pixels to (784, 1).
return self.pixels[idx].unsqueeze(-1), self.labels[idx]
# --- 3. Pixel Transformer (PiT) Model Architecture ---
class PixelTransformer(nn.Module):
"""
Pixel Transformer (PiT) model.
Treats each pixel as a token and uses a Transformer Encoder for classification.
"""
def __init__(self, seq_len, num_classes, embed_dim, num_layers, num_heads, mlp_dim, dropout):
super().__init__()
# 1. Pixel Projection: Each pixel (a single value) is projected to embed_dim.
# This is the core "pixels-as-tokens" step.
self.pixel_projection = nn.Linear(1, embed_dim)
# 2. CLS Token: A learnable parameter that is prepended to the sequence of
# pixel embeddings. Its output state is used for classification.
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
# 3. Position Embedding: Learnable embeddings to encode spatial information.
# Size is (seq_len + 1) to account for the CLS token.
# This removes the inductive bias of fixed positional encodings.
self.position_embedding = nn.Parameter(torch.randn(1, seq_len + 1, embed_dim))
self.dropout = nn.Dropout(dropout)
# 4. Transformer Encoder: The main workhorse of the model.
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=mlp_dim,
dropout=dropout,
activation="gelu",
batch_first=True # Important for (batch, seq, feature) input format
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# 5. Classification Head: A simple MLP head on top of the CLS token's output.
self.mlp_head = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, num_classes)
)
def forward(self, x):
# Input x shape: (batch_size, seq_len, 1) -> (B, 784, 1)
# Project pixels to embedding dimension
x = self.pixel_projection(x) # (B, 784, 1) -> (B, 784, embed_dim)
# Prepend CLS token
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # (B, 1, embed_dim)
x = torch.cat((cls_tokens, x), dim=1) # (B, 785, embed_dim)
# Add position embedding
x = x + self.position_embedding # (B, 785, embed_dim)
x = self.dropout(x)
# Pass through Transformer Encoder
x = self.transformer_encoder(x) # (B, 785, embed_dim)
# Extract the CLS token's output (at position 0)
cls_output = x[:, 0] # (B, embed_dim)
# Pass through MLP head to get logits
logits = self.mlp_head(cls_output) # (B, num_classes)
return logits
# --- 4. Training and Evaluation Functions ---
def train_one_epoch(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
progress_bar = tqdm(dataloader, desc="Training", leave=False)
for pixels, labels in progress_bar:
pixels, labels = pixels.to(device), labels.to(device)
# Forward pass
logits = model(pixels)
loss = criterion(logits, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
progress_bar.set_postfix(loss=loss.item())
return total_loss / len(dataloader)
def evaluate(model, dataloader, criterion, device):
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
progress_bar = tqdm(dataloader, desc="Evaluating", leave=False)
for pixels, labels in progress_bar:
pixels, labels = pixels.to(device), labels.to(device)
logits = model(pixels)
loss = criterion(logits, labels)
total_loss += loss.item()
_, predicted = torch.max(logits.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
progress_bar.set_postfix(acc=100. * correct / total)
avg_loss = total_loss / len(dataloader)
accuracy = 100. * correct / total
return avg_loss, accuracy
# --- 5. Main Execution Block ---
if __name__ == "__main__":
device = CONFIG["device"]
# Load full training data and split into train/validation sets
# This helps monitor overfitting, as mnist_train_small is quite small.
full_train_dataset = MNIST_CSV_Dataset(CONFIG["train_file"])
train_indices, val_indices = train_test_split(
range(len(full_train_dataset)),
test_size=0.1, # 10% for validation
random_state=42
)
train_dataset = torch.utils.data.Subset(full_train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(full_train_dataset, val_indices)
test_dataset = MNIST_CSV_Dataset(CONFIG["test_file"])
train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False)
print(f"\nData loaded.")
print(f" Training samples: {len(train_dataset)}")
print(f" Validation samples: {len(val_dataset)}")
print(f" Test samples: {len(test_dataset)}\n")
# Initialize model, loss function, and optimizer
model = PixelTransformer(
seq_len=CONFIG["sequence_length"],
num_classes=CONFIG["num_classes"],
embed_dim=CONFIG["embed_dim"],
num_layers=CONFIG["num_layers"],
num_heads=CONFIG["num_heads"],
mlp_dim=CONFIG["mlp_dim"],
dropout=CONFIG["dropout"]
).to(device)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model initialized on {device}.")
print(f"Total trainable parameters: {total_params:,}\n")
criterion = nn.CrossEntropyLoss()
# AdamW is often preferred for Transformers
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])
# Training loop
best_val_acc = 0
print("--- Starting Training ---")
for epoch in range(CONFIG["epochs"]):
train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_acc = evaluate(model, val_loader, criterion, device)
print(
f"Epoch {epoch+1:02}/{CONFIG['epochs']} | "
f"Train Loss: {train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} | "
f"Val Acc: {val_acc:.2f}%"
)
if val_acc > best_val_acc:
best_val_acc = val_acc
print(f" -> New best validation accuracy! Saving model state.")
torch.save(model.state_dict(), "PiT_MNIST_best.pth")
print("--- Training Finished ---\n")
# Final evaluation on the test set using the best model
print("--- Evaluating on Test Set ---")
model.load_state_dict(torch.load("PiT_MNIST_best.pth"))
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"Final Test Loss: {test_loss:.4f}")
print(f"Final Test Accuracy: {test_acc:.2f}%")
print("----------------------------\n")
[The PiT_MNIST_V1.0.ipynb script ran out of memory in CPUR, but was able to run and train fast in A100 GPU mode]
--- Configuration ---
train_file: /content/sample_data/mnist_train_small.csv
test_file: /content/sample_data/mnist_test.csv
image_size: 28
num_classes: 10
embed_dim: 128
num_layers: 6
num_heads: 8
mlp_dim: 512
dropout: 0.1
batch_size: 128
epochs: 25
learning_rate: 0.0001
device: cuda
sequence_length: 784
---------------------
Data loaded.
Training samples: 17999
Validation samples: 2000
Test samples: 9999
Model initialized on cuda.
Total trainable parameters: 1,292,042
--- Starting Training ---
Epoch 01/25 | Train Loss: 2.2063 | Val Loss: 2.0610 | Val Acc: 22.75%
-> New best validation accuracy! Saving model state.
Epoch 02/25 | Train Loss: 1.9907 | Val Loss: 1.7945 | Val Acc: 32.00%
-> New best validation accuracy! Saving model state.
Epoch 03/25 | Train Loss: 1.5767 | Val Loss: 1.1938 | Val Acc: 58.35%
-> New best validation accuracy! Saving model state.
Epoch 04/25 | Train Loss: 1.0441 | Val Loss: 0.7131 | Val Acc: 77.10%
-> New best validation accuracy! Saving model state.
Epoch 05/25 | Train Loss: 0.7299 | Val Loss: 0.5490 | Val Acc: 82.95%
-> New best validation accuracy! Saving model state.
Epoch 06/25 | Train Loss: 0.5935 | Val Loss: 0.4821 | Val Acc: 84.60%
-> New best validation accuracy! Saving model state.
Epoch 07/25 | Train Loss: 0.5311 | Val Loss: 0.4021 | Val Acc: 86.95%
-> New best validation accuracy! Saving model state.
Epoch 08/25 | Train Loss: 0.4682 | Val Loss: 0.3680 | Val Acc: 88.05%
-> New best validation accuracy! Saving model state.
Epoch 09/25 | Train Loss: 0.4264 | Val Loss: 0.3446 | Val Acc: 89.20%
-> New best validation accuracy! Saving model state.
Epoch 10/25 | Train Loss: 0.4038 | Val Loss: 0.3163 | Val Acc: 89.95%
-> New best validation accuracy! Saving model state.
Epoch 11/25 | Train Loss: 0.3641 | Val Loss: 0.2941 | Val Acc: 90.80%
-> New best validation accuracy! Saving model state.
Epoch 12/25 | Train Loss: 0.3447 | Val Loss: 0.2759 | Val Acc: 91.45%
-> New best validation accuracy! Saving model state.
Epoch 13/25 | Train Loss: 0.3181 | Val Loss: 0.2603 | Val Acc: 92.05%
-> New best validation accuracy! Saving model state.
Epoch 14/25 | Train Loss: 0.3023 | Val Loss: 0.2695 | Val Acc: 91.90%
Epoch 15/25 | Train Loss: 0.2800 | Val Loss: 0.2441 | Val Acc: 92.30%
-> New best validation accuracy! Saving model state.
Epoch 16/25 | Train Loss: 0.2677 | Val Loss: 0.2377 | Val Acc: 92.65%
-> New best validation accuracy! Saving model state.
Epoch 17/25 | Train Loss: 0.2535 | Val Loss: 0.2143 | Val Acc: 93.80%
-> New best validation accuracy! Saving model state.
Epoch 18/25 | Train Loss: 0.2395 | Val Loss: 0.2059 | Val Acc: 94.05%
-> New best validation accuracy! Saving model state.
Epoch 19/25 | Train Loss: 0.2276 | Val Loss: 0.2126 | Val Acc: 93.75%
Epoch 20/25 | Train Loss: 0.2189 | Val Loss: 0.1907 | Val Acc: 94.40%
-> New best validation accuracy! Saving model state.
Epoch 21/25 | Train Loss: 0.2113 | Val Loss: 0.1892 | Val Acc: 94.35%
Epoch 22/25 | Train Loss: 0.2004 | Val Loss: 0.1775 | Val Acc: 94.50%
-> New best validation accuracy! Saving model state.
Epoch 23/25 | Train Loss: 0.1927 | Val Loss: 0.1912 | Val Acc: 94.15%
Epoch 24/25 | Train Loss: 0.1836 | Val Loss: 0.1746 | Val Acc: 94.75%
-> New best validation accuracy! Saving model state.
Epoch 25/25 | Train Loss: 0.1804 | Val Loss: 0.1642 | Val Acc: 94.75%
--- Training Finished ---
--- Evaluating on Test Set ---
Final Test Loss: 0.1662
Final Test Accuracy: 95.01%
----------------------------
---
license: apache-2.0
---