0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

LLM的长度外推浅谈

深度学习自然语言处理 来源:NLP工作站 2023-07-28 17:37 次阅读

一、NBCE

NBCE:使用朴素贝叶斯扩展LLM的Context处理长度

苏神最早提出的扩展LLM的context方法,基于bayes启发得到的公式:

fd1b2440-2d29-11ee-815d-dac502259ad0.pngfd312d9e-2d29-11ee-815d-dac502259ad0.png

问答下实测确实不错,在较长context下的阅读理解还算好用。

局限性是,无序性,即无法识别Context的输入顺序,这在续写故事等场景可能表现欠佳,做一些依赖每个context生成答案,比如提取文档摘要,效果较差。

outputs=model(input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
use_cache=True,
past_key_values=past_key_values
)
past_key_values=outputs.past_key_values

#=====核心代码开始=====
beta=0.25
probas=torch.nn.functional.softmax(outputs.logits[:,-1],dim=-1)
logits=probas.log()
k=(probas*logits).sum(dim=-1)[1:].argmax()+1
logits_max=logits[k]
logits_uncond=logits[0]
logits=(1+beta)*logits_max-beta*logits_uncond
#=====核心代码结束=====

#构建分布,采样
tau=0.01#tau=1是标准的随机采样,tau->0则是贪心搜索
probas=torch.nn.functional.softmax(logits[None]/tau,dim=-1)
next_tokens=torch.multinomial(probas,num_samples=1).squeeze(1)

此处代码,图片,文本均选自科学空间。

二、线性内插

llama基于rotary embedding在2048长度上预训练,该方法通过将position压缩到0~2048之间,从而达到长度外推的目的。

longchat将模型微调为上下文长度外扩为16384,压缩比为 8。例如,position_ids = 10000 的 token 变为position_ids = 10000 / 8 = 1250,相邻 token 10001 变为 10001 / 8 = 1250.125

该方法的缺陷是需要进行一定量的微调,让模型来适应这种改变。

importtorch
importtransformers
importtransformers.models.llama.modeling_llama
fromeinopsimportrearrange

fromfunctoolsimportpartial

classCondenseRotaryEmbedding(torch.nn.Module):
def__init__(self,dim,ratio,max_position_embeddings=2048,base=10000,device=None):
super().__init__()
inv_freq=1.0/(base**(torch.arange(0,dim,2).float().to(device)/dim))
self.register_buffer("inv_freq",inv_freq)

#Buildheretomake`torch.jit.trace`work.
self.ratio=ratio
max_position_embeddings*=ratio
print(f"CondensingPositionalembeddingsfrom{max_position_embeddings}to{max_position_embeddings//ratio}")
self.max_seq_len_cached=max_position_embeddings
t=torch.arange(self.max_seq_len_cached,device=self.inv_freq.device,dtype=self.inv_freq.dtype)/ratio
freqs=torch.einsum("i,j->ij",t,self.inv_freq)
#Differentfrompaper,butitusesadifferentpermutationinordertoobtainthesamecalculation
emb=torch.cat((freqs,freqs),dim=-1)
dtype=torch.get_default_dtype()
self.register_buffer("cos_cached",emb.cos()[None,None,:,:].to(dtype),persistent=False)
self.register_buffer("sin_cached",emb.sin()[None,None,:,:].to(dtype),persistent=False)

defforward(self,x,seq_len=None):
#x:[bs,num_attention_heads,seq_len,head_size]
#This`if`blockisunlikelytoberunafterwebuildsin/cosin`__init__`.Keepthelogicherejustincase.
ifseq_len>self.max_seq_len_cached:
self.max_seq_len_cached=seq_len
t=torch.arange(self.max_seq_len_cached,device=x.device,dtype=self.inv_freq.dtype)/self.ratio
freqs=torch.einsum("i,j->ij",t,self.inv_freq)
#Differentfrompaper,butitusesadifferentpermutationinordertoobtainthesamecalculation
emb=torch.cat((freqs,freqs),dim=-1).to(x.device)
self.register_buffer("cos_cached",emb.cos()[None,None,:,:].to(x.dtype),persistent=False)
self.register_buffer("sin_cached",emb.sin()[None,None,:,:].to(x.dtype),persistent=False)
return(
self.cos_cached[:,:,:seq_len,...].to(dtype=x.dtype),
self.sin_cached[:,:,:seq_len,...].to(dtype=x.dtype),
)

