0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

基于对抗自注意力机制的预训练语言模型

深度学习自然语言处理 来源:ICLR 2022 作者:曾伟豪 2022-07-08 16:57 次阅读

Introduction

本文提出了 Adversarial Self-Attention 机制(ASA),利用对抗训练重构 Transformer 的注意力,使模型在被污染的模型结构中得到训练。 尝试解决的问题:

大量的证据表明,自注意力可以从 allowing bias 中获益,allowing bias 可以将一定程度的先验(如 masking,分布的平滑)加入原始的注意力结构中。这些先验知识能够让模型从较小的语料中学习有用的知识。但是这些先验知识一般是任务特定的知识,使得模型很难扩展到丰富的任务上。

adversarial training 通过给输入内容添加扰动来提升模型的鲁棒性。作者发现仅仅给 input embedding 添加扰动很难 confuse 到 attention maps. 模型的注意在扰动前后没有发生变化。

为了解决上述问题,作者提出了 ASA,具有以下的优势:

最大化 empirical training risk,在自动化构建先验知识的过程学习得到biased(or adversarial)的结构。

adversial 结构是由输入数据学到,使得 ASA 区别于传统的对抗训练或自注意力的变体。

使用梯度反转层来将 model 和 adversary 结合为整体。

ASA 天然具有可解释性。

Preliminary

表示输入的特征,在传统的对抗训练中, 通常是 token 序列或者是 token 的 embedding, 表示 ground truth. 对于由 参数化的模型,模型的预测结果可以表示为 。

2.1 Adversarial training

对抗训练的目的是旨在通过推近经过扰动的模型预测和目标分布之间的距离来提升模型的鲁棒性:

d5da9fe0-fe9b-11ec-ba43-dac502259ad0.png

其中 代表经过对抗扰动 扰动后的模型预测, 表示模型的目标分布。 对抗扰动 通过最大化 empirical training risk 获得:

d5ee5a76-fe9b-11ec-ba43-dac502259ad0.png

其中 是对 做出的约束,希望在 较小的情况下给模型造成较大的扰动。上述的两个表示展示的就是对抗的过程。

2.2General Self-Attention

定义自注意力的表达式为:

d5fd9c52-fe9b-11ec-ba43-dac502259ad0.png

在最普通的自注意力机制中 代表全等矩阵,而之前的研究中, 代表的是用来平滑注意力结构的输出分布的一定程度的先验知识。 作者在本文将 定义为元素为 的 binary 矩阵。

Adversarial Self-Attention Mechanism

3.1 Optimization

ASA 的目的是掩盖模型中最脆弱的注意力单元。这些最脆弱的单元取决于模型的输入,因此对抗可以表示为由输入学习到的“meta-knowledge”:,ASA 注意力可以表示为:

d619c8b4-fe9b-11ec-ba43-dac502259ad0.png

与对抗训练类似,模型用来最小化如下的 divergence:

d62c9c14-fe9b-11ec-ba43-dac502259ad0.png

通过最大化 empirical risk 估计得到 :

d63a855e-fe9b-11ec-ba43-dac502259ad0.png

其中 表示的是 的决策边界,用来防止 ASA 损害模型的训练。

考虑到 以 attention mask 的形式存在,因此更适合通过约束 masked units 的比例来约束。由于很难测量 。 的具体数值,因此将 hard constraint 转化为具有惩罚的 unconstraint:

d64eab74-fe9b-11ec-ba43-dac502259ad0.png

其中 t 用来控制对抗的程度。

3.2 Implementation

作者提出了 ASA 的简单且快速的实现。

d663af10-fe9b-11ec-ba43-dac502259ad0.png

对于第 自注意力层, 可以由输入的隐层状态获得。具体而言,使用线性层将隐层状态转化为 以及 ,通过点乘获得矩阵 ,再通过重参数化技巧将矩阵 binary 化。 由于对抗训练通常包括 inner maximization 以及 outer minimization 两个目标,因此至少需要两次 backward 过程。因此为了加速训练,作者采用了 Gradient Reversal Layer(GRL)将两个过程合并。

