7 minute read

Navigating the Challenges of Fine Tuning and Catastrophic Forgetting

Published on
by
Yurts summary

Open source language models have changed the way we use AI, but there's a lot of room for improvement to meet the diverse needs of businesses. By gathering feedback from enterprise users across different industries and updating our models in response, we can make these tools even more valuable for industries everywhere. In response to the need for solutions that can adapt to evolving business needs, we've adopted an innovative approach that leverages the potential of open-source language models for achieving closer alignment with user expectations.

To align Large Language models (LLMs) with user preferences, my preferred training method has consistently been LoRA. As LoRA has a low computational memory requirement, it enables the tuning of LLMs on consumer GPUs. In addition, as the LoRA technique only trains low-rank adaptors while keeping the language models' backbone frozen, we anticipate a reduced susceptibility to catastrophic forgetting, a phenomenon very common to artificial neural networks, characterized by the model forgetting all its previously learned information when trained on a new task.

While applying LoRA for fine-tuning language models on two sequential datasets, I noticed a significant drop in the performance (aka catastrophic forgetting) of the trained model on the first dataset. This was particularly surprising since, intuitively, one would expect that the trained low-rank adaptors would minimally alter the frozen language models’ backbone; that is, the net change to the model’s parameters should be very small. 

Despite the minimal change in the model's parameters after LoRA training, the observation of catastrophic forgetting of previous information is quite alarming. This suggests that applications requiring continual alignment of LLMs to user preferences through LoRA might be sub-optimal, particularly in enterprise environments with limited on-premise training budgets.

To alleviate the catastrophic forgetting issue that LLMs were experiencing, I compared it with a technique known as FIP, which I developed at Caltech. This method takes into account the geometry of the loss landscape during training on a new task. Although the changes to the models' parameters are significantly larger compared to LoRA training, the model maintains its performance on the previous task while becoming adept at the new task. I believe that, with appropriate adjustments to reduce the memory requirement of the FIP algorithm, it could become an effective technique for regularly aligning LLMs with user preferences, particularly in environments with restricted on-premises training budgets. 

Read our full research below:

Continual alignment of LLMs with user preferences: FIP's Edge Over LoRA in Memory Preservation

In collaboration with Guruprasad Raghavan (Yurts AI), Surya Narayanan Hari (Caltech), Matt Thomson (Caltech)

Large Language Models (LLMs) have transformed how we interact with AI, showcasing exceptional abilities in generating human-like text. However, for LLMs to be truly effective across different industries, they must be robust to inter-disciplinary transfer and align with user preferences and ethical standards. Techniques such as Reinforcement Learning from Human Feedback (RLHF) [1], Direct Preference Optimization (DPO) [2], and Inverse Preference Optimization (IPO) [3] have been developed to achieve this alignment. By integrating these methods, LLMs can better meet the specific needs and values of users, facilitating wider adoption and more meaningful applications in various sectors.

The challenge emerges when aiming to continuously align LLMs with user preferences through Direct Preference Optimization (DPO), especially when user preference data is collected regularly, such as monthly. The goal is to fine-tune the base LLM with this data on a monthly basis, ensuring the model not only adapts to the latest preferences but also retains and integrates all previously learned preferences, achieving a compounding effect of alignment over time.

In practice, however, the continual fine-tuning of LLMs—as well as other neural networks like CNNs and LSTMs—for alignment with user preferences often fails due to catastrophic forgetting. This phenomenon, where a neural network forgets previously learned information (Task A) when learning new information (Task B) sequentially, has been a recognized challenge since the 1980s, as highlighted by the seminal McCloskey paper [4]. This inherent limitation complicates the goal of achieving a compounding effect of alignment of LLMs over time.

What can be done to alleviate catastrophic forgetting?

A battery of techniques have been developed for addressing the issue of catastrophic forgetting in neural networks, namely Elastic Weight consolidation (EWC) [5], Synaptic Intelligence (SI) [6], Gradient Episodic Memory (GEM) [7], Brain-inspired replay (BIR) [8], Functionally invariant paths (FIP) [9], Low Rank Adaptation (LoRA) [10]. Most of these techniques have been successfully applied to the domain of images, specifically for image classification - but haven’t been extensively explored for Natural language processing & the language domain. 