defreplace_llama_with_condense(ratio):
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding=partial(CondenseRotaryEmbedding,ratio=ratio)

三、NTK-Aware Scaled RoPE

RoPE是一种β进制编码//spaces.ac.cn/archives/9675

fd4b1808-2d29-11ee-815d-dac502259ad0.png

有意思的解释一下,RoPE 的行为就像一个时钟。12小时时钟基本上是一个维度为 3、底数为 60 的 RoPE。因此,每秒钟,分针转动 1/60 分钟,每分钟,时针转动 1/60。

现在,如果将时间减慢 4 倍,那就是二使用的线性RoPE 缩放。不幸的是,现在区分每一秒,因为现在秒针几乎每秒都不会移动。

因此,如果有人给你两个不同的时间,仅相差一秒,你将无法从远处区分它们。NTK-Aware RoPE 扩展不会减慢时间。一秒仍然是一秒,但它会使分钟减慢 1.5 倍,将小时减慢 2 倍。

这样,您可以将 90 分钟容纳在一个小时中,将 24 小时容纳在半天中。

所以现在你基本上有了一个可以测量 129.6k 秒而不是 43.2k 秒的时钟。由于在查看时间时不需要精确测量时针,因此与秒相比,更大程度地缩放小时至关重要。

不想失去秒针的精度,但可以承受分针甚至时针的精度损失。

importtransformers

old_init=transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
defntk_scaled_init(self,dim,max_position_embeddings=2048,base=10000,device=None):

#Themethodisjustthesethreelines
max_position_embeddings=16384
a=8#Alphavalue
base=base*a**(dim/(dim-2))#Basechangeformula

old_init(self,dim,max_position_embeddings,base,device)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__=ntk_scaled_init

四、Dynamically Scaled RoPE

fd56100a-2d29-11ee-815d-dac502259ad0.png

对于上面的方法二、三,都涉及到一个超参数α,用于调节缩放比例,该方法是通过序列长度动态选择正确的比例参数,效果可以看上图。

对于线性插值,前 2k 上下文的精确位置值,然后在模型逐个生成标记时重新计算每个新序列长度的位置向量。本质上,将比例设置为原始模型上下文长度/当前序列长度。

对于动态 NTK,α 的缩放设置为 (α * 当前序列长度 / 原始模型上下文长度) - (α - 1)。随着序列长度的增加动态缩放超参数。

importmath
importtorch

classLlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
def__init__(self,dim,max_position_embeddings=2048,base=10000,ntk=False,device=None):
super().__init__()
self.ntk=ntk
self.base=base
self.dim=dim
self.max_position_embeddings=max_position_embeddings
inv_freq=1.0/(base**(torch.arange(0,dim,2).float().to(device)/dim))
self.register_buffer("inv_freq",inv_freq)

#Buildheretomake`torch.jit.trace`work.
self.max_seq_len_cached=max_position_embeddings
t=torch.arange(self.max_seq_len_cached,device=self.inv_freq.device,dtype=self.inv_freq.dtype)
freqs=torch.einsum("i,j->ij",t,self.inv_freq)
#Differentfrompaper,butitusesadifferentpermutationinordertoobtainthesamecalculation
emb=torch.cat((freqs,freqs),dim=-1)
dtype=torch.get_default_dtype()
self.register_buffer("cos_cached",emb.cos()[None,None,:,:].to(dtype),persistent=False)
self.register_buffer("sin_cached",emb.sin()[None,None,:,:].to(dtype),persistent=False)

