0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

一文解析PPO算法原理

深度学习自然语言处理 来源:深度学习自然语言处理 2024-01-09 12:12 次阅读

作者:LLM-Finder,某厂研究大语言模型和多模态学习

写这篇文章的动机

1. 在笔者看来RLHF是LLMs智能的关键之一;

2. 国内厂商在这方面投入比较少,目前看起来并没有很重视;

3. 大家偏向于认为ChatGPT的RLHF做法最多的线索来源于InstructGPT,但是InstructGPT原文的描述也挺含糊的,很多东西只能靠猜和结合开源的实现来解读;

4. 通常学习强化学习所依赖链路比较长,笔者希望以最直观的方式帮助大家通关。

笔者会分两篇文章来介绍,第一篇是理论篇,第二篇是实践篇。读者会在第一篇学习到PPO的原理和instrcutGPT中的RLHF做法;在第二篇中学习到目前影响比较大的开源RLHF实现。

e7082e78-ae34-11ee-8b88-92fbcf53809c.jpg

据公开可获得的信息来看,ChatGPT需要有大致三个阶段的训练过程,如上图所示:

1.Pretraining: 在大规模“无监督”的语料上训练,训练任务是预测下一个词。

2.Supervised Fine-Tuning(SFT):在人类标注上进行微调,所谓人类标注就是人类写Prompt,人类写答案。然后语言模型学习模仿人类是如何作答的。这部分通常要求数据集多样性很好,也因为标注成本很高,通常量级很小。

3.Reinforcement Learning with human feedback(RLHF):对于同一个Prompt把模型的多个输出给人类排序,获取人类偏好标注。用人类的偏好标注,训练一个reward model。训练得到的reward model会作为PPO算法中的reawrd function,来继续优化SFT得到的模型。

通常来说,第一步最有资源门槛,第三步最有技术门槛(同时也需要大量的资源),第二步最简单。所以目前很多厂商是直接拿了开源的第一步的模型,做SFT,或者continue-pretrain(比较小规模的无监督训练)再SFT。他们PR的时候可能会嘴一句,无需复杂的RLHF,只需做细致的微调也能达到很好的效果。

后面两个步骤,通常被视作是人类偏好对齐(alignment),让模型更好地跟随人类的指令作回复。而一些研究发现,对齐后的模型是会有对齐税的现象的(alignment tax),即在通用能力上会有所下降。

因此,不少人是这样认为的:第一步预训练得到的模型就已经决定了后续模型的能力上限;后面两步要做的事情仅仅是在尽可能减少对齐税的情况下,对齐人类偏好。

这里可以分两种情况分析:

• SFT过数据太多遍了,导致大模型出现遗忘;

•安全性对齐很多模型能回答的问题,强制不让回答肯定会对模型能力有所牵制。

在笔者看来,某种意义下RL提供了对LLM的response的Global-level的监督,在一些需要答案非常精确的场景上,RL可能可以发挥出更大的威力。这个看法的依据也很朴素:

1. 比如在coding、数学推导等场景,只要response在关键的地方犯了一点点错给人的感觉就是模型不会,但是SFT的loss可能区分不出来是犯错了还是只是写法风格的差异。

2. SFT给定了标准答案,LLM的上限可能会被标注者的水平所限制;RLHF则只给定了人类偏好,得到了一定(有可能是很大)程度的解放,模型有可能探索出更高程度的智能。这一点并不是无中生有的想法,在游戏AI领域有太多的验证,即在模仿人类玩法(imitation learning)之后,再用RL训练出来的模型,就是能获得更高的智能。这里语言模型跟游戏又有多少本质的区别呢。

InstructGPT中的RLHF

这里简要带过具体数据构造和训练细节,后面会专门有一篇对InstructGPT像素级的解读。

如前文所述,InstructGPT也是包含3阶段的训练,同时我们应该注意到,RLHF这一步训练,实则包含两步训练:

1. 训练Reward Model(RM);

2. 用Reward Model和SFT Model构造Reward Function,基于PPO算法来训练LLM。

数据集

SFT、RM和PPO用到的数据集数据量如下表所示:

e70eebc8-ae34-11ee-8b88-92fbcf53809c.jpg

注意,上表统计的是prompts数量,在RM数据中每个prompt,对应会有4~9个responses。

