Pytorch學(xué)習(xí)筆記3:索引與切片測(cè)試代碼
#添加到學(xué)習(xí)筆記2末尾,直接運(yùn)行。代碼意義可以看注釋。
print('——————————索引與切片——————————')
idx=torch.rand(4,3,28,28)
print('tensor shape:',idx[0,0].shape)#取第1張圖片,第1個(gè)通道的shape
print('tensor shape:',idx[0,0,27,27].shape)#注意索引是從0開(kāi)始的
print('tensor shape:',idx[0,0,27,27])#取第1張圖片,第1個(gè)通道,第28行,第28列的元素,是一個(gè)標(biāo)量
print('tensor shape:',idx[:1,:2,:,:].shape)#:表示一個(gè)箭頭,:1表示0到1不包含1,:1表示0到2不包含2,單獨(dú)的冒號(hào)表示取全部
print('tensor shape:',idx[:1,1:,:,:].shape)#:表示一個(gè)箭頭,1:表示從1開(kāi)始到最后,包含1
print('tensor shape:',idx[:1,-1:,:,:].shape)#:表示一個(gè)箭頭,-1:表示從-1開(kāi)始到最后
print('tensor shape:',idx[:1,:-1,:,:].shape)#:表示一個(gè)箭頭,:-1表示到-1,不包含-1
print('tensor shape:',idx[:,:,0:28:2,0:28:2].shape)#:隔行采樣,不包含28
print('tensor shape:',idx[:,:,0:28:1,0:28:1].shape)#:逐行采樣,不包含28
tmp=torch.tensor([0,2])
print('tensor shape:',idx.index_select(0,torch.arange(2)).shape)
tmp1=torch.arange(8)#生成0-7,一共8個(gè)數(shù)字的1維tensor
print(tmp1)
print('tensor shape:',idx.index_select(0,tmp).shape)#index_select第二個(gè)參數(shù)必須是tensor,不能是list
print('tensor shape:',idx[0,...,::2].shape)#...表示取中間全部維度,::2表示各行采樣,0表示取第0張圖片
tmp2=torch.randn(4,3)#用mask取大于0.5的數(shù)值
print(tmp2)
mask=tmp2.ge(0.5)
print(mask)
tmp3=torch.masked_select(tmp2,mask)
print(tmp3)#輸出是1維向量
tmp4=torch.randn(2,3)#把tensor打平成1維,然后根據(jù)給出的索引取值
print(tmp4)
tmp5=torch.take(tmp4,torch.tensor([1,3,5]))
print(tmp5)
print('——————————索引與切片——————————')