This is a minimal, single-file reproduction of the grokking / delayed generalization phenomenon described in:
- Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, Vedant Misra. Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets. arXiv:2201.02177 (2022).
The core observation in this setting is that a model can reach near-perfect training accuracy quickly while validation accuracy remains near chance for a long time, and then—after many additional optimization steps—validation accuracy transitions sharply to near-perfect generalization.
Task (modular addition)
For a prime (or any integer) modulus p, define the dataset of all ordered pairs:
- inputs:
(a, b)witha, b ∈ {0, …, p-1} - label:
(a + b) mod p
Tokenization
Each example is represented as a short token sequence:
[a, "+", b, "="]
with:
- numbers: token IDs
0 … p-1 "+": token IDp"=": token IDp+1
The model is trained to predict the label (a class in 0 … p-1) from the sequence.
Train/validation split
A random split is taken over the full set of p² pairs:
train_fraccontrols what fraction goes into training- the remainder is used for validation
Model
A small causal Transformer (token embeddings + positional embeddings + nn.TransformerEncoder with a causal mask) reads the 4-token sequence and predicts the result using the hidden state at the final position.
Optimization
Training uses AdamW with explicit weight decay.
- PyTorch
wandboptuna(optional, only if you use--optuna)
python train_mod_add.pyA more explicit run (defaults shown):
python train_mod_add.py --p 97 --train_frac 0.5 --steps 100000 --eval_every 250 --batch_size 512 --lr 1e-3 --weight_decay 1.0Force CPU:
python train_mod_add.py --device cpuThe script prints metrics periodically:
train/accshould typically rise toward 1.0 relatively early.val/accwill often hover near chance (1/p) for a long time.- Later in training,
val/acccan jump rapidly from ~chance to near 1.0.
For p=97, chance accuracy is:
chance_acc = 1/97 ≈ 0.0103
The printout includes chance_acc to sanity-check that you are seeing the right regime.
If you log to W&B, the qualitative pattern to watch is:
train/acc: rises early and stays highval/acc: flat near chance, then a sharp transition upwardtrain/lossvsval/loss: can show long periods of decoupled behavior
If you do not see delayed generalization, try these levers:
-
Reduce training set fraction
- smaller
--train_fracgenerally makes generalization harder and can make the delay more pronounced - try
--train_frac 0.3or0.2
- smaller
-
Tune learning rate and warmup
- if unstable: lower
--lrand/or increase--lr_warmup_steps - if nothing happens: sometimes slightly higher
--lrhelps escape plateaus
- if unstable: lower
-
Increase Dropout
- if unstable val/acc try >=0.1
Example run that may make the delayed transition easier to see:
python train_mod_add.py --p 97 --train_frac 0.1 --steps 300000 --eval_every 250 --batch_size 512 --lr 5e-4 --lr_warmup_steps 100 --weight_decay 1.0 --dropout 0.0--p: modulus (default:97)--train_frac: fraction ofp²pairs used for training (default:0.5)--seed: controls the random train/val split and initialization (default:1337)--steps: number of optimizer steps (default:100000)--eval_every: evaluation frequency in steps (default:250)--batch_size: training batch size cap (default:512)
--d_model: embedding/hidden width (default:128)--nhead: attention heads (default:4)
(d_modelmust be divisible bynhead)--d_ff: feedforward width (default:512)--num_layers: Transformer encoder layers (default:2)--dropout: dropout probability (default:0.1)
--lr: AdamW learning rate (default:1e-3)--lr_warmup_steps: linear LR warmup steps (default:10)--cooldown_frac: end-of-training cosine LR cooldown fraction (default:0.0, i.e. constant LR)--weight_decay: AdamW weight decay (default:1.0)
--device:cuda,cpu, ormps(default auto-detect)--no_compile: disablestorch.compile(possibly required for your setup)--high_precision: disables autocast and TF32; runs in float32--profile_steps: prints a short profiler report for a few steps
--no_wandb: disable logging--wandb_project,--wandb_group,--wandb_name: W&B metadata
The script includes an Optuna driver intended to find configurations that reach a target validation accuracy in as few steps as possible.
Run:
python train_mod_add.py --optuna --optuna_n_trials 100 --optuna_steps 2000 --optuna_eval_every 25 --optuna_target_val_acc 1.0Notes:
optunais imported only when--optunais set.- The objective uses the number of steps required to hit
--optuna_target_val_acc. - Trials can be pruned early using heuristic pruning rules; disable the median-based rule with
--optuna_prune_median_offif it is not helping.
SimpleCausalTransformer:
- learned token embedding + learned positional embedding
nn.TransformerEncoderwith:- GELU activation
- pre-norm (
norm_first=True) - a fixed upper-triangular causal mask
- output head:
- takes the representation at the final position
- maps to
plogits via a linear layer
evaluate() computes:
- average cross-entropy loss
- accuracy over an entire DataLoader (train set or validation set).
@article{power2022grokking,
title = {Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets},
author = {Power, Alethea and Burda, Yuri and Edwards, Harri and Babuschkin, Igor and Misra, Vedant},
journal = {arXiv preprint arXiv:2201.02177},
year = {2022}
}