Knowledge Distillation — Techniques for Efficient Inference of LLMs (IV/IV)

Andrei Apostol
MantisNLP
Published in
6 min readNov 8, 2023

--

Welcome to the fourth and final part of our blog series on techniques for efficient inference of LLMs. In this segment, we explore Knowledge Distillation (KD).

KD involves leveraging a larger, pre-trained model (teacher) to guide the training of a smaller (student) model, allowing the student to learn from the rich knowledge and representations captured by the teacher. We will go over various forms of KD, and how they contribute to creating models with reduced parameters while preserving high performance.

We will also present KD methods that do not require access to the model’s internals (useful in the case of closed-source LLMs).

Applications of Knowledge Distillation in NLP

Knowledge distillation, originally formulated in Hinton’s seminal paper [16] is the practice of including the outputs of a larger, pre-trained model (teacher) into the training process of a smaller (student) model . While it can take many forms, it is typically done by constructing a special loss function whose purpose is to match the logits. In this context, the teacher model can also refer to an ensemble of trained models (which typically achieves a higher accuracy than a single trained model).

This process, known as response-based distillation is illustrated below:

Response-based distillation. From a survey paper

There are two other forms of distillation, namely:

  • Feature-based distillation, whereby the intermediate layer activations of the teacher model are used as targets for the intermediate activations of the student. The rationale here being that the intermediate activations capture representative features from the data.
  • Relation-based knowledge distillation. Rather than using the activations directly, one can use the relationship between the feature maps. This can be represented, among others, via correlation maps or feature embeddings.

For more information on the different types of KD, we refer the reader to this blog post or the survey paper that is cited. For our purposes, we will focus on how KD has been applied in the context of NLP models, and LLMs more specifically.

One notable success was in training DistilBert, a model trained to replicate the outputs of the pre-trained Bert model. Specifically, it uses a loss composed of 3 terms: distillation loss, masked language modeling loss, and a cosine embedding loss. The distillation loss is defined as:

where t_i is the probability estimated by the teacher, and s_i the one estimated by the student.

Moreover, the authors add the regular MLM loss. The cosine embedding loss minimizes the distance in the embedding space, and was found to be helpful in their experiments.

The resulting model has 40% less parameters, runs 60% faster and maintains 97% of BERT’s accuracy. It is, thus, convenient to use such a model for downstream tasks.

Knowledge Distillation through synthetic data generation

With the advent of LLMs, directly using softmax outputs for knowledge distillation can incur a training cost that is too high. Moreover, when one cannot access intermediate activations from the teacher model (e.g. a closed source LLM such as OpenAI’s GPT3), applying knowledge distillation becomes cumbersome.

As such, new forms of knowledge distillation have emerged. For instance, using LLMs to generate synthetic data. Specifically, the authors use prompt engineering and few-shot examples from the real data distribution.

Fig. 2 from GPTMix

The prompt is constructed by specifying the task type and label type. Examples, together with their labels, are selected from the dataset and introduced into the prompt. The model is then prompted to generate additional samples.

Training a downstream BERT model on an (intentionally limited) subset of SST-2 mixed together with synthetic samples, the error rate is vastly reduced. The numbers are laid out in the table below.

Table 2 from GPTMix

The authors argue that training a downstream model on this data is a form of knowledge distillation. While the logits are not directly used in this case, the outputs are, which are a result of the teacher’s internal model of the world.

Generating Instruction Datasets

Synthetic data is not the only way outputs from LLMs can be used in distillation.

Instruction-following models have recently emerged as a friendlier, more intuitive alternative to classical models, eliminating most of the need for prompting.

They use language models trained with supervised finetuning as a starting point, and additionally finetune these models to do instruction following, i.e. text samples with the (instruction, response) template. Once this is done, an additional 3rd (optional) stage can be performed, namely reinforcement learning from human feedback.

In the RLHF stage, human labelers are presented with choices between two different model outputs, generated from the same prompt. For each such choice, the human must choose the preferred output. Once this dataset is collected, a reward model is constructed. Using this reward model as an optimization target, the final model is finetuned. The reason for this is to align the model outputs with human preferences. We refer interested readers to Chip Huyen’s excellent blogpost on this idea.

The pipeline can be seen below:

Fig. from Chip Huyen’s blogpost on RLHF

Note, however, that gathering instruction data, as well as human preference data, is a process that is prohibitively costly. That is in addition to the compute required to train the reward model and finetune the model for two additional rounds.
Due to this, recent models such as Alpaca have adopted a different strategy. Namely, they start from an open source pre-trained model, namely LLaMa-7b, and finetune it on instruction data generated by OpenAI’s GPT3, following the strategy from the Self-Instruct paper.

This process is illustrated below:

Fig. from the Alpaca announcement page

The authors start from 175 human written instruction-following prompts, to be used as seeds. These are then used as few-shot examples for GPT3 to generate more instruction following data.

They generate approximately 52k examples, costing around 600$ and use this to do supervised finetuning on the open-source LLaMa model.

Note, however, that training a downstream model on the outputs of another model can affect the ability to run it for commercial purposes, depending on how the teacher model is licensed.

Wrapping Up

In the final part of our blog series, we explored Knowledge Distillation (KD) and its applications in NLP and LLMs.

In a nutshell, KD involves using a larger, pre-trained model (teacher) to guide the training of a smaller (student) model, resulting in more efficient models with fewer parameters while maintaining high performance. Additionally, there are forms of KD that allow treating the model as a black box, with techniques like synthetic data generation and self-instruct finetuning.

By leveraging KD, practitioners can create faster, more scalable, and resource-efficient language models.

--

--