E4 : Distilling Step-by-Step
Injecting reasoning capabilities in small language models help them outperform LLMs with reduced computational requirements and lesser training data
Paper Name : Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data and Smaller Model Sizes
Paper URL : https://arxiv.org/abs/2305.02301
Authors : Cheng-Yu Hsieh, Chun-Liang Li, Chih-Kuan Yeh, Hootan Nakhost, Yasuhisa Fujii, Alexander Ratner, Ranjay Krishna, Chen-Yu Lee, Tomas Pfister
Conference : ACL 2023
Please find the annotated paper here
Problem Statement :
- The size of LLMs remains a posing challenge for fine-tuning, deploying it for production solution due to its compute and memory requirement in-spite of the remarkable capabilities it possess.
- Small language models on the other hand lack the capabilities of a LLM but are well suited for downstream task specific deployments due to their size and compute requirement.
Standard Fine-tuning Approaches :
- Using pre-trained models for task specific activities requires standard fine-tuning or task distillation of small language models or in-context learning of LLMs.
- Standard fine-tuning involves tuning a pre-trained model on human annotated domain specific input data
- Task distillation involves tuning a student model on noisy pseudo labeled training dataset obtained from unlabelled training data using teacher model (pre-trained model).
- Fine-tuning LLMs is hectic given the memory and compute required. Hence techniques like CoT are used to enable in-context learning in LLMs.
Solution :
- Distilling step-by-step fuses the reasoning capabilities of LLM and the size of small language models to create models that match LLM level performance.
- Small language models are fine-tuned to predict labels and rationales.
- The predicted labels are compared against the human annotated labels in case of standard fine-tuning task or LLM annotated labels in case of task distillation.
- The predicted rationales are compared against the rationales generated by LLM for the training data.
- Combined cross entropy loss of labels and rationales are used to fine-tune the model.
Experimental Setup :
- Comparison of results from Standard Fine-tuning Vs Standard Task distillation Vs Distilling step-by-step Vs CoT tuned LLM
- Training, validation and test data considered for this apporach include-SNLI,ANLI - Natural Language Inference, CQA - common sense Question Answering, SVAMP - arithmetic
datasets - 16XA100 GPUs were used as part of this training and inference process
- Performance of PaLM-540B (LLM) was compared against the performances of T5 Base(220M), T5 Large(770M),T5 XXL(11B) models under different conditions.
Observations :
- All T5 models fine-tuned with Distilling step-by-step approach required only 12.5% off the original training data size to surpass the performance of same T5 models fine-tuned via standard fine-tuning and task distillation techniques as well.
- T5 Large model(11B) fine-tuned using Distilling step-by-step technique, always surpassed the performance of PaLM (540B) on e-NLI,ANLI,CQA dataset.
- T5 Large model (11B) fined tuned using Distilling step-by-step technique failed to match the performance of PaLM (540B) on SVAMP dataset. This was hypothesised due to the very small training data (approx. 800 examples).
- When asDiv dataset which is similar to SVAMP dataset was added to SVAMP and fine-tuned using Distilling step-by-step technique, T5 Large model (11B) was able to almost match the performance of PaLM (540B).
- T5 Large model (770M) fine-tuned using Distilling step-by-step was able to outperform PaLM model with minimum 80% training data. There were instances where T5 Large model was able to match PaLM model performance with just 10% off training dataset (e-SNLI).
Conclusions :
- Distilling step-by-step a new technique that (a). trains smaller language models that outperform LLMs performance thus reducing compute requirements for deployment (b). trains with lesser amount of training data (average of 50%) but still surpassing performance of standard fine-tuned and distillation techniques(c). requires lesser training data and smaller language models to match or outperform the performance of LLMs
- One major limitation of this approach is its dependancy on LLMs to generate rationales, as LLMs are still susceptible to failure with respect to reasoning abilities for complex problems.