为了提高softmax并行性,之前方法(FlashAttention、FlashDecoding)将计算过程拆分,各自计算partial softmax结果,最后需要通过同步操作来更新partial softmax结果。例如FlashAttention每次计算partial softmax结果都会更新之前的结果,而FlashDecoding是在最后统一更新所有partial softmax结果。
本文在A100 GPU上分析了输入长度为1024的情况,这种同步partial softmax更新操作占Llama2-7B推理的注意力计算的18.8%。(本文没说是FlashAttention还是FlashDecoding的结果,个人认为FlashDecoding的同步更新代价并不大,应该远小于18.8%)
这是LLM推理加速的第一个挑战。此外,本文还提出了两个挑战:
在解码阶段,Flat GEMM操作的计算资源未得到充分利用。这是由于解码阶段是按顺序生成token(一次只生成一个token),GEMM操作趋于flat-shape,甚至batch size等1时变成了GEMV(General Matrix-Vector Multiplication),具体看论文Figure 2。当batch size较小时(e.g., 8),cublas和cutlass会将矩阵填充zeros以执行更大batchsize(e.g., 64)的GEMM,导致计算利用率不足50%。
动态输入和固定硬件配置影响了LLM推理的性能。例如,当batch size较小时,LLM推理的解码过程是memory-bounded,而当batch size较大时是compute-bounded。
针对这3个问题,本文分别提出了对应优化方法:
Asynchronized softmax with unified max value.FlashDecoding++为分块softmax计算设置了一个共享的最大值。这样可以独立计算partial softmax,无需同步更新。
Flat GEMM optimization with double buffering.FlashDecoding++只将矩阵大小填充到8,对比之前针对flat-shaped GEMM设计的为64,提高了计算利用率。论文指出,具有不同shape的flat GEMMs面临的瓶颈也不同,于是进一步利用双缓冲等技术提高kernel性能。
Heuristic dataflow with hardware resource adaption.FlashDecoding++同时考虑了动态输入和硬件配置,针对LLM推理时数据流进行动态kernel优化。
下图展示了以上3种方法的示意图:
2. Backgrounds
LLM推理中的主要操作如下图所示:linear projection(①和⑤)、attention(②、③和④)和feedforward network(⑥)。为简单起见,这里忽略了position embedding、non-linear activation、mask等操作。本文将LLM推理时对Prompt的处理过程称为prefillphase,第二阶段预测过程称为decodephase。这两个阶段的算子基本一致,主要是输入数据的shape是不同的。由于decodephase一次只处理一个令牌(batch size=1,或batch size很小),因此输入矩阵是flat-shape matrices(甚至是vectors),参见下图Decode phase部分中和KV Cache拼接的红色向量。
LLM推理中的另一个问题就是Softmax算子,其需要计算并存储所有全局数据,并且数据量随着数据长度成平方增长,存在内存消耗高和低并行性等问题。一般计算流程如下:
3. Asynchronized Softmax with Unified Maximum Value如下
图b所示,FlashAttention和FlashDecoding对softmax操作进行了分块处理,但是块与块之间需要进行同步(主要是局部最大值)。本文发现这种同步操作的开销约为20%。因此,作者希望去除同步操作,也就是独立计算出partial softmax结果。
4. Flat GEMM Optimization with Double Buffering
Decoding阶段的过程主要由GEMV(batch size=1)或flat GEMM(batch size>1)。GEMV/GEMM运算可以用M、N、K来表示,其中两个相乘矩阵的大小分别为M × K和K × N。一般LLM推理引擎利用Tensor Core使用cuBLAS和CUTLASS等库来加速。尽管Tensor Core适合处理M = 8的GEMM,但这些库为了隐藏memory latency,通常将M维度平铺到64。然而,decodephase的GEMV或flat GEMM的M通远小于64,于是填充0到64,导致计算利用率低下。
为了隐藏memory access latency,本文引入了double buffering技术。具体来说就是在共享内存中分配两个buffer,一个buffer用于执行当前tile的GEMM计算,同时另一个buffer则加载下一个tile GEMM所需的数据。这样计算和内存访问是重叠的,本文在N较大时采取这种策略,下图为示意图。
5. Heuristic Dataflow with Hardware Resource Adaption
影响LLM推理性能的因素有很多:(a)动态输入。batch size和输入序列长度的变化造成了工作负载变化。(b)模型多样性。主要指模型结构和模型大小。(c)GPU能力不同。例如内存带宽、缓存大小和计算能力。(d)工程优化。
虽然这些因素构建了一个很大的搜索空间,但LLM中不同layer的同质性大大减少了算子优化的搜索空间。例如,prefillphase和decodephase中有4个GEMV/GEMM操作(K、Q、V投影、O投影、2个FFN),都可以表示为[M, K]和N x K,对应了四种[N, K]组合,如下图所示。此外,prefillphase的M与输入序列长度和batch size有关,decodephase的M只与batch size有关。
本文根据不同的M, K, N选取FastGEMV、flat GEMM(本文方法)、CUTLASS。
个人总结
这篇文章没有FlashAttention和FlashDecoding惊艳,个人觉得FlashDecoding的同步处理代价不大,而且本文中动态调整softmax方法也引入了判断、终止和分支跳转等操作。另一个Double Buffering就是内存优化常用的乒乓buffer,也没什么新东西。
不过话说回来,如今在tranformer架构不变的情况,LLM加速只能靠这些工程手段去优化,的确也有不错效果。还是很有价值的。
-
数据
+关注
关注
8文章
7067浏览量
89114 -
gpu
+关注
关注
28文章
4743浏览量
128995 -
LLM
+关注
关注
0文章
290浏览量
351
原文标题:【FlashAttention-V4,非官方】FlashDecoding++
文章出处:【微信号:GiantPandaCV,微信公众号:GiantPandaCV】欢迎添加关注!文章转载请注明出处。
发布评论请先 登录
相关推荐
评论