0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

FlashAttenion-V3: Flash Decoding详解

jf_pmFSk4VX 来源:GiantPandaCV 2023-10-31 16:18 次阅读

Flash Attention V1和V2的作者又推出了Flash Decoding,真是太强了!

Flash-Decoding借鉴了FlashAttention的优点,将并行化维度扩展到keys/values序列长度。这种方法几乎不收序列长度影响(这对LLM模型能力很重要),可以充分利用GPU,即使在batch size较小时(inference特点),也可以极大提高了encoding速度。

相关背景知识先推荐阅读:

FlashAttention图解(如何加速Attention)

FlashAttention2详解(性能比FlashAttention提升200%)

Motivation

最近,像ChatGPT或Llama这样的LLM模型受到了空前的关注。然而,它们的运行成本却非常高昂。虽然单次回复的成本约为0.01美元(例如在AWS 8块A100上运行几秒钟),但是当扩展到数十亿用户的多次交互时,成本会迅速上升。而且一些场景的成本更高,例如代码自动补全,因为只要用户输入一个新字符就会执行。由于LLM应用非常广泛且还在迅速增长,即使稍微提升其运行效率也会产生巨大的收益。

LLM inference(或称为decoding)是一个迭代的过程:预测的tokens是逐个生成的。如果生成的句子有N个单词,那么模型需要进行N次forward。一个常用的优化技巧是KV Cache,该方法缓存了之前forward的一些中间结果,节约了大部分运算(如MatMul),但是attention操作是个例外。随着输出tokens长度增加,attention操作的复杂度也极具上升。

然而我们希望LLM能处理长上下文。增加了上下文长度,LLM可以输出更长的文档、跟踪更长的对话,甚至在编写代码之前处理整个代码库。例如,2022年大多数LLM的上下文长度最多为2k(如GPT-3),但现在LLM上下文长度可以扩展到32k(Llama-2-32k),甚至最近达到了100k(CodeLlama)。在这种情况下,attention操作在推理过程中占据了相当大的时间比例。此外,当batch size增加时,即使在相对较小的上下文中,attention操作也可能成为瓶颈。这是因为该操作需要对内存的访问会随着batch size增加而增加,而模型中其他操作只和模型大小相关。

因此,本文提出了Flash-Decoding,可以推理过程中显著加速attention操作(例如长序列生成速度提高8倍)。其主要思想是最大化并行加载keys和values的效率,通过重新缩放组合得到正确结果。

Multi-head attention for decoding

在decoding过程中,每个生成的新token需要与先前的tokens合并后,才能继续执行attention操作,即936fb5aa-77c1-11ee-939d-92fbcf53809c.png。Attention操作在训练过程的瓶颈主要卡在访问内存读写中间结果(例如93895640-77c1-11ee-939d-92fbcf53809c.png)的带宽,相关加速方案可以参考FlashAttention和FlashAttention2。

然而,上述优化不适合直接应用于推理过程。因为在训练过程中,FlashAttention对batch size和query length进行了并行化加速。而在推理过程中,query length通常为1,这意味着如果batch size小于GPU上的SM数量(例如A100上有108个SMs),那么整个计算过程只使用了GPU的一小部分!特别是当上下文较长时,通常会减小batch size来适应GPU内存。例如batch size = 1时,FlashAttention对GPU利用率小于1%!

下面展示了FlashAttention的计算示意图,该示例将keys和values分为了2个block:

93a173e2-77c1-11ee-939d-92fbcf53809c.png

FlashAttention示意图

对应的计算公式:

93b5acae-77c1-11ee-939d-92fbcf53809c.png

FlashAttention示意图对应的计算公式

注意93bdf760-77c1-11ee-939d-92fbcf53809c.png的计算过程依赖93c63aba-77c1-11ee-939d-92fbcf53809c.png,从下图也可以看出,FlashAttention是按顺序更新output的,其实当时我在看FlashAttention这篇文章时就觉得这个顺序操作可以优化的,因为反正都要rescale,不如最后统一rescale,没必要等之前block计算完(为了获取上一个block的max值)

93d525ac-77c1-11ee-939d-92fbcf53809c.jpg

flashattention计算过程

A faster attention for decoding: Flash-Decoding

上面提到FlashAttention对batch size和query length进行了并行化加速,Flash-Decoding在此基础上增加了一个新的并行化维度:keys/values的序列长度。即使batch size很小,但只要上下文足够长,它就可以充分利用GPU。与FlashAttention类似,Flash-Decoding几乎不用额外存储大量数据到全局内存中,从而减少了内存开销。

93e66074-77c1-11ee-939d-92fbcf53809c.gif

flashdecoding计算过程

Flash Decoding主要包含以下三个步骤(可以结合上图来看):

将keys和values分成较小的block

使用FlashAttention并行计算query与每个block的注意力(这是和FlashAttention最大的区别)。对于每个block的每行(因为一行是一个特征维度),Flash Decoding会额外记录attention values的log-sum-exp(标量值,用于第3步进行rescale)

对所有output blocks进行reduction得到最终的output,需要用log-sum-exp值来重新调整每个块的贡献

实际应用中,第1步中的数据分块不涉及GPU操作(因为不需要在物理上分开),只需要对第2步和第3步执行单独的kernels。虽然最终的reduction操作会引入一些额外的计算,但在总体上,Flash-Decoding通过增加并行化的方式取得了更高的效率。

