Numpy数组索引的一些理解

最近在学Pytorch,因切片规则和Numpy相同,特在此记录。

多维数组

因为Pytorch中tensor的切片规则和Numpy相同,所以代码中我们用Pytorch来演示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 随机构建一个三维张量,每个维度有三个元素
x = torch.rand(3,3,3)
'''
tensor([[[0.7181, 0.7586, 0.0087],
[0.2987, 0.6424, 0.0378],
[0.8481, 0.4940, 0.9429]],

[[0.9045, 0.6331, 0.2378],
[0.7245, 0.3930, 0.2547],
[0.7040, 0.7866, 0.5593]],

[[0.0584, 0.6290, 0.1482],
[0.6745, 0.9922, 0.1043],
[0.5660, 0.7678, 0.0531]]])
'''
# 取第一二三维度上索引均为1的元素
x[1,1,1]
# tensor(0.3930)

# 第一个维度取所有的元素,第二个维度取索引为1的元素
x[:, 1]
'''
tensor([[0.2987, 0.6424, 0.0378],
[0.7245, 0.3930, 0.2547],
[0.6745, 0.9922, 0.1043]])
'''

总结来说,就是张量x有几个维度,切片操作时x[]方括号里面就可以有几个下标。:代表取所有,后面的下标不写就默认是全选,相当于:

参考