大模型原理:多头注意力-优化 作者:马育民 • 2026-01-21 15:34 • 阅读:10001 # 优化 在 [大模型原理:多头注意力](https://www.malaoshi.top/show_1GW2dIcfTKjw.html "大模型原理:多头注意力") 中,实现多头注意力,这是通过实例化并组合多个CausalAttention对象来完成的 维护两个单独的类 `MultiHeadAttentionWrapper` 和 `CausalAttention` ,不如将这两个概念合并成一个 `MultiHeadAttention` 类。 此外,还会进行一些其他调整,以更 **高效地实现多头注意力**。 ### 对比 在 `MultiHeadAttentionWrapper` 中,多头机制通过创建 `CausalAttention` 对象的列表( `self.heads` )来实现,每个对象代表一个独立的注意力头,如下图:  `CausalAttention` 类单独执行注意力机制,每个头的结果会被拼接。 相比之下,`MultiHeadAttention` 类会将 **多头功能整合到一个类内**。通过重新调整投影后的查询张量、键张量和值张量的形状,将输入分为多个头,然后在计算注意力后合并这些头的结果。 # 实现封装类 ``` import torch import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, d_in, d_out, context_length ,dropout, num_heads, qkv_bias=False): """ :param d_in: :param d_out: :param context_length: :param dropout: :param num_heads: :param qkv_bias: """ super().__init__() assert (d_out % num_heads == 0),\ "d_out must be divisible by num_heads" self.d_out = d_out self.num_heads = num_heads # 减少投影维度以匹配所需的输出维度 self.head_dim = d_out // num_heads self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) # 使用一个线性层来组合头的输出 self.out_proj = nn.Linear(d_out, d_out) self.dropout = nn.Dropout(dropout) self.register_buffer( "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1) ) def forward(self, x): b, num_tokens, d_in = x.shape keys = self.W_key(x) queries = self.W_query(x) values = self.W_value(x) # 通过添加一个 num_heads维度来隐式地分隔矩阵。然后展开最后一个维度:(b,numtokens,dout) -> (b,num_tokens,num_heads,head_dim) keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) values = values.view(b, num_tokens, self.num_heads, self.head_dim) queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) # 从形状(b,num_tokens,num_heads,head_dim) 转换到(b,num_heads,num_tokens,head_dim) keys= keys.transpose(1, 2) queries = queries.transpose(1, 2) values = values.transpose(1, 2) # 计算每个头的注意力分数(点积) attn_scores = queries @ keys.transpose(2, 3) # 被截断为词元数量的掩码 mask_bool = self.mask.bool()[:num_tokens, :num_tokens] # 使用掩码来填充注意力分数 attn_scores.masked_fill_(mask_bool, -torch.inf) attn_weights = torch.softmax( attn_scores / keys.shape[-1]**0.5, dim=-1) attn_weights = self.dropout(attn_weights) # 张量形状:(b, num_tokens,n_heads,head_dim) context_vec = (attn_weights @ values).transpose(1, 2) # 组合头,其中 self.d_out= self.num_heads *self.head_dim context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) # 添加一个可选的线性投影 context_vec = self.out_proj(context_vec) return context_vec ``` ### 解释 在 `MultiHeadAttention` 类中,初始化了一个更大的权重矩阵 $$W\_q$$,并只与输入矩阵进行一次矩阵乘法操作,得到一个查询矩阵 $$Q$$,然后将查询矩阵分割成了 $$Q\_1$$ 和 $$Q\_2$$ (下)。 对键矩阵和值矩阵的操作与之类似,为了减少视觉混乱,这里没有展示 在PyTorch中,通过 `.view()` 方法进行张量重塑,使用 `.transpose()` 方法进行张量转置,实现了对查询张量、键张量和值张量的分割。输入首先经过线性层进行变换(针对查询矩阵、键矩阵和值矩阵),然后被重塑为多个头。  关键操作是将 `d_out` 维度分割为 `num_heads` 和 `head_dim` ,其中 `head_dim = d_out / num_heads`。 此分割通过 `.view` 方法来实现:维度为 `(b,num_tokens, d_out)` 的张量被重塑后的维度为 `(b, num_tokens,num_heads, head_dim)` 然后转置张量,使 `num_heads` 维度置于 `num_tokens` 维度之前,从而形成一个 `(b, num_heads, num_tokens, head_dim)` 的形状。这种转置对于正确对齐不同头的查询矩阵、键矩阵和值矩阵,以及有效地执行批处理矩阵乘法至关重要 # 使用 ### 准备代码 ``` # 输入的嵌入向量 inputs = torch.tensor( [[0.43, 0.15, 0.89], # Your (x^1) [0.55, 0.87, 0.66], # journey (x^2) [0.57, 0.85, 0.64], # starts (x^3) [0.22, 0.58, 0.33], # with (x^4) [0.77, 0.25, 0.10], # one (x^5) [0.05, 0.80, 0.55]] # step (x^6) ) # 模拟批量输入,因为真正执行时,是通过数据加载器,批量读取的 batch = torch.stack((inputs, inputs), dim=0) print("批量数据:\n", batch) print("批量数据的shape:", batch.shape) ``` ### 使用 ``` # 计算上下文向量 print("\n\b计算上下文向量--------") # 为了复现,设置随机种子 torch.manual_seed(123) """ context_length:词元的数量,用于设置掩码矩阵 d_in:输入嵌入维度 """ batch_size, context_length, d_in = batch.shape # 输出嵌入维度d_out=2 d_out = 2 print("batch_size:", batch_size) print("context_length:", context_length) print("d_in:", d_in) # 传入参数:实例化 mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2) # 计算上下文向量 context_vecs = mha(batch) print("上下文向量:\n", context_vecs) print("上下文向量.shape:", context_vecs.shape) ``` 执行结果: ``` 上下文向量: tensor([[[0.3190, 0.4858], [0.2943, 0.3897], [0.2856, 0.3593], [0.2693, 0.3873], [0.2639, 0.3928], [0.2575, 0.4028]], [[0.3190, 0.4858], [0.2943, 0.3897], [0.2856, 0.3593], [0.2693, 0.3873], [0.2639, 0.3928], [0.2575, 0.4028]]], grad_fn=) 上下文向量.shape: torch.Size([2, 6, 2]) ``` 原文出处:http://www.malaoshi.top/show_1GW2dNTdlgrD.html