PiT_MNIST_Colab / PiT_MNIST_Colab_README.md
MartialTerran's picture
Update PiT_MNIST_Colab_README.md
85e8225 verified

The following PiT_MNIST_V1.0.ipynb is a direct implementationi of the PiT pixel transformer described in the 2024 paper titled An Image is Worth More Than 16 x 16 Patches: Exploring Transformers on Individual Pixels at https://arxiv.org/html/2406.09415v1 which describes "directly treating each individual pixel as a token and achieve highly performant results" This script simply applies this PiT model architecture without any modifications to the standard NMIST numeral-images-classification dataset that is provided in Google Colab sample_data folder. The script was ran for 25 epochs and obtained 92.30 Accuracy on the Validation set ( Train Loss: 0.2800 | Val Loss: 0.2441 | Val Acc: 92.30%) by epoch 15. Loss fell and Accuracy increased (almost) monontonically per each epoch until Epoch 18. (one minor dip in accuracy between Epoch 13 and 14, and again at Epoch 18-19, and 23-24 while Train Loss always continued to drop) Final Test Accuracy: 95.01% (25 Epochs) Final Test Loss: 0.1662

Ran on A100 PiT_MNIST_V1.0.ipynb Current session GPU 0 minutes ago 2.78 GB 6.51 GB

Python 3 Google Compute Engine backend (GPU) Showing resources from 7:40 PM to 8:01 PM System RAM 2.8 / 83.5 GB

GPU RAM 6.5 / 40.0 GB

Disk 37.7 / 112.6 GB

==============================================================================

PiT_MNIST_V1.0.py [in colab: PiT_MNIST_V1.0.ipynb]

ML-Engineer LLM Agent Implementation

Description:

This script implements a Pixel Transformer (PiT) for MNIST classification,

based on the paper "An Image is Worth More Than 16x16 Patches"

(arXiv:2406.09415). It treats each pixel as an individual token, forgoing

the patch-based approach of traditional Vision Transformers.

Designed for Google Colab using the sample_data/mnist_*.csv files.

==============================================================================

import torch import torch.nn as nn import pandas as pd from torch.utils.data import Dataset, DataLoader from sklearn.model_selection import train_test_split from tqdm import tqdm import math

--- 1. Configuration & Hyperparameters ---

These parameters are chosen to be reasonable for the MNIST task and

inspired by the "Tiny" or "Small" variants in the paper.

CONFIG = { "train_file": "/content/sample_data/mnist_train_small.csv", "test_file": "/content/sample_data/mnist_test.csv", "image_size": 28, "num_classes": 10, "embed_dim": 128, # 'd' in the paper. Dimension for each pixel embedding. "num_layers": 6, # Number of Transformer Encoder layers. "num_heads": 8, # Number of heads in Multi-Head Self-Attention. Must be a divisor of embed_dim. "mlp_dim": 512, # Hidden dimension of the MLP block inside the Transformer. (4 * embed_dim is common) "dropout": 0.1, "batch_size": 128, "epochs": 25, # Increased epochs for better convergence on the small dataset. "learning_rate": 1e-4, "device": "cuda" if torch.cuda.is_available() else "cpu", } CONFIG["sequence_length"] = CONFIG["image_size"] * CONFIG["image_size"] # 784 for MNIST

print("--- Configuration ---") for key, value in CONFIG.items(): print(f"{key}: {value}") print("---------------------\n")

--- 2. Data Loading and Preprocessing ---

class MNIST_CSV_Dataset(Dataset): """Custom PyTorch Dataset for loading MNIST data from CSV files.""" def init(self, file_path): df = pd.read_csv(file_path) self.labels = torch.tensor(df.iloc[:, 0].values, dtype=torch.long) # Normalize pixel values to [0, 1] and keep as float self.pixels = torch.tensor(df.iloc[:, 1:].values, dtype=torch.float32) / 255.0

def __len__(self):
    return len(self.labels)

def __getitem__(self, idx):
    # The PiT's projection layer expects input of shape (in_features),
    # so for each pixel, we need a tensor of shape (1).
    # We reshape the 784 pixels to (784, 1).
    return self.pixels[idx].unsqueeze(-1), self.labels[idx]

--- 3. Pixel Transformer (PiT) Model Architecture ---

class PixelTransformer(nn.Module): """ Pixel Transformer (PiT) model. Treats each pixel as a token and uses a Transformer Encoder for classification. """ def init(self, seq_len, num_classes, embed_dim, num_layers, num_heads, mlp_dim, dropout): super().init()

    # 1. Pixel Projection: Each pixel (a single value) is projected to embed_dim.
    # This is the core "pixels-as-tokens" step.
    self.pixel_projection = nn.Linear(1, embed_dim)

    # 2. CLS Token: A learnable parameter that is prepended to the sequence of
    # pixel embeddings. Its output state is used for classification.
    self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

    # 3. Position Embedding: Learnable embeddings to encode spatial information.
    # Size is (seq_len + 1) to account for the CLS token.
    # This removes the inductive bias of fixed positional encodings.
    self.position_embedding = nn.Parameter(torch.randn(1, seq_len + 1, embed_dim))

    self.dropout = nn.Dropout(dropout)

    # 4. Transformer Encoder: The main workhorse of the model.
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=embed_dim,
        nhead=num_heads,
        dim_feedforward=mlp_dim,
        dropout=dropout,
        activation="gelu",
        batch_first=True  # Important for (batch, seq, feature) input format
    )
    self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    # 5. Classification Head: A simple MLP head on top of the CLS token's output.
    self.mlp_head = nn.Sequential(
        nn.LayerNorm(embed_dim),
        nn.Linear(embed_dim, num_classes)
    )

