0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

【FlashAttention-V4,非官方】FlashDecoding++

jf_pmFSk4VX 来源:GiantPandaCV 2023-11-14 15:41 次阅读

1. Introdcution

为了提高softmax并行性,之前方法(FlashAttention、FlashDecoding)将计算过程拆分,各自计算partial softmax结果,最后需要通过同步操作来更新partial softmax结果。例如FlashAttention每次计算partial softmax结果都会更新之前的结果,而FlashDecoding是在最后统一更新所有partial softmax结果。

本文在A100 GPU上分析了输入长度为1024的情况,这种同步partial softmax更新操作占Llama2-7B推理的注意力计算的18.8%。(本文没说是FlashAttention还是FlashDecoding的结果,个人认为FlashDecoding的同步更新代价并不大,应该远小于18.8%)

这是LLM推理加速的第一个挑战。此外,本文还提出了两个挑战:

在解码阶段,Flat GEMM操作的计算资源未得到充分利用。这是由于解码阶段是按顺序生成token(一次只生成一个token),GEMM操作趋于flat-shape,甚至batch size等1时变成了GEMV(General Matrix-Vector Multiplication),具体看论文Figure 2。当batch size较小时(e.g., 8),cublas和cutlass会将矩阵填充zeros以执行更大batchsize(e.g., 64)的GEMM,导致计算利用率不足50%。

动态输入和固定硬件配置影响了LLM推理的性能。例如,当batch size较小时,LLM推理的解码过程是memory-bounded,而当batch size较大时是compute-bounded。

针对这3个问题,本文分别提出了对应优化方法:

Asynchronized softmax with unified max value.FlashDecoding++为分块softmax计算设置了一个共享的最大值。这样可以独立计算partial softmax,无需同步更新。

Flat GEMM optimization with double buffering.FlashDecoding++只将矩阵大小填充到8,对比之前针对flat-shaped GEMM设计的为64,提高了计算利用率。论文指出,具有不同shape的flat GEMMs面临的瓶颈也不同,于是进一步利用双缓冲等技术提高kernel性能。

Heuristic dataflow with hardware resource adaption.FlashDecoding++同时考虑了动态输入和硬件配置,针对LLM推理时数据流进行动态kernel优化。

下图展示了以上3种方法的示意图:

e8d33a2a-828a-11ee-939d-92fbcf53809c.png

2. Backgrounds

LLM推理中的主要操作如下图所示:linear projection(①和⑤)、attention(②、③和④)和feedforward network(⑥)。为简单起见,这里忽略了position embedding、non-linear activation、mask等操作。本文将LLM推理时对Prompt的处理过程称为prefillphase,第二阶段预测过程称为decodephase。这两个阶段的算子基本一致,主要是输入数据的shape是不同的。由于decodephase一次只处理一个令牌(batch size=1,或batch size很小),因此输入矩阵是flat-shape matrices(甚至是vectors),参见下图Decode phase部分中和KV Cache拼接的红色向量。

e8efc816-828a-11ee-939d-92fbcf53809c.png

LLM推理中的另一个问题就是Softmax算子,其需要计算并存储所有全局数据,并且数据量随着数据长度成平方增长,存在内存消耗高和低并行性等问题。一般计算流程如下:

e919c08a-828a-11ee-939d-92fbcf53809c.png

3. Asynchronized Softmax with Unified Maximum Value如下

图b所示,FlashAttention和FlashDecoding对softmax操作进行了分块处理,但是块与块之间需要进行同步(主要是局部最大值)。本文发现这种同步操作的开销约为20%。因此,作者希望去除同步操作,也就是独立计算出partial softmax结果。

e92b4954-828a-11ee-939d-92fbcf53809c.png

e94fa128-828a-11ee-939d-92fbcf53809c.png

e9642c9c-828a-11ee-939d-92fbcf53809c.png

e97a9f7c-828a-11ee-939d-92fbcf53809c.png

e9a2dfa0-828a-11ee-939d-92fbcf53809c.png

4. Flat GEMM Optimization with Double Buffering

