为什么羊驼群里混入一只虎鲸,关于Orca和逐步蒸馏

kevin zhou
Jan 22, 2024

--

作为系列的第三章,也算是个小收尾

前两章传送门:

不敢想象自动给训练数据打标签能有多爽 (qq.com)

读书人想要点数据,怎么能叫偷呢?要叫借, 也可以叫Self-Instruct (qq.com)

自从Llama诞生以来,几乎已经成为了开源世界的模型标准,而诸多基于Llama重训练和微调的各个版本也应运而生,其中比较有代表性的有以下这些:

好似庞大的家族图谱一样,说白了,就是基于Llama通过利用不同的数据集,制定差异化的下游任务,不同训练和微调方法产生了多种多样的模型。

然后在羊驼家族里,今年年中的时候混入了一只奇怪的生物Orca,直译就是虎鲸

这是由微软研究院的同事通过一种新的训练方式来利用13B的模型(别和我犟说13B也不一定是Llama,如果你看过论文第13页3.2章的话),它可以在13B的模型上产生匹配GPT-3.5能力的推理能力,在BBH测试中取得了非常好的成绩(这和普通的打榜模型利用eval数据去刷榜还是有本质区别的)

论文地址:2306.02707.pdf (arxiv.org)

那么这个Orca究竟和我前面讲的另外两篇模型训练有什么区别呢?从严格意义上来讲,这些方式都是与SFT(有监督的指令精调)范围,但是Orca和前两种靠取得GPT的label,和问答数据有本质的区别。

这个区别就是小型化的LLM从GPT想获取什么,与前两章介绍的方法不同的是,Orca想取得的东西,并非GPT的数据,而是GPT的思维方式!

这个东西听起来好抽象,让我们先从一个基本概念说起

COT(Chain of thought)

论文地址:https://arxiv.org/abs/2201.11903

直译过来就是思维链,我用简单的方式来解释就是人类对每次认真回答的问答都是经过深思熟虑得出来的结论,那么什么是深思熟虑,其实可以简单的理解为任务拆分和关联

思维链则是LLM版本的深思熟虑,(当然是低阶版本,因为用起来容易,后面我会讲高阶版本),拿以下的个例子来说明:

同样一道算术题,大家知道GPT经常会在回答一些数学问题的时候出错,如果你可以像右边的图一样通过在few_shot的prompt中加入你解题的思路就会极大的提升做题的准确率

Tips: 其实我有比论文更简单的方式,只要你在每句prompt之后都加入一个后缀的语句”please give a solution or answer step by step”,你就可以触发GPT的COT机制,而且往往如果你对某些问题没有办法很好的分解思路,我倒不推荐你如例子里那样去做,反而会将GPT引入错误的方向,用我这个方法,没错的,你的准确率会提升一个甚至几个档次,但是由此而产生的Token数量增加和Cost增加也是成正比的,这篇文章您看到这,其实这就已经是赚到了,一般人我不告诉他

Orca论文里也举了这样的例子

和上图对比,下图这种prompt方式叫做论文中称作Role of System Instructions,本质上也是一种获得COT的方式

对于一个LLM的训练任务,一般来讲训练数据长这样

或者从LLM上借出来时候长这样:

{“instruction”: “Arrange the words in the given sentence to form a grammatically correct sentence.”, “input”: “the quickly brown fox jumped”, “output”: “the brown fox jumped quickly”}

以上的类型,基本是由定义的文本,标签,或者问答语句以及各种配合下游任务的catelog字段来组成的jsonl文件

不同于一般的训练数据,Orca的训练数据结构都是这样的:

⟨ System message, User query, LLM response ⟩

其中最为关键的就是关于System message的定义,他是驱动LLM产生COT回答的关键

来自原论文的初始16个system message的format:

训练任务被分为5个子sub-colletion,每个sub-collection是由不同的task组成,每个task由一系列的query来组成,每个sub-collection都拥有各自领域的专业数据,每个query生成的也都源自于专业领域的数据(这里的COT是collection的名字和上文的COT数据不是一个意思)

关于数据集的准备方面,最后论文里提到了一共准备了500万条的instructions(queries augmented with system messages),那么这部分数据用GPT-3.5的API来生成COT的问答,另外又从500万条里面采样了100万条用GPT-4的API来生成COT的问答(看出来控制成本多严格了吗,即使是自己家的机器也不敢随便浪费)

通过以上的方式,Orca并没有像其他的Llama衍生模型一样是靠”借”GPT的数据来完成自己的训练(过拟合,欠泛化),而是更着重于学习GPT的推理思路和过程,它的损失函数也是针对这些来设计的,会更有效率的学习偏重于逻辑思考的token(文中未给出详细的设计方法),通过这些举措在根本上实现了一条不同的路。

其实在传统的深度学习领域有一门训练方法叫做蒸馏,某种程度上跟我这三章讲的东西虽然不是一个套路,但是其实中心思想都差不多,本质上是一种学习或者叫权重的迁移(和迁移学习不是一回事)

蒸馏这些年不那么流行的原因在于其实用传统方法,老师训练学生模型一点也不简单,而且要用大量的原始数据驱动(因为要尽可能去拟合和老师和学生模型之间的Loss funtion的值),所以其实不如quantize用起来舒服,因为蒸馏虽然是一种训练方式,但是本质是其实可以理解为它是为推理服务的,因为模型小型化的本质绝对不是为了训练用的(这话大家可以深入思考一下)

Google在今年的7月5号推出了一篇论文

原文地址:2305.02301.pdf (arxiv.org)

其实要我说,这个很大程度上借鉴了Orca的思想,当然COT也是人家Google提出的,所以也别在意谁先谁后了

Distilling step-by-step,即逐步蒸馏,那和传统的蒸馏比起来,它有什么特点呢?

传统蒸馏的时候,学生和老师模型在数据训练中求的是答案之间的损失函数,训练的目的是尽可能缩小这个老师和学生之前的答案的差距,通过这个模式来训练,势必要用大量的数据来训练,而逐步蒸馏和Orca一样的,它在设计蒸馏的时候,学习的可不是老师的答案,而是COT那套东西,当然这个论文造了个名字叫合理性(ralations)

大致原理就是,给定一个大模型和一个无标记的数据。然后设计prompt模版(和Orca一样)

然后就是如何将上面得到的叫COT也好,叫合理性(Rationales)用于下游的任务中。作者将学习合理性(Rationales)作为一个多任务学习问题。让接下来的模型不仅需要预测输出的标签,还需要生成相关的合理性(Rationales)。并且比Orca论文写的好的地方我认为是,直接提出了合理性损失函数,通过缩减老师模型和学生模型的合理性之间的所示函数的值,让训练拟合

总结一下就是不管是哪条大路最后都通向罗马,甚至这两个方法几乎是一模一样,在两个顶级研究团队同时对同样的思路发表论文,虽然这在历史上不是第一次,也不会是最后一次,但可以从侧面说明这个理论的重要性和确定性,两篇论文都给了很好的和其他模型的对比数据来说明问题,由于篇幅关系,我就不沾了

这是借东风系列的第三篇,第三种借一般人其实是很难借到的,目前Orca和逐步蒸馏法都没有开源项目,如果靠自己写还是有一定难度的(Label和损失函数的定义就不太好弄),后续如果有好的项目关于这块的实操部分,虽然这个三章结束了,我可能会再写一篇实战篇,敬请期待

--

--