torch.roll(input, shifts, dims=None)?
這個函數(shù)是用來移位的,是順移。input是咱們要移動的tensor向量,shifts是要移動到的位置,要移動去哪兒,dims是值在什么方向上(維度)去移動。比如2維的數(shù)據(jù),那就兩個方向,橫著或者豎著。最關(guān)鍵的一句話,所有操作針對的是第一行或者第一列,下面舉例子給大家做解釋,自己慢慢體會
import torch
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
print("")
print(x)
y = torch.roll(x, 1, 0)
print("")
print(y)
輸出:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[7, 8, 9],
[1, 2, 3],
[4, 5, 6]])
torch.roll(x, 1, 0) 這行代碼的意思就是把x的第一行(0維度)移到1這個位置上,其他位置的數(shù)據(jù)順移。
x——咱們要移動的向量
1——第一行向量要移動到的最終位置
0——從行的角度去移動
import torch
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
print("")
print(x)
y = torch.roll(x, -1, 1)
print("")
print(y)
輸出:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[2, 3, 1],
[5, 6, 4],
[8, 9, 7]])
torch.roll(x, -1, 1) 意思就是把x的第一列(1維度)移到-1這個位置(最后一個位置)上,其他位置的數(shù)據(jù)順移。
shifts和dims可以是元組,其實就是分步驟去移動
import torch
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
print("")
print(x)
y = torch.roll(x, (0,1), (1,1))
print("")
print(y)
輸出:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[3, 1, 2],
[6, 4, 5],
[9, 7, 8]])
torch.roll(x, (0,1), (1,1)) :
首先,針對元組第一個元素,把x的第一列(1維度)移到0這個位置(已經(jīng)在0這個位置,因此原地不動)上,其他位置的數(shù)據(jù)順移。(所有數(shù)據(jù)原地不動)文章來源:http://www.zghlxwxcb.cn/news/detail-856848.html
然后,針對元組第二個元素,把a的第一列(1維度)移到1這個位置上,其他位置的數(shù)據(jù)順移。文章來源地址http://www.zghlxwxcb.cn/news/detail-856848.html
到了這里,關(guān)于pytorch中torch.roll用法說明的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!