0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

PyTorch教程-9.7. 时间反向传播

jf_pJlTbmA9 来源:PyTorch 作者:PyTorch 2023-06-05 15:44 次阅读

如果您完成了第 9.5 节中的练习,您会发现梯度裁剪对于防止偶尔出现的大量梯度破坏训练稳定性至关重要。我们暗示爆炸梯度源于长序列的反向传播。在介绍大量现代 RNN 架构之前,让我们仔细看看反向传播在数学细节中是如何在序列模型中工作的。希望这个讨论能使梯度消失和爆炸的概念更加精确。如果你还记得我们在 5.3 节介绍 MLP 时通过计算图进行前向和反向传播的讨论,那么 RNN 中的前向传播应该相对简单。在 RNN 中应用反向传播称为 时间反向传播 ( Werbos, 1990 ). 此过程要求我们一次扩展(或展开)RNN 的计算图。展开的 RNN 本质上是一个前馈神经网络,具有相同的参数在整个展开的网络中重复出现的特殊属性,出现在每个时间步长。然后,就像在任何前馈神经网络中一样,我们可以应用链式法则,通过展开的网络反向传播梯度。每个参数的梯度必须在参数出现在展开网络中的所有位置上求和。从我们关于卷积神经网络的章节中应该熟悉处理这种权重绑定。

出现并发症是因为序列可能相当长。处理由超过一千个标记组成的文本序列并不罕见。请注意,从计算(太多内存)和优化(数值不稳定)的角度来看,这都会带来问题。第一步的输入在到达输出之前要经过 1000 多个矩阵乘积,还需要另外 1000 个矩阵乘积来计算梯度。我们现在分析可能出现的问题以及如何在实践中解决它。

9.7.1. RNN 中的梯度分析

我们从 RNN 工作原理的简化模型开始。该模型忽略了有关隐藏状态细节及其更新方式的细节。这里的数学符号没有明确区分标量、向量和矩阵。我们只是想培养一些直觉。在这个简化模型中,我们表示ht作为隐藏状态, xt作为输入,和ot作为时间步的输出t. 回忆一下我们在第 9.4.2 节中的讨论,输入和隐藏状态可以在乘以隐藏层中的一个权重变量之前连接起来。因此,我们使用 wh和wo分别表示隐藏层和输出层的权重。因此,每个时间步的隐藏状态和输出是

(9.7.1)ht=f(xt,ht−1,wh),ot=g(ht,wo),

在哪里f和g分别是隐藏层和输出层的变换。因此,我们有一个价值链 {…,(xt−1,ht−1,ot−1),(xt,ht,ot),…} 通过循环计算相互依赖。前向传播相当简单。我们所需要的只是遍历(xt,ht,ot)一次三倍一个时间步长。输出之间的差异ot和想要的目标 yt然后通过所有的目标函数进行评估 T时间步长为

(9.7.2)L(x1,…,xT,y1,…,yT,wh,wo)=1T∑t=1Tl(yt,ot).

对于反向传播,事情有点棘手,尤其是当我们计算关于参数的梯度时wh目标函数的L. 具体来说,根据链式法则,

(9.7.3)∂L∂wh=1T∑t=1T∂l(yt,ot)∂wh=1T∑t=1T∂l(yt,ot)∂ot∂g(ht,wo)∂ht∂ht∂wh.

(9.7.3)中乘积的第一和第二个因子 很容易计算。第三个因素 ∂ht/∂wh事情变得棘手了,因为我们需要循环计算参数的影响wh在 ht. 根据 (9.7.1)中的循环计算,ht取决于两者ht−1 和wh, 其中计算ht−1也取决于 wh. 因此,评估的总导数ht关于wh使用链式规则收益率

(9.7.4)∂ht∂wh=∂f(xt,ht−1,wh)∂wh+∂f(xt,ht−1,wh)∂ht−1∂ht−1∂wh.

为了推导上述梯度,假设我们有三个序列 {at},{bt},{ct}令人满意a0=0和 at=bt+ctat−1为了t=1,2,…. 然后为 t≥1, 很容易证明

(9.7.5)at=bt+∑i=1t−1(∏j=i+1tcj)bi.

通过替换at,bt, 和ct根据

(9.7.6)at=∂ht∂wh,bt=∂f(xt,ht−1,wh)∂wh,ct=∂f(xt,ht−1,wh)∂ht−1,

(9.7.4)中的梯度计算 满足at=bt+ctat−1. 因此,根据 (9.7.5) ,我们可以删除(9.7.4)中的循环计算

(9.7.7)∂ht∂wh=∂f(xt,ht−1,wh)∂wh+∑i=1t−1(∏j=i+1t∂f(xj,hj−1,wh)∂hj−1)∂f(xi,hi−1,wh)∂wh.