Benchmarks on CodeLlama 34B

作者对CodeLLaMa-34b的decoding throughput进行了基准测试。该模型与Llama 2具有相同的架构。作者在各种序列长度(从512到64k)上测试了decoding速度,并比较了多种attention计算方法:

PyTorch:使用纯PyTorch primitives运行注意力计算(不使用FlashAttention)。

FlashAttention v2(v2.2之前的版本)。

FasterTransformer:使用FasterTransformer attention kernel

Flash-Decoding

将从内存中读取整个模型和KV Cache所需的时间作为上限

940efbf6-77c1-11ee-939d-92fbcf53809c.png

Untitled

从上图可以看出,Flash-Decoding在处理非常大的序列时速度可以提高8倍,并且比其他方法具有更好的可扩展性。所有方法在处理small prompts时表现相似,但随着序列长度从512增加到64k,其他方法的性能都变差了,而Flash-Decoding对序列长度的增加并不敏感(下图也是很好的证明)

9422813a-77c1-11ee-939d-92fbcf53809c.png

micro-benchmark on A100

Using Flash-Decoding

作者还通了Flash-Decoding使用方式:

基于FlashAttention package ,从版本2.2开始。

xFormers,在版本0.0.22中提供了xformers.ops.memory_efficient_attention模块

作者也提供了LLaMa v2/CodeLLaMa的repo1和xFormers repo2。此外,作者还提供了一个针对LLaMa v1/v2的最小示例。

个人总结

Flash-Decoding对LLM在GPU上inference进行了显著加速(尤其是batch size较小时),并且在处理长序列时具有更好的可扩展性。

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • gpu
    gpu
    +关注

    关注

    28

    文章

    4673

    浏览量

    128557
  • 模型
    +关注

    关注

    1

    文章

    3112

    浏览量

    48646
  • LLM
    LLM
    +关注

    关注

    0

    文章

    263

    浏览量

    297

原文标题:FlashAttenion-V3: Flash Decoding详解

文章出处:【微信号:GiantPandaCV,微信公众号:GiantPandaCV】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    Flash基本操作——Flash基础(1)#多媒体技术

    FlaSh
    未来加油dz
    发布于 :2023年05月24日 10:43:53

    Flash基本操作——Flash工具1(3)#多媒体技术

    FlaSh
    未来加油dz
    发布于 :2023年05月24日 10:46:17

    Flash基本操作——Flash工具2(3)#多媒体技术

    FlaSh
    未来加油dz
    发布于 :2023年05月24日 10:48:11

    Flash基本操作——Flash工具3(1)#多媒体技术

    FlaSh
    未来加油dz
    发布于 :2023年05月24日 10:49:01

    Flash基本操作——Flash工具3(2)#多媒体技术

    FlaSh
    未来加油dz
    发布于 :2023年05月24日 10:49:44

    Flash基本操作——Flash工具3(3)#多媒体技术

    FlaSh
    未来加油dz
    发布于 :2023年05月24日 10:50:22

    Necessary to disable "Above 4G Decoding" for View with vGPU?

    /grid-vgpu-deployment-guide.pdf 在第17页,它为几个服务器制造商提供了BIOS建议。 它建议禁用SuperMicro的“Above 4G Decoding”。 对于Dom0为32位
    发表于 09-04 15:36

    3~25V与10安3~15V电压可调电压电路原理图详解

    3~25V与10安3~15V电压可调电压电路原理图详解
    发表于 04-16 20:47

    模电Flash动画详解

    模电Flash动画详解,一共有161个!
    发表于 09-27 08:15

    Flash Magic V2.45

    Flash Magic V2.45 Flash Magic V2.45软件
    发表于 05-10 11:24 8次下载

    基于MSP430功能模块详解系列之——FLASH存储器

    基于MSP430功能模块详解系列之——FLASH存储器
    发表于 10-12 15:27 11次下载
    基于MSP430功能模块<b class='flag-5'>详解</b>系列之——<b class='flag-5'>FLASH</b>存储器

    MP3-FLASH-16P 使用说明书 V1.0

    蓝板MP3-FLASH-16P使用说明书 V1.0 MP3-FLASH-16P 是一个提供串口的语音模块,完美的集成了 MP3、WAV 的硬解码。同时软件支持工业级别的串口通信协议,以
    发表于 11-28 14:08 24次下载

    【转载】keil将程序装入外部FLASH详解

    【转载】keil将程序装入外部FLASH详解
    发表于 12-01 20:21 14次下载
    【转载】keil将程序装入外部<b class='flag-5'>FLASH</b><b class='flag-5'>详解</b>

    开源软件-Morse_Encoding_Decoding摩斯密码工具

    ./oschina_soft/Morse_Encoding_Decoding.zip
    发表于 06-28 11:52 1次下载
    开源软件-Morse_Encoding_<b class='flag-5'>Decoding</b>摩斯密码工具

    瑞萨Flash程序员V3 发布说明

    电子发烧友网站提供《瑞萨Flash程序员V3 发布说明.pdf》资料免费下载
    发表于 02-19 09:37 1次下载
    瑞萨<b class='flag-5'>Flash</b>程序员<b class='flag-5'>V3</b> 发布说明