电子发烧友App

硬声App

0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示
创作
电子发烧友网>电子资料下载>电子资料>PyTorch教程11.6之自注意力和位置编码

PyTorch教程11.6之自注意力和位置编码

2023-06-05 | pdf | 0.33 MB | 次下载 | 免费

资料介绍

深度学习中,我们经常使用 CNN 或 RNN 对序列进行编码。现在考虑到注意力机制,想象一下将一系列标记输入注意力机制,这样在每个步骤中,每个标记都有自己的查询、键和值。在这里,当在下一层计算令牌表示的值时,令牌可以(通过其查询向量)参与每个其他令牌(基于它们的键向量进行匹配)。使用完整的查询键兼容性分数集,我们可以通过在其他标记上构建适当的加权和来为每个标记计算表示。因为每个标记都关注另一个标记(不同于解码器步骤关注编码器步骤的情况),这种架构通常被描述为自注意力模型 Lin等。, 2017 年 Vaswani等人。, 2017 ),以及其他地方描述的内部注意力模型 ( Cheng et al. , 2016 , Parikh et al. , 2016 , Paulus et al. , 2017 )在本节中,我们将讨论使用自注意力的序列编码,包括使用序列顺序的附加信息

import math
import torch
from torch import nn
from d2l import torch as d2l
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import numpy as np
import tensorflow as tf
from d2l import tensorflow as d2l

11.6.1。自注意力

给定一系列输入标记 x1,…,xn任何地方 xi∈Rd(1≤i≤n), 它的self-attention输出一个相同长度的序列 y1,…,yn, 在哪里

(11.6.1)yi=f(xi,(x1,x1),…,(xn,xn))∈Rd

根据 (11.1.1)中attention pooling的定义。使用多头注意力,以下代码片段计算具有形状(批量大小、时间步数或标记中的序列长度, d). 输出张量具有相同的形状。

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
        (batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()

batch_size, num_queries, valid_lens = 2, 4, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
        (batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)

batch_size, num_queries, valid_lens = 2, 4, jnp.array([3, 2])
X = jnp.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention.init_with_output(d2l.get_key(), X, X, X, valid_lens,
                      training=False)[0][0],
        (batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                  num_hiddens, num_heads, 0.5)

batch_size, num_queries, valid_lens = 2, 4, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens, training=False),
        (batch_size, num_queries, num_hiddens))

11.6.2。比较 CNN、RNN 和自注意力

让我们比较一下映射一系列的架构n标记到另一个等长序列,其中每个输入或输出标记由一个d维向量。具体来说,我们将考虑 CNN、RNN 和自注意力。我们将比较它们的计算复杂度、顺序操作和最大路径长度。请注意,顺序操作会阻止并行计算,而序列位置的任意组合之间的较短路径可以更容易地学习序列内的远程依赖关系 Hochreiter等人,2001 年

https://file.elecfans.com/web2/M00/AA/44/pYYBAGR9OB2AYW27AAGoqLUwK-4826.svg

图 11.6.1比较 CNN(省略填充标记)、RNN 和自注意力架构。

考虑一个卷积层,其内核大小为k. 我们将在后面的章节中提供有关使用 CNN 进行序列处理的更多详细信息。现在,我们只需要知道,因为序列长度是n,输入和输出通道的数量都是 d, 卷积层的计算复杂度为 O(knd2). 如图11.6.1 所示,CNN 是分层的,因此有O(1) 顺序操作和最大路径长度是

下载该资料的人也在下载 下载该资料的人还在阅读
更多 >

评论

查看更多

下载排行

本周

  1. 1使用单片机实现七人表决器的程序和仿真资料免费下载
  2. 2.96 MB   |  44次下载  |  免费
  3. 2联想E46L DAOLL6笔记本电脑图纸
  4. 1.10 MB   |  2次下载  |  5 积分
  5. 3MATLAB绘图合集
  6. 27.12 MB   |  2次下载  |  5 积分
  7. 4PR735,使用UCC28060的600W交错式PFC转换器
  8. 540.03KB   |  1次下载  |  免费
  9. 5UCC38C42 30W同步降压转换器参考设计
  10. 428.07KB   |  1次下载  |  免费
  11. 6DV2004S1/ES1/HS1快速充电开发系统
  12. 2.08MB   |  1次下载  |  免费
  13. 7模态分解合集matlab代码
  14. 3.03 MB   |  1次下载  |  2 积分
  15. 8美的电磁炉维修手册大全
  16. 1.56 MB   |  1次下载  |  5 积分

本月

  1. 1使用单片机实现七人表决器的程序和仿真资料免费下载
  2. 2.96 MB   |  44次下载  |  免费
  3. 2UC3842/3/4/5电源管理芯片中文手册
  4. 1.75 MB   |  15次下载  |  免费
  5. 3DMT0660数字万用表产品说明书
  6. 0.70 MB   |  13次下载  |  免费
  7. 4TPS54202H降压转换器评估模块用户指南
  8. 1.02MB   |  8次下载  |  免费
  9. 5STM32F101x8/STM32F101xB手册
  10. 1.69 MB   |  8次下载  |  1 积分
  11. 6HY12P65/HY12P66数字万用表芯片规格书
  12. 0.69 MB   |  6次下载  |  免费
  13. 7华瑞昇CR216芯片数字万用表规格书附原理图及校正流程方法
  14. 0.74 MB   |  6次下载  |  3 积分
  15. 8华瑞昇CR215芯片数字万用表原理图
  16. 0.21 MB   |  5次下载  |  3 积分

总榜

  1. 1matlab软件下载入口
  2. 未知  |  935119次下载  |  10 积分
  3. 2开源硬件-PMP21529.1-4 开关降压/升压双向直流/直流转换器 PCB layout 设计
  4. 1.48MB  |  420061次下载  |  10 积分
  5. 3Altium DXP2002下载入口
  6. 未知  |  233084次下载  |  10 积分
  7. 4电路仿真软件multisim 10.0免费下载
  8. 340992  |  191367次下载  |  10 积分
  9. 5十天学会AVR单片机与C语言视频教程 下载
  10. 158M  |  183335次下载  |  10 积分
  11. 6labview8.5下载
  12. 未知  |  81581次下载  |  10 积分
  13. 7Keil工具MDK-Arm免费下载
  14. 0.02 MB  |  73807次下载  |  10 积分
  15. 8LabVIEW 8.6下载
  16. 未知  |  65987次下载  |  10 积分