pytorch api文档:张量的 .stride() 方法-查看步长,判断是否连续 作者:马育民 • 2026-01-21 19:45 • 阅读:10000 # 介绍 张量的 `.stride()` 方法,是理解张量内存布局(连续/非连续)的“底层钥匙”——直接告诉我们:遍历张量时,每跳过一个维度的元素,需要在内存中移动多少步(元素个数)。 ### 作用 返回一个元组,元组中第 `i` 个元素表示“遍历张量第 `i` 维时,每移动一个位置,需要在内存中跳过的元素数量”(步长)。 - 元组长度 = 张量的维度数(比如2维张量返回2个值,3维返回3个值); - 步长的单位是“元素个数”(不是字节),PyTorch 会自动根据元素类型(如 float32=4字节)换算成实际内存字节数; - 该方法是**只读检测**,不修改张量,无任何内存开销。 # 语法 ``` Tensor.stride() → Tuple[int, ...] ``` # 例子 ```python # 2维张量:shape [2, 3],内存中存储为 [1,2,3,4,5,6] x = torch.tensor([[1, 2, 3], [4, 5, 6]]) print("张量shape:", x.shape) # (2, 3) print("张量stride:", x.stride()) # (3, 1) ``` 对 `stride=(3,1)` 的解读: - **维度1(列维度)步长=1**:要从第0列跳到第1列(如1→2),只需在内存中往后跳 `1` 个元素; - **维度0(行维度)步长=3**:要从第0行(`[1,2,3]`)跳到第1行(`[4,5,6]`),需要在内存中往后跳 `3` 个元素(从1→2→3→4,共3步); 这是**连续张量的典型步长**:满足 `前一维步长 = 后一维步长 × 后一维大小`,如下: ``` 3 = 1 × 3 // 维度0步长 = 维度1步长 × 维度1大小 ``` # 例子:步长与张量内存的关系 通过不同场景的示例,直观理解步长的含义,以及它如何决定张量是否连续: ### 示例1:基础2维张量(连续) ```python x = torch.tensor([[1, 2, 3], [4, 5, 6]]) print("=== 基础2维连续张量 ===") print("shape:", x.shape) # (2, 3) print("stride:", x.stride()) # (3, 1) print("是否连续:", x.is_contiguous()) # True # 验证遍历逻辑:按维度遍历的每个元素,对应内存中的位置 # 遍历顺序:(0,0)→(0,1)→(0,2)→(1,0)→(1,1)→(1,2) # 内存位置:0→1→2→3→4→5(步长3和1完美匹配) indices = [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)] memory_pos = [i * x.stride()[0] + j * x.stride()[1] for i,j in indices] print("遍历元素的内存位置:", memory_pos) # [0,1,2,3,4,5] ``` ### 示例2:转置后的张量(非连续) ```python x_trans = x.transpose(0, 1) # 交换行和列,shape变为(3,2) print("\n=== 转置后的非连续张量 ===") print("shape:", x_trans.shape) # (3, 2) print("stride:", x_trans.stride()) # (1, 3) print("是否连续:", x_trans.is_contiguous()) # False # 验证遍历逻辑:按新维度遍历,内存位置跳着走 # 遍历顺序:(0,0)→(0,1)→(1,0)→(1,1)→(2,0)→(2,1) # 内存位置:0→3→1→4→2→5(步长1和3导致跳着读) indices_trans = [(0,0), (0,1), (1,0), (1,1), (2,0), (2,1)] memory_pos_trans = [i * x_trans.stride()[0] + j * x_trans.stride()[1] for i,j in indices_trans] print("遍历元素的内存位置:", memory_pos_trans) # [0,3,1,4,2,5] ``` ### 示例3:3维张量的步长(连续 vs 非连续) ```python # 3维连续张量:shape [2,3,4] x_3d = torch.randn(2, 3, 4) print("\n=== 3维连续张量 ===") print("shape:", x_3d.shape) # (2, 3, 4) print("stride:", x_3d.stride()) # (12, 4, 1) # 验证连续规则:12=4×3,4=1×4 → 满足 # permute交换维度1和2,变为非连续 x_3d_perm = x_3d.permute(0, 2, 1) print("\n=== 3维permute后非连续张量 ===") print("shape:", x_3d_perm.shape) # (2, 4, 3) print("stride:", x_3d_perm.stride()) # (12, 1, 4) # 验证连续规则:12≠1×4 → 不满足,非连续 ``` ### 示例4:切片操作导致的步长变化(非连续) ```python # 隔列取数:步长改变,变为非连续 x_slice = x[:, ::2] # shape (2,2),取第0、2列 print("\n=== 切片后的非连续张量 ===") print("shape:", x_slice.shape) # (2, 2) print("stride:", x_slice.stride()) # (3, 2) print("是否连续:", x_slice.is_contiguous()) # False ``` # 应用场景 ### 判断张量是否连续(底层逻辑) `.is_contiguous()` 的本质就是检查步长是否满足「连续规则」: - 对 N 维张量,从最后一维往前推: 第 `i` 维步长 = 第 `i+1` 维步长 × 第 `i+1` 维大小; - 最后一维的步长必须为 1(因为遍历最后一维是逐个元素读)。 比如: - 连续张量 `x.stride()=(3,1)`:3=1×3,最后一维步长=1 → 满足; - 非连续张量 `x_trans.stride()=(1,3)`:1≠3×2,最后一维步长=3≠1 → 不满足。 ### 解释 .view() 报错的原因 `.view()` 要求张量内存连续,本质是要求步长符合连续规则——非连续张量的步长混乱,无法按新形状“线性解读”内存,因此报错。 ### 优化内存访问效率 步长越小,内存访问越连续,计算效率越高: - 连续张量的步长是“紧凑的”(最后一维步长=1,前一维步长=后一维大小×后一维步长); - 非连续张量的步长会导致“内存跳跃”,缓存命中率低,计算速度慢。 # 注意 1. **步长与内存共享**:`transpose()`/`permute()`/切片等操作只会修改步长(解读规则),不会改变内存中元素的存储顺序,因此和原张量共享内存; 2. **`.contiguous()` 会重置步长**:非连续张量调用 `.contiguous()` 后,会重新排列内存,步长恢复为连续规则(如 `x_trans.contiguous().stride()` 变为 `(2,1)`); 3. **步长与张量类型无关**:无论 float32、int64 还是 CUDA 张量,步长的计算逻辑完全一致(单位都是“元素个数”,不是字节); 4. **空张量/标量的步长**:标量(0维张量)的 `.stride()` 返回空元组,空张量的步长根据形状适配(如 `torch.empty(2,0).stride()` 返回 `(0,1)`)。 # 总结 1. `.stride()` 返回张量各维度的遍历步长,是理解张量内存布局的核心指标; 2. 连续张量的步长满足「前一维步长 = 后一维步长 × 后一维大小」,最后一维步长=1; 3. 非连续张量(转置/切片/permute)的步长打破上述规则,导致内存访问跳跃,`.view()` 报错; 4. 步长是 `.is_contiguous()` 和 `.contiguous()` 的底层判断依据,掌握步长就能彻底理解张量的内存逻辑。 原文出处:http://www.malaoshi.top/show_1GW2dNDAEfI1.html