虽然我们可以使用链式法则来计算 ∂ht/∂wh递归地,这条链会变得很长t很大。让我们讨论一些处理这个问题的策略。

9.7.1.1. 全计算

一个想法可能是计算(9.7.7)中的总和 。然而,这是非常缓慢的,梯度可能会爆炸,因为初始条件的细微变化可能会对结果产生很大影响。也就是说,我们可以看到类似于蝴蝶效应的现象,即初始条件的微小变化会导致结果发生不成比例的变化。这通常是不希望的。毕竟,我们正在寻找能够很好泛化的稳健估计器。因此,这种策略几乎从未在实践中使用过。

9.7.1.2. 截断时间步长

或者,我们可以在(9.7.7)之后截断总和 τ脚步。这是我们迄今为止一直在讨论的内容。这导致了对真实梯度的近似,简单地通过终止总和 ∂ht−τ/∂wh. 在实践中,这非常有效。这就是通常所说的随时间截断的反向传播( Jaeger, 2002 )。这样做的后果之一是该模型主要关注短期影响而不是长期后果。这实际上是可取的,因为它会使估计偏向于更简单和更稳定的模型。

9.7.1.3. 随机截断

最后,我们可以更换∂ht/∂wh通过一个随机变量,它在预期中是正确的但截断了序列。这是通过使用一系列ξt预定义的 0≤πt≤1, 在哪里P(ξt=0)=1−πt和 P(ξt=πt−1)=πt, 因此E[ξt]=1. 我们用这个来代替渐变∂ht/∂wh在 (9.7.4)中与

(9.7.8)zt=∂f(xt,ht−1,wh)∂wh+ξt∂f(xt,ht−1,wh)∂ht−1∂ht−1∂wh.

它遵循的定义ξt那 E[zt]=∂ht/∂wh. 每当ξt=0 循环计算在该时间步终止t. 这导致了不同长度序列的加权和,其中长序列很少见但适当超重。这个想法是由Tallec 和 Ollivier ( 2017 )提出的。

9.7.1.4. 比较策略

pYYBAGR9NsmAUpAXAADLdgJCruk421.svg

图 9.7.1比较 RNN 中计算梯度的策略。从上到下:随机截断、规则截断和全计算。

图 9.7.1说明了使用 RNN 的时间反向传播分析时间机器的前几个字符时的三种策略

第一行是将文本分成不同长度的段的随机截断。

第二行是将文本分成相同长度的子序列的常规截断。这就是我们在 RNN 实验中一直在做的事情。

第三行是通过时间的完整反向传播,导致计算上不可行的表达式。

不幸的是,虽然在理论上很有吸引力,但随机截断并没有比常规截断好多少,这很可能是由于多种因素造成的。首先,经过多次反向传播步骤后观察到的效果足以在实践中捕获依赖关系。其次,增加的方差抵消了更多步骤梯度更准确的事实。第三,我们实际上想要只有小范围交互的模型。因此,随着时间的推移定期截断的反向传播具有轻微的正则化效果,这可能是理想的。

9.7.2. 详细的时间反向传播

讨论完一般原理后,让我们详细讨论时间反向传播。与9.7.1节的分析不同 ,下面我们将展示如何计算目标函数对所有分解模型参数的梯度。为了简单起见,我们考虑一个没有偏置参数的 RNN,其隐藏层中的激活函数使用恒等映射(ϕ(x)=x). 对于时间步t, 让单个示例输入和目标为 xt∈Rd和yt, 分别。隐藏状态ht∈Rh和输出 ot∈Rq被计算为

(9.7.9)ht=Whxxt+Whhht−1,ot=Wqhht,

在哪里Whx∈Rh×d, Whh∈Rh×h, 和 Wqh∈Rq×h是权重参数。表示为l(ot,yt)时间步长的损失 t. 我们的目标函数,损失超过T因此,从序列开始的时间步长是

(9.7.10)L=1T∑t=1Tl(ot,yt).

为了可视化RNN计算过程中模型变量和参数之间的依赖关系,我们可以为模型绘制计算图,如图9.7.2所示。例如,时间步长 3 的隐藏状态的计算, h3, 取决于模型参数 Whx和Whh, 最后一个时间步的隐藏状态h2, 和当前时间步长的输入x3.

pYYBAGR9NsuAByZbAADzFmtQmKA820.svg

图 9.7.2显示具有三个时间步长的 RNN 模型的依赖关系的计算图。方框代表变量(未加阴影)或参数(加阴影),圆圈代表运算符。

正如刚才提到的,图 9.7.2中的模型参数是 Whx,Whh, 和 Wqh. 通常,训练此模型需要针对这些参数进行梯度计算 ∂L/∂Whx, ∂L/∂Whh, 和 ∂L/∂Wqh. 根据图 9.7.2中的依赖关系,我们可以沿箭头相反的方向遍历,依次计算并存储梯度。为了在链式法则中灵活表达不同形状的矩阵、向量和标量的乘法,我们继续使用 prod操作员如第 5.3 节所述。

