Inside ReAugKD: Amazon’s Method that Combines RAG and Knowledge Distillation for Generative AI Models
The technique uses a teacher- student approach to optimize the size and knowledge of foundation models.
I recently started an AI-focused educational newsletter, that already has over 160,000 subscribers. TheSequence is a no-BS (meaning no hype, no news, etc) ML-oriented newsletter that takes 5 minutes to read. The goal is to keep you up to date with machine learning projects, research papers, and concepts. Please give it a try by subscribing below:
Building smaller and more efficient foundation models is one of the key priorities fo the next wave of generative AI. Among the trends in the market, retrieve augmented generation(RAG) has emerged as one of the dominat techniques to expand the knowledge of foundation models with external data sources. Similarly, knowledge distillation(KD) is one of the most interesting methods for optimizing the size of foundation models.
Could we combine the two?
Recently, Amazon Science published a paper detailing introduces a novel approach called retrieval-augmented knowledge distillation (ReAugKD) to bridge the gap between the efficiency of smaller “student” models and the potency of their larger “teacher” counterparts.
The concept behind ReAugKD is ingeniously simple yet remarkably effective. It harnesses the capabilities of teacher models by utilizing their data representations and predictions, which are stored in a lookup table. These insights are then employed to guide the predictions of student models for similar inputs. The versatility of this approach extends beyond language models, making it adaptable to various task-specific external knowledge domains.
To gauge the prowess of ReAugKD, an extensive evaluation was conducted, pitting it against ten existing models across six natural language processing tasks. These tasks encompassed paraphrasing, natural-language inference, and question answering. The results are compelling: ReAugKD emerged as the top performer in five of the tasks, securing the second spot in the sixth. On average, it establishes a new state-of-the-art benchmark, all while incurring a minimal latency overhead of just 3%.
Training ReAugKD
The training methodology employed by ReAugKD follows a two-step procedure. In the initial step, a teacher model, fine-tuned for a specific downstream task, serves as the foundation. Subsequently, a linear-projection layer is introduced atop the model’s encoder, aligning its embeddings with those of the student model. To optimize the parameters of this linear-projection layer, a supervised contrastive loss mechanism is employed. This method pairs training examples with the same labels as positives and contrasts them with randomly sampled negatives from the batch.
The second phase of ReAugKD involkves generating resized teacher embeddings and teacher predictions specifically tailored for the input data that will serve as the foundation for student model training. Here, a critical step is the creation of a similarity matrix for the teacher embeddings, effectively quantifying the likeness between the embedding of each input and those of all other inputs in the dataset.
The heart of the method lies in the training of the student model, where Amazon Science relies on the use of similarity matrices for both student and teacher embeddings. A meticulous loss function, designed to minimize the Kullback–Leibler divergence, comes into play. This divergence optimization aligns the teacher-teacher similarity distribution with the teacher-student similarity distribution, ensuring that both the student and the teacher share a consistent understanding of similarity during inference.
But the sophistication doesn’t stop there. Amazon Science’s loss function incorporates another vital element: the widely acclaimed cross-entropy loss. This element skillfully computes the divergence between the student’s predictions and the teacher’s predictions, adding another layer of refinement to the distillation process.
Some Cool Results
Amazon Science conducted rigorous testing to assess the effectiveness of ReAugKD in distilling the 12-layer BERT-Base model down to a more efficient six-layer BERT model. This evaluation centered around six datasets featured in the GLUE benchmark.
The results are nothing short of impressive. Amazon Science’s method achieves state-of-the-art performance on five of the six datasets, marking an average improvement of 0.42% over the previous best knowledge distillation (KD) approach. Notably, it excels with remarkable gains of 1.37% and 1.43% on two of the benchmark tasks, solidifying its prowess in enhancing model performance.
Furthermore, Amazon Science introduced a variant of ReAugKD that leverages knowledge base retrieval, and the results speak volumes. This augmented version showcases a substantial 0.45% improvement over the baseline ReAugKD without retrieval, underscoring the undeniable benefits of retrieval augmentation within their approach.
In summary, Amazon Science’s ReAugKD proves its mettle through comprehensive testing, emerging as a potent solution for distilling complex models while consistently outperforming existing techniques across a range of benchmark datasets. The addition of knowledge base retrieval further enhances its capabilities, cementing its status as a powerful tool in the world of model optimization.