← Back to blog

Seven Palliatives to Catastrophic Interference

Introduction

TLDR: Continual learning is the ability to learn new stuff without overwriting old stuff. Modern deep nets cannot continually learn, because they are distributed information processing systems optimized through backpropagation.
Frontier AIs — relevantly, LLMs — still suffer at least five major bottlenecks separating them from useful, general purpose intelligence (of which the brain is the only known example). These are:
  1. Continual Learning
  2. Sensorimotor Grounding
  3. Sample Efficiency
  4. Alignment
  5. Resource Budget
These are not exhaustive, nor are they independent. Moreover, I don't intend to claim that solving all of these problems will be necessary in the path toward generally capable AI. The metaphor between deep nets and the brain is at best tenuous, so we shouldn't assume what worked for humans will be needed for silicon-based life.
In this piece, I'd like to focus on the first cited bottleneck: continual learning. In particular, I want to give an overview of the problem and the various palliatives we've developed for it over the past 30 years.
Mars Perserverance Rover
Mars Perserverance Rover
####

The Continual Learning Problem

In plain words, continual learning is the ability to rapidly accommodate new information without (severely) overwriting old information. Its complement is catastrophic interference: the effect of newly learned information catastrophically interfering with old information.
Your brain can continually learn effortlessly: learning to speak French does not catastrophically interfere with your knowledge of calculus.1
The continual learning problem was first formalized by Michael McCloskey and Neal Cohen (1989), two cognitive neuroscientists from Hopkins and Illinois. They characterized this problem as an inherent feature of distributed information systems.
In a connectionist model with distributed representations […] each connection weight is involved in responding to many different inputs. Thus, adjustment of weights to encode the desired response to a new input pattern will necessarily alter the network's response to other inputs as well (147).
When you look at deep nets as these distributed systems, the presence of catastrophic interference feels almost obvious. But the brain is a big and noisy distributed system — so why does it not suffer the same fate? Perhaps it has to do with the learning algorithm of our brain.
Every modern deep net is optimized through backpropagation (AKA steepest descent)2. Systems trained with backpropagation optimize for a single objective: to minimize loss over a training set. While it's common in ML research to add many terms to this loss — effectively optimizing over multiple objectives — the number of terms is finite, and must be readily calculable from a single (x,y)(x,y) pair. Backpropagation will have no consideration for anything not in its loss function, and therefore gives the network no reason to prioritize desirable properties such as generality, which are not readily encodable as loss terms.
This is the same reason that all deep nets have an implicit i.i.d. constraint. That is, deep nets cannot learn sequentially, building on previously mastered material the way humans do, but must instead randomly sample from the full training distribution. This is fundamentally different from how we learn. Human knowledge is built on a dependency graph: we need addition to ground multiplication, multiplication to ground exponentiation, algebra to ground calculus. An LLM, on the other hand, has no such structure. Within a single training batch, the model might see a Shakespeare sonnet, a partial differential equation, a Reddit thread about cats, and a snippet of Java — with no guarantee that earlier batches built up any prerequisite understanding. While this is not necessarily harmful during pretraining, it limits the ability of deployed models to benefit from a sequential, test-time stream of new data.
Sequential learning: optimizing for one class (region A) then another (region B) can miss the joint optimum (region C)
Sequential learning: optimizing for one class (region A) then another (region B) can miss the joint optimum (region C)
IID learning interleaves tasks (zigzag) and can reach regions where the model performs well on both classes
IID learning interleaves tasks (zigzag) and can reach regions where the model performs well on both classes

Why care?

There are many good reasons to desire continual learning.
  1. In the context of LLMs, you'll find they are most useful when you direct them very accurately toward the right problem (prompt engineering). But distilling our preferences, style, and constraints into text is burdensome and difficult to maintain. Perhaps more importantly, text is a lossy medium to transmit feedback, hence why lived experience (not books) is the ultimate teacher for humans. Enabling an LLM to continuously adapt to our feedback, much like a coworker or intern would, transforms these tools from approximate aids to systems we can trust.
  2. For other tool-based AI applications, consider deploying an intelligent Mars rover. Imagine NASA sends a rover to Mars with an autonomous driving system that performs perfectly on simulated terrain. As the rover drives around, it encounters something entirely out-of-distribution (perhaps a Martian threatening it?). A human might reason through the situation, simulate possibilities, and adapt. The rover has no such ability. We need to teach this rover (and any other task-focused AI system) to continuously learn to avoid being stuck again.
