torchのnarrowメソッド

torchのnarrowメソッドがパッと見よくわからんので試してみた結果のメモ。
[Tensor] narrow(dim, index, size)

↑のサイトを参考にした。
narrowメソッドは何をするかというと、テンソル内部のある次元のある部分を切り取って返す関数である。ちなみに参照渡しなので注意。

[Tensor] narrow(dim, index, size)の3つの引数はそれぞれ、
・dim:切り取る次元(行列(2次テンソル)なら行方向で切るか(1)列方向で切るか(2)の2種類 (ただし、dim > 0)
・index:切り取る部分の始点 (ただし、1 < index <= Tensorsize(dim))
・size:切り取る部分のサイズ(始点からどこまで切るか) (ただし、index+size-1 <= Tensorsize(dim) )
※ここで、Tensorsize(dim)はテンソルのdim次元の最大のindexの長さとする

例(参考サイトから引用)

th> x = torch.Tensor(5,6):zero()
                                                                      [0.0001s]
th> print(x)
 0  0  0  0  0  0
 0  0  0  0  0  0
 0  0  0  0  0  0
 0  0  0  0  0  0
 0  0  0  0  0  0
[torch.DoubleTensor of size 5x6]

                                                                      [0.0003s]
th> y = x:narrow(1,2,3)
                                                                      [0.0001s]
th> print(y)
 0  0  0  0  0  0
 0  0  0  0  0  0
 0  0  0  0  0  0
[torch.DoubleTensor of size 3x6]

                                                                      [0.0003s]
th> y:fill(1)
 1  1  1  1  1  1
 1  1  1  1  1  1
 1  1  1  1  1  1
[torch.DoubleTensor of size 3x6]

                                                                      [0.0003s]
th> print(x)
 0  0  0  0  0  0
 1  1  1  1  1  1
 1  1  1  1  1  1
 1  1  1  1  1  1
 0  0  0  0  0  0
[torch.DoubleTensor of size 5x6]

                                                                      [0.0003s]

これは、2次元テンソルつまり行列の例で、y=x:narrow(1,2,3)は「行方向に2行目からスタートして2+3-1行目まで切ったものをyに代入する」という意味になる。
これは参照渡しになっているので、y:fill(1)するともとのxもそれに応じて変化する。

試しに列方向にも切ってみる。

th> r = x:narrow(2,3,4)
                                                                      [0.0001s]
th> r
 0  0  0  0
 1  1  1  1
 1  1  1  1
 1  1  1  1
 0  0  0  0
[torch.DoubleTensor of size 5x4]

                                                                      [0.0003s]
th> r:fill(2)
 2  2  2  2
 2  2  2  2
 2  2  2  2
 2  2  2  2
 2  2  2  2
[torch.DoubleTensor of size 5x4]

                                                                      [0.0003s]
th> x
 0  0  2  2  2  2
 1  1  2  2  2  2
 1  1  2  2  2  2
 1  1  2  2  2  2
 0  0  2  2  2  2
[torch.DoubleTensor of size 5x6]

                                                                      [0.0003s]

これは「列方向に3行目からスタートして3+4-1行目まで切ったものをrに代入する」という操作に対応する。エラーが起きないように注意が必要な操作だ・・・