3.3 Training

训练目标如下所示:

d677006a-fe9b-11ec-ba43-dac502259ad0.png

表示 task- specific 损失, 表示加上 ASA 对抗后的损失, 表示对于对于 的约束。

Experiments

4.1Result

d697f5f4-fe9b-11ec-ba43-dac502259ad0.png

从上表可以看出,在微调方面,ASA 支持的模型始终在很大程度上超过了原始的BERT 和 RoBERTa. 可以看到,ASA 在小规模数据集比如说 STS-B,DREAM 上表现优异(一般认为这些小规模数据集上更容易过拟合)同时在更大规模的数据集上如 MNLI,QNLI 以及 QQP 上仍然有较好的提升,说明了 ASA 在提升模型泛化能力的同时能提升模型的语言表示能力。 如下表所示,ASA 在提升模型鲁棒性上具有较大的作用。

d6b2e4c2-fe9b-11ec-ba43-dac502259ad0.png

4.2 分析实验

1. VS. Naive smoothing 将 ASA 与其他注意力平滑方式进行比较。

d6c547e8-fe9b-11ec-ba43-dac502259ad0.png

2. VS. Adversial training 将 ASA 与其他对抗训练方式进行比较

d6d7050a-fe9b-11ec-ba43-dac502259ad0.png

4.3Visualization

1. Why ASA improves generalization 对抗能够减弱关键词的注意力而让非关键词接受更多的注意力。ASA 阻止了模型的懒惰预测,但敦促它从被污染的线索中学习,从而提高了泛化能力。

d6efa628-fe9b-11ec-ba43-dac502259ad0.png

2. Bottom layers are more vulnerable 可以看到 masking 占比随着层数由底层到高层逐渐降低,更高的 masking 占比意味着层的脆弱性更高。

d715222c-fe9b-11ec-ba43-dac502259ad0.png

Conclusion

本文提出了 Adversarial Self-Attention mechanism(ASA)来提高预训练语言模型的泛化性和鲁棒性。大量实验表明本文提出的方法能够在预训练和微调阶段提升模型的鲁棒性。

·审核编辑 :李倩

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 自动化
    +关注

    关注

    28

    文章

    5246

    浏览量

    78133
  • 语言模型
    +关注

    关注

    0

    文章

    466

    浏览量

    10171

原文标题:ICLR2022 | 基于对抗自注意力机制的预训练语言模型