首先,在任何时间步根据模型输出对目标函数进行微分t相当简单:

(9.7.11)∂L∂ot=∂l(ot,yt)T⋅∂ot∈Rq.

现在,我们可以计算目标相对于参数的梯度Wqh在输出层: ∂L/∂Wqh∈Rq×h. 根据图 9.7.2,目标L依赖于取决于 Wqh通过o1,…,oT. 使用链式规则收益率

(9.7.12)∂L∂Wqh=∑t=1Tprod(∂L∂ot,∂ot∂Wqh)=∑t=1T∂L∂otht⊤,

在哪里∂L/∂ot由(9.7.11)给出 。

接下来,如图9.7.2所示,在最后的时间步 T, 目标函数L取决于隐藏状态 hT只能通过oT. 因此,我们很容易找到梯度 ∂L/∂hT∈Rh使用链式法则:

(9.7.13)∂L∂hT=prod(∂L∂oT,∂oT∂hT)=Wqh⊤∂L∂oT.

任何时间步长都会变得更加棘手t

(9.7.14)∂L∂ht=prod(∂L∂ht+1,∂ht+1∂ht)+prod(∂L∂ot,∂ot∂ht)=Whh⊤∂L∂ht+1+Wqh⊤∂L∂ot.

为了分析,扩展任何时间步长的循环计算 1≤t≤T给

(9.7.15)∂L∂ht=∑i=tT(Whh⊤)T−iWqh⊤∂L∂oT+t−i.

我们可以从(9.7.15)中看到,这个简单的线性示例已经展示了长序列模型的一些关键问题:它涉及潜在的非常大的幂Whh⊤. 其中,小于 1 的特征值消失,大于 1 的特征值发散。这在数值上是不稳定的,表现为梯度消失和爆炸。如第 9.7.1 节所述,解决此问题的一种方法是将时间步长截断为便于计算的大小。实际上,这种截断也可以通过在给定数量的时间步后分离梯度来实现。稍后,我们将看到更复杂的序列模型(如长短期记忆)如何进一步缓解这种情况。

最后,图 9.7.2表明目标函数 L取决于模型参数Whx和 Whh通过隐藏状态在隐藏层中 h1,…,hT. 计算关于这些参数的梯度 ∂L/∂Whx∈Rh×d 和 ∂L/∂Whh∈Rh×h,我们应用给出的链式规则

(9.7.16)∂L∂Whx=∑t=1Tprod(∂L∂ht,∂ht∂Whx)=∑t=1T∂L∂htxt⊤,∂L∂Whh=∑t=1Tprod(∂L∂ht,∂ht∂Whh)=∑t=1T∂L∂htht−1⊤,

在哪里∂L/∂ht由(9.7.13)和 (9.7.14)循环计算的是影响数值稳定性的关键量。

由于时间反向传播是反向传播在 RNN 中的应用,正如我们在第 5.3 节中解释的那样,训练 RNN 交替进行正向传播和时间反向传播。此外,通过时间的反向传播依次计算并存储上述梯度。具体来说就是复用存储的中间值,避免重复计算,比如存储 ∂L/∂ht用于两者的计算∂L/∂Whx和 ∂L/∂Whh.

9.7.3. 概括

时间反向传播仅仅是反向传播对具有隐藏状态的序列模型的应用。截断是为了计算方便和数值稳定性所需要的,例如规则截断和随机截断。矩阵的高次幂会导致特征值发散或消失。这以爆炸或消失梯度的形式表现出来。为了高效计算,中间值在反向传播期间被缓存。

9.7.4. 练习

假设我们有一个对称矩阵 M∈Rn×n具有特征值 λi其对应的特征向量是 vi(i=1,…,n). 不失一般性,假设它们按顺序排列 |λi|≥|λi+1|.

显示Mk有特征值λik.

证明对于一个随机向量x∈Rn, 很有可能Mkx将与特征向量非常一致v1的 M. 将此声明正式化。

上述结果对 RNN 中的梯度意味着什么?

除了梯度裁剪,你能想到任何其他方法来应对递归神经网络中的梯度爆炸吗?

Discussions

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • pytorch
    +关注

    关注

    2

    文章

    808

    浏览量

    13218