Below, we will closely examine two techniques: (a) LoRA, and (b) FIPs for enabling networks to continually learn on language tasks. LoRA works by selectively only training low rank adaptors for a new task, while keeping the network's entire backbone frozen. This enables fine-tuning of very large networks with low computational resources. In addition to memory savings, the authors suggest that the technique alleviates catastrophic forgetting as the low rank adaptors when added to the network’s backbone doesn’t shift the network in the weights space as much - implying that the trained network remaining closer to the original network would retain its functional performance. 

LoRA employs low rank adapters when learning a new task, leading to the popular belief that it could prevent Catastrophic forgetting—i.e. retention of previous task information while learning new information. Contrary to expectations, our experiments reveal that applying LoRA does not mitigate catastrophic forgetting in the context of continual learning. Instead, FIP uncovers neural networks, or LLMs in this particular instance, that simultaneously retain performance on the previous task while picking up the new task. This effectiveness stems from the FIP technique’s approach of modeling the networks’ weight space as a curved Riemann manifold, ensuring that the newly trained network on the new task remains closer in the functional space to the original network while traversing the weights space. 

A major reason LoRA does not perform as expected is due to the faulty assumption that keeping the weights (or parameters) of the LLM “close-enough” to the base LLM will return functional performance; However, this assumption falls short when the underlying loss landscape is highly rugged and non-convex. In such landscapes, we need techniques that consider the geometry of the LLMs functional space during continual training, which is precisely what FIP accomplishes. 

