标准自注意力(Self-Attention) 和因果注意力(Causal Attention) 的区别 作者:马育民 • 2026-01-21 09:39 • 阅读:10002 # 介绍 **因果注意力(Causal Attention)** 和**标准自注意力(Self-Attention)** 的核心区别,是理解Transformer **编码器**/**解码器 差异** 的关键。 # 核心区别 总结最关键的差异,帮你快速建立认知: | 维度 | 标准自注意力(Self-Attention) | 因果注意力(Causal Attention) | |------|--------------------------------|--------------------------------| | **信息访问范围** | 全序列可见(当前位置能看到所有位置,包括过去、现在、未来) | 仅单向可见(当前位置只能看到自身及之前的位置,未来位置被屏蔽) | | **核心目的** | 理解上下文(建模整个序列的双向依赖) | 自回归生成(保障时序因果性,杜绝未来信息泄露) | | **掩码使用** | 无掩码(或仅填充掩码 Padding Mask) | 强制使用下三角因果掩码(Causal Mask) | | **典型应用** | Transformer **编码器**(如BERT、ViT) | Transformer **解码器**(如GPT、TTS模型) | | **计算逻辑** | 注意力分数无限制 | 未来位置注意力分数置为负无穷,softmax后权重为0 | # 例子 ### 一、标准自注意力(双向):“上帝视角” 在做阅读理解,看一句话 `我喜欢吃苹果`,可以**同时看前后所有字** 来理解意思——这就是标准自注意力的逻辑。 #### 1. 可视化 假设有一个4个词的句子:`[我, 喜, 欢, 吃]`(对应位置1/2/3/4),标准自注意力的“可见范围”如下: | 正在处理的位置(行) | 能参考的位置(列) | 通俗解释 | |----------------------|--------------------|----------| | 位置1(我) | 1、2、3、4 | 看“我”时,能同时看到“我、喜、欢、吃” | | 位置2(喜) | 1、2、3、4 | 看“喜”时,能同时看到“我、喜、欢、吃” | | 位置3(欢) | 1、2、3、4 | 看“欢”时,能同时看到“我、喜、欢、吃” | | 位置4(吃) | 1、2、3、4 | 看“吃”时,能同时看到“我、喜、欢、吃” | **提示:** **没有任何限制,全可见**,就像读文章时可以回头看、也可以提前看后面的内容。 用“画图”的方式更简单(√=能看,×=不能看): ``` 位置1:√ √ √ √ 位置2:√ √ √ √ 位置3:√ √ √ √ 位置4:√ √ √ √ ``` **上图的意思:** - **行** = 「当前正在处理的位置」(比如你现在要生成第2个词,这一行就代表“第2个词”); - **列** = 「能看到/参考的位置」(比如生成第2个词时,能看第1个词、还是能看第3个词); - **√/×** = 「能不能参考」(√=能看,×=不能看); - **注意力权重** = 「参考的程度」(数值越大,越关注这个位置)。 ### 二、因果注意力(单向):“只能按顺序看” 比如在写作文时,手写一句话 `我喜欢吃苹果`,**写第2个字 `喜` 时,还没写 `欢` 和 `吃` ,所以只能看到已经写的 `我` 和 `喜`** ——这就是因果注意力的逻辑。 #### 1. 可视化(同样用“句子+表格”) 还是句子`[我, 喜, 欢, 吃]`(位置1/2/3/4),因果注意力的“可见范围”: | 正在处理的位置(行) | 能参考的位置(列) | 通俗解释 | |----------------------|--------------------|----------| | 位置1(我) | 1 | 写“我”时,只看到自己(还没写任何其他字) | | 位置2(喜) | 1、2 | 写“喜”时,只能看到已经写的“我”和刚写的“喜” | | 位置3(欢) | 1、2、3 | 写“欢”时,只能看到“我、喜、欢”(还没写“吃”) | | 位置4(吃) | 1、2、3、4 | 写“吃”时,能看到前面所有已写的字 | **提示:** **只能看“当前位置及左边(历史)”,右边(未来)全屏蔽**,这就是“下三角掩码”的本质——因为这个图案像一个“下三角”: ### 用“画图”的方式 ``` 位置1:√ × × × (只能看自己) 位置2:√ √ × × (能看1、2) 位置3:√ √ √ × (能看1、2、3) 位置4:√ √ √ √ (能看1、2、3、4) ``` 将上图简化如下(下三角区域全是√,上三角全是×,所以叫“下三角掩码”): ``` √ × × × √ √ × × √ √ √ × √ √ √ √ ``` **上图的意思:** - **行** = 「当前正在处理的位置」(比如你现在要生成第2个词,这一行就代表“第2个词”); - **列** = 「能看到/参考的位置」(比如生成第2个词时,能看第1个词、还是能看第3个词); - **√/×** = 「能不能参考」(√=能看,×=不能看); - **注意力权重** = 「参考的程度」(数值越大,越关注这个位置)。 --- # 代码分步拆解 (只看核心,去掉复杂计算) 把之前的代码拆成**3步**,只关注“掩码如何工作”,不纠结数学计算,新手也能看懂: ```python import torch # 第一步:先造一个“注意力分数矩阵”(假设还没加掩码) # 形状:[1, 4, 4] → 1个样本,4个位置,每个位置对4个位置的分数 scores = torch.tensor([[ [1.0, 2.0, 3.0, 4.0], # 位置1对1/2/3/4的分数 [1.0, 2.0, 3.0, 4.0], # 位置2对1/2/3/4的分数 [1.0, 2.0, 3.0, 4.0], # 位置3对1/2/3/4的分数 [1.0, 2.0, 3.0, 4.0] # 位置4对1/2/3/4的分数 ]]) print("第一步:原始注意力分数(没加掩码)") print(scores.squeeze()) # 去掉多余维度,方便看 print("="*50) # 第二步:生成因果掩码(下三角掩码) seq_len = 4 # 生成4x4的全1矩阵,然后取下三角(tril=triangular lower) causal_mask = torch.tril(torch.ones(seq_len, seq_len)) print("第二步:因果掩码(1=能看,0=不能看)") print(causal_mask) print("="*50) # 第三步:用掩码屏蔽未来位置(把0的位置换成负无穷) # masked_fill:把mask中为0的位置,替换成-∞ masked_scores = scores.masked_fill(causal_mask == 0, float('-inf')) print("第三步:加掩码后的分数(-∞=看不到)") print(masked_scores.squeeze()) ``` #### 输出结果(逐行解释): ``` 第一步:原始注意力分数(没加掩码) tensor([[1., 2., 3., 4.], # 位置1能看到所有位置的分数 [1., 2., 3., 4.], # 位置2能看到所有位置的分数 [1., 2., 3., 4.], # 位置3能看到所有位置的分数 [1., 2., 3., 4.]]) # 位置4能看到所有位置的分数 ================================================== 第二步:因果掩码(1=能看,0=不能看) tensor([[1., 0., 0., 0.], # 位置1只能看自己(1),其他都不能看(0) [1., 1., 0., 0.], # 位置2能看1、2(1),不能看3、4(0) [1., 1., 1., 0.], # 位置3能看1、2、3(1),不能看4(0) [1., 1., 1., 1.]]) # 位置4能看所有(1) ================================================== 第三步:加掩码后的分数(-∞=看不到) tensor([[ 1., -inf, -inf, -inf], # 位置1的未来位置(2/3/4)变成-∞ [ 1., 2., -inf, -inf], # 位置2的未来位置(3/4)变成-∞ [ 1., 2., 3., -inf], # 位置3的未来位置(4)变成-∞ [ 1., 2., 3., 4.]]) # 位置4没有未来位置,不变 ``` #### 最后一步:softmax转换为权重(为什么-∞会变成0) - `softmax`的作用:把分数转换成“0~1”的权重,总和为1; - `softmax(-∞) = 0`:负无穷的位置权重会变成0,相当于“完全不关注”; 比如位置2的分数`[1,2,-inf,-inf]`经过softmax后: ``` softmax([1,2,-inf,-inf]) = [e^1/(e^1+e^2), e^2/(e^1+e^2), 0, 0] ≈ [0.27, 0.73, 0, 0] ``` 👉 结果就是:位置2只关注位置1(0.27)和位置2(0.73),完全不关注3、4(0)——这就是“因果”的本质:只看历史,不看未来。 --- # 代码对比 下面用极简代码实现两种注意力,直观看到差异(基于PyTorch): ```python import torch import torch.nn.functional as F # 模拟输入:batch_size=1,seq_len=4,hidden_dim=8 batch_size, seq_len, d_k = 1, 4, 8 q = k = v = torch.randn(batch_size, seq_len, d_k) # 简化:q=k=v # ---------------------- 1. 标准自注意力 ---------------------- def standard_self_attention(q, k, v): # 计算注意力分数 scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) # 无掩码:所有位置都能关注 attn_weights = F.softmax(scores, dim=-1) output = torch.matmul(attn_weights, v) return output, attn_weights # ---------------------- 2. 因果注意力 ---------------------- def causal_attention(q, k, v): # 计算注意力分数 scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) # 生成因果掩码:下三角为True(允许访问),上三角为False(屏蔽) causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool() # 屏蔽未来位置:将False的位置置为负无穷 scores = scores.masked_fill(~causal_mask, float('-inf')) # softmax后,负无穷位置权重为0 attn_weights = F.softmax(scores, dim=-1) output = torch.matmul(attn_weights, v) return output, attn_weights # 运行并打印注意力权重(直观看差异) _, standard_weights = standard_self_attention(q, k, v) _, causal_weights = causal_attention(q, k, v) print("=== 标准自注意力权重矩阵(全连接)===") print(standard_weights.squeeze().round(2)) # squeeze去掉batch维度 print("\n=== 因果注意力权重矩阵(下三角)===") print(causal_weights.squeeze().round(2)) ``` ### 输出结果示例(直观差异): ``` === 标准自注意力权重矩阵(全连接)=== tensor([[0.23, 0.27, 0.25, 0.25], # 位置1能关注所有4个位置 [0.24, 0.26, 0.25, 0.25], # 位置2能关注所有4个位置 [0.25, 0.25, 0.25, 0.25], # 位置3能关注所有4个位置 [0.26, 0.24, 0.25, 0.25]]) # 位置4能关注所有4个位置 === 因果注意力权重矩阵(下三角)=== tensor([[1.00, 0.00, 0.00, 0.00], # 位置1只能关注自己 [0.48, 0.52, 0.00, 0.00], # 位置2能关注1+2 [0.32, 0.34, 0.34, 0.00], # 位置3能关注1+2+3 [0.25, 0.25, 0.25, 0.25]]) # 位置4能关注1+2+3+4 ``` # 应用场景差异 1. **标准自注意力**: - 用于Transformer**编码器**(如BERT、RoBERTa、ViT); - 适合 **理解类** 任务:文本分类、命名实体识别、图像理解(无时序限制); - 核心价值:充分利用全上下文信息,提升理解精度。 2. **因果注意力**: - 用于Transformer**解码器**(如GPT、LLaMA、Transformer-TTS); - 适合 **生成类** 任务:文本生成、机器翻译(解码阶段)、语音合成、时间序列预测; - 核心价值:保障生成的时序逻辑,避免模型“提前看到未来信息”而作弊。 # 容易混淆 - 因果注意力**本质是带掩码的自注意力**:它不是独立于自注意力的新机制,而是自注意力的“时序限制版本”; - 填充掩码(Padding Mask)≠ 因果掩码:填充掩码是为了屏蔽padding的无效位置(两类注意力都可能用),因果掩码是为了屏蔽未来位置(仅因果注意力用); - 交叉注意力(Cross Attention)≠ 两者:交叉注意力是“查询(Q)来自解码器,键值(K/V)来自编码器”,和自注意力(Q/K/V都来自同一序列)是不同维度的概念。 # 总结 1. 核心差异:**信息访问范围**——标准自注意力是双向全序列可见,因果注意力是单向仅历史可见; 2. 用途差异:标准自注意力用于“理解”,因果注意力用于“生成”; 3. 实现差异:因果注意力仅比标准自注意力多了一步“下三角掩码屏蔽未来位置”的操作。 原文出处:http://www.malaoshi.top/show_1GW2dDHAuGtE.html