In PyTorch and Numpy, slicing format differs from that of Python list.
tensor[:2]
stands for the first and second samples.
tensor = torch.ones([5,32,32,1])
lists = list(range(0,5120))
print(tensor[:2].shape) # torch.Size([2, 32, 32, 1])
tensor[:2,1]
stands for the second row of the first two samples.
print(tensor[:2, 1].shape) # torch.Size([2, 32, 1])
tensor[:2][1]
stands for the first sample of the first two samples.
print(tensor[:2][1].shape) # torch.Size([32, 32, 1)
lists[:2][1]
stands for the first scalar of the first two scalars.
print(lists[:2][1]) # 1
PREVIOUSLinux