def forward(self, x):
    # Input x shape: (batch_size, seq_len, 1) -> (B, 784, 1)

    # Project pixels to embedding dimension
    x = self.pixel_projection(x)  # (B, 784, 1) -> (B, 784, embed_dim)

    # Prepend CLS token
    cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)  # (B, 1, embed_dim)
    x = torch.cat((cls_tokens, x), dim=1)  # (B, 785, embed_dim)

    # Add position embedding
    x = x + self.position_embedding # (B, 785, embed_dim)
    x = self.dropout(x)

    # Pass through Transformer Encoder
    x = self.transformer_encoder(x) # (B, 785, embed_dim)

    # Extract the CLS token's output (at position 0)
    cls_output = x[:, 0] # (B, embed_dim)

    # Pass through MLP head to get logits
    logits = self.mlp_head(cls_output) # (B, num_classes)

    return logits

--- 4. Training and Evaluation Functions ---

def train_one_epoch(model, dataloader, criterion, optimizer, device): model.train() total_loss = 0 progress_bar = tqdm(dataloader, desc="Training", leave=False) for pixels, labels in progress_bar: pixels, labels = pixels.to(device), labels.to(device)

    # Forward pass
    logits = model(pixels)
    loss = criterion(logits, labels)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
    progress_bar.set_postfix(loss=loss.item())

return total_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device): model.eval() total_loss = 0 correct = 0 total = 0 with torch.no_grad(): progress_bar = tqdm(dataloader, desc="Evaluating", leave=False) for pixels, labels in progress_bar: pixels, labels = pixels.to(device), labels.to(device)

        logits = model(pixels)
        loss = criterion(logits, labels)

        total_loss += loss.item()
        _, predicted = torch.max(logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        progress_bar.set_postfix(acc=100. * correct / total)

avg_loss = total_loss / len(dataloader)
accuracy = 100. * correct / total
return avg_loss, accuracy

--- 5. Main Execution Block ---

if name == "main": device = CONFIG["device"]

# Load full training data and split into train/validation sets
# This helps monitor overfitting, as mnist_train_small is quite small.
full_train_dataset = MNIST_CSV_Dataset(CONFIG["train_file"])
train_indices, val_indices = train_test_split(
    range(len(full_train_dataset)),
    test_size=0.1,  # 10% for validation
    random_state=42
)
train_dataset = torch.utils.data.Subset(full_train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(full_train_dataset, val_indices)
test_dataset = MNIST_CSV_Dataset(CONFIG["test_file"])

train_loader = DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False)

print(f"\nData loaded.")
print(f"  Training samples:   {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Test samples:       {len(test_dataset)}\n")

# Initialize model, loss function, and optimizer
model = PixelTransformer(
    seq_len=CONFIG["sequence_length"],
    num_classes=CONFIG["num_classes"],
    embed_dim=CONFIG["embed_dim"],
    num_layers=CONFIG["num_layers"],
    num_heads=CONFIG["num_heads"],
    mlp_dim=CONFIG["mlp_dim"],
    dropout=CONFIG["dropout"]
).to(device)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model initialized on {device}.")
print(f"Total trainable parameters: {total_params:,}\n")

criterion = nn.CrossEntropyLoss()
# AdamW is often preferred for Transformers
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"])

# Training loop
best_val_acc = 0
print("--- Starting Training ---")
for epoch in range(CONFIG["epochs"]):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    print(
        f"Epoch {epoch+1:02}/{CONFIG['epochs']} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"Val Acc: {val_acc:.2f}%"
    )

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        print(f"  -> New best validation accuracy! Saving model state.")
        torch.save(model.state_dict(), "PiT_MNIST_best.pth")

print("--- Training Finished ---\n")

# Final evaluation on the test set using the best model
print("--- Evaluating on Test Set ---")
model.load_state_dict(torch.load("PiT_MNIST_best.pth"))
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"Final Test Loss: {test_loss:.4f}")
print(f"Final Test Accuracy: {test_acc:.2f}%")
print("----------------------------\n")