defforward(self,x,seq_len=None):
#x:[bs,num_attention_heads,seq_len,head_size]
#This`if`blockisunlikelytoberunafterwebuildsin/cosin`__init__`.Keepthelogicherejustincase.
ifseq_len>self.max_seq_len_cached:
self.max_seq_len_cached=seq_len
ifself.ntk:
base=self.base*((self.ntk*seq_len/self.max_position_embeddings)-(self.ntk-1))**(self.dim/(self.dim-2))
inv_freq=1.0/(base**(torch.arange(0,self.dim,2).float().to(x.device)/self.dim))
self.register_buffer("inv_freq",inv_freq)
t=torch.arange(self.max_seq_len_cached,device=x.device,dtype=self.inv_freq.dtype)
ifnotself.ntk:
t*=self.max_position_embeddings/seq_len
freqs=torch.einsum("i,j->ij",t,self.inv_freq)
#Differentfrompaper,butitusesadifferentpermutationinordertoobtainthesamecalculation
emb=torch.cat((freqs,freqs),dim=-1).to(x.device)
self.register_buffer("cos_cached",emb.cos()[None,None,:,:].to(x.dtype),persistent=False)
self.register_buffer("sin_cached",emb.sin()[None,None,:,:].to(x.dtype),persistent=False)
return(
self.cos_cached[:,:,:seq_len,...].to(dtype=x.dtype),
self.sin_cached[:,:,:seq_len,...].to(dtype=x.dtype),
)

五、consistent of Dynamically Scaled RoPE

fd799f3e-2d29-11ee-815d-dac502259ad0.png

方法四存在一个问题是,因为α是动态的,因为解码是有cache的,所以,在生成第100个token时,算的α和第200个token时,算的α时不一致的。fd9f8a78-2d29-11ee-815d-dac502259ad0.png

query和key的rotation base不一致,正确的应该时这样

fda853ec-2d29-11ee-815d-dac502259ad0.png

importmath
fromtypingimportList,Optional,Tuple,Union

importtorch
importtorch.nn.functionalasF
importtorch.utils.checkpoint
fromtorchimportnn
fromtransformers.models.llama.modeling_llamaimportrepeat_kv,apply_rotary_pos_emb
fromtransformers.models.llama.modeling_llamaimportLlamaAttention

defforward(
self,
hidden_states:torch.Tensor,
attention_mask:Optional[torch.Tensor]=None,
position_ids:Optional[torch.LongTensor]=None,
past_key_value:Optional[Tuple[torch.Tensor]]=None,
output_attentions:bool=False,
use_cache:bool=False,
)->Tuple[torch.Tensor,Optional[torch.Tensor],Optional[Tuple[torch.Tensor]]]:
bsz,q_len,_=hidden_states.size()