Clearly, continual learning is a desirable problem to solve. I can't claim that it's actually a barrier toward general learning systems, but its consequences are so vast that even incremental progress in this domain seems worthwhile.

Summary

In summary, modern deep nets suffer from catastrophic interference, unable to achieve good task performance on a new task without severely degrading performance on a previously successful task. This flaw is inherent to any distributed information processing system, especially one optimized through a steepest-descent learning algorithm, and leads to a constraint that all data must be presented i.i.d.
####

Notation

The rest of this blog will focus on palliatives to reduce catastrophic forgetting in neural networks. Before we dive into these, it's helpful to formalize our notation a bit so we don't waste repeated time on context.
We'll consider a model with parameters θ\theta trained sequentially on tasks TA,TB,,TZT_A, T_B, \ldots, T_Z, each with dataset Dk={(xi,yi)}\mathcal{D}_k = \{(x_i, y_i)\}. I write θk\theta_k^* for the SGD-optimized parameters after training on task TkT_k. Each task has a loss Lk(θ)\mathcal{L}_k(\theta) computed over data Dk\mathcal{D}_k. Regularization methods will often modify this loss by adding a term measuring deviation from the prior solution. Methods that don't regularize instead modify the dataset, the architecture, or the optimization procedure itself.
We'll discuss these methods in the context of a multi-class classification task, though there are ways to modify all these methods for regression, RL, self-supervised learning, etc.

Cheat sheet

Refer back as needed
  • θ\theta — model parameters at any given time
  • θk\theta_k^* — parameters after training on task TkT_k
  • fθ()f_\theta(\cdots) — model parameterized by θ\theta
  • TkT_k — the kk-th task
  • Dk\mathcal{D}_k — dataset for task TkT_k
  • Lk(θ)\mathcal{L}_k(\theta) — loss on task TkT_k
  • R(θ,θA)R(\theta, \theta_A^*) — regularization term measuring deviation from prior solution
  • λ\lambda — stability-plasticity tradeoff coefficient
  • NkN_k — number of samples in dataset Dk\mathcal{D}_k
  • DD — total number of parameters
  • CC — number of classes
####

Palliatives for Catastrophic Interference

In this section, we'll fly over the past three decades of research to see what have emerged as the best solutions to catastrophic interference. The palliatives below are applicable to any modern deep net. In order, we'll discuss:
  1. Self-distillation (regularization)
  2. Replay buffers (data-centric)
  3. Extending context length (infrastructure)
  4. Node sharpening (parameter isolation)
  5. Progressive neural nets (parameter expansive)
  6. Orthogonal gradient descent (regularization)
  7. Elastic weight consolidation (regularization)
Before diving into these, I want to provide a high-level overview of each to refer back to.
  • Self-distillation restricts the solution space of the newly trained model such that it overlaps with the prior version of itself.
  • Replay and rehearsal continually retrain on old information, by cycling in a small, random subset of prior (and, ideally, maximally informative) training data into each new mini-batch.
  • Extending the context length is a palliative pursued mainly for LLMs, which provides the model with more memory space to persist across sessions and learnings.
  • Node sharpening focuses on turning a highly distributed network into a sparse one, such that new knowledge is slotted into a distinct, non-colliding space rather than overwriting a distributed representation for many inputs. MoE is a modern application of these techniques to LLMs.
  • Progressive neural nets add new capacity for each task rather than repurposing existing parameters, sidestepping interference entirely at the cost of parameter growth.
  • Orthogonal gradient descent packs more into a single parameter, by optimizing subsequent tasks within a subspace of the loss landscape orthogonal to previous task optimizations.
  • Elastic weight consolidation penalizes changes to parameters that were important for earlier tasks, using the Fisher information matrix as a proxy for importance.