文章出处:【微信号:zenRRan,微信公众号:深度学习自然语言处理】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    【大规模语言模型:从理论到实践】- 阅读体验

    再次感谢电子发烧友提供的书籍试读机会。今天来分享下我在学习大模型训练注意力机制 的心得体会。 虽然注意力
    发表于 06-07 14:44

    语言模型:原理与工程时间+小白初识大语言模型

    语言模型进行训练,此处训练为自然语言处理领域的
    发表于 05-12 23:57

    【大语言模型:原理与工程实践】大语言模型的应用

    。 关于大语言模型是否具备与人类“系统2”相似的能力,存在广泛的争议。然而,随着模型参数量的增加和大规模训练的实施,大
    发表于 05-07 17:21

    【大语言模型:原理与工程实践】大语言模型训练

    语言模型的核心特点在于其庞大的参数量,这赋予了模型强大的学习容量,使其无需依赖微调即可适应各种下游任务,而更倾向于培养通用的处理能力。然而,随着学习容量的增加,对
    发表于 05-07 17:10

    【大语言模型:原理与工程实践】大语言模型的基础技术

    模型仍以Transformer为基础进行训练。Transformer是一种基于注意力机制的编码器-解码器结构,其核心由编码器和解码器组成,
    发表于 05-05 12:17

    【大语言模型:原理与工程实践】核心技术综述

    训练和微调,直到模型的部署和性能评估。以下是对这些技术的综述: 模型架构: LLMs通常采用深层的神经网络架构,最常见的是Transformer网络,它包含多个
    发表于 05-05 10:56

    【大语言模型:原理与工程实践】揭开大语言模型的面纱

    Transformer架构,利用注意力机制对文本进行编码,通过训练、有监督微调和强化学习等阶段,不断提升性能,展现出强大的
    发表于 05-04 23:55

    【大语言模型:原理与工程实践】探索《大语言模型原理与工程实践》

    处理中训练架构Transformer,以及这些技术在现实世界中的如何应用。通过具体案例的分析,作者展示了大语言模型在解决实际问题中的强大能力,同时也指出了当前技术面临的挑战和局限性。
    发表于 04-30 15:35

    盘点一下史上最全大语言模型训练中的网络技术

    人工智能的基础设施在大语言模型训练和推理过程中发挥了关键的作用。随着大语言模型规模不断增大,其对计算和通信的需求也在不断增加。高
    的头像 发表于 03-27 17:24 641次阅读
    盘点一下史上最全大<b class='flag-5'>语言</b><b class='flag-5'>模型</b><b class='flag-5'>训练</b>中的网络技术

    名单公布!【书籍评测活动NO.30】大规模语言模型:从理论到实践

    榜销售TOP1的桂冠,可想大家对本书的认可和支持! 这本书为什么如此受欢迎?它究竟讲了什么?下面就给大家详细~~ 本书主要内容 本书围绕大语言模型构建的四个主要阶段——训练、有监督
    发表于 03-11 15:16

    模型与人类的注意力视角下参数规模扩大与指令微调对模型语言理解的作用

    近期的大语言模型(LLM)在自然语言理解和生成上展现出了接近人类的强大能力,远远优于先前的BERT等预训练模型(PLM)。
    的头像 发表于 01-04 14:06 226次阅读
    <b class='flag-5'>模型</b>与人类的<b class='flag-5'>注意力</b>视角下参数规模扩大与指令微调对<b class='flag-5'>模型</b><b class='flag-5'>语言</b>理解的作用

    全新近似注意力机制HyperAttention:对长上下文友好、LLM推理提速50%

    已经成功应用于自然语言处理、计算机视觉和时间序列预测等领域的各种学习任务。虽然取得了成功,但这些模型仍面临着严重的可扩展性限制,原因是对其注意力层的精确计算导致了二次(在序列长度上)运行时和内存复杂性。这对将 Transfor
    的头像 发表于 11-20 09:15 382次阅读
    全新近似<b class='flag-5'>注意力</b><b class='flag-5'>机制</b>HyperAttention:对长上下文友好、LLM推理提速50%

    语言模型(LLM)预训练数据集调研分析

    语言模型涉及数据的通常有有多个阶段(Aligning language models to follow instructions [1] ):pre-train、sft(supervised
    的头像 发表于 09-19 10:00 704次阅读
    大<b class='flag-5'>语言</b><b class='flag-5'>模型</b>(LLM)预<b class='flag-5'>训练</b>数据集调研分析

    训练语言模型带来的硬件挑战

    生成式AI和大语言模型(LLM)正在以难以置信的方式吸引全世界的目光,本文简要介绍了大语言模型训练这些
    的头像 发表于 09-01 17:14 1235次阅读
    <b class='flag-5'>训练</b>大<b class='flag-5'>语言</b><b class='flag-5'>模型</b>带来的硬件挑战

    详细介绍​注意力机制中的掩码

    注意力机制的掩码允许我们发送不同长度的批次数据一次性的发送到transformer中。在代码中是通过将所有序列填充到相同的长度,然后使用“attention_mask”张量来识别哪些令牌是填充的来做到这一点,本文将详细介绍这个掩码的原理和
    的头像 发表于 07-17 16:46 532次阅读
    详细介绍​<b class='flag-5'>注意力</b><b class='flag-5'>机制</b>中的掩码