We exemplify this by training a BERT model on two classification tasks: the first entails classifying Yelp reviews (where the model receives a Yelp user review and predicts the user's star rating) [yelp-huggingface], and the second involves classifying IMDB reviews (where the model processes an IMDB user review to determine if it is positive or negative) [imdb-huggingface]. 

Experimental setup

We train BERT on the two tasks: (i) predicting number of stars (**) rated by the user for a yelp review and (ii) the positive/negative sentiment of an IMDB review, in a sequential manner.

Table 1: Naive Fine-tuning; LoRA Fine-tuning; FIP Fine-tuning on two sequential tasks (Yelp → IMDB)
Table 2: Change in performance of final network from Yelp trained network, post IMDB training

The tables and plots above clearly demonstrate that both naive fine-tuning of BERT and LoRA-based tuning cause LLMs to forget their previous task—predicting number of stars on Yelp reviews—while they excel at the new task of predicting sentiment of IMDB reviews. In contrast, FIP ensures that the network retains performance on the prior Yelp task while learning the new IMDB task.

A deep dive on what’s going on when we FIP tune and why isn’t LoRA working as expected?

LoRA is performing as expected. Specifically, LoRA is a method that freezes the base neural network and solely adjusts the low-rank adapters. This ensures that the model, fine-tuned for a new task, remains "close" in the parameter (or weights) space to the original base neural network.

From Figure-1 (below), we observe that the fractional change in the LoRA adaptor weights post fine-tuning on the second task, measured via the Frobenius norm, is minimal compared to the change in model weights after FIP tuning. This indicates that the LoRA-tuned network (on IMDB) remains very close to the base model (trained on Yelp reviews) yet exhibits poor performance on the Yelp task. Conversely, the FIP-tuned network's weights are farther from the base model's (trained on Yelp) while still maintaining its performance.

Figure 1: Fractional change in the networks weights (Frobenius norm)

This observation suggests that proximity in weight space to the base model is not particularly meaningful, especially when the loss landscape is highly non-convex—or, in other words, very rugged. Instead, FIP, which focuses on maintaining closeness to the base network in functional space (as opposed to weight space), retains performance on the previous task while also performing well on the newly-tuned task (IMDB). 

For a deep dive on the FIP algorithm, definitely check out our paper: Engineering flexible machine learning systems by traversing functionally invariant paths.

References: 

  1. Ouyang, Long, et al. "Training language models to follow instructions with human feedback." Advances in neural information processing systems 35 (2022): 27730-27744.
  2. Rafailov, Rafael, et al. "Direct preference optimization: Your language model is secretly a reward model. arXiv 2023." arXiv preprint arXiv:2305.18290 (2023).
  3. Hejna, Joey, and Dorsa Sadigh. "Inverse preference learning: Preference-based rl without a reward function." Advances in Neural Information Processing Systems 36 (2024).
  4. McCloskey, Michael, and Neal J. Cohen. "Catastrophic interference in connectionist networks: The sequential learning problem." Psychology of learning and motivation. Vol. 24. Academic Press, 1989. 109-165.
  5. Kirkpatrick, James, et al. "Overcoming catastrophic forgetting in neural networks." Proceedings of the national academy of sciences 114.13 (2017): 3521-3526.
  6. Zenke, Friedemann, Ben Poole, and Surya Ganguli. "Continual learning through synaptic intelligence." International conference on machine learning. PMLR, 2017.
  7. Lopez-Paz, David, and Marc'Aurelio Ranzato. "Gradient episodic memory for continual learning." Advances in neural information processing systems 30 (2017).
  8. Van de Ven, Gido M., Hava T. Siegelmann, and Andreas S. Tolias. "Brain-inspired replay for continual learning with artificial neural networks." Nature communications 11.1 (2020): 4069.
  9. Raghavan, Guruprasad, and Matt Thomson. "Engineering flexible machine learning systems by traversing functionally invariant paths in weight space." arXiv preprint arXiv:2205.00334 (2022).
  10. Hu, Edward J., et al. "Lora: Low-rank adaptation of large language models." arXiv preprint arXiv:2106.09685 (2021).

Frequently asked questions

What is catastrophic forgetting, and why does it pose a challenge in continuously learning Generative AI systems?
Catastrophic forgetting is a phenomenon in artificial neural networks where a model forgets previously learned information upon learning new tasks. It's particularly challenging as it hampers the model's ability to adapt to new data while retaining essential earlier knowledge, crucial for continually learning systems.
How does the LoRA technique attempt to fine-tune LLMs while minimizing catastrophic forgetting, and what limitations has it encountered?
Catastrophic forgetting is a phenomenon in artificial neural networks where a model forgets previously learned information upon learning new tasks. It's particularly challenging as it hampers the model's ability to adapt to new data while retaining essential earlier knowledge, crucial for continually learning systems.
In simple terms, what is the Functionally Invariant Paths (FIP) method, and how does it work?
The Functionally Invariant Paths (FIP) method adjusts the neural network's weight space by modeling it as a curved Riemannian manifold. This ensures that while the network learns new tasks, it remains functionally close to its original configuration, effectively retaining its performance on previous tasks despite substantial changes in parameters.
How does the efficiency and application of FIP differ from that of LoRA when fine-tuning LLMs?
FIP differs from LoRA in its approach to mitigating catastrophic forgetting. While LoRA maintains the model's parameter proximity, FIP emphasizes functional closeness, leading to superior retention of prior task performance. FIP's approach, focusing on the geometry of the learning landscape, enables it to outperform LoRA in sequential learning tasks, making it a more versatile solution for real-world applications.
What could be the broader impact of addressing catastrophic forgetting on the future development and application of Generative AI technologies?
Mitigating catastrophic forgetting can improve generative AI, making it better at learning continuously and becoming performant at multiple knowledge domains, namely in areas like healthcare, aerospace, and manufacturing. In addition to fine-tuning systems for different knowledge domains, efficient continual learning will enable continual alignment of LLMs with user feedback. This improvement means Generative AI can handle tasks across different industries and can be gradually improved to better align with the enterprises’ end knowledge workers preferences.
Stay up to date with enterprise AI
Thank you! Your submission has been received!
Oops! Something went wrong while submitting the form.
written by
Guruprasad Raghavan
Lead Research Scientist and Founder
7 minute read