Multiple techniques are often used together, but by far the most durable palliatives in industry and research are replay buffers and distillation. Practitioners should reach for these first. The later methods are more exotic, yet fascinating in their own respect.
####
The first two methods — distillation and replay — remain the two simplest and most successful methods available.

Self-Distillation

Self-distillation is a regularization method, keeping predictions of the task TBT_B model close to the predictions of the task TAT_A model. Broadly, regularization in continual learning constrains optimization to remain in regions of parameter space that preserve the model's behavior on previously learned tasks, by penalizing updates in directions that are estimated to be important for past performance. In regularization, the loss might be encoded as
L(θ)=LB(θ)+λR(θ,θA)\mathcal{L}(\theta)= \mathcal{L}_B(\theta)+ \lambda \, R(\theta, \theta_A^*)
We're balancing the loss on the new task with R(θ,θA)R(\theta, \theta_A^*), some measure of functional deviation from the old model parameterized by θA\theta_A^*.
Based on Figure 1 from EWC - Kirkpatrick et al. (2016)
Based on Figure 1 from EWC - Kirkpatrick et al. (2016)
In the figure, black is the optimization direction if we merely optimized for task B performance through gradient descent, green is the desired optimization direction to preserve performance on both tasks, while red is if something has gone horribly wrong. (This picture should remind you of the picture from earlier demonstrating the imperative of I.I.D learning, which is itself a form of regularization).

Algorithm

In self-distillation, R(θ,θA)R(\theta, \theta_A^*) is the divergence between current outputs and previous task outputs from parameters θA\theta_A^*. This divergence is measured through KL divergence, per Hinton et al., which encodes a notion of distance between two distributions (more). Thus,
L(θ)=LB(θ)+λ KL(fθA(x)  fθ(x))\mathcal{L}(\theta) = \mathcal{L}_B(\theta) + \lambda \ \text{KL}\bigl(f_{\theta_A^*}(x)\ || \ f_\theta(x)\bigr)
In Learning without Forgetting, the authors rewrite the KL divergence as the cross entropy H(fθA(x)  fθ(x))H(f_{\theta_A^*}(x)\ - \ f_\theta(x)) minus the entropy H(fθA(x))H(f_{\theta_A^*}(x)). This is mathematically equivalent, though less intuitive — thus the code below uses the former definition. The authors also add an extra L2 weight decay to the loss, which is a standard part of AdamW.
One final detail to highlight is the use of temperature softening (dividing logits by TT) to make the output distribution of both models less peaky (i.e. predicting one class with 0.990.99 confidence). This is gives more information to fθf_\theta about fθAf_{\theta_A^*}'s distribution, minimizing zero gradients or excessive focus on the argmax class.

Code

model_B = swap_head(model_A, task_B_classes) # current model, initialized from A
model_A.eval() # frozen teacher
optimizer = AdamW(model_B.parameters(), lr)

