大模型原理:多头注意力 作者:马育民 • 2026-01-21 08:49 • 阅读:10000 # 介绍 把先前实现的因果注意力类扩展到多个头上。这也被称为 **多头注意力** ### 多头 是将注意力机制分成 **多个“头”**,**每个“头”独立工作**。在这种情况下,单个因果注意力模块可以被看作单头注意力,因为它只有一组注意力权重按顺序处理输入。 # 叠加多个单头注意力层 实现多头注意力需要构建多个自注意力机制的实例如下图,每个实例都有其独立的权重,然后将这些输出进行合成  下图展示了多头注意力模块的结构,它是由上图所示的多个单头注意力模块依次叠加在一起组成的,即:**两个堆叠在一起的单头注意力模块**。 因此,不能使用单一的矩阵 $$W\_v$$ 来计算值矩阵,而是在一个 **有两个头的多头注意模块中,有两个值权重矩阵: $$W\_{v1}$$ 和 $$W\_{v2}$$ 。** 同样适用于其他的权重矩阵,比如 $$W\_q$$ 和 $$W\_k$$ 。得到了两组上下文向量 $$Z\_1$$ 和 $$Z\_2$$ ,然后 **合并成一个单一的上下文向量矩阵 $$Z$$ **  多头注意力的主要思想是 **多次(并行)运行注意力机制**,每次使用学到的不同的线性投影——这些投影是通过将输入数据(比如注意力机制中的查询向量、键向量和值向量)乘以权重矩阵得到的。 # 实现 ### 准备代码 ``` import torch import torch.nn as nn class CausalAttention(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False): """ 初始化 :param d_in: 输入嵌入维度 :param d_out: 输出嵌入维度 :param context_length: 设置掩码矩阵 :param dropout: dropout忽略概率 :param qkv_bias: """ super().__init__() self.d_out = d_out self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) # 增加 Dropout 层 self.dropout = nn.Dropout(dropout) # 将上三角掩码注册为缓冲区,不需要梯度更新,但需要随模型一起保存/加载、迁移设备 self.register_buffer( 'mask', torch.triu(torch.ones(context_length, context_length), diagonal=1) ) def forward(self, x): b, num_tokens, d_in = x.shape keys = self.W_key(x) queries = self.W_query(x) values = self.W_value(x) # 将维度1和2转置,将批维度保持在第一个位置(0) attn_scores = queries @ keys.transpose(1, 2) # 直接修改传入的矩阵,而不是返回副本,减少内存 attn_scores.masked_fill_( self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) attn_weights = torch.softmax( attn_scores / keys.shape[-1]**0.5, dim=-1 ) # 对注意力权重矩阵进行dropout操作 attn_weights = self.dropout(attn_weights) # 计算上下文向量 context_vec = attn_weights @ values return context_vec ``` ### 多头注意力封装类 在代码中,可以通过实现一个简单的 `MultiHeadAttentionWrapper` 类来达到这一目标,`MultiHeadAttentionWrapper` 类堆叠了多个之前实现的`CausalAttention` 模块实例 ``` class MultiHeadAttentionWrapper(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): """ 初始化 :param d_in: 输入嵌入维度 :param d_out: 输出嵌入维度 :param context_length: 设置掩码矩阵 :param dropout: dropout忽略概率 :param num_heads: 多头注意力数量 :param qkv_bias: """ super().__init__() self.heads = nn.ModuleList( [CausalAttention( d_in, d_out, context_length, dropout, qkv_bias ) for _ in range(num_heads)] ) def forward(self, x): return torch.cat([head(x) for head in self.heads], dim=-1) ``` # 使用 如下图,使用 `MultiHeadAttentionWrapper`,指定了 **注意力头的数量 `num_heads=2` **,就会得到一个具有 **两组上下文向量矩阵的张量**。 在每个上下文向量矩阵中: - 行表示对应于词元的上下文向量 - 列则对应于通过 `d_out=4` 指定的嵌入维度,沿着 **列维度 连接 这些上下文向量矩阵** 由于有 **两个注意力头**,并且 **嵌入维度为`2`**,因此最终的嵌入维度是 `2×2=4`  ### 准备代码 ``` # 输入嵌入维度d_in=3 d_in = inputs.shape[1] # 输出嵌入维度d_out=2 d_out = 2 # 模拟批量输入,因为真正执行时,是通过数据加载器,批量读取的 batch = torch.stack((inputs, inputs), dim=0) print("批量数据:\n", batch) print("批量数据的shape:", batch.shape) ``` ### 使用 ``` # 计算上下文向量 print("\n\b计算上下文向量--------") # 为了复现,设置随机种子 torch.manual_seed(123) # 词元的数量,用于设置掩码矩阵 context_length = batch.shape[1] print("context_length:", context_length) # 传入参数:实例化 mha = MultiHeadAttentionWrapper( d_in, d_out, context_length, 0.0, num_heads=2 ) # 计算上下文向量 context_vecs = mha(batch) print("上下文向量:\n", context_vecs) print("上下文向量.shape:", context_vecs.shape) ``` 执行结果: ``` 上下文向量: tensor([[[-0.4519, 0.2216, 0.4772, 0.1063], [-0.5874, 0.0058, 0.5891, 0.3257], [-0.6300, -0.0632, 0.6202, 0.3860], [-0.5675, -0.0843, 0.5478, 0.3589], [-0.5526, -0.0981, 0.5321, 0.3428], [-0.5299, -0.1081, 0.5077, 0.3493]], [[-0.4519, 0.2216, 0.4772, 0.1063], [-0.5874, 0.0058, 0.5891, 0.3257], [-0.6300, -0.0632, 0.6202, 0.3860], [-0.5675, -0.0843, 0.5478, 0.3589], [-0.5526, -0.0981, 0.5321, 0.3428], [-0.5299, -0.1081, 0.5077, 0.3493]]], grad_fn=) 上下文向量.shape: torch.Size([2, 6, 4]) ``` 结果中的 `context_vecs` 张量维度 `[2, 6, 4]`的解释: - 第一维:因为 **有两个输入文本**(输入文本是重复的,所以这些上下文向量完全相同) - 第二维:表示每个输入中的 `6` 个词元 - 第三维:表示每个词元的四维嵌入 原文出处:http://www.malaoshi.top/show_1GW2dIcfTKjw.html