[The PiT_MNIST_V1.0.ipynb script ran out of memory in CPUR, but was able to run and train fast in A100 GPU mode]

--- Configuration --- train_file: /content/sample_data/mnist_train_small.csv test_file: /content/sample_data/mnist_test.csv image_size: 28 num_classes: 10 embed_dim: 128 num_layers: 6 num_heads: 8 mlp_dim: 512 dropout: 0.1 batch_size: 128 epochs: 25 learning_rate: 0.0001 device: cuda sequence_length: 784

Data loaded. Training samples: 17999 Validation samples: 2000 Test samples: 9999

Model initialized on cuda. Total trainable parameters: 1,292,042

--- Starting Training --- Epoch 01/25 | Train Loss: 2.2063 | Val Loss: 2.0610 | Val Acc: 22.75% -> New best validation accuracy! Saving model state. Epoch 02/25 | Train Loss: 1.9907 | Val Loss: 1.7945 | Val Acc: 32.00% -> New best validation accuracy! Saving model state. Epoch 03/25 | Train Loss: 1.5767 | Val Loss: 1.1938 | Val Acc: 58.35% -> New best validation accuracy! Saving model state. Epoch 04/25 | Train Loss: 1.0441 | Val Loss: 0.7131 | Val Acc: 77.10% -> New best validation accuracy! Saving model state. Epoch 05/25 | Train Loss: 0.7299 | Val Loss: 0.5490 | Val Acc: 82.95% -> New best validation accuracy! Saving model state. Epoch 06/25 | Train Loss: 0.5935 | Val Loss: 0.4821 | Val Acc: 84.60% -> New best validation accuracy! Saving model state. Epoch 07/25 | Train Loss: 0.5311 | Val Loss: 0.4021 | Val Acc: 86.95% -> New best validation accuracy! Saving model state. Epoch 08/25 | Train Loss: 0.4682 | Val Loss: 0.3680 | Val Acc: 88.05% -> New best validation accuracy! Saving model state. Epoch 09/25 | Train Loss: 0.4264 | Val Loss: 0.3446 | Val Acc: 89.20% -> New best validation accuracy! Saving model state. Epoch 10/25 | Train Loss: 0.4038 | Val Loss: 0.3163 | Val Acc: 89.95% -> New best validation accuracy! Saving model state. Epoch 11/25 | Train Loss: 0.3641 | Val Loss: 0.2941 | Val Acc: 90.80% -> New best validation accuracy! Saving model state. Epoch 12/25 | Train Loss: 0.3447 | Val Loss: 0.2759 | Val Acc: 91.45% -> New best validation accuracy! Saving model state. Epoch 13/25 | Train Loss: 0.3181 | Val Loss: 0.2603 | Val Acc: 92.05% -> New best validation accuracy! Saving model state. Epoch 14/25 | Train Loss: 0.3023 | Val Loss: 0.2695 | Val Acc: 91.90% Epoch 15/25 | Train Loss: 0.2800 | Val Loss: 0.2441 | Val Acc: 92.30% -> New best validation accuracy! Saving model state. Epoch 16/25 | Train Loss: 0.2677 | Val Loss: 0.2377 | Val Acc: 92.65% -> New best validation accuracy! Saving model state. Epoch 17/25 | Train Loss: 0.2535 | Val Loss: 0.2143 | Val Acc: 93.80% -> New best validation accuracy! Saving model state. Epoch 18/25 | Train Loss: 0.2395 | Val Loss: 0.2059 | Val Acc: 94.05% -> New best validation accuracy! Saving model state. Epoch 19/25 | Train Loss: 0.2276 | Val Loss: 0.2126 | Val Acc: 93.75% Epoch 20/25 | Train Loss: 0.2189 | Val Loss: 0.1907 | Val Acc: 94.40% -> New best validation accuracy! Saving model state. Epoch 21/25 | Train Loss: 0.2113 | Val Loss: 0.1892 | Val Acc: 94.35% Epoch 22/25 | Train Loss: 0.2004 | Val Loss: 0.1775 | Val Acc: 94.50% -> New best validation accuracy! Saving model state. Epoch 23/25 | Train Loss: 0.1927 | Val Loss: 0.1912 | Val Acc: 94.15% Epoch 24/25 | Train Loss: 0.1836 | Val Loss: 0.1746 | Val Acc: 94.75% -> New best validation accuracy! Saving model state. Epoch 25/25 | Train Loss: 0.1804 | Val Loss: 0.1642 | Val Acc: 94.75% --- Training Finished ---

--- Evaluating on Test Set --- Final Test Loss: 0.1662 Final Test Accuracy: 95.01%


license: apache-2.0