for x, y in batch:
    B_logits = model_B.forward(x)
    with torch.no_grad():
        A_logits = model_A.forward(x) # no gradients needed for teacher

    loss_new = criterion(B_logits, y) # e.g. cross entropy
    loss_distill = kl_div(
        log_softmax(B_logits / T, dim=-1),
        softmax(A_logits / T, dim=-1)
    ) * (T ** 2)

    loss = loss_new + lam * loss_distill
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Limitations

  • Distillation occurs purely in the output space, meaning it has no explicit awareness of which parameters matter most for task A (compared to methods like EWC below), leading to slightly unpredictable retention
  • Often practically inferior to replay when memory of previous task data is allowed (which we'll discuss next!)

Recent extensions

Shenfeld et al. recently extended this work by using a demonstration-conditioned copy of the current model as the teacher — free to in-context learn from these demonstrations — keeping distillation on-policy. On-policy just means we learn from data collected by the policy we're currently improving, which is different from off-policy which learns from data by a different policy (like an older copy of itself, as above).
Their method, referred to as SDFT, uses the same model in two roles: a student conditioned only on the query xx, producing YS^=fθ(x)\widehat{Y_S} = f_\theta(\cdot | x), and a teacher — the same model, but conditioned on both the query and an expert demonstration cc, producing YT^=f(x,c)\widehat{Y_T} = f(\cdot | x, c). The teacher doesn't need a separate frozen copy of a previous model; it is the current model, just given more context. Training then minimizes the reverse KL divergence KL(YS^  YT^)\text{KL}(\widehat{Y_S} \ \| \ \widehat{Y_T}) on trajectories sampled from the student itself. This is different from the classical self-distillation above, where the teacher is a frozen snapshot of the old model fθAf_{\theta_A^*} and therefore an increasingly off-policy target as the student improves. This applies most relevantly to LLMs or other large transformer models, for whom in-context learning emerges as a capability.
Keeping learning on-policy increasingly looks like a necessary condition for continual learning (1) (2). The intuition here is that it's much easier to learn new knowledge grounded in what you already know than it is to learn something completely from scratch.
While SDFT is much less general than the self-distillation described above, it's an exciting new area of research I wanted to surface.
####

Replay Buffers and Rehearsal

Replay buffers are a very simple and effective method to reduce catastrophic interference. They do so by mixing new task data with a curated buffer of past data.
Replay buffers spawned out of work in reinforcement learning, a domain where a good method to balance learning from the experiences of old policies (off-policy learning) and the current policy (on-policy learning) becomes imperative. For our purposes, we'll discuss replay buffers in the context of supervised continual learning. Supervised CL can borrow the RL algorithms as both settings face a non-stationary data stream that can be stabilized by mixing in past data.

Algorithm

  1. Maintain a fixed-capacity buffer B\mathcal{B}
  2. On arrival of new task with data DB\mathcal{D}_B: a). For each training step, sample a batch from DB\mathcal{D}_B and a batch from B\mathcal{B}. Train fθf_\theta on their union: L=E(x,y)DB[(y,fθ(x))]+E(x,y)B[(y,fθ(x))]\mathcal{L} = \mathbb{E}_{(x,y) \sim \mathcal{D}_B}[\ell(y, f_\theta(x))] + \mathbb{E}_{(x,y) \sim \mathcal{B}}[\ell(y, f_\theta(x))]
  3. Buffer update: after (or during) training on DB\mathcal{D}_B, select samples from DB\mathcal{D}_B to add to B\mathcal{B}. If B\mathcal{B} is at capacity, evict existing samples to make room
    • Insertion policy πadd\pi_{\text{add}}: which samples from DB\mathcal{D}_B are stored (Default: all)
    • Eviction policy πevict\pi_{\text{evict}}: which samples in B\mathcal{B} are removed? (Default: uniform random)
Follow-up research explores methods to compress the buffer intelligently, modifying πadd\pi_{\text{add}}, πevict\pi_{\text{evict}}, or both. The main modifications either maximize the difficulty, diversity, or representativeness of B\mathcal{B}.

Practical considerations

For a "bag of tricks" to make replay work well, and fast, I urge you to check out Buzzega et al. (2020). Empirically, even a few samples per class provides meaningful stability, with diminishing returns after 20\sim 20 samples per class (table 1).

Limitations

  • Overhead from buffer in compute, as (many) additional forward and backward passes are needed to progress on new tasks
  • Building replay buffers for LLMs is difficult due to licensing and scale, though training a generator is a common workaround.
####

Extending Context Length