在构造RM数据的时候,作者采集了用户的prompts,每个prompts包含4~9个模型的输出,模型的输出会给标注员进行排序。

训练Reward Model(RM)

目标:给pormpt-response pair打分,拟合人类的偏好。

模型:这InstructGPT的paper中,虽然用了1.3B、6B和175B的GPT-3来做实验,但是综合考虑下,只用6B的模型来训练Reward Model,因为作者发现用175B的模型会不稳定。把最后的unembedding层换成一个输出为scalar的线性层。这里读者可能会有点混乱,众所周知,GPT的模型结构是sequence-in,sequence-out的,怎么变成scalar呢?这里文章似乎也没提到,根据笔者的判断和开源实现,推测是直接用最后一个token的输出接一个linear

Reward Model的初始化:6B的GPT-3模型在多个公开数据((ARC, BoolQ, CoQA, DROP, MultiNLI, OpenBookQA, QuAC, RACE, and Winogrande)上fintune。不过Paper中提到其实从预训练模型或者SFT模型开始训练结果也差不多。

训练:以前的做法是,RM每次比较两个模型输出的好坏,做法很简单类似对比学习,两个样本对应两个类别,RM对这两个样本分别输出两个得分,拼成一个logits向量;人类标注比较好的那个输出作为label,比如第一个比较好那么label为0,第二个比较好label为1;用cross entropy约束之。

但是作者发现这么做很容易过拟合;也不高效,因为每比较一次都要重新过一下reward model。

因此作者的做法是,在一个batch里面,把每个Prompt对应的所有的模型输出,都过一遍Reward model,并把所有两两组合都比较一遍。比如一个Prompt有K个模型输出,那么模型则只需要处理K个样本就可以一气儿做(2K)次比较。loss的设计如下:

e718f262-ae34-11ee-8b88-92fbcf53809c.png

很直观,其中,x是prompt,yw和yl分别是较好和较差的模型response,rθ(x,y)是Reward Model的输出。σ在文中似乎没有解释,不过根据公式推断和开源实现,应该是sigmod函数。

这里要注意一个细节:在RM训练完之后,会让RM的输出减去一个bias,使得reward score在人类写的答案上(labeler demonstrations)的平均分为0。这里笔者没找到具体在什么数据上统计的,猜测是在SFT数据上做的,如果有读者知道是怎么做的欢迎指出。

Reinforcement Learning(RL)

直接看需要最大化的目标函数

e71d12c0-ae34-11ee-8b88-92fbcf53809c.png

其中,πΦRL和πSFT分别是正在用RL训练的语言模型和SFT训练得到的模型。

上式中,

第一项期望式是在最大化reward的同时,最小化和SFT模型的per-token KL penalty,可以理解为是一种正则手段,两者组合成关于prompt-Responce pair最终的Reward:R(x,y)=rθ(x,y)−βlog(πΦRL(y∣x)/πSFT(y∣x))。per-token KL penalty的好处如下:

1. 充当熵红利(Entropy bonus),鼓励policy探索并阻止其坍塌为单一模式。

2. 确保策略模型产生的输出 与 Reward Model在训练期间看到的输出 不会相差太大,保证Reward的可靠性。仅含这一项就是单纯使用了PPO。这里也可以看出来,Reward model的能力可能会成为RLHF的瓶颈。

第二项期望式是可选项,注意到它其实是使用预训练的数据来做跟预训练同样的任务(predict next word),因为这一项的数据不是模型生成的其实跟RL是并行的目标。包含这一项的算法称之为PPO-ptx。

PPO算法

本小节以最小知识补充为前提,快速介绍PPO,不用犯怵,很简单而直观。

通常来说,对于一个强化学习模型,会有一个做动作的策略网络π,它根据自己观测的状态(si)做出动作(ai)跟环境交互,然后会拿到一个即刻的reward(ri), 同时进入到下一个状态(si+1);策略网络再继续观测状态si+1做下一个动作ai+1...直到达到最终状态。这样,策略网络和环境的一系列互动后最终会得到一个轨迹(trajectory):τ=s1,a1,r1,s2,a2,r2,...,sT,aT,rT。

那么,在语言模型的场景下,策略网络就是待微调的LLM,它所能做的动作就是预测下一个token,它观测的转状态就是预测下一个token时所能观测到的context(Prompt+这个token前所生成的所有tokens)。

reward除了最后一个rT等于上文提到的R(x,y)=rθ(x,y)−βlog(πRLΦ(yT∣x,y1,...,yT−1)/πSFT(yT∣x,y1,...,yT−1))

其他的ri=−βlog(πRLΦ(yi∣x,y1,...,yi−1)/πSFT(yi∣x,y1,...,yi−1))。

好,在LLM的场景中,现在可以统一一下符号:s1=x,ai=yi,si=cat(x,y1,y2,...,yi−1),其中x是prompt,yi是第i步蹦的token。看到这,了解PPO的同学基本上就清晰了RLHF具体是怎么做优化的了,可以直接跳过下面的科普部分。

因为PPO原文是基于Actor-Critic算法做的,Actor-Critic算法是进阶版的Policy Gradient算法。下面我们从policy gradient到Actor-Critic,再到PPO,帮助RL背景比较弱的读者串一遍。

Policy Gradient(PG)算法

核心要义:用“Reward”作为权重,最大化策略网络所做出的动作的概率。

伪代码核心部分一句话的事:

e72b6604-ae34-11ee-8b88-92fbcf53809c.jpg

用策略网络πθ采样出一个轨迹,然后根据即刻得到的rewardrt计算 discounted rewardRt=∑i=tTγi−tri;用Rt作为权重,最大化这个轨迹下所采取的动作的概率log(π(at∣st))⋅Rt,用梯度上升优化之。

虽然在强化学习算法中对每一步都有一个即时的“reward”,但是每一步对后面的可能状态都是有影响的。

即,后面的动作获取的即时“reward”都能累计到前面的动作的贡献。但是直接加上去可能不好,毕竟不是前面的动作直接获取的reward,但是可以打个折扣再加上去,即乘个小于1的γ。

这里面读者可能会有个问题:可是不好的动作也要最大化概率吗?

这里有必要稍微展开一下:

1.Rt也可以是负的,对负的Rt那就是最小化动作at的概率,这也是为什么前面提到要对RM的输出做归一化的其中一个原因之一。

2.即便Rt都是正的,但只要充分采样,同一个状态下相对的Rt较小的动作也是会被抑制的,因为同一个状态下的动作概率求和等于1,此消彼长,只有权重最大的动作才会得到奖励。

可是,比如同一个状态下,有两个动作的Rt是正的,但是因为动作采样本来就很稀疏的,我们很可能不幸运采样到了相对较小的Rt对应的动作,而没有采样到相对较大的。但因为它是正的,这时候当前的机制下,还是会鼓励这个动作,这样的话网络很容易一直沿着不太好的策略去优化。为了解决这个问题,我们引入Actor-Critic算法。

Actor-Critic (AC)算法

核心要义:再增加一个Critic网络来构造一个Reward baseline,只有获得的reward比baseline要好才奖励这个动作,否则抑制它。

e72f29e2-ae34-11ee-8b88-92fbcf53809c.jpg

Actor指的是策略网络πθ;Criticbϕ目的就是给定一个策略网络,预估每个状态st,策略网络所能拿到期望rewardbϕ(st)是多少。什么是期望reward,无非就是在状态st,对πθ采样不同的动作at所能获取的Rt的平均值嘛。我们要选择的动作当然是获取的reward比平均reward要好的动作,不比baseline好的动作就得抑制它。

观测上面算法2,其实对比PG算法就加了两行:

1. 原来用Reward function来加权,现在用Advantage function来加权。现在我们把bϕ(st)当作一个baseline方法所能拿到的reward, 用采样出来的at所拿到的rewardRt减去bϕ(st)作为最大化当前动作概率的权重:At=Rt−bϕ(st)。其中 A_t 通常被称作是Advantage function(或Advantage estimator),即优势函数。

2.拉近bϕ(st)和Rt的距离,初学者对这个可能会费解。实则很好理解,记住bϕ在做什么,要预估当前策略下Rt的期望,我只要不管三七二十一,每来一个动作的Rt都拉近一下距离,其实就是在预估平均值。更一般地:

其实上面用到的bϕ,它无非是换了皮的Vπθϕ(st)(简写成Vϕ(st)),即RL中的重要概念V function:给定策略πθ在st上的期望reward。那么最后一步 T 到达的state sT通常来讲是没有随机性的(比如下棋,最后一个state决定赢输就是固定的reward;LLM,最后一个token生成完,response确定了,reward也就确定了),因此rT应该和Vϕ(sT)相等。

所以我们可以重写上面的优势函数:

A^t=−Vϕ(st)+rt+γrt+1+⋯+γT−tVϕ(sT)

写成Generalized Advantage Estimation,当λ=1 下式等于上式:

A^t=δt+(γλ)δt+1+⋯+(γλ)T−t+1δT−1

其中,δt=rt+γVϕ(st+1)−Vϕ(st)是时序差分式(TD error)。

记住这个结论:这样我们可以用A^t优化πθ,现在我们可以用▽θlog(πθ(at∣st))⋅A^t来更新策略网络了。

PPO Finetuning

上面提到的算法,有一个最严重的弊端是,一个轨迹只用一次就丢掉了。可是,采样轨迹通常是很耗时的,对应到在LLM场景则需要做推理,众所周知LLM的推理是比训练费劲很多的,它需要一个一个地蹦词。可是直接用之前的策略采样出来的样本来优化现在的策略网络肯定不行,如何合理复用样本则是PPO要做的事情。

做法巨简单,大致可以用这个思想来更新:

定义 动作概率比rt(θ)=πθold(at∣st)πθ(at∣st),用▽θrt(θ)⋅A^t去梯度上升更新策略网络,注意这里stat和A^t都是只之前的策略网络πθold采样得到的。这个公式,在笔者看来没有直观的解释,需要一丢丢推导,因为是科普向这里读者先承认就好了,后面笔者会单开一篇文章再重新梳理一遍。

本质上是最大化这个目标函数:

e739bf6a-ae34-11ee-8b88-92fbcf53809c.jpg

但是如果πθ和πθold如果差别太大,就不能用这个式子优化了,PPO给出的做法是给rt(θ)卡阈值,太大或太小就不用这一步的样本更新了:

e7445ae2-ae34-11ee-8b88-92fbcf53809c.png

上面的目标函数可以分类讨论进行分析,对优势函数A^t大于0和小于0两种情况分析,这个目标函数的图像长这样:

e751e306-ae34-11ee-8b88-92fbcf53809c.jpg

观测图像:

当A^t大于0,要提高动作的概率,但是如果概率比之前大比较多了(πθ是πθold的1+ϵ倍),就不提高了

当A^t小于0,要减少动作的概率,但是如果概率比之前小比较多了(πθ是πθold的1−ϵ倍),就不减少了

伪代码如下:

e75c7d8e-ae34-11ee-8b88-92fbcf53809c.jpg

科普到此结束,看到这读者就可以看懂RLHF的代码。值得注意的是为了减少读者负担做了大量的叙述上的简化,方法上是比较完备的,但是说法上不够严谨。Again,更详细的强化学习科普会单开一篇文章。

大语言模型的PPO

稍微整理一下,符号和上面的科普部分不一致,不过应该不影响理解

1.现在我们的actor是SFT初始化的LLMπΦRL;

2.为了计算reward,我们需要两个冻住参数网络,一个RM,一个是冻住的SFT模型πSFT用来计算KL散度,参考下面两式子:rT=R(x,y)=rθ(x,y)−βlog(πRLΦ(yT∣x,y1,...,yT−1)/πSFT(yT∣x,y1,...,yT−1))其他步的ri=−βlog(πRLΦ(yi∣x,y1,...,yi−1)/πSFT(yi∣x,y1,...,yi−1));

3.为了执行PPO算法,我们需要引入一个估计V值的网络Vη,它初始化来自RM。所以统共,有4个网络,两个训练的actorπΦRL和criticVη;两个用来计算reward的SFT模型πSFT和RM模型。然后actor初始化来自SFT,critic初始化来自RM。

把这四个网络,结合reward的构造,带入到上面提到的PPO算法中,整个过程就比较清晰了。

盗一下DeepSpeed-Chat的图,图解如下:

e760d8c0-ae34-11ee-8b88-92fbcf53809c.jpg

看到这,相信读者已经可以轻易看懂的DeepSpeed-Chat代码了。‍‍

审核编辑:黄飞

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • ChatGPT
    +关注

    关注

    29

    文章

    1558

    浏览量

    7592

原文标题:PPO算法

文章出处:【微信号:zenRRan,微信公众号:深度学习自然语言处理】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    了解ADAS与自动驾驶的现状、算法和技术路线

    了解ADAS与自动驾驶的现状、算法和技术路线
    的头像 发表于 07-30 09:58 1.6w次阅读

    拆解大语言模型RLHF中的PPO算法

    由于本文以大语言模型 RLHF 的 PPO 算法为主,所以希望你在阅读前先弄明白大语言模型 RLHF 的前两步,即 SFT Model 和 Reward Model 的训练过程。另外因为本文不是纯讲强化学习的文章,所以我在叙述的时候不会假设你已经非常了解强化学习了。
    的头像 发表于 12-11 18:30 2231次阅读
    拆解大语言模型RLHF中的<b class='flag-5'>PPO</b><b class='flag-5'>算法</b>

    解析BLDC电机控制算法

    许多不同的控制算法都被用以提供对于BLDC电机的控制。典型做法是,将功率晶体管用作线性稳压器来控制电机电压。当驱动高功率电机时,这种方法并不实用。
    发表于 03-27 09:31 1868次阅读

    光耦PC817中解析

    光耦PC817中解析
    发表于 08-20 14:32

    用PID算法调温的经验

    本文主要是分享资料,讲解不会太多,因为分享的资料里面就有具体的详细解析,而且百度上面也有详细的资料,所以本次博主要是讲解我用PID算法调温的经验。PID算法调整温度最大的问题的温度的
    发表于 11-23 08:27

    PID算法解析,绝对实用

    PID算法解析,绝对实用
    发表于 01-21 07:40

    C++的G代码解析算法研究

    在数控技术发展过程中,G 代码的解析优劣是促进数控技术的发展因素之。但目前的解析算法,并不能更高效的进行解析处理。经过对G 代码进行分析,
    发表于 07-21 16:36 0次下载

    基于KMP算法的串口通讯协议解析邹铁

    基于KMP算法的串口通讯协议解析_邹铁
    发表于 03-17 08:00 2次下载

    解析机器学习常用35大算法

    本文将带你遍历机器学习领域最受欢迎的算法。系统地了解这些算法有助于进步掌握机器学习。当然,本文收录的算法并不完全,分类的方式也不唯
    的头像 发表于 06-30 04:24 3876次阅读
    <b class='flag-5'>一</b><b class='flag-5'>文</b><b class='flag-5'>解析</b>机器学习常用35大<b class='flag-5'>算法</b>

    解析PLC的应用

    解析PLC的应用,具体的跟随小编起来了解下。
    的头像 发表于 07-19 11:21 5257次阅读
    <b class='flag-5'>一</b><b class='flag-5'>文</b><b class='flag-5'>解析</b>PLC的应用

    基于PPO强化学习算法的AI应用案例

    Viet Nguyen就是其中个。这位来自德国的程序员表示自己只玩到了第9个关卡。因此,他决定利用强化学习AI算法来帮他完成未通关的遗憾。
    发表于 07-29 09:30 2795次阅读

    PID算法详细解析——基于单片机

    本文主要是分享资料,讲解不会太多,因为分享的资料里面就有具体的详细解析,而且百度上面也有详细的资料,所以本次博主要是讲解我用PID算法调温的经验。PID算法调整温度最大的问题的温度的
    发表于 11-15 10:21 13次下载
    PID<b class='flag-5'>算法</b>详细<b class='flag-5'>解析</b>——基于单片机

    解析通信系统的高效正交变量优化算法

    本文讨论了算法,用于在具有正交输入向量的二维空间中找到最佳调整点。该算法根据测量数据点求解相交圆的方程。
    的头像 发表于 05-05 16:37 1712次阅读
    <b class='flag-5'>一</b><b class='flag-5'>文</b><b class='flag-5'>解析</b>通信系统的高效正交变量优化<b class='flag-5'>算法</b>

    PPO物理改性及化学改性的方法

    PPO改性方法分为物理改性(共混、填充等)和化学改性(主链、端基改性等),物理改性主要是与其他高性能树脂共混形成塑料合金,化学改性是在PPO分子链上引入活性基团改善相容性或与其他分子进行嵌段、接枝以克服自身缺陷。
    的头像 发表于 09-06 15:12 4312次阅读

    解析并查集(Union-Find)算法原理

    并查集(Union-Find)算法个专门针对「动态连通性」的算法,我之前写过两次,因为这个算法的考察频率高,而且它也是最小生成树算法的前
    发表于 03-24 18:22 688次阅读