Decoding阶段的过程主要由GEMV(batch size=1)或flat GEMM(batch size>1)。GEMV/GEMM运算可以用M、N、K来表示,其中两个相乘矩阵的大小分别为M × K和K × N。一般LLM推理引擎利用Tensor Core使用cuBLAS和CUTLASS等库来加速。尽管Tensor Core适合处理M = 8的GEMM,但这些库为了隐藏memory latency,通常将M维度平铺到64。然而,decodephase的GEMV或flat GEMM的M通远小于64,于是填充0到64,导致计算利用率低下。

e9ca7236-828a-11ee-939d-92fbcf53809c.png

e9e9159c-828a-11ee-939d-92fbcf53809c.png

ea01d4e2-828a-11ee-939d-92fbcf53809c.png

ea189fce-828a-11ee-939d-92fbcf53809c.png

ea4097e0-828a-11ee-939d-92fbcf53809c.png

ea52f958-828a-11ee-939d-92fbcf53809c.png

为了隐藏memory access latency,本文引入了double buffering技术。具体来说就是在共享内存中分配两个buffer,一个buffer用于执行当前tile的GEMM计算,同时另一个buffer则加载下一个tile GEMM所需的数据。这样计算和内存访问是重叠的,本文在N较大时采取这种策略,下图为示意图。

ea730982-828a-11ee-939d-92fbcf53809c.png

5. Heuristic Dataflow with Hardware Resource Adaption

影响LLM推理性能的因素有很多:(a)动态输入。batch size和输入序列长度的变化造成了工作负载变化。(b)模型多样性。主要指模型结构和模型大小。(c)GPU能力不同。例如内存带宽、缓存大小和计算能力。(d)工程优化。

虽然这些因素构建了一个很大的搜索空间,但LLM中不同layer的同质性大大减少了算子优化的搜索空间。例如,prefillphase和decodephase中有4个GEMV/GEMM操作(K、Q、V投影、O投影、2个FFN),都可以表示为[M, K]和N x K,对应了四种[N, K]组合,如下图所示。此外,prefillphase的M与输入序列长度和batch size有关,decodephase的M只与batch size有关。

eaa4a406-828a-11ee-939d-92fbcf53809c.png

本文根据不同的M, K, N选取FastGEMV、flat GEMM(本文方法)、CUTLASS。

eab90a0e-828a-11ee-939d-92fbcf53809c.png

个人总结

这篇文章没有FlashAttention和FlashDecoding惊艳,个人觉得FlashDecoding的同步处理代价不大,而且本文中动态调整softmax方法也引入了判断、终止和分支跳转等操作。另一个Double Buffering就是内存优化常用的乒乓buffer,也没什么新东西。

不过话说回来,如今在tranformer架构不变的情况,LLM加速只能靠这些工程手段去优化,的确也有不错效果。还是很有价值的。

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 数据
    +关注

    关注

    8

    文章

    6581

    浏览量

    87959
  • gpu
    gpu
    +关注

    关注

    27

    文章

    4513

    浏览量

    127586
  • LLM
    LLM
    +关注

    关注

    0

    文章

    224

    浏览量

    252

原文标题:【FlashAttention-V4,非官方】FlashDecoding++

