PyTorch と in-place operation のエラー
本記事で伝えたいこと
-
PyTorchで
RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.
という Error が出たら、in-place operationが原因。 -
in-place operation とは、
x.add_()
,x += 0.5
,x[mask] = 0.5
のような、
テンソルの値を直接書き換える演算。 -
上記のErrorは in-place operation を以下の方法で置き換えることで解決
-
x.add_()
はx = x.add()
とする。 -
x += 1
はx = x + 1
とする。 -
mask を使う場合は、
元の値 + mask * (変更後 - 元の値)
とする>>> import torch >>> x = torch.tensor([0.2, 0.4, 0.6, 0.8], requires_grad=True) >>> mask = x > 0.5 >>> # x[mask] = 0.0 # Error >>> x = x + mask * (torch.zeros_like(x) - x) >>> print(x) tensor([0.2000, 0.4000, 0.0000, 0.0000], grad_fn=<AddBackward0>)
-
slice や index を使用したい場合は、slice や index から mask を作成する。
# indice -> mask indice = torch.tensor([2, 3], dtype=torch.long) mask = torch.zeros_like(x).bool().scatter_(0, indice, torch.ones_like(indice).bool()) # slice -> mask mask = torch.zeros_like(x).bool() mask[2:] = True
-
in-place operation とは
in-place operation とは、新しくコピーを作ることなく、オブジェクトの中身を変更する演算のことである。in-place operation を実行するための演算子を in-place operator という。(意訳)
通常の代入演算 x = x + 1
と、in-place operator による演算 x += 1
の比較。
python では、iadd()
や +=
などを使うと in-place operation になる。
PyTorch で in-place operation
PyTorch で in-place operation をする場合は以下のような方法がある。(他にもあるかも。)
-
x.add_()
,x.mul_()
などの通常のメソッドに_
を付けたメソッドを使用する。 -
x.data
を使う。(正確には in-place operation とは異なりそう。) - indexやmaskを使用する
-
+=
,*=
などを使う
x.add_()
を使う
1. >>> import torch
>>> x = torch.tensor([1.])
>>> x.add(1)
>>> print(x)
tensor([1.])
通常の add
を使用した場合は、x
は変更されない。
x.add(1)
という新しい変数が確保されているイメージ。
>>> import torch
>>> x = torch.tensor([1.])
>>> x.add_(1) # in-place operation
>>> print(x)
tensor([2.])
add_
(ハイフン有り)を使用した場合は、x
の値が変更される。
x.data
を使う(正確にはin-placeではなさそう)
2. import torch
def my_add_one(input):
input = input + 1
x = torch.tensor([1.])
my_add_one(x)
print(x)
# tensor([1.])
pythonの関数の引数は参照渡しなので、input
と x
は id
が同じだが、
input = input + 1
の左側の変数 input
は id
は上で述べた2つと異なる。
import torch
def my_add_one(input):
input.data = input + 1
x = torch.tensor([1.])
my_add_one(x)
print(x)
# tensor([2.])
input.data = input + 1
とすると x
の中身が書き換えられる。
3. indexやmaskを使用する
>>> import torch
>>> x = torch.tensor([1, 2, 3])
>>> mask = x <= 2
>>> x[mask] = 10
>>> print(x)
tensor([10, 10, 3])
x[0] = 2
や x[mask] = 2
のように、index、slice、maskなどを使用して選択したテンソルの一部分に代入すると、値が置き換えられる。
in-place operation を勾配計算では避ける
発生するエラー
PyTorchで requires_grad=True
としたテンソルに対して、勾配計算のために何らかの in-place な計算処理を行うと以下のようなエラーが発生する。
例えば、以下のようにして requires_grad=True
としたテンソルに in-place operation をすると、上述のエラーが発生する。
import torch
x = torch.rand(20, requires_grad=True)
x.add_(1) # Error
x[1] = 0 # Error
x += 1 # Error
x.data = x + 1 # これはErrorがでない
PyTorchの公式によると、勾配計算のためのforwardの計算が壊れてしまうため(意訳)in-place operation による Error が発生すると書いてある。また、Errorが発生しない場合もあるとも書かれている(よくわからない。。。)。
対処法(再掲)
-
x.add_()
はx = x.add()
とする。 -
x += 1
はx = x + 1
とする。 -
mask を使用したい場合は、例えば下の式に示すように
とa を新しい値b とa' で置き換えたい場合、変更後の値から元の値を引いた値に対して mask との要素積をとり、変更前の値に足せば良い。(b' と示したのはどうでも良い値)?
>>> import torch
>>> x = torch.tensor([0.2, 0.4, 0.6, 0.8], requires_grad=True)
>>> mask = x > 0.5
>>> # x[mask] = 0.0 # Error
>>> x = x + mask * (torch.zeros_like(x) - x)
>>> print(x)
tensor([0.2000, 0.4000, 0.0000, 0.0000], grad_fn=<AddBackward0>)
上のコードは、テンソルの値が
# indice -> mask
indice = torch.tensor([2, 3], dtype=torch.long)
mask = torch.zeros_like(x).bool().scatter_(0, indice, torch.ones_like(indice).bool())
# slice -> mask
mask = torch.zeros_like(x).bool()
mask[2:] = True
slice や index を使用したい場合は、slice や index から mask を作成し、上述の方法で対処すれば良い。
参考文献
Discussion