深度学习PyTorch笔记(3):Tensor的索引
这是《动手学深度学习》(PyTorch版)(Dive-into-DL-PyTorch)的学习笔记,里面有一些代码是我自己拓展的。
其他笔记在专栏 深度学习 中。
1.2.2 索引
裁剪
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
torch.clamp(x, 2, 7) #对x进行在2和7之间的裁剪
tensor([[2, 2, 3],
[4, 5, 6],
[7, 7, 7]])
x[行切片,列切片]
x = torch.tensor([[1,2,3,4], [3,4,5,6], [0,9,0,1], [8,2,1,3]])
print(x)
tensor([[1, 2, 3, 4],
[3, 4, 5, 6],
[0, 9, 0, 1],
[8, 2, 1, 3]])
print(x[1:,]) #逗号前面是行切片,索引为1的行开始切片到最后,当对列没有修改时,逗号可省略。x[1:]=x[1:,]=x[1:,:]
tensor([[3, 4, 5, 6],
[0, 9, 0, 1],
[8, 2, 1, 3]])
print(x[1:3,]) #只切片索引为1的行
tensor([[3, 4, 5, 6],
[0, 9, 0, 1]])
print(x[1:3,1:3]) #切片索引为1、2和行和列
tensor([[4, 5],
[9, 0]])
print(x[:,1:3]) #但是这里如果是x[,1:3]会报错
tensor([[2, 3],
[4, 5],
[9, 0],
[2, 1]])
print(x[::2, ::3]) #跳着访问,第0行和第2行,第0列和第3列
tensor([[1, 4],
[0, 1]])
print(x[-1]) #可以用负索引
tensor([8, 2, 1, 3])
x[1, 2] = 10 #根据索引更改
print(x)
x[0:2, :] = 12
print(x)
tensor([[ 1, 2, 3, 4],
[ 3, 4, 10, 6],
[ 0, 9, 0, 1],
[ 8, 2, 1, 3]])
tensor([[12, 12, 12, 12],
[12, 12, 12, 12],
[ 0, 9, 0, 1],
[ 8, 2, 1, 3]])
可以用切片来修改:
x = torch.tensor([[1,2,3,4], [3,4,5,6], [0,9,0,1], [8,2,1,3]])
y = x[0,]
print(x)
print(y)
y += 10
print(y)
print(x[0])
x[0] -= 5
print(x[0])
print(y)
tensor([[1, 2, 3, 4],
[3, 4, 5, 6],
[0, 9, 0, 1],
[8, 2, 1, 3]])
tensor([1, 2, 3, 4])
tensor([11, 12, 13, 14])
tensor([11, 12, 13, 14])
tensor([6, 7, 8, 9])
tensor([6, 7, 8, 9])
- 这里需要注意!!!索引出来的结果与原数据内存共享,修改了一个,另一个也修改,所以后面的结果是同步的
PyTorch还提供了一些高级的选择函数:
#index_select(input, dim, index):在指定维度dim上选取(dim=0,按行选取;dim=1,按列选取),index是索引的序号。
x = torch.tensor([[1,2,3,4], [3,4,5,6], [0,9,0,1], [8,2,1,3]])
print(x)
y = torch.index_select(x, 0, torch.tensor([0, 2])) #按行选取,索引第0行和第2行
print(y)
z = torch.index_select(x, 1, torch.tensor(1)) #按列选取,第1列
print(z)
tensor([[1, 2, 3, 4],
[3, 4, 5, 6],
[0, 9, 0, 1],
[8, 2, 1, 3]])
tensor([[1, 2, 3, 4],
[0, 9, 0, 1]])
tensor([[2],
[4],
[9],
[2]])
#masked_select(input, mask):mask取出的是布尔值索引(掩码)(即真为1,假为0),然后根据取出的非0掩码从中取值
x = torch.tensor([[0,2,4], [1,3,5]])
print(x)
y = torch.masked_select(x, x<5) #x<5时,布尔值索引是[1,1,1,1,1,0],所以取出[0, 2, 4, 1, 3]
print(y)
tensor([[0, 2, 4],
[1, 3, 5]])
tensor([0, 2, 4, 1, 3])
#nonzero(input):取出非0元素的下标
x = torch.tensor([[0,2,4], [1,0,5]])
print(torch.nonzero(x))
tensor([[0, 1],
[0, 2],
[1, 0],
[1, 2]])
#gather(input, dim, index):根据index,在dim维度上选取数据,输出的size与index一样
x = torch.tensor([[0,2,4], [1,3,5], [2,1,0]])
print(x)
index = torch.tensor([[2,0], [1,2], [0,1]])
print(index)
a = torch.gather(x, 0, index)
b = torch.gather(x, 1, index)
print(a)
print(b)
tensor([[0, 2, 4],
[1, 3, 5],
[2, 1, 0]])
tensor([[2, 0],
[1, 2],
[0, 1]])
tensor([[2, 2],
[1, 1],
[0, 3]])
tensor([[4, 0],
[3, 5],
[2, 1]])
a中,dim=0,表示在行上取数据。那么就以列作为取值的基准。index中第一列的[2,1,0]表示a的第0列是x的第0列中,行号为[2,1,0]的数,以此类推。
由于index只有两列,所以a的结果不涉及x的第3列。
b中,dim=1,表示在列上取数据。那么就以行作为取值的基准。index中第一行的[2,0]表示b的第0行是x的第0行中,列号为[2,0]的数,以此类推。
评论(0)
您还未登录,请登录后发表或查看评论