ifself.pretraining_tp>1:
key_value_slicing=(self.num_key_value_heads*self.head_dim)//self.pretraining_tp
query_slices=self.q_proj.weight.split((self.num_heads*self.head_dim)//self.pretraining_tp,dim=0)
key_slices=self.k_proj.weight.split(key_value_slicing,dim=0)
value_slices=self.v_proj.weight.split(key_value_slicing,dim=0)

query_states=[F.linear(hidden_states,query_slices[i])foriinrange(self.pretraining_tp)]
query_states=torch.cat(query_states,dim=-1)

key_states=[F.linear(hidden_states,key_slices[i])foriinrange(self.pretraining_tp)]
key_states=torch.cat(key_states,dim=-1)

value_states=[F.linear(hidden_states,value_slices[i])foriinrange(self.pretraining_tp)]
value_states=torch.cat(value_states,dim=-1)

else:
query_states=self.q_proj(hidden_states)
key_states=self.k_proj(hidden_states)
value_states=self.v_proj(hidden_states)

query_states=query_states.view(bsz,q_len,self.num_heads,self.head_dim).transpose(1,2)
key_states=key_states.view(bsz,q_len,self.num_key_value_heads,self.head_dim).transpose(1,2)
value_states=value_states.view(bsz,q_len,self.num_key_value_heads,self.head_dim).transpose(1,2)

kv_seq_len=key_states.shape[-2]
ifpast_key_valueisnotNone:
kv_seq_len+=past_key_value[0].shape[-2]
cos,sin=self.rotary_emb(value_states,seq_len=kv_seq_len)

ifpast_key_valueisnotNone:
#reusekw/oRoPE
key_states=torch.cat([past_key_value[0],key_states],dim=2)

#applyRoPEafterretrievingallkeysandqueries
query_states,rotated_key_states=apply_rotary_pos_emb(query_states,key_states,cos,sin,position_ids)

ifpast_key_valueisnotNone:
#reusev,self_attention
value_states=torch.cat([past_key_value[1],value_states],dim=2)

past_key_value=(key_states,value_states)ifuse_cacheelseNone#cachethekeyw/oRoPE

#repeatk/vheadsifn_kv_heads< n_heads
    rotated_key_states = repeat_kv(rotated_key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    attn_weights = torch.matmul(query_states, rotated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
            f" {attn_weights.size()}"
        )

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )
        attn_weights = attn_weights + attention_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    if self.pretraining_tp >1:
attn_output=attn_output.split(self.hidden_size//self.pretraining_tp,dim=2)
o_proj_slices=self.o_proj.weight.split(self.hidden_size//self.pretraining_tp,dim=1)
attn_output=sum([F.linear(attn_output[i],o_proj_slices[i])foriinrange(self.pretraining_tp)])
else:
attn_output=self.o_proj(attn_output)

ifnotoutput_attentions:


attn_weights=None

returnattn_output,attn_weights,past_key_value


defreplace_llama_attn_with_consistent_ntk_rope():
LlamaAttention.forward=forward





审核编辑:刘清

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 解码器
    +关注

    关注

    9

    文章

    1141

    浏览量

    40713
  • LLM
    LLM
    +关注

    关注

    0

    文章

    285

    浏览量

    325

原文标题:浅谈LLM的长度外推

文章出处:【微信号:zenRRan,微信公众号:深度学习自然语言处理】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    对比解码在LLM上的应用

    为了改进LLM的推理能力,University of California联合Meta AI实验室提出将Contrastive Decoding应用于多种任务的LLM方法。实验表明,所提方法能有效改进LLM的推理能力。让我们走进
    发表于 09-21 11:37 620次阅读
    对比解码在<b class='flag-5'>LLM</b>上的应用

    饿了么确认收购百度外卖!最快本周收购,百度外卖为何会变成百度的弃子?

     饿了么和百度外度一直就以竞争对手的形式出现,然而两者之间相斗争总有败的一方,近日饿了么和百度外卖又一次上了热搜榜。
    发表于 08-21 14:52 847次阅读

    饿了么正式宣布收购百度外卖 内部邮件曝光

    8月24日下午消息,饿了么刚刚正式宣布收购百度外卖,随后,百度外卖内部邮件曝光。邮件表示,合并后,百度外卖仍以独立的品牌和运营体系发展,包括管理层在内的人员架构保持不变。
    发表于 08-24 16:59 807次阅读

    LLM性能的主要因素

    现在是2023年5月,截止目前,网络上已经开源了众多的LLM,如何用较低的成本,判断LLM的基础性能,选到适合自己任务的LLM,成为一个关键。 本文会涉及以下几个问题: 影响LLM性能
    的头像 发表于 05-22 15:26 1695次阅读
    <b class='flag-5'>LLM</b>性能的主要因素

    中国研究人员提出StructGPT,提高LLM对结构化数据的零样本推理能力

    尽管结构化数据的体量往往非常巨大,但不可能容纳输入提示中的所有数据记录(例如,ChatGPT 的最大上下文长度为 4096)。将结构化数据线性化为 LLM 可以轻松掌握的语句是解决此问题的简单方法。工具操作技术激励他们增强 LLM
    的头像 发表于 05-24 16:02 2981次阅读
    中国研究人员提出StructGPT,提高<b class='flag-5'>LLM</b>对结构化数据的零样本推理能力

    使用MLC-LLM支持RWKV-5推理的过程思考

    LLM的理解比较有限,从代码实现的角度来说,RWKV的状态和KV Cache不同,不依赖序列长度,这让RWKV模型在各种长度下运行内存和运行速度都是趋于稳定的,所以我感觉工程价值是比基于Transformer架构比如Llama
    的头像 发表于 11-19 15:58 974次阅读
    使用MLC-<b class='flag-5'>LLM</b>支持RWKV-5推理的过程思考

    如何利用位置编码实现长度外

    无论是缩放位置索引还是修改基地,所有token都变得彼此更接近,这将损害LLM区分相近token的位置顺序的能力。结合他们对RoPE的波长的观察,存在一些波长比预训练的上下文窗口长的维度,NTK-by-parts插值的作者建议完全不插值较高的频率维度。
    发表于 01-08 09:58 453次阅读
    如何利用位置编码实现<b class='flag-5'>长度外</b><b class='flag-5'>推</b>?

    LLM推理加速新范式!推测解码(Speculative Decoding)最新综述

    低下(->每个token的生成都需要重复读写LLM的巨量参数),并且序列的生成时间随着序列长度的增加而线性增加。
    的头像 发表于 01-29 15:54 2742次阅读
    <b class='flag-5'>LLM</b>推理加速新范式!推测解码(Speculative Decoding)最新综述

    什么是LLMLLM的工作原理和结构

    随着人工智能技术的飞速发展,大型语言模型(Large Language Model,简称LLM)逐渐成为自然语言处理(NLP)领域的研究热点。LLM以其强大的文本生成、理解和推理能力,在文本
    的头像 发表于 07-02 11:45 7443次阅读

    LLM模型的应用领域

    在本文中,我们将深入探讨LLM(Large Language Model,大型语言模型)的应用领域。LLM是一种基于深度学习的人工智能技术,它能够理解和生成自然语言文本。近年来,随着计算能力的提高
    的头像 发表于 07-09 09:52 559次阅读

    llm模型有哪些格式

    LLM(Large Language Model,大型语言模型)是一种深度学习模型,主要用于处理自然语言处理(NLP)任务。LLM模型的格式多种多样,以下是一些常见的LLM模型格式
    的头像 发表于 07-09 09:59 591次阅读

    CS1-U DC/AC5-240V磁性开关长度要求

    磁性开关的长度要求并非固定不变,而是需要根据具体的应用场景和安装环境进行灵活选择。在选择磁性开关时,除了考虑其长度外,还需要关注其技术参数、工作环境要求以及安装间距等因素,以确保其能够正常工作并满足实际需求。
    的头像 发表于 10-12 18:07 155次阅读

    LLM和传统机器学习的区别

    在人工智能领域,LLM(Large Language Models,大型语言模型)和传统机器学习是两种不同的技术路径,它们在处理数据、模型结构、应用场景等方面有着显著的差异。 1. 模型结构
    的头像 发表于 11-08 09:25 387次阅读

    LLM技术对人工智能发展的影响

    随着人工智能技术的飞速发展,大型语言模型(LLM)技术已经成为推动AI领域进步的关键力量。LLM技术通过深度学习和自然语言处理技术,使得机器能够理解和生成自然语言,极大地扩展了人工智能的应用范围
    的头像 发表于 11-08 09:28 338次阅读

    什么是LLMLLM在自然语言处理中的应用

    随着人工智能技术的飞速发展,自然语言处理(NLP)领域迎来了革命性的进步。其中,大型语言模型(LLM)的出现,标志着我们对语言理解能力的一次飞跃。LLM通过深度学习和海量数据训练,使得机器能够以前
    的头像 发表于 11-19 15:32 501次阅读