使用 DALL-E 3 生成的图像
Transformers改变了我们的世界。
它们是当今几乎所有顺序任务的主导 AI 架构。
Transformer 架构如此出色的原因之一是其自注意力机制,该机制允许同时处理标记,而不是像以前的架构(例如RNN、LSTM和GRU)那样按顺序处理。
Transformer 可视化(图片来自作者的书《100 幅图像中的 AI》)
但 Transformer 并不完美。
它们在序列长度上具有二次计算复杂度。
这意味着随着输入序列的长度增加,所需的计算量会呈二次方增长(序列长度的平方)。
这是因为它们有自我注意力机制,序列中的每个标记都会关注其他每个标记以理解上下文。
这在计算资源有限的环境下处理长序列时极大地限制了 Transformers。
为了解决这个问题,研究人员在最近的 ArXiv 预印本中对传统 LSTM 和 GRU 的内部结构进行了调整,形成了称为 minLSTM 和 minGRU 的最小版本。
这些网络使用的参数少得多,可以并行训练,而且比传统网络快得多。
这绝对令人难以置信!
这里我们深入探讨了 RNN、LSTM 和 GRU 是什么,以及如何修改它们以克服当前流行的 Transformer 架构的局限性。
我们为什么首先需要 RNN?
多层感知器 (MLP)或前馈神经网络无法处理序列数据。
因此,循环神经网络(RNN)应运而生,用来解决这一问题。
这些神经网络使用内部隐藏状态或“记忆”来处理序列信息并保存有关先前输入的信息。
利用这种内存,他们可以捕获由时间步骤分隔的不同序列元素之间的依赖关系(称为时间依赖性)。
循环神经网络 (RNN) 可视化(图片来自作者的书《100 幅图像中的人工智能》)
RNN 使用时间反向传播(BPTT)进行训练。
这是用于前馈神经网络的标准反向传播算法的扩展。
BPTT 涉及随着时间的推移展开网络并将每个时间步骤视为前馈神经网络中的一层。
前向传递步骤处理输入序列。
在输出层计算误差,并将得到的梯度从最后一个时间步反向传播到第一个时间步,更新 RNN 的参数。
时间反向传播(BPPT)(图片来自Wikimedia Commons)
由于消失梯度问题, RNN 很难学习长时间依赖性。
这时,随着时间的反向传播,梯度会变得非常小,导致无法学习。
消失梯度问题的可视化(图片来自作者的书《100 幅图像中的 AI》)
另一方面,事情也可能出错,梯度变得太大,导致训练不稳定。
这被称为梯度爆炸问题。
梯度爆炸问题的可视化(图片来自作者的书《100 幅图像中的 AI》)
1997 年, LSTM(长短期记忆)架构被引入来解决这些挑战。
我们接下来讨论一下。
LSTM 的诞生
LSTM 架构是 RNN 的一种修改,它可以保留长时间信息,而不会受到梯度消失/爆炸问题的影响。
LSTM 架构可视化(图片来自作者的书《100 幅图像中的 AI》)
LSTM 由以下部分组成:
- Cell state——存储长期信息
- Hidden state——承载当前时间步的短期输出
- 三个门(输入门、遗忘门、输出门)
在每个步骤中,LSTM 都会根据多个数学运算和门控来决定要忘记多少信息、要向其单元状态添加多少信息以及要为下一步输出多少信息。
详细的 LSTM 架构(图片来自 Wikimedia Commons 上的 Guillaume Chevalier)
LSTM 模块包含O(4d(h)(d(x) + d(h)))参数,其中d(h)是隐藏状态的大小,d(x)是输入向量x(t)的大小。
但 LSTM 还能进一步改进吗?
GRU的崛起
门循环单元 (GRU) 架构于 2014 年推出,简化了 LSTM。
GRU 可视化(图片来自作者的书《100 幅图像中的 AI》)
它不使用 LSTM 的三个门和两个状态,而是使用两个门和一个状态。
在 GRU 中,LSTM 的遗忘门和输入门被合并为一个更新门。该门决定应保留多少过去信息以及应添加多少新信息。
LSTM 的输出门被 GRU 中的重置门取代。该门决定在添加新信息之前应该“重置”或遗忘多少过去信息。
这些变化将网络的参数减少到O(3d(h)(d(x) + d(h))),其中d(h)是隐藏状态的大小,d(x)是输入向量的大小x(t)。
这使得训练和推理时间比 LSTM 更快。
LSTM 和 GRU 都使用时间反向传播(BPTT)进行顺序训练。
这需要线性训练时间,这限制了它们扩展到长输入序列长度的能力。
Transformers的规则
2017年,Transformers彻底接管了顺序处理任务领域。
他们的自我注意力机制允许序列中的每个标记同时(而不是顺序地)关注每个其他标记,以理解上下文。
这种方法使得架构可并行化。
注意力(自我注意力)可视化(图片来自作者的书《100 幅图像中的 AI》)
自注意力有其优点,但它在序列长度上引入了二次时间复杂度,限制了 Transformer 扩展到长上下文的能力。
如果我们能够找到能够解决 Transformers 这一限制的模型会怎样?
答案就在 RNN 本身
研究人员表明,可以通过从不同的门中删除许多隐藏状态依赖关系来简化 LSTM 和 GRU(并进行一些其他更改,这将在下一节中讨论)。
这可以消除这些架构对使用时间反向传播(BPTT)进行训练的依赖性。
如果不是 BPTT,那是什么?
可以使用一种称为并行前缀扫描算法的算法来训练这些修改版本的 LSTM 和 GRU。
该算法使用结合运算符(例如加法或乘法)有效地计算一系列数据点上的前缀运算。
通过将问题分解为更小的部分,该算法可以并行解决这些部分。
从数学上来说,该算法可以描述如下:
并行前缀扫描算法的公式(作者提供图片)
其中是应用于元素u(i)和位置的结合二元运算符(即加法或乘法) k。
该算法以并行方式计算结果y(k),时间复杂度为,O(log(N))而不是顺序计算,这会花费处理器的O(N)时间N。
这可以用来并行解决以下类型的方程。
由 v(t)、a(t)、v(t-1) 和 b(t) 表示的实数组成的线性方程(作者提供图片)
有趣的是,这个方程是 LSTM 和 GRU 架构的一部分。
接下来我们来讨论一下如何做。
将 GRU 简化为“minGRU”
在 minGRU 中,更新门和候选隐藏状态对先前隐藏状态的依赖被删除。
这允许在结果方程上应用并行前缀扫描算法。
此外,在 GRU 中,tanh(双曲正切)激活函数用于将候选隐藏状态的值限制在范围内(-1, 1)。
这种限制通过防止隐藏状态变得过大并避免消失梯度问题来稳定训练。
在 minGRU 中,这种tanh激活被删除,从而使训练变得更简单、更快速。
这两项变化如下所示。
GRU 和 minGRU 内部的区别(图片来自原始研究论文)
O(2 * d(h) * d(x))因此,与 GRU 相比,minGRU 只需要参数O(3 * d(h) (d(x) + d(h))),其中d(h)和d(x)分别代表输入和隐藏状态的大小。
除此之外,它现在还可以并行训练。
将 LSTM 简化为“minLSTM”
在 minLSTM 中,单元状态、遗忘门和输入门与先前状态的依赖关系被移除。
这使得其方程可以使用并行前缀扫描算法并行化。
与minGRU类似,tanh从细胞状态和隐藏状态的计算中去除了激活函数。
LSTM 架构中还有最后一个变化,即将其转变为 minLSTM。
在传统的 LSTM 中,输入门和遗忘门是独立计算的。因此,无法保证它们对综合效应的综合影响在时间步骤中保持稳定。
这引入了与时间相关的缩放,即隐藏状态的缩放可以随时间而变化。
为了解决这个问题,对遗忘门和输入门进行规范化,以确保它们的总和始终等于 1。
规范化遗忘门和输入门,以确保 f'(t) + i'(t) = 1 (作者提供图片)
为了确保隐藏状态在尺度上与时间无关,输出门和隐藏状态直接与细胞状态相联系。
这也会降低原始细胞状态,因为隐藏状态接管了它的角色。
请注意,GRU 不需要此步骤,因为它们的输出在规模上已经与时间无关。
LSTM 和 minLSTM 内部的区别(图片来自原始研究论文)
O(3 * d(h) * d(x))因此,与 LSTM 相比,minLSTM 只需要参数O(4 * d(h) (d(x) + d(h))),其中d(h)和d(x)分别代表输入和隐藏状态的大小。
同时,它现在也可以并行进行训练。
这些 RNN 的表现如何?
运行时性能
minLSTM 和 minGRU 的运行时间明显快于传统方法。
值得注意的是,与 GRU 和 LSTM 相比,对于 4096 的序列长度,minGRU 和 minLSTM 分别快 1324 倍和 1361 倍。
通俗地说,minGRU 只需一天即可完成训练,而其传统对手 GRU 则可能需要三年以上的时间!
与状态空间模型 Mamba相比,minLSTM 和 minGRU 具有相似的运行性能。
不同模型的运行时性能,其中 minGRU、minLSTM 和 Mamba 的线条重叠。(图片来自原始研究论文)
内存使用
与传统方法相比,minGRU 和 minLSTM 需要多 88% 的内存。
这是因为他们使用并行前缀扫描算法,从而产生更大的计算图。
Mamba 也是如此,它需要的内存比 minGRU 多 56% 左右。
这并不令人担心,因为在训练期间,RNN 的瓶颈通常是运行时间。
不同模型使用的内存(图片来自原始研究论文)
加速
与传统模型相比,minLSTM 模型需要的参数更少,在 T4 GPU 上对 512 序列长度的训练中,minGRU 模型实现了 175 倍的加速,而 minLSTM 模型则实现了 235 倍的加速。
minGRU 和 minLSTM 相对于传统算法的加速(图片来自原始研究论文)
训练稳定性
minGRU 在训练过程中比 minLSTM 更稳定。
原因是 minGRU 只有一组参数(即更新门)需要更新,而 minLSTM 则有两组参数(即遗忘门和输入门)。
这使得 minGRU 更容易优化。
强化学习中的表现
在D4RL 基准的MuJoCo(带接触的多关节动力学)运动任务上进行评估时,这些架构表现相当出色。
它们比Decision S4更出色,并且可与Aaren、Decision Mamba和Decision Transformer架构相媲美。
不同D4RL数据集上的强化学习结果。这里显示了专家标准化的回报,分数越高表示性能越好。
语言建模中的表现
为了进行本次评估,我们基于Andrej Karpathy的NanoGPT 框架,针对莎士比亚的作品训练了一个字符级 GPT 。
minLSTM 和 minGRU 在这里表现非常出色。
两者以及Mamba和 Transformers 都取得了类似的测试损失。
值得注意的是,minGRU、minLSTM、Mamba 和 Transformers 的可比测试损失分别为 1.548、1.555、1.575 和 1.547。
此外,minGRU 和 minLSTM 的训练速度比 Transformers 快 2.5 倍,但达到了相当的性能。
这是因为,与 Transformer 的二次复杂度相比,它们的序列长度时间复杂度具有线性。
不同模型的交叉熵损失学习曲线比较(图片来自原始研究论文)
这些令人惊讶的发现让我们怀疑“RNN 就是我们所需要的吗?”。
我对重新引入这些高效 RNN 之后将会取得的进步感到非常兴奋。
参考:
https://levelup.gitconnected.com/rnns-are-coming-back-to-take-over-transformers-yes-for-real-51697943bc67