new

Get trending papers in your email inbox!

Subscribe

byAK and the research community

Mar 14

Grokking at the Edge of Numerical Stability

Grokking, the sudden generalization that occurs after prolonged overfitting, is a surprising phenomenon challenging our understanding of deep learning. Although significant progress has been made in understanding grokking, the reasons behind the delayed generalization and its dependence on regularization remain unclear. In this work, we argue that without regularization, grokking tasks push models to the edge of numerical stability, introducing floating point errors in the Softmax function, which we refer to as Softmax Collapse (SC). We demonstrate that SC prevents grokking and that mitigating SC enables grokking without regularization. Investigating the root cause of SC, we find that beyond the point of overfitting, the gradients strongly align with what we call the na\"ive loss minimization (NLM) direction. This component of the gradient does not alter the model's predictions but decreases the loss by scaling the logits, typically by scaling the weights along their current direction. We show that this scaling of the logits explains the delay in generalization characteristic of grokking and eventually leads to SC, halting further learning. To validate our hypotheses, we introduce two key contributions that address the challenges in grokking tasks: StableMax, a new activation function that prevents SC and enables grokking without regularization, and perpGrad, a training algorithm that promotes quick generalization in grokking tasks by preventing NLM altogether. These contributions provide new insights into grokking, elucidating its delayed generalization, reliance on regularization, and the effectiveness of existing grokking-inducing methods. Code for this paper is available at https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability.

Deep Networks Always Grok and Here is Why

Grokking, or delayed generalization, is a phenomenon where generalization in a deep neural network (DNN) occurs long after achieving near zero training error. Previous studies have reported the occurrence of grokking in specific controlled settings, such as DNNs initialized with large-norm parameters or transformers trained on algorithmic datasets. We demonstrate that grokking is actually much more widespread and materializes in a wide range of practical settings, such as training of a convolutional neural network (CNN) on CIFAR10 or a Resnet on Imagenette. We introduce the new concept of delayed robustness, whereby a DNN groks adversarial examples and becomes robust, long after interpolation and/or generalization. We develop an analytical explanation for the emergence of both delayed generalization and delayed robustness based on a new measure of the local complexity of a DNN's input-output mapping. Our local complexity measures the density of the so-called 'linear regions' (aka, spline partition regions) that tile the DNN input space, and serves as a utile progress measure for training. We provide the first evidence that for classification problems, the linear regions undergo a phase transition during training whereafter they migrate away from the training samples (making the DNN mapping smoother there) and towards the decision boundary (making the DNN mapping less smooth there). Grokking occurs post phase transition as a robust partition of the input space emerges thanks to the linearization of the DNN mapping around the training points. Website: https://bit.ly/grok-adversarial

Grokking as the Transition from Lazy to Rich Training Dynamics

We propose that the grokking phenomenon, where the train loss of a neural network decreases much earlier than its test loss, can arise due to a neural network transitioning from lazy training dynamics to a rich, feature learning regime. To illustrate this mechanism, we study the simple setting of vanilla gradient descent on a polynomial regression problem with a two layer neural network which exhibits grokking without regularization in a way that cannot be explained by existing theories. We identify sufficient statistics for the test loss of such a network, and tracking these over training reveals that grokking arises in this setting when the network first attempts to fit a kernel regression solution with its initial features, followed by late-time feature learning where a generalizing solution is identified after train loss is already low. We provide an asymptotic theoretical description of the grokking dynamics in this model using dynamical mean field theory (DMFT) for high dimensional data. We find that the key determinants of grokking are the rate of feature learning -- which can be controlled precisely by parameters that scale the network output -- and the alignment of the initial features with the target function y(x). We argue this delayed generalization arises when (1) the top eigenvectors of the initial neural tangent kernel and the task labels y(x) are misaligned, but (2) the dataset size is large enough so that it is possible for the network to generalize eventually, but not so large that train loss perfectly tracks test loss at all epochs, and (3) the network begins training in the lazy regime so does not learn features immediately. We conclude with evidence that this transition from lazy (linear model) to rich training (feature learning) can control grokking in more general settings, like on MNIST, one-layer Transformers, and student-teacher networks.

Grokking Tickets: Lottery Tickets Accelerate Grokking

Grokking is one of the most surprising puzzles in neural network generalization: a network first reaches a memorization solution with perfect training accuracy and poor generalization, but with further training, it reaches a perfectly generalized solution. We aim to analyze the mechanism of grokking from the lottery ticket hypothesis, identifying the process to find the lottery tickets (good sparse subnetworks) as the key to describing the transitional phase between memorization and generalization. We refer to these subnetworks as ''Grokking tickets'', which is identified via magnitude pruning after perfect generalization. First, using ''Grokking tickets'', we show that the lottery tickets drastically accelerate grokking compared to the dense networks on various configurations (MLP and Transformer, and an arithmetic and image classification tasks). Additionally, to verify that ''Grokking ticket'' are a more critical factor than weight norms, we compared the ''good'' subnetworks with a dense network having the same L1 and L2 norms. Results show that the subnetworks generalize faster than the controlled dense model. In further investigations, we discovered that at an appropriate pruning rate, grokking can be achieved even without weight decay. We also show that speedup does not happen when using tickets identified at the memorization solution or transition between memorization and generalization or when pruning networks at the initialization (Random pruning, Grasp, SNIP, and Synflow). The results indicate that the weight norm of network parameters is not enough to explain the process of grokking, but the importance of finding good subnetworks to describe the transition from memorization to generalization. The implementation code can be accessed via this link: https://github.com/gouki510/Grokking-Tickets.