This method is different from the rest I'll discuss because it isn't an algorithmic change or addition. More and more I hear it come up as "the only thing we might need," so I feel obligated to discuss it.
Since GPT-3, LLMs have demonstrated remarkable in-context learning, so there's reason to believe merely increasing the space LLMs have to do this will be sufficient to solve continual learning. Dario Amodei — the CEO of Anthropic — at least seems to think so.
Effectively, the context window of an LLM serves as its memory: a record of its past existence, its learnings, and the desires of the human steering it. The thought is that if we can expand memory to include everything the LLM has failed on and had to learn that it'll be able to mimic the learning process of humans.
The argument for context length is that, ad infinitum, a context window stores everything necessary to continually adapt and engage with new problems, provided that a playground harness is given to experiment. Models will need to systematically attempt and prune solutions, recording their attempts along the whole way to converge toward a refined solution.
To make you skeptical of this, here's one thought experiment I'll call the "extreme amnesiac." Imagine that each morning your memory was completely wiped, and the only artifact you have of your existence, knowledge, experiences and values was a big book that you keep adding to each day called "My Life." Each morning, you reread that book (which you can do very quickly!) to remember everything about yourself, and each night you add to it to capture today's learnings. Intuitively, this distillation from lived experience to text loses many vital details about you.
While it's possible that context length will be sufficient, I doubt it for the three problems below. In any case, labs will continue working on this problem irrespective of the conviction I have for it, so research will give us the answer soon.

Limitations

  • Attention scales quadratically in context length in both memory and FLOPs
  • Context rot demonstrates a nonlinear relationship between tokens in context length and their utility in producing a better response.
  • Text is a lossy representation of true guidance, which for humans includes (at least) sensorimotor grounding, social feedback, and temporal context.
For more on this, I liked this recent LessWrong post.
####

Node Sharpening

Node sharpening squeezes the activations toward a sparse representation, so that newly incorporated learning can slot into an unoccupied (or, less occupied) part of the network. Node sharpening is essentially the intermediate between a look-up table (great continual learning, terrible generalization) and neural nets (great generalization, terrible continual learning).
The goal is to force our network to have a few, distinct nodes at high activations for distinct inputs as opposed to a spread out, largely similar and uniform activation pattern for all inputs. Hence, we call this a parameter isolation technique.
The continual learning and generalizability tradeoff balanced by node sharpening
The continual learning and generalizability tradeoff balanced by node sharpening

Code

Node sharpening is typically implemented as an additional term to the loss, forcing the model to minimize the distance between the pre-sharpened and sharpened (implemented as a low temperature softmax) hidden state and perform learning with the sharpened representation.
class Model:
...
	def forward(x):
    h = hidden_layers(x)
    h_sharp = sharpen(h) # sharpen the intermediate activations
    out = output_layers(h_sharp)  
    return out, h, h_sharp # return hidden to compute sharpness loss
    
def sharpen(x, temperature=5.0):
    return torch.softmax(x * temperature, dim=-1)
    
# Train loop with added loss to sharpen
for x, y in batch:
	pred, h, h_sharp = model.forward(x)
	sharpness_loss = ((h - h_sharp) ** 2).sum()
	loss = criterion(pred, y) + lam * sharpness_loss
	optimizer.zero_grad()
	loss.backward()
	optimizer.step()
	...
Node sharpening is rarely used now, as deep networks naturally become sparse while still generalizing extremely well. Once again, Moore's Law rendered this algorithmic solution inferior to scale.

Limitations

  • Sharpening trades generalization abilities for continual learning abilities
  • Modern deep representations naturally become sparse
####

Progressive Neural Networks

Progressive neural networks solve the problem of catastrophic forgetting by instantiating new copies of the network for each of the KK tasks being solved, while still retaining learned connections to the prior task-specific networks. Essentially, they are an even sparser representation than produced by node sharpening, with individual networks dedicated to tasks. Because PNNs add new parameters per task, we call this a parameter expansive method.
PNNs practically guarantee no forgetting and they achieve higher upper bound performance by benefiting from transfer learning on previous tasks. However, this is done at the cost of parameter count growing quadratically with the number of tasks.
Figure 1 from Rusu et al. (2016)
Figure 1 from Rusu et al. (2016)
In this figure, each column represents a trained network for task 1, 2, and 3 (left to right). Later task networks maintain connections to the previous [0:i1][0:i-1] columns between each hidden state forward pass. The [a] boxes represent adapters, which are just MLPs acting on the concatenated vector of all past column hidden states.
The forward pass for each hidden state hi(k)h_i^{(k)} is written as:
hi(k)=f(Wi(k)hi1(k)+j<kUi(k:j)hi1(j))h_i^{(k)} = f\left( W_i^{(k)} h_{i-1}^{(k)} + \sum_{j < k} U_i^{(k:j)} \, h_{i-1}^{(j)} \right)
Progressive neural nets work, but the quadratic parameter expense is untenable under almost all circumstances, and thus they are often opted against in favor of fine-tuning + regularization.

