Grokking in AI
A mechanistic interpretability perspective into how neural networks learn
Disclaimer: I do not take credit for introducing this concept/research. This article is a simplified version of the tremendous work done by Neel Nanda et al, which I really enjoyed and borrows heavily from it.
Imagine learning to ride a bike. At first, you might memorize the steps: how to pedal, how to balance, how to steer. You might struggle and fall, relying heavily on what you’ve been told. But then, after many attempts, there’s a moment where it all clicks — your body suddenly understands the balance and coordination needed, and you can ride smoothly without thinking about each action. This is akin to grokking in neural networks: a sudden shift from rote learning(memorization) to true understanding (generalization). This is fundamental to how AI models learn, and it has big implications for AI safety.
What is Grokking?
Back in 2021, OpenAI researchers discovered something peculiar when they were training small AI models on simple tasks like modular addition (adding two numbers and taking the remainder when divided by another number, like 7 + 5 = 12, which is 3 when divided by 9). Initially, the model just memorized the training data. But after a long time, it suddenly learned to solve the problem for any pair of numbers, even ones it had never seen before.
This graph shows the “grokking” phenomenon. Initially, the model’s performance on the training data (blue line) improves rapidly as it memorizes, while its performance on unseen test data (orange line) remains poor. Then, suddenly, its test performance dramatically improves — it has “grokked” how to do modular addition!
This phenomenon is like watching someone learn multiplication tables. At first, they just memorize that 6 × 7 = 42 and 8 × 9 = 72. Then, one day, they suddenly understand the pattern and can multiply any two numbers. In AI, this shift from memorization to generalization can be sudden and without any prior signal.
Why Does It Matter?
Grokking could be a key to understanding how AI systems learn and develop new capabilities. Let us explore why understanding grokking is a major focus in AI safety research:
- Emergent Capabilities: If AI systems can suddenly “grok” complex tasks, they might also suddenly develop dangerous capabilities like deception or situational awareness. This makes AI safety harder because we might not see warning signs.
- Understanding Generalization: To ensure AI systems are safe, we need to understand how they learn to generalize. Do they learn aligned, safe behaviors, or deceptive, misaligned ones?
- The Power of Interpretability: By “reverse engineering” AI systems, we can understand their inner workings. In this case, the AI learned to do modular addition using trigonometry and Fourier transforms — an unusual but working algorithm for performing modular addition.
The Science Behind Grokking
Grokking is deeply related to “phase changes” in model training. The graph above shows a phase change in a model trained on infinite data — note the sudden drop in loss. A phase change is when a model’s performance on a task suddenly improves. This happens when different components of the model (neural network) start working together in a coordinated way.
When the model is trained on finite data with high regularization, it shows grokking. This is shown in the graph below where there are multiple phase changes. There is a phase change when the training loss suddenly dips around the 10k epoch , and a phase change when both training and test loss suddenly dip close to the 200k epoch to signify generalization.
In the modular addition task mentioned previously, the model’s learning mechanism had 2 main competing algorithms:
Memorization: Just remember every answer, like a lookup table.
Generalization: Learn to do modular addition properly.
Surprisingly, the generalization algorithm involves converting numbers into waves (using a technique called Fourier transforms), then using trigonometry to add the waves. This is a valid way to do modular addition, but not one a human would typically think of!
This graph shows how much the model “cares” about different wave frequencies. The sparsity (many zero or near-zero values) shows that it’s focusing on just a few key frequencies. This is evidence that the model is indeed using waves to represent numbers!
Here’s what happened during the model training:
- Memorization (Early): The model memorizes answers. It’s easy but inefficient.
- Interpolation (Middle): The model starts finding patterns to memorize more efficiently. This is the same process that leads to generalization.
- Generalization (Late): The generalizing algorithm suddenly “clicks,” and the model’s performance jumps. This is the grokking moment!
- Cleanup: The model starts to discard the memorization algorithm since it’s now unnecessary.
- Stability: The model has now transitioned to using the generalizing circuit.
Implications for AI Safety
- Phase Changes Everywhere: If most AI capabilities emerge through phase changes, we should expect many “grokking-like” events in large language models. This means AI systems might suddenly become deceptive or situationally aware without warning.
- Training Dynamics Matter: An AI that understands it’s being trained might try to influence its own training to stay misaligned. To prevent this, we need to shape the training dynamics so it never becomes deceptive in the first place.
- Interpretability is Key: We need to understand what’s happening inside AI systems to ensure they’re aligned with human values. This involves reverse-engineering learning algorithms/dynamics of models.
Limitations and Future Work
The learning dynamics behind Grokking is fascinating, but it has limitations and gaps to be explored. The models studied are tiny compared to models like GPT-3, and it’s trained on a very simple task. Real-world AI models have many more parameters and learn much more complex tasks.
Future work could investigate whether large language models show phase changes, study how different training techniques affect grokking, and see if the trigonometry-based algorithm appears in language models when they do math. Most excitingly, researchers could train models on “interpretability-based metrics” to see if we can discourage undesirable behaviors like deception.
Conclusions
Grokking is a significant milestone in the training of neural networks, where models transition from memorization to generalization through phase changes. This phenomenon, explored through the lens of mechanistic interpretability, sheds light on the inner workings of neural networks and their capabilities. As we build more powerful AI systems, understanding these learning dynamics will be the key to ensuring they remain safe and aligned with human values.
References
I will be posting regularly about AI-safety literature I read and hopefully projects I do on this platform. My goals with this blog is to develop a better understanding of what I read and build by having to distill it to an external audience. I also want to create a useful corpus of information for other students or learners. Follow if interested and hit me up if any clarifications are required or errors are present.
Cheers!