收藏 人收藏

    评论

    相关推荐

    反向传播如何实现

    实现反向传播
    发表于 07-09 16:10

    神经网络和反向传播算法

    03_深度学习入门_神经网络和反向传播算法
    发表于 09-12 07:08

    反向传播算法的工作原理

    反向传播算法(BP算法)是目前用来训练人工神经网络的最常用且最有效的算法。作为谷歌机器学习速成课程的配套材料,谷歌推出一个演示网站,直观地介绍了反向传播算法的工作原理。
    的头像 发表于 07-02 16:01 1w次阅读
    <b class='flag-5'>反向</b><b class='flag-5'>传播</b>算法的工作原理

    深读解析反向传播算法在解决模型优化问题的方面应用

    反向传播算法隶属于深度学习,它在解决模型优化问题的方面有着重要的地位。
    的头像 发表于 11-01 15:48 5612次阅读
    深读解析<b class='flag-5'>反向</b><b class='flag-5'>传播</b>算法在解决模型优化问题的方面应用

    人工智能(AI)学习:如何讲解BP(反向传播)流程

    关于BP知乎上的解释是这样的,反向传播整个流程如下: 1)进行前向传播计算,利用前向传播公式,得到隐藏层和输出层的激活值。 2)对输出层(第l层),计算残差:
    发表于 11-03 16:55 0次下载
    人工智能(AI)学习:如何讲解BP(<b class='flag-5'>反向</b><b class='flag-5'>传播</b>)流程

    浅析深度神经网络(DNN)反向传播算法(BP)

    在 深度神经网络(DNN)模型与前向传播算法 中,我们对DNN的模型和前向传播算法做了总结,这里我们更进一步,对DNN的反向传播算法(Back Propagation,BP)做一个总结
    的头像 发表于 03-22 16:28 3663次阅读
    浅析深度神经网络(DNN)<b class='flag-5'>反向</b><b class='flag-5'>传播</b>算法(BP)

    BP(BackPropagation)反向传播神经网络介绍及公式推导

    BP(BackPropagation)反向传播神经网络介绍及公式推导(电源和地电气安全间距)-该文档为BP(BackPropagation)反向传播神经网络介绍及公式推导详述资料,讲解
    发表于 07-26 10:31 48次下载
    BP(BackPropagation)<b class='flag-5'>反向</b><b class='flag-5'>传播</b>神经网络介绍及公式推导

    详解神经网络中反向传播和梯度下降

    摘要:反向传播指的是计算神经网络参数梯度的方法。
    的头像 发表于 03-14 11:07 1031次阅读

    PyTorch教程5.3之前向传播反向传播和计算图

    电子发烧友网站提供《PyTorch教程5.3之前向传播反向传播和计算图.pdf》资料免费下载
    发表于 06-05 15:36 0次下载
    <b class='flag-5'>PyTorch</b>教程5.3之前向<b class='flag-5'>传播</b>、<b class='flag-5'>反向</b><b class='flag-5'>传播</b>和计算图

    PyTorch教程之时间反向传播

    电子发烧友网站提供《PyTorch教程之时间反向传播.pdf》资料免费下载
    发表于 06-05 09:49 0次下载
    <b class='flag-5'>PyTorch</b>教程之<b class='flag-5'>时间</b><b class='flag-5'>反向</b><b class='flag-5'>传播</b>

    PyTorch教程-5.3. 前向传播反向传播和计算图

    5.3. 前向传播反向传播和计算图¶ Colab [火炬]在 Colab 中打开笔记本 Colab [mxnet] Open the notebook in Colab Colab
    的头像 发表于 06-05 15:43 1118次阅读
    <b class='flag-5'>PyTorch</b>教程-5.3. 前向<b class='flag-5'>传播</b>、<b class='flag-5'>反向</b><b class='flag-5'>传播</b>和计算图

    神经网络前向传播反向传播区别

    神经网络是一种强大的机器学习模型,广泛应用于各种领域,如图像识别、语音识别、自然语言处理等。神经网络的核心是前向传播反向传播算法。本文将详细介绍神经网络的前向传播
    的头像 发表于 07-02 14:18 806次阅读

    反向传播神经网络建模基本原理

    反向传播神经网络(Backpropagation Neural Network,简称BP神经网络)是一种多层前馈神经网络,通过反向传播算法进行训练。它在解决分类、回归、模式识别等问题上
    的头像 发表于 07-03 11:08 447次阅读

    神经网络反向传播算法的优缺点有哪些

    神经网络反向传播算法(Backpropagation Algorithm)是一种广泛应用于深度学习和机器学习领域的优化算法,用于训练多层前馈神经网络。本文将介绍反向传播算法的优缺点。
    的头像 发表于 07-03 11:24 929次阅读

    【每天学点AI】前向传播、损失函数、反向传播

    在深度学习的领域中,前向传播反向传播和损失函数是构建和训练神经网络模型的三个核心概念。今天,小编将通过一个简单的实例,解释这三个概念,并展示它们的作用。前向传播:神经网络的“思考”过
    的头像 发表于 11-15 10:32 633次阅读
    【每天学点AI】前向<b class='flag-5'>传播</b>、损失函数、<b class='flag-5'>反向</b><b class='flag-5'>传播</b>