Limitations

  • Parameter count of PNN grows quadratically with the number of tasks

Orthogonal Gradient Descent

Orthogonal gradient descent (OGD) is an improvement to standard SGD optimization that exploits the overparameterization of deep networks. OGD projects gradient updates from a new task TBT_B to the orthogonal subspace of all previous gradient updates from task TAT_A. This ensures that optimizing for any new task is a non-competing objective to the optimization of previous tasks.
OGD relies on deep neural networks being massively overparameterized – meaning many paths exist through the loss landscape to reach the same minimum. A simple analogy: suppose you want to represent a two dimensional point (x,y)(x,y) with three parameters (a,b,c)(a,b,c). In this setup, as long as a=x,b=ya=x, b=y is maintained, cc is completely free. Overparameterized networks have the same kind of slack, which is exactly what OGD exploits. As an aside, this overparameterization of neural nets actually explains much of the success of deep learning (see Double Descent).

Algorithm

The following describes the OGD algorithm for a multi-class classification task over CC classes, as is formalized in the paper.
In OGD, we maintain a buffer SS of mutually orthogonal gradient vectors accumulated from past tasks. When we later train on a new task, each gradient update is projected to be orthogonal to SS, ensuring it doesn't interfere with what was previously learned.
During training on task TAT_A with dataset DA\mathcal{D}_A, SGD produces gradient vectors θfk(xiA,θ)\nabla_\theta f_k(x_i^A, \theta) for each sample i{0,,NA1}i \in \{0, \ldots, N_A - 1\} and class k{0,,C1}k \in \{0, \ldots, C-1\}. Ideally, every future task TBT_B gradient would be orthogonal to all NA×CN_A \times C of these vectors. However, this is intractable in both compute and memory. Two simplifications make it practical: first, we only compute gradients with respect to the ground-truth class, reducing the set by a factor of CC; second, the authors find that even a small fixed buffer of 200\sim200 out of the NAN_A total samples is empirically sufficient for non-competing updates.
SS is updated at a per-task cadence: we complete training on task A,A, then fill SS with ~200 gradient vectors from that task. Each vector added is orthogonalized against all existing vectors in SS via Gram-Schmidt. Theoretically, SS can hold at most DD vectors (the parameter dimension), at which point it forms an orthogonal basis of the full gradient space. Naturally, some directions will be a lot more informative than others for a given task, which helps us understand why 200\sim 200 maintained comparable performance despite being far from the theoretical target.
Pseudo-algorithm from Farajtabar et al. (2018)
Pseudo-algorithm from Farajtabar et al. (2018)

