🦁

tensorflow2.xで推論中、メモリ使用量がどんどん増えていく事象への対処

2023/05/07に公開

tensorflow2.xでModel.predictを実行するたびにメモリ使用量がもりもり増えていく現象が観測されている。
まず思いつくのはgc.collect()であるが、今回はこれが効かない。

参考:
https://github.com/keras-team/keras/issues/13118

1. Sequential APIの場合

まず、提案されている解決法がModel.predict_on_batch()を用いる方法。

for x in dataset:
    y_pred = model.predct_on_batch(x)

Functional APIを用いていると、

AttributeError: 'Functional' object has no attribute 'predct_on_batch'

とのエラーが出る。

2. セッションをクリアする方法

from keras import backend as K
import gc
...
y_pred = model.predict(x)
K.clear_session()
gc.collect()

参照:
https://github.com/keras-team/keras/issues/13118#issuecomment-1236169753

3. 上記が無理だった場合(特にFunctional API)

この場合は、直接__call__を呼べばいい
※バグあり。追記参照

for x in dataset:
    y_pred = model(x, training=False).numpy()

追記:
tensorflowのバージョンによってはmodel.predict()とmodel(x)の結果が異なるバグが生じている。(tensorflow-macos=2.12)
https://github.com/tensorflow/tensorflow/issues/32799

Discussion