Pythonのデコレータの基礎と応用
アドベントカレンダー「ほぼ横浜の民」の3日目の記事です。
今年は Python のデコレータについて書いています。かなり雑に説明すると、デコレータとは関数についている @staticmethod
や @classmethod
のことです。OSS を見ているとカスタムのデコレータもあって、これってどのように機能しているんだろう?と気になることが多くなってきたので少し勉強してみました。
この記事を読んでわかること
- そもそもデコレータって何?
- デコレータってどうやって定義するの?
- 引数付きのデコレータはどうやって定義する?
- 引数付きのデコレータはなぜネストしている?
- なぜ
functools.wraps
を使ってデコレータを定義する必要があるの? - 機械学習関連の OSS ではどのように使われている?
デコレータとは?
デコレータは、関数をラップすることで別の関数を返却する関数です。実際にはクラスに対しても同様の概念がありますが、関数に適用することが多いため以下では関数デコレータについて説明します。以下の my_decorator
がデコレータです。@staticmethod
や @classmethod
の形で見たことがあるのではないでしょうか。
@my_decorator
def hoge():
pass
ここで、my_decorator
は関数であり、@
+ デコレータ
という記法で対象の関数 hoge
に適用できます。この記法で行われている変換は以下とほぼ等価になります。
hoge = my_decorator(hoge)
デコレータを使うことにより、関数呼び出しの前後で処理を追加できるようになります。ではどのように処理を追加できるのでしょうか?より具体的な例を見ていきましょう。
最もシンプルなデコレータ
最もシンプルな例として、関数の位置引数を numpy.float16
に変換するデコレータ auto_fp16
を作成してみます。本当は kwargs
に対しても処理を書いたり、入力の型を確認したりしないといけないのですが、今回は簡単のために省略しています。
def auto_fp16(func):
def wrapper(*args, **kwargs):
new_args = [np.float16(a) for a in args] # 追加処理
result = func(*new_args, **kwargs) # 対象関数の処理
return result
return wrapper
最初に言ったように、デコレータは関数をラップして別の関数を返します。したがって、auto_fp16
の内部で適用対象の関数 func
をラップした関数 wrapper
が定義されており、これがデコレータ auto_fp16
の返り値となっています。関数 wrapper
の内部では追加処理が記述されており、ここで位置引数が numpy.float16
に変換されて元々の関数 func
の引数に渡されています。
ちなみに wrapper(*args, **kwargs)
の時点でなんじゃこりゃ?となってしまった方はこちらを参照してください。簡単に言ってしまえば、こう書くことで任意の引数を受け取ることができるようになります。
では実際の出力を確認してみましょう。
import numpy as np
@auto_fp16
def add(x, y):
return x + y
ret = add(np.float32(1.0), np.float32(2.0))
print(f'output: {ret} ({ret.dtype})')
# output: 3.0 (float16)
関数の入力として numpy.float32
型で渡されていた2変数が numpy.float16
に変換されて計算されていることが確認できました。最初にも説明した通り、デコレータ auto_fp16
は以下の変換とほぼ等価です。
add = auto_fp16(add)
引数付きのデコレータ
さっきのデコレータに引数を追加するだけではだめ?
次に、デコレータに引数を渡すことでその挙動を制御してみます。今回は、出力を numpy.float32
に戻すかどうかの論理値 out_fp32
をオプションに設定します。イメージとしては以下の変換を行うことになります。
add = auto_fp16(out_fp32=True)(add)
さて、デコレータはどのように定義できるでしょうか?先ほどまでの例にキーワード引数がついただけじゃないかと思ってこう考えた方もいたかもしれません。しかし、これはうまく動作しません。
def auto_fp16(func, out_fp32=False):
def wrapper(*args, **kwargs):
new_args = [np.float16(a) for a in args]
result = func(*new_args, **kwargs)
return result
return wrapper
実際に動かしてみると func
が渡されていないよと怒られてしまいます。
@auto_fp16(out_fp32=True)
def add(x, y):
return x + y
TypeError: auto_fp16() missing 1 required positional argument: 'func'
上記のデコレータの定義だと、以下のように定義すれば期待通りに動作します(冷静に考えれば普通ですよね)。ただ、デコレータとして @
を用いた記法が使えなくなるのは致命的ですね。
add = auto_fp16(add, out_fp32=True)
デコレータの記法を使うためには、def auto_fp16(func)
の部分では追加の変数を渡さない、つまりは関数オブジェクトのみを渡すようにする必要があります。したがって、デコレータのコア部分はそのままにして、外側から追加の変数を渡してあげる必要があるわけです。こんなわざとらしい言い方をしたのでお気づきの方もいるかと思いますが、デコレータ関数をクロージャ化すればOKです。クロージャって何?と思った方のために次で少し補足します。
クロージャを活用して引数を渡す
少し話は逸れますが、クロージャについて説明するために以下のような関数を考えます。
def func():
x = 2
def double(y):
return y * x
return double
f = func()
f(3)
# 6
結論から言うと、この double
がクロージャです。関数の中の関数であること(グローバルスコープ以外で定義されていること)によって、定義時の自身を含むスコープの情報を記憶することができます。すなわち、func
を通った後でも変数 x
の情報を保持したまま関数 double
の処理を実行することができます。
では、話を戻してデコレータをクロージャ化してみます。
def auto_fp16(out_fp32=False):
def auto_fp16_wrapper(func):
def wrapper(*args, **kwargs):
new_args = [np.float16(a) for a in args]
result = func(*new_args, **kwargs)
if out_fp32:
result = np.float32(result)
return result
return wrapper
return auto_fp16_wrapper
上記のようにデコレータの外側で auto_fp16
を定義して追加変数を渡します。これにより auto_fp16_wrapper
を auto_fp16
のスコープの情報を記憶するクロージャにできました。auto_fp16_wrapper
はクロージャであるため、auto_fp16
のスコープを抜けた後も引数 out_fp32
の情報を保持することができます。最終的に、この関数 auto_fp16
が実行しているメインの処理部分は auto_fp16_wrapper
であり、これは引数なしの auto_fp16
と同じになります。複雑化したようでやっていることは変わらないんですね。では実際の挙動を見てみましょう。
@auto_fp16(out_fp32=True)
def add(x, y):
return x + y
ret = add(np.float32(1.0), np.float32(2.0))
print(f'output: {ret} ({ret.dtype})')
# output: 3.0 (float32)
out_fp32
のオプションが有効化されたことにより、add
関数の出力が numpy.float32
に変換されていることが確認できます。
functools.wraps によるデコレータの定義
ここまでで関数デコレータの基本的な説明が終わりました。しかし、実はこのままでは問題があります。add
関数の名前を出力してみましょう。
print(add.__name__)
# wrapper
忘れてしまいそうですが、 add
関数は wrapper
関数で置き換えられてしまっており、正しい情報を参照することができません。ここでの置き換えとは、add
関数のメタデータが wrapper
関数のメタデータに置き換えられていることを指します。そのため例えば以下のような問題が存在します。
- 組み込み関数
help
が機能しない
help(add)
# Help on function wrapper in module __main__:
#
# wrapper(*args, **kwargs)
- オブジェクトシリアライザーは元の関数の位置を決定できずにエラーになる
import pickle
pickle.dumps(add)
# Traceback ...
# AttributeError: Can't pickle local object 'auto_fp16.<locals>.auto_fp16_wrapper.<locals>.wrapper'
元の関数を正しく参照できないことによる問題は、特にデバッガ機能では致命的です。このような問題を引き起こさないためには、デコレータを適用する関数 (add
) のメタデータを、デコレータが返却する関数 (wrapper
) にコピーする必要があります。これ自分でやるの?と思った方もいらっしゃるかと思いますが、Python では wraps
という便利なヘルパー関数が準備されています。使い方は以下のようになります。
from functools import wraps
def auto_fp16(out_fp32=False):
def auto_fp16_wrapper(func):
@wraps(func) # 追加
def wrapper(*args, **kwargs):
new_args = [np.float16(a) for a in args]
result = func(*new_args, **kwargs)
if out_fp32:
result = np.float32(result)
return result
return wrapper
return auto_fp16_wrapper
たった1行追加するだけで、add
関数のメタデータが wrapper
関数にもコピーされるようになります。さきほどの問題も以下のように解決されます(もう一度 add
関数を定義しなおす必要がありますがここでは省略します)。以下のように期待通りの動作になりました。
help(add)
# Help on function add in module __main__:
#
# add(x, y)
import pickle
pickle.dumps(add)
# b'\x80\x03c__main__\nadd\nq\x00.'
OSS におけるデコレータの実装を覗いてみる
最後に、ここまで説明してきた知識を活かして、OSSにおけるデコレータ実装を見てみます。ここでは物体検出フレームワークである MMDetection のデコレータを見てみます(正確には MMCV で定義されています)。
auto_fp16
1. このデコレータは物体検出モデルの forward
メソッド等に適用することで float16 での学習を可能にします(この部分は理解できなくてもOKです)。使うときはこんな感じです。
class MyModule(nn.Module)
# Convert x and y to fp16
@auto_fp16()
def forward(self, x, y):
pass
では、このデコレータのソースコードを見てみましょう。以下では説明のために一部のみを抜粋しています。
def auto_fp16(apply_to=None, out_fp32=False):
def auto_fp16_wrapper(old_func):
@functools.wraps(old_func)
def new_func(*args, **kwargs):
# 適用対象が float16 変換に対応しているか確認
# float16 に変換する変数名のリスト args_to_cast を取得
# ...(省略)...
# 位置引数 args を必要に応じて float16 に変換
new_args = []
if args:
arg_names = args_info.args[:len(args)]
for i, arg_name in enumerate(arg_names):
if arg_name in args_to_cast:
new_args.append(
cast_tensor_type(args[i], torch.float, torch.half))
else:
new_args.append(args[i])
# キーワード引数 kwargs を必要に応じて float16 に変換
# ...(省略)...
# 変換された引数を元々の関数に適用
#(Pytorch version >= 1.6.0 では torch.cuda.amp.autocast を利用)
if (TORCH_VERSION != 'parrots' and
digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True):
output = old_func(*new_args, **new_kwargs)
else:
output = old_func(*new_args, **new_kwargs)
# out_fp32 の値に応じて出力を float32 に変換
if out_fp32:
output = cast_tensor_type(output, torch.half, torch.float)
return output
return new_func
return auto_fp16_wrapper
おや?と思われた方多いんじゃないでしょうか。なんだか今回の記事で作ってきたデコレータにそっくりですね。実は、OSS の実例を交えた説明にしたいという思いから、MMDetection のデコレータを基に例を作成していました。細かい処理の分岐などに違いはありますが、ベースは標準的な関数デコレータとして実装されていることがわかると思います。複雑そうに見えても理解してしまえば作りは至って単純ですね。
register_module
2. 任意のクラスオブジェクトを MMDetection のレジストリに登録する際にこのデコレータを使います。例えば、ユーザーが作成した新しいモデルや損失関数などを既存レジストリに追加登録することが可能です。この機能により、MMDetection 内の既存モジュールと同様にカスタムモジュールも簡単に管理できます。まずは使用例を見てみましょう。
MODELS = Registry('models')
@MODELS.register_module()
Class MyNet:
pass
mynet = MODELS.build(dict(type='MyNet'))
こんな感じでカスタムモデルをモデルレジストリに登録し、共通のAPIからモデルインスタンスを作成できるようになりました。ここで何か違和感があるのではないでしょうか?そうです、上の例では初めてデコレータがクラスに対して適用されています。一番最初に「実際にはクラスに対しても同様の概念があります」と書いた伏線がここで回収されることになります。ただ考え方としてはこれまでと同じで、以下のようにクラスをラップしているイメージになります。
MyNet = MODELS.register_module(MyNet)
実際のコードはやや複雑なので、デコレータ部分が実装されている Registry
クラスの簡略化版を以下に示します。
class Registry:
def __init__(self, name, ...):
self._name = name
self._module_dict = dict()
def register_module(self):
def _register(cls):
name = cls.__name__
self._module_dict[name] = cls
return cls
return _register
# ...(省略)...
register_module
メソッドがクラスに対するデコレータになっています。やっていることはいたってシンプルで、対象レジストリクラスの _module_dict
プロパティに追加したいクラスを登録しているだけです。デコレータはこのようにクラスや関数の登録にも利用することができ、モジュールとしての汎用的なインターフェースを整える助けにもなります。
せっかくなので MMDetection のレジストリがどのように実装されているのかもう少し見てみます。MMDetection では、レジストリからのインスタンス作成向けに Registry.build
メソッドを準備しており、ユーザーが独自の build
メソッドを定義しない場合には build_from_cfg
関数が呼び出されます。この関数が基本的な Registry.build
メソッドにあたるので、この関数を例に実装を見てみましょう。以下にレジストリからのインスタンス作成に関する簡略化したコードを示しています。
class Registry:
# ...(省略)...
def get(self, key):
return self._module_dict[key]
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
def build_from_cfg(cfg, registry, default_args=None):
# cfg, default_args により渡されたパラメータを辞書型変数 args に格納
# ...(省略)...
obj_type = args.pop('type')
obj_cls = registry.get(obj_type)
return obj_cls(**args)
build_from_cfg
では、レジストリに保存されている _module_dict
を参照してキー名 (obj_type
)からクラスオブジェクト (obj_cls
) を呼び出し、必要な引数を与えてインスタンスを返却しています。これはデコレータを用いたレジストリ実装のひな型にできそうなので今後使ってみたいですね。本題のデコレータからは少し話が逸れましたが、デコレータを用いたレジストリ機能により、汎用的な実装を実現していることを確認できました。
まとめ
- デコレータは関数をラップすることで別の関数を返却する関数
- デコレータの内部にラップ関数が配置される形で定義できる
- デコレータをクロージャ化することで引数を渡す
- 元々の関数と引数を同時に渡すことで
@
記法が使いにくくなるためクロージャ化する - 元々の関数の情報を保持するために
functools.wraps
を使ってデコレータを定義する - OSS におけるデコレータは入力の型変換や関数・クラスの登録等に使われている
所感
調べる前までは正直どこで使いどころがあるの?と疑問に思っていましたが、OSS の例などを通して、実際の使いどころのイメージがわいてきました。他にも関数の処理速度のベンチマークを取るために使われることも多いみたいですね。具体的な利用例については引き続き調査を継続しつつ、今後1年以内にどこかで使ってみたいです。
Discussion