Limitations:

  • Storage usage grows with the number of tasks S=#tasks×D×200|S| = \#tasks \times D \times 200
  • DD (# params) can be very large (LLMs are over a trillion parameters) and thus even a single gradient vector can be difficult to store
  • Rigid orthogonality is theoretically pretty but practically overly conservative
Follow up work (GPM 2021) use SVD on network activations to reduce gradient subspace storage size, while more recent work (STIL 2023) refute that orthogonality can only be enforced in the weight space — not gradient space — since that is where the optimization occurs. Evidently, this line of work remains active.

Elastic Weight Consolidation

Elastic Weight Consolidation is the final regularization palliative we'll explore, penalizing changes to parameters deemed important to earlier tasks through the proxy of Fisher information. As a reminder, the loss used in regularization methods is:
L(θ)=LB(θ)+λR(θ,θA)\mathcal{L}(\theta)= \mathcal{L}_B(\theta)+ \lambda \, R(\theta, \theta_A^*)
In EWC, R(θ,θA)=12iFi(θiθA,i)2R(\theta, \theta_A^*) = \frac{1}{2}\sum_iF_i(\theta_i - \theta_{A,i}^*)^2 — penalizing squared deviations from highly sensitive parameters using Fisher information. We'll try to better understand this choice below.

Understanding Fisher Information

Fisher information tells us how sensitive a model parameterized by θ\theta with a probability distribution p(x;θ)p(x;\theta) is to changes in any one of its parameters θi\theta_i. Formally, F(θ)=E[(θlogp(x;θ))2]F(\theta) = \mathbb{E}\left[ \left( \frac{\partial}{\partial \theta} \log p(x;\theta) \right)^2\right], which is variance of a score function s(θ)=θlogp(x;θ)s(\theta) = \frac{\partial}{\partial\theta}\log p(x;\theta).
At a glance, it's not obvious where the log\log comes from, nor why the variance of a gradient would give any insight into parameter sensitivity. As with most ML derivations, the log\log is both a calculation simplifier (turns a finite product into a sum which are easy to compute gradients on) and normalizer, in this case normalizing the integral (implicit to the expectation) by how rare parameters are to emerge from the training distribution of xsx's, D\mathcal{D} . The variance of the gradient of the log likelihood says "θ\theta is <u>this</u> sensitive to the underlying distribution of xx". Capturing this variability gives some sense of each parameters importance and relation to that original task distribution.
While Fisher information all of DD parameters produces a (D,D)(D,D) matrix, the paper uses the diagonal of the Fisher information matrix, costing O(D)O(D) storage instead of O(D2)O(D^2). This explicitly assumes that parameter importances are independent, which is wrong but a helpful assumption.
For more about Fisher information, I urge you to read Anwi's post on the topic, which does a great job at elucidating its runes and incantations.

Code

def compute_fisher_diagonal(model, dataset) -> dict:
    """Compute the diagonal Fisher: F_ii = E[(d/d_theta_i log p(y|x))^2]"""
    fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters()}
    N = len(dataset)

    model.eval()
    for input, gt_label in dataset:
        model.zero_grad() # clear gradients
        logits = model(input) # (B, C)
        loss = F.nll_loss(F.log_softmax(logits, dim=1), gt_label) 
        loss.backward() # computes gradients

        for n, p in model.named_parameters():
	        # same shape as params p, e.g. (H, W)
            fisher[n].data += p.grad.data ** 2 / N 

    return fisher

def get_ewc_loss(model, fisher, p_old):
    """Compute the EWC penalty as the sum of Fisher-weighted
    squared parameter shifts"""
    loss = 0
    for n, p in model.named_parameters():
        delta = p - p_old[n] # shape e.g. (H,W)
        loss += (fisher[n] * delta ** 2).sum() # scalar
        
    return loss
    
# USAGE
# 1. Train on task A normally
train(model, dataset_A)
# 2. Snapshot params at theta_A*
p_old = {n: p.data.clone() for n, p in model.named_parameters()}
# 3. Compute Fisher at theta_A* over a subset of D_A
fisher = compute_fisher_diagonal(model, dataset_A_subset) # ~200 samples
# 4. Train on task B with EWC
for x, y in dataset_B:
    pred = model(x)
    loss = criterion(pred, y) + lam * get_ewc_loss(model, fisher, p_old)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Limitations

  • The diagonal Fisher is a crude approximation of the true Fisher information, limiting effectiveness
  • λ\lambda is a new hyperparameter that's hard to balance
  • Practically outperformed by a replay buffer of 1020\sim10-20 samples per class (above)
####

Conclusion

While a variety of tools and techniques enable pseudo continual learning — namely replay and fine-tuning with regularization work best — we are still very far from a system able to both generalize and incorporate new knowledge as well as our brain. To build increasingly powerful and impactful socioeconomic machines (and our beloved Mars rover), continual learning remains one of the tightest bottlenecks to research our way through.

Footnotes

  1. A reader might push back that the brain actually does catastrophically forget, as we don't remember the class we took 10 years ago nearly as well as we do the class we're taking now. The nuance here is in the severity of forgetting, not the mere presence of it. Long-term depression is inherent to the brain, as old synaptic connections weaken causing retrieval difficulty. But is that information really lost?
  2. Any modern optimizer, like Adam or RMSProp, is still fundamentally gradient-based, just with extra bells and whistles