博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
LSTM改善RNN梯度弥散和梯度爆炸问题
阅读量:4506 次
发布时间:2019-06-08

本文共 439 字,大约阅读时间需要 1 分钟。

我们给定一个三个时间的RNN单元,如下:

我们假设最左端的输入 S_0 为给定值, 且神经元中没有激活函数(便于分析), 则前向过程如下:

S_1 = W_xX_1 + W_sS_0 + b_1 \qquad \qquad \qquad O_1 = W_oS_1 + b_2 \\ S_2 = W_xX_2 + W_sS_1 + b_1 \qquad \qquad \qquad O_2 = W_oS_2 + b_2 \\ S_3 = W_xX_3 + W_sS_2 + b_1 \qquad \qquad \qquad O_3 = W_oS_3 + b_2 \\

在 t=3 时刻, 损失函数为 L_3 = \frac{1}{2}(Y_3 - O_3)^2 ,那么如果我们要训练RNN时, 实际上就是是对 W_x, W_s, W_o,b_1,b_2 求偏导, 并不断调整它们以使得 L_3 尽可能达到最小(参见反向传播算法与梯度下降算法)。

那么我们得到以下公式:

\frac{\delta L_3}{\delta W_0} = \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta W_0} \\ \frac{\delta L_3}{\delta W_x} = \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta W_x} + \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta S_2} \frac{\delta S_2}{\delta W_x} + \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta S_2} \frac{\delta S_2}{\delta S_1}\frac{\delta S_1}{\delta W_x} \\ \frac{\delta L_3}{\delta W_s} = \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta W_s} + \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta S_2} \frac{\delta S_2}{\delta W_s} + \frac{\delta L_3}{\delta O_3} \frac{\delta O_3}{\delta S_3} \frac{\delta S_3}{\delta S_2} \frac{\delta S_2}{\delta S_1}\frac{\delta S_1}{\delta W_s} \\

将上述偏导公式与第三节中的公式比较,我们发现, 随着神经网络层数的加深对 W_0 而言并没有什么影响, 而对 W_x, W_s 会随着时间序列的拉长而产生梯度消失和梯度爆炸问题。

根据上述分析整理一下公式可得, 对于任意时刻t对 W_x, W_s 求偏导的公式为:

\frac{\delta L_t}{\delta W_x } = \sum_{k=0}^t \frac{\delta L_t}{\delta O_t} \frac{\delta O_t}{\delta S_t}( \prod_{j=k+1}^t \frac{\delta S_j}{\delta S_{j-1}} ) \frac{ \delta S_k }{\delta W_x} \\ \frac{\delta L_t}{\delta W_s } = \sum_{k=0}^t \frac{\delta L_t}{\delta O_t} \frac{\delta O_t}{\delta S_t}( \prod_{j=k+1}^t \frac{\delta S_j}{\delta S_{j-1}} ) \frac{ \delta S_k }{\delta W_s}

由 以上可知,RNN 中总的梯度是不会消失的。即便梯度越传越弱,那也只是远距离的梯度消失,由于近距离的梯度不会消失,所有梯度之和便不会消失。RNN 所谓梯度消失的真正含义是,梯度被近距离梯度主导,导致模型难以学到远距离的依赖关系。

参考:

 

转载于:https://www.cnblogs.com/USTC-ZCC/p/11159658.html

你可能感兴趣的文章
docker入门3:基础操作(2)
查看>>
WC2019退役失败记
查看>>
Centos6.6下安装nginx1.6.3
查看>>
iOS开发之多线程
查看>>
[算法竞赛]第七章_暴力求解法
查看>>
关于全局替换空格,制表符,换行符
查看>>
MorkDown 常用语法总结
查看>>
自定义python web框架
查看>>
sqlserver生成随机数 2011-12-21 15:47 QQ空间
查看>>
jQuery禁止鼠标右键
查看>>
查询linux计算机的出口ip
查看>>
解决Android的ListView控件滚动时背景变黑
查看>>
laravel 多检索条件列表查询
查看>>
Java_基础—finally关键字的特点及作用
查看>>
SQLServer 日期函数大全
查看>>
Linux常用网络命令
查看>>
激活webstorm11
查看>>
mysql 行转列 和 列转行
查看>>
[Leetcode]
查看>>
再谈vertical-align与line-height
查看>>