多头注意力机制(Multi-Head Attention)原理与代码

多头注意力机制(Multi-Head Attention)原理与代码

在自然语言处理(NLP)领域,多头注意力机制是现代深度学习模型,尤其是Transformer架构的核心部分。它能够并行地关注输入的不同部分,从而提升模型的表达能力,改进对信息的提取与处理。接下来,我们将详细探讨多头注意力机制的原理,并提供对应的代码实现。

1. 多头注意力机制原理

**注意力机制(Attention Mechanism)**本质上是一种加权求和的操作,它通过计算查询(Query)、键(Key)和值(Value)之间的相似度来决定如何将不同部分的信息进行聚合。在单头注意力中,通常使用一个查询向量与所有键进行相似度计算,并根据计算出的权重值对对应的值进行加权求和,最终得到输出。

多头注意力机制则是将多个独立的注意力头进行并行计算,然后将每个头的输出进行拼接或加权平均。这样能够让模型在不同的“子空间”中学习到更多的信息,从而提升模型对复杂模式的学习能力。

1.1 公式表示

给定输入矩阵 QQKKVV(分别表示查询、键和值),单头注意力的计算过程如下:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

其中:

  • QQKKVV 的维度分别为 [n,dk][n, d_k][n,dv][n, d_v]
  • dkd_k 是键的维度,通常是一个常数;
  • nn 是输入序列的长度。

多头注意力机制则将上述单头注意力操作并行化,多个头的输出拼接后通过一个线性层进行映射。公式为:

MultiHead(Q,K,V)=Concat(head1,head2,...,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(head_1, head_2, ..., head_h)W^O

其中:

  • headi=Attention(QWiQ,KWiK,VWiV)head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) 是每个头的输出;
  • WiQ,WiK,WiV,WOW_i^Q, W_i^K, W_i^V, W^O 是需要学习的权重矩阵;
  • hh 是头的数量。

1.2 多头注意力机制的优势

  • 信息丰富性:通过不同的头对输入的不同部分进行关注,可以获得更加多样化的信息表达。
  • 捕获不同语义关系:每个头能够关注输入的不同方面,比如词语间的不同语法关系、语义关系等,从而增强模型的理解能力。
  • 并行化计算:每个头的计算是独立的,因此多头注意力可以并行计算,大大提高计算效率。

2. 多头注意力机制的代码实现

以下是用PyTorch实现的多头注意力机制代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        # 确保embed_size能被head_dim整除
        assert embed_size % self.head_dim == 0, "Embedding size must be divisible by heads"
        
        # 定义查询、键、值的权重矩阵
        self.Wq = nn.Linear(embed_size, embed_size)
        self.Wk = nn.Linear(embed_size, embed_size)
        self.Wv = nn.Linear(embed_size, embed_size)
        
        # 定义输出的线性变换
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
    
    def forward(self, values, keys, query, mask=None):
        # 获取批次大小
        batch_size = query.shape[0]
        
        # 将Q、K、V通过线性层映射到合适的维度
        Q = self.Wq(query)
        K = self.Wk(keys)
        V = self.Wv(values)
        
        # 分割为多个头
        Q = Q.view(batch_size, -1, self.heads, self.head_dim)
        K = K.view(batch_size, -1, self.heads, self.head_dim)
        V = V.view(batch_size, -1, self.heads, self.head_dim)
        
        # 转置Q, K, V的形状,使得可以在最后一个维度上进行矩阵乘法
        Q = Q.permute(0, 2, 1, 3)  # (batch_size, heads, seq_len, head_dim)
        K = K.permute(0, 2, 1, 3)
        V = V.permute(0, 2, 1, 3)
        
        # 计算注意力权重
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2))  # (batch_size, heads, seq_len, seq_len)
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float('-inf'))  # Mask掉不该关注的位置
        
        attention = torch.softmax(energy / (self.head_dim ** (1 / 2)), dim=-1)
        
        # 使用注意力权重加权值
        out = torch.matmul(attention, V)  # (batch_size, heads, seq_len, head_dim)
        
        # 转置回原来的形状
        out = out.permute(0, 2, 1, 3).contiguous()
        
        # 将多个头的输出拼接起来
        out = out.view(batch_size, -1, self.heads * self.head_dim)
        
        # 通过输出的线性层映射
        out = self.fc_out(out)
        
        return out

2.1 代码说明

  1. 初始化部分
    • embed_size:表示输入的嵌入维度;
    • heads:表示注意力头的数量;
    • head_dim:每个头的维度,等于 embed_size // heads
    • WqWkWv:分别是查询、键、值的线性变换矩阵;
    • fc_out:是最终的输出线性层。
  2. 前向传播部分
    • 首先通过线性层将输入的查询、键和值转换为对应的向量;
    • 然后对每个查询、键、值进行分割和转置,准备多头注意力的计算;
    • 计算注意力权重,并使用这些权重对值进行加权求和,得到每个头的输出;
    • 最后将多个头的输出拼接后通过一个线性层映射到输出空间。

3. 结语

多头注意力机制的引入极大提升了模型的表达能力,并通过并行化计算有效提高了训练效率。其在Transformer架构中的应用,推动了NLP领域的巨大突破。从原理到实现代码,掌握多头注意力机制不仅有助于理解现代深度学习模型的工作方式,还能帮助开发者更好地进行模型设计和优化。


4. 工作流程图

graph TD;
    A[输入Q, K, V] --> B[线性变换Q, K, V];
    B --> C[分割多个头];
    C --> D[计算注意力权重];
    D --> E[加权求和得到每个头的输出];
    E --> F[拼接多个头的输出];
    F --> G[通过线性层输出];
    G --> H[得到最终的多头注意力结果];
THE END