Coffee Time Papers: Grokking

Generalization Beyond Overfitting on Small Algorithmic Datasets

Dagang Wei
4 min readJul 5, 2024

This blog post is part of the series Coffee Time Papers.

Paper

Overview

This paper proposes a new way to study how neural networks generalize beyond simply memorizing small, algorithmically generated datasets. The paper focuses on binary operation tables (like addition or multiplication tables) where some entries are left blank. Training a neural network to fill in these blanks is like solving a puzzle.

The key findings of the paper are:

  • Generalization Beyond Overfitting: Neural networks can learn the underlying pattern in the data and generalize to the blank entries, even after initially overfitting (memorizing the training data).
  • Grokking: The authors coin the term “grokking” to describe a phenomenon where a neural network’s ability to generalize suddenly improves dramatically after a period of overfitting.
  • Data Efficiency Curves: The paper presents curves showing how the amount of training data affects the network’s ability to generalize for different binary operations.
  • Optimization and Generalization: The amount of training time needed for generalization increases rapidly as the amount of training data decreases.
  • Impact of Optimization Techniques: Techniques like weight decay (a way to prevent overfitting) significantly improve the network’s ability to generalize.
  • Visualization of Embeddings: Visualizing how the network represents the symbols in the binary operations can reveal the underlying mathematical structure.

Overall, this paper suggests that studying small algorithmic datasets can provide valuable insights into how neural networks learn and generalize, potentially leading to better training methods and architectures for more complex tasks.

Q & A

Q: What is the phenomenon of “grokking” as described in the paper?

A: Grokking is a term coined by the authors to describe a phenomenon observed in neural networks during training on small, algorithmically generated datasets. It refers to a sudden and significant improvement in the network’s ability to generalize to new, unseen data, even after a period of overfitting (memorizing the training data). Essentially, the network seems to abruptly “understand” the underlying pattern in the data, leading to a dramatic increase in performance.

Q: What kind of datasets are used to study grokking in this paper?

A: The paper focuses on small, algorithmically generated datasets based on binary operation tables. These tables are similar to addition or multiplication tables, where the network is trained to predict the result of an operation between two elements. Some entries in the tables are left blank, and the network’s task is to fill in these blanks, effectively learning the underlying operation.

Q: How does the amount of training data affect the network’s ability to generalize?

A: The paper shows that the amount of training data significantly impacts the network’s ability to generalize. As the amount of training data decreases, the time it takes for the network to achieve good generalization (grokking) increases exponentially. This suggests a trade-off between the amount of data and the computational resources required for training.

Q: What is the impact of weight decay on generalization?

A: Weight decay, a regularization technique that prevents overfitting, is found to have a substantial positive effect on the network’s ability to generalize. It more than halves the amount of training data needed for the network to achieve good performance compared to other interventions. This suggests that weight decay helps the network learn the underlying patterns in the data rather than just memorizing the training examples.

Q: How does weight decay work?

Weight decay works by adding a penalty term to the loss function that the network is trying to minimize during training. This penalty term is proportional to the square of the weights in the network. For example, this term is calculated as:

penalty = (weight_decay_factor / 2) * sum(weights^2)

where weight_decay_factor is a hyperparameter that controls the strength of the penalty. During training, the network adjusts its weights to minimize the total loss, which now includes the weight decay penalty. This penalty encourages the network to keep its weights small.

Smaller weights generally lead to simpler models that are less prone to overfitting. This is because large weights can cause the network to focus too much on the specific details of the training data, rather than learning the underlying patterns. By keeping the weights small, weight decay helps the network generalize better to new, unseen data.

Q: How does the visualization of embeddings help understand the network’s learning process?

A: Visualizing the network’s internal representations (embeddings) of the symbols in the binary operation tables can reveal the underlying mathematical structure of the operations. For example, in modular arithmetic problems, the embeddings tend to form circular or cylindrical shapes, reflecting the cyclical nature of modular addition. This suggests that the network learns to represent the mathematical objects in a way that captures their inherent properties.

--

--