

新闻资讯
技术学院本文介绍一种无需显式循环即可从 pytorch 二维张量每行中按指定起始索引和固定长度提取子张量的方法,利用 `torch.arange` 与 `torch.gather` 实现全向量化索引。
在深度学习与科学计算中,常需对批量数据(如 N×D 的特征矩阵)按行进行变起点、定长度的切片操作。例如:给定一个形状为 (N, D) 的张量 data,以及长度为 N 的起始索引张量 start_idx,要求对第 i 行提取 data[i, start_idx[i]:start_idx[i] + L],其中 L 为统一子序列长度。若使用 Python 循环或列表推导,不仅低效,还破坏了张量计算的并行性。
PyTorch 提供了高效的向量化方案:构造索引张量 + gather 沿指定维度收集。核心思路是:
注意:start_idx 必须为整数类型(如 torch.long),浮点型索引不被支持;且所有子序列长度必须一致(L 固定),否则无法构成规则索引张量。
以下是完整可运行示例:
import torch def gather_rows_by_range(data: torch.Tensor, start_idx: torch.Tensor, length: int, dim: int = 1) -> torch.Tensor: """ 从 data 的每行(若 dim=1)或每列(若 dim=0)中提取长度为 length 的连续子序列, 起始位置由 start_idx 指定(按行/列对齐)。 Args: data: 输入张量,形状 (N, D) start_idx: 起始索引,形状 (N,),dtype=torch.long length: 子序列固定长度(标量) dim: 沿哪一维采样(默认 1,即按行取列) Returns: 输出张量,形状 (N, length) """ # 为每行生成 [s, s+1, ..., s+length-1] ranges = torch.stack([ torch.arange(s, s + length, device=data.device, dtype=torch.long) for s in start_idx ]) return data.gather(dim, ranges) # 示例数据 data = torch.tensor([[ 1., 2., 3., 4., 5.], [ 6., 7., 8., 9., 10.], [11., 12., 13., 14., 15.]]) start_idx = torch.tensor([0, 3, 1], dtype=torch.long) result = gather_rows_by_range(data, start_idx, length=2, dim=1) print(result) # 输出: # tensor([[ 1., 2.], # [ 9., 10.], # [12., 13.]])
✅ 优势总结:
⚠️ 注意事项:
# 更高效的索引张量构造(推荐用于大数据量) index_tensor = start_idx.unsqueeze(1) + torch.arange(length, device=data.device) result = data.gather(1, index_tensor)
该方法是 PyTorch 中实现“行级动态切片”的标准实践,在 Transformer 的 sliding window attention、时序模型的 patching 等场景中广泛应用。