在深度学习中,我们经常使用 CNN 或 RNN 对序列进行编码。现在考虑到注意力机制,想象一下将一系列标记输入注意力机制,这样在每个步骤中,每个标记都有自己的查询、键和值。在这里,当在下一层计算令牌表示的值时,令牌可以(通过其查询向量)参与每个其他令牌(基于它们的键向量进行匹配)。使用完整的查询键兼容性分数集,我们可以通过在其他标记上构建适当的加权和来为每个标记计算表示。因为每个标记都关注另一个标记(不同于解码器步骤关注编码器步骤的情况),这种架构通常被描述为自注意力模型 (Lin等。, 2017 年, Vaswani等人。, 2017 ),以及其他地方描述的内部注意力模型 ( Cheng et al. , 2016 , Parikh et al. , 2016 , Paulus et al. , 2017 )。在本节中,我们将讨论使用自注意力的序列编码,包括使用序列顺序的附加信息。
11.6.1。自注意力
给定一系列输入标记 x1,…,xn任何地方 xi∈Rd(1≤i≤n), 它的self-attention输出一个相同长度的序列 y1,…,yn, 在哪里
根据 (11.1.1)中attention pooling的定义。使用多头注意力,以下代码片段计算具有形状(批量大小、时间步数或标记中的序列长度, d). 输出张量具有相同的形状。
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
batch_size, num_queries, valid_lens = 2, 4, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, jnp.array([3, 2])
X = jnp.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention.init_with_output(d2l.get_key(), X, X, X, valid_lens,
training=False)[0][0],
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens, training=False),
(batch_size, num_queries, num_hiddens))
11.6.2。比较 CNN、RNN 和自注意力
让我们比较一下映射一系列的架构n标记到另一个等长序列,其中每个输入或输出标记由一个d维向量。具体来说,我们将考虑 CNN、RNN 和自注意力。我们将比较它们的计算复杂度、顺序操作和最大路径长度。请注意,顺序操作会阻止并行计算,而序列位置的任意组合之间的较短路径可以更容易地学习序列内的远程依赖关系 (Hochreiter等人,2001 年)。
考虑一个卷积层,其内核大小为k. 我们将在后面的章节中提供有关使用 CNN 进行序列处理的更多详细信息。现在,我们只需要知道,因为序列长度是n,输入和输出通道的数量都是 d, 卷积层的计算复杂度为 O(knd2). 如图11.6.1 所示,CNN 是分层的,因此有O(1) 顺序操作和最大路径长度是
评论
查看更多