文章出处:【微信号:GiantPandaCV,微信公众号:GiantPandaCV】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    X-CUBE-CRYPTOLIB V4库文件无法添加,链接错误的原因?

    官方介绍,V4版本的静态库是支持多个编译器的,芯片是stm32h743,我用的是arm-none-eabi-gcc,版本是gcc version 13.2.1 20231009 (Arm GNU
    发表于 05-31 07:54

    将stm32_eth_lib的以太网程序移植到非官方版开发板一直不成功怎么解决?

    我最近将官方原版的uip协议族下面的以太网协议移植到官方的开发板(如格兰瑞SupeARm STM32F107VX,或者神舟107系列),一直没有成功,请问有哪位大神研究过,或者有相关的经验,帮助小弟呀,小弟跪谢了
    发表于 05-17 12:32

    为什么无法安装CUBE-MX-NFC6扩展包?

    在官网下载CUBE-MX-NFC6扩展包以使用NUCLEO-NFC08A1扩展板,但是CUBE MX提示非官方扩展包
    发表于 03-18 06:10

    开发者发布自制非官方YouTube应用,Vision Pro头显可用

    据了解,该第三方应用可完整实现原生的 YouTube 功能,包括手势操控和保持原始视频宽高比等功能,同时支持浏览播放列表,甚至在观看过程中出现的 YouTube 广告。据开发者表示,此举旨在免受谷歌声讨。
    的头像 发表于 02-03 10:53 378次阅读

    有没有大佬用过银河麒麟,进来聊聊

    因需要,要在银河麒麟上装一个open-vm-dkms,但是yum库的官方源里面没有这个软件包,想讨论讨论还有没有别的非官方的靠谱源,或者其他工具可以安装这个软件的,顺便在问一下,银河麒麟不是基于ubuntu18么,为什么里面没有集成apt?
    发表于 01-23 14:57

    OpenOCD在线下载调试报错的原因?

    由于我使用的是非官方的Flash,所以我需要重新编译OpenOCD,在编译完成过后,替换OpenOCD文件以后进行在线下载调试出现了如下的问题,我以为是我Flash命令的问题,知道我更换回原本的官方
    发表于 01-15 07:58

    LTC3305串联蓄电池组的各个节点不能与芯片的V1-V4引脚直接相连吗?

    。 另外对比了下您发给我的官方PCB设计电路,主要区别是:官方设计中在蓄电池与芯片的连接回路中加入了保险丝以及5Ω电阻。 串联蓄电池组的各个节点不能与芯片的V1-V4引脚直接相连吗? 芯片尚未使能,并未处
    发表于 12-26 06:38

    LTC7545具体的增益计算公式是什么?

    LTC7545按照下面的电路连接,这个是非官方标准的接法,想问下具体的增益计算公式是什么,VOUT和VIN以及数字code值之间的关系。
    发表于 12-04 07:27

    请问官方有max9288GTM/V的linux驱动吗?

    请问官方有max9288GTM/V的linux驱动吗?
    发表于 11-28 19:32

    FlashAttention2详解(性能比FlashAttention提升200%)

    GPU performance characteristics. GPU主要计算单元(如浮点运算单元)和内存层次结构。大多数现代GPU包含专用的低精度矩阵乘法单元(如Nvidia GPU的Tensor Core用于FP16/BF16矩阵乘法)。
    的头像 发表于 11-24 16:21 764次阅读
    <b class='flag-5'>FlashAttention</b>2详解(性能比<b class='flag-5'>FlashAttention</b>提升200%)

    移植E203到非官方开发板,32.768KHz的时钟无法通过IP核分频得出怎么处理?

    VIVADO 的官方IP核最少分频出4MHz多,而32.768KHz太小了,难道只能自己写分频器吗? 谢谢。
    发表于 08-12 07:03

    在使用NucleiStudio环境下进行,代码导入调试时无法连接到开发板的原因?

    使用非官方开发板平头哥200t开发板,完成E203综合,以及xdc约束文件修改。但是在使用NucleiStudio环境下进行,代码导入调试时,无法连接到开发板。 驱动没有问题(绿灯已亮) 连接如下: 想请假如何配置,有相关文章说更改OpenOCD,但是不清楚如何更改。 报错信息如下:
    发表于 08-12 06:44

    E203使用非官方调试器jlink连接时报错The connected J-Link does not support selecting another hart/如何解决?

    我把蜂鸟E203的软核输出给了外部的IO口然后连接了非官方的jtag调试器。在用jlink连接时报错: The connected J-Link does not support selecting
    发表于 08-12 06:29

    为什么无法安装CUBE-MX-NFC6扩展包

    在官网下载CUBE-MX-NFC6扩展包以使用NUCLEO-NFC08A1扩展板,但是CUBE MX提示非官方扩展包
    发表于 08-07 06:43

    让Attention提速9倍!FlashAttention燃爆显存,Transformer上下文长度史诗级提升

    FlashAttention新升级!斯坦福博士一人重写算法,第二代实现了最高9倍速提升。 继超快且省内存的注意力算法FlashAttention爆火后,升级版的2代来了
    的头像 发表于 07-24 16:55 637次阅读
    让Attention提速9倍!<b class='flag-5'>FlashAttention</b>燃爆显存,Transformer上下文长度史诗级提升