多头注意力机制(Multi-Head Attention)原理与代码
多头注意力机制(Multi-Head Attention)原理与代码
在自然语言处理(NLP)领域,多头注意力机制是现代深度学习模型,尤其是Transformer架构的核心部分。它能够并行地关注输入的不同部分,从而提升模型的表达能力,改进对信息的提取与处理。接下来,我们将详细探讨多头注意力机制的原理,并提供对应的代码实现。
1. 多头注意力机制原理
**注意力机制(Attention Mechanism)**本质上是一种加权求和的操作,它通过计算查询(Query)、键(Key)和值(Value)之间的相似度来决定如何将不同部分的信息进行聚合。在单头注意力中,通常使用一个查询向量与所有键进行相似度计算,并根据计算出的权重值对对应的值进行加权求和,最终得到输出。
多头注意力机制则是将多个独立的注意力头进行并行计算,然后将每个头的输出进行拼接或加权平均。这样能够让模型在不同的“子空间”中学习到更多的信息,从而提升模型对复杂模式的学习能力。
1.1 公式表示
给定输入矩阵 QQ、KK、VV(分别表示查询、键和值),单头注意力的计算过程如下:
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
其中:
- QQ、KK、VV 的维度分别为 [n,dk][n, d_k] 和 [n,dv][n, d_v];
- dkd_k 是键的维度,通常是一个常数;
- nn 是输入序列的长度。
多头注意力机制则将上述单头注意力操作并行化,多个头的输出拼接后通过一个线性层进行映射。公式为:
MultiHead(Q,K,V)=Concat(head1,head2,...,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(head_1, head_2, ..., head_h)W^O
其中:
- headi=Attention(QWiQ,KWiK,VWiV)head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) 是每个头的输出;
- WiQ,WiK,WiV,WOW_i^Q, W_i^K, W_i^V, W^O 是需要学习的权重矩阵;
- hh 是头的数量。
1.2 多头注意力机制的优势
- 信息丰富性:通过不同的头对输入的不同部分进行关注,可以获得更加多样化的信息表达。
- 捕获不同语义关系:每个头能够关注输入的不同方面,比如词语间的不同语法关系、语义关系等,从而增强模型的理解能力。
- 并行化计算:每个头的计算是独立的,因此多头注意力可以并行计算,大大提高计算效率。
2. 多头注意力机制的代码实现
以下是用PyTorch实现的多头注意力机制代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, heads):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
# 确保embed_size能被head_dim整除
assert embed_size % self.head_dim == 0, "Embedding size must be divisible by heads"
# 定义查询、键、值的权重矩阵
self.Wq = nn.Linear(embed_size, embed_size)
self.Wk = nn.Linear(embed_size, embed_size)
self.Wv = nn.Linear(embed_size, embed_size)
# 定义输出的线性变换
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask=None):
# 获取批次大小
batch_size = query.shape[0]
# 将Q、K、V通过线性层映射到合适的维度
Q = self.Wq(query)
K = self.Wk(keys)
V = self.Wv(values)
# 分割为多个头
Q = Q.view(batch_size, -1, self.heads, self.head_dim)
K = K.view(batch_size, -1, self.heads, self.head_dim)
V = V.view(batch_size, -1, self.heads, self.head_dim)
# 转置Q, K, V的形状,使得可以在最后一个维度上进行矩阵乘法
Q = Q.permute(0, 2, 1, 3) # (batch_size, heads, seq_len, head_dim)
K = K.permute(0, 2, 1, 3)
V = V.permute(0, 2, 1, 3)
# 计算注意力权重
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) # (batch_size, heads, seq_len, seq_len)
if mask is not None:
energy = energy.masked_fill(mask == 0, float('-inf')) # Mask掉不该关注的位置
attention = torch.softmax(energy / (self.head_dim ** (1 / 2)), dim=-1)
# 使用注意力权重加权值
out = torch.matmul(attention, V) # (batch_size, heads, seq_len, head_dim)
# 转置回原来的形状
out = out.permute(0, 2, 1, 3).contiguous()
# 将多个头的输出拼接起来
out = out.view(batch_size, -1, self.heads * self.head_dim)
# 通过输出的线性层映射
out = self.fc_out(out)
return out
2.1 代码说明
- 初始化部分:
embed_size
:表示输入的嵌入维度;heads
:表示注意力头的数量;head_dim
:每个头的维度,等于embed_size // heads
。Wq
、Wk
、Wv
:分别是查询、键、值的线性变换矩阵;fc_out
:是最终的输出线性层。
- 前向传播部分:
- 首先通过线性层将输入的查询、键和值转换为对应的向量;
- 然后对每个查询、键、值进行分割和转置,准备多头注意力的计算;
- 计算注意力权重,并使用这些权重对值进行加权求和,得到每个头的输出;
- 最后将多个头的输出拼接后通过一个线性层映射到输出空间。
3. 结语
多头注意力机制的引入极大提升了模型的表达能力,并通过并行化计算有效提高了训练效率。其在Transformer架构中的应用,推动了NLP领域的巨大突破。从原理到实现代码,掌握多头注意力机制不仅有助于理解现代深度学习模型的工作方式,还能帮助开发者更好地进行模型设计和优化。
4. 工作流程图
graph TD;
A[输入Q, K, V] --> B[线性变换Q, K, V];
B --> C[分割多个头];
C --> D[计算注意力权重];
D --> E[加权求和得到每个头的输出];
E --> F[拼接多个头的输出];
F --> G[通过线性层输出];
G --> H[得到最终的多头注意力结果];