前言
标准的Transformer Block并不简介,每个block由attention, MLP, skip connection, normalization各子模块构成。一些看似微小的修改可能导致模型训练速度下降,甚至导致模型无法收敛。
在本篇工作中,我们探索了Transformer Block精简的方式。结合了信号传播理论以及一些经验性的观察,我们在不损失训练速度的前提下,移除了skip connection, out project, value project, normalization操作 以及串行组织block的形式。在Decoder-only和Encoder-only两类模型上,我们减少了15%可训练参数,并提高了15%的训练速度。
官方仓库:
bobby-he/simplified_transformers
论文:Simplifying Transformer Blocks.
一些标记注解:
每个transformer block如上述公式组成,每个子模块都配备了一个系数,这个后续会使用到
Removing Skip Connection
作者先前的一项工作Deep Transformers without Shortcuts: Modifying Self-attention for Faithful Signal Propagation 删除了残差连接,提出的操作Value-SkipInit,将自注意力相关操作修改为:
其中I代表的是一个Identity操作,A(X)表示原始注意力操作。这两个操作各自有一个可训练标量 和 ,初始化为 , 。
这个设计的insight是每个token在训练前期更多的是关注自身相关性,类似的如Pre-LN操作,在Batch Normalization Biases Residual Blocks Towards the Identity Function in Deep Networks这项工作发现,Pre-LN相当于把 skip-branch 权重提高,降低residual-branch权重,以在较深的神经网络里仍然有良好的信号传播。
而The Shaped Transformer: Attention Models in the Infinite Depth-and-Width Limit 该工作里提出了Shape Attention,也是收到信号传播理论的启发,将注意力公式更改为:
相比之下多了一个C矩阵,这是个常量矩阵(论文称其为centering matrix),不参与训练。他的值被设置为当 querykey dot 为0时候,A(x)的值,那么我们回去看A(x)公式,就剩一个mask值,因此代码里是这么写的:
# Centered attention, from https://arxiv.org/abs/2306.17759 uniform_causal_attn_mat = torch.ones( (max_positions, max_positions), dtype=torch.float32 ) / torch.arange(1, max_positions + 1).view(-1, 1) self.register_buffer( "uniform_causal_attn_mat", torch.tril( uniform_causal_attn_mat, ).view(1, 1, max_positions, max_positions), persistent=False, )
对于CausalLM来说,MASK是个下三角矩阵,形状为(S, S)的矩阵,第i行,只有前i个位置有值,经过softmax后,1.0概率被平分到有值的位置,这就是为什么它要做一个 ones / arange 的操作,一段示例代码为:
import torch max_positions = 32 mask = torch.tril(torch.ones(max_positions, max_positions)) + torch.triu(torch.ones(max_positions, max_positions), 1) * -65536 print(torch.softmax(mask, -1)) tensor([[1.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.5000, 0.5000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [0.3333, 0.3333, 0.3333, ..., 0.0000, 0.0000, 0.0000], ..., [0.0333, 0.0333, 0.0333, ..., 0.0333, 0.0000, 0.0000], [0.0323, 0.0323, 0.0323, ..., 0.0323, 0.0323, 0.0000], [0.0312, 0.0312, 0.0312, ..., 0.0312, 0.0312, 0.0312]])
而新的可训练标量 = ,以保证初始化时,
其中这些可训练标量如果改成headwise,即每个注意力头独立,则性能有部分提升。当然作者还是强调其中的一个重要的点是,显式的将MLP Block的系数降低:
论文里针对18层Transformer,设置为0.1
Recovering Training Speed
在引入shape attention并移除残差连接后,训是没问题了,但是会导致收敛变慢:
经过前面的修改,那么对于Attention模块里,在训练初期其实就简化成X和Vproject矩阵和OutProject矩阵做矩阵乘操作。
众所周知,这种没有残差连接的网络训练是要比带残差结构的网络要慢的。我们从别的工作也可以得知,Pre-LN操作,是会降低残差分支的占比系数,相当于降低了学习率,也缩减了线性层里参数更新的scale
X matmul W,那么计算X的梯度公式有一项就是W嘛
这促使我们开始引入重参数化操作思考V矩阵和OutProject矩阵
作者针对Vproject和Outproject两个矩阵乘操作,给残差分支和跳跃分支各引入一个可训练参数 , ,通过实验发现,大部分层最终系数比值 收敛到了0
这意味着 和 两个矩阵是一个Identity矩阵,因此作者将这两个参数移除掉,并称为Simplified Attention Sub-block (SAS),使用SAS比原始Pre-LN block收敛更快了:
REMOVING THE MLP SUB-BLOCK SKIP CONNECTION
在这部分实验里,作者把目光投向了GPT-J里提出的Parallel Block,其移除了MLP的残差分支,保留了另外一个残差分支:
对应公式为:
作者直接将SAS Block进行替换,得到Parallel形式的 SAS-P Block。我们比较下和原始串行的实现:
在训练初期,Attention部分是Identity输出,因此两种形式的SAS Block在训练初期是等价的。
REMOVING NORMALISATION LAYERS
最后作者尝试将Norm层给移除,得到
作者的idea来自于,先前PreLN的作用(如把 skip-branch 权重提高,降低residual-branch权重)已经通过前面的一系列修改实现了,因此可以直接删除Norm层
当然还是得看实验效果,回到这张图,可以看到移除了Norm对收敛还是有一定影响的。作者猜测在信号传播理论范围之外,Norm层能加速训练收敛,如Scaling Vision Transformers to 22 Billion Parameters
引入了更多LayerNorm层,将ViT缩放至22B参数量上
因此作者还是主张保留PreLN结构:
最后实验
作者也补充了一些训练速度benchmark,模型准确率,以及收敛趋势的实验:
总结
作者对Transformer Block移除了各种参数,减少了15%参数量,提高了15%的训练速度,各个环节都有做充分的实验,但一些经验性得到的结论也并没有直接回答一些问题(如LN为什么影响收敛速度)。
实验规模并不大,而标准的TransformerBlock还是在各个Scale里得到广泛验证的,期待有人进一步试验
你说的对,但我还是套LLAMA结构
审核编辑:黄飞
评论
查看更多