pytorch api文档:张量的 .masked_fill_()方法-根据布尔类型的张量填充 作者:马育民 • 2026-01-21 16:57 • 阅读:10000 # 介绍 `.masked_fill_()` 方法,这个方法是 `.masked_fill()` 的**原地操作版本**,核心作用和用法与 `.masked_fill()` 一致,但会 **直接修改原张量** 而非返回新张量,在内存优化场景中常用。 ### 作用 与 `.masked_fill()` 逻辑完全相同——根据布尔型 `mask` 张量,将原张量中 `mask` 为 `True` 的位置替换为 `value`,`False` 位置保留原值。 ### 区别 - `.masked_fill()`:返回 **新张量**,原张量保持不变; - `.masked_fill_()`:**原地修改原张量**,返回修改后的原张量(无新内存分配)。 ### 命名规则 PyTorch 中所有以下划线 `_` 结尾的方法都是 **原地操作(in-place)**,会直接修改调用者本身,而非创建副本。 # 语法 ``` Tensor.masked_fill_(mask, value) → Tensor ``` **参数说明**(与 `.masked_fill()` 完全一致): - `mask`:布尔型张量,形状需与原张量广播兼容; - `value`:要填充的数值,需与原张量数据类型兼容。 # 例子 通过示例直观对比 `.masked_fill()` 和 `.masked_fill_()` 的差异,理解“原地修改”的含义: ```python import torch # 1. 准备基础数据 # 模拟注意力分数(batch_size=1, seq_len=3) attn_scores = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) # 生成布尔掩码(True=需要屏蔽的位置) mask = torch.tensor([[False, True, True], [False, False, True], [False, False, False]]) ``` ### 非原地操作:masked_fill(返回新张量,原张量不变) ``` attn_scores_new = attn_scores.masked_fill(mask, -1e9) print("=== 非原地操作(masked_fill)===") print("原张量(未修改):") print(attn_scores) print("\n返回的新张量:") print(attn_scores_new) # 输出: # 原张量(未修改): # tensor([[1., 2., 3.], # [4., 5., 6.], # [7., 8., 9.]]) # 返回的新张量: # tensor([[ 1.0000e+00, -1.0000e+09, -1.0000e+09], # [ 4.0000e+00, 5.0000e+00, -1.0000e+09], # [ 7.0000e+00, 8.0000e+00, 9.0000e+00]]) ``` ### 原地操作:masked\_fill\_(直接修改原张量) ``` # 先重置张量(避免受上一步影响) attn_scores = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) attn_scores.masked_fill_(mask, -1e9) # 无新张量,直接修改原张量 print("\n=== 原地操作(masked_fill_)===") print("原张量(已被修改):") print(attn_scores) # 输出(与新张量内容一致,但原张量已变): # tensor([[ 1.0000e+00, -1.0000e+09, -1.0000e+09], # [ 4.0000e+00, 5.0000e+00, -1.0000e+09], # [ 7.0000e+00, 8.0000e+00, 9.0000e+00]]) ``` # 适用场景 `.masked_fill_()` 主要用于**内存敏感场景**(如大模型训练/推理): - 处理超大张量(如 `[1024, 1024]` 的注意力分数)时,原地操作可避免创建副本,节省内存; - 临时张量复用:多次修改同一张量时,原地操作无需重复分配内存。 # 注意(原地操作的坑) 原地操作虽节省内存,但有严格使用限制,新手需重点注意: ### 梯度计算风险 若张量参与梯度计算(`requires_grad=True`),原地操作可能破坏梯度流,导致反向传播报错或结果错误。 ❌ 错误示例: ```python x = torch.tensor([1.0, 2.0], requires_grad=True) x.masked_fill_(torch.tensor([False, True]), 0.0) # 原地修改梯度张量,可能导致梯度错误 ``` ✅ 正确做法: 对需要梯度的张量,优先使用 `.masked_fill()`(非原地),仅对无梯度的张量(如掩码、缓冲区)使用 `.masked_fill_()`。 ### 张量共享内存风险 若多个变量指向同一张量,原地修改会同时改变所有变量的值: ```python a = torch.tensor([1.0, 2.0]) b = a # b与a共享内存 a.masked_fill_(torch.tensor([True, False]), 0.0) print(b) # 输出:tensor([0., 2.]) → b也被修改 ``` ### 与模型缓冲区的配合 模型中注册的缓冲区(`register_buffer`)无梯度,适合用 `.masked_fill_()` 原地修改(如动态更新掩码): ```python class MyModel(nn.Module): def __init__(self): super().__init__() self.register_buffer("mask", torch.ones(3, 3)) def forward(self, x): self.mask.masked_fill_(torch.triu(torch.ones(3,3), 1).bool(), -1e9) # 原地修改缓冲区 return x * self.mask ``` # 总结 1. `.masked_fill_()` 是 `.masked_fill()` 的原地操作版本,核心逻辑一致,区别仅在于是否修改原张量; 2. 原地操作节省内存,但对梯度张量使用会破坏梯度流,新手优先用 `.masked_fill()`; 3. 仅对无梯度的张量(如缓冲区、临时掩码)使用 `.masked_fill_()`,避免踩梯度和内存共享的坑。 原文出处:http://www.malaoshi.top/show_1GW2dK2sR135.html