🐍

Pythonのデコレータの基礎と応用

2022/01/02に公開

アドベントカレンダー「ほぼ横浜の民」の3日目の記事です。
今年は Python のデコレータについて書いています。かなり雑に説明すると、デコレータとは関数についている @staticmethod@classmethod のことです。OSS を見ているとカスタムのデコレータもあって、これってどのように機能しているんだろう?と気になることが多くなってきたので少し勉強してみました。

この記事を読んでわかること

  • そもそもデコレータって何?
  • デコレータってどうやって定義するの?
  • 引数付きのデコレータはどうやって定義する?
  • 引数付きのデコレータはなぜネストしている?
  • なぜ functools.wraps を使ってデコレータを定義する必要があるの?
  • 機械学習関連の OSS ではどのように使われている?

デコレータとは?

デコレータは、関数をラップすることで別の関数を返却する関数です。実際にはクラスに対しても同様の概念がありますが、関数に適用することが多いため以下では関数デコレータについて説明します。以下の my_decorator がデコレータです。@staticmethod@classmethod の形で見たことがあるのではないでしょうか。

@my_decorator
def hoge():
    pass

ここで、my_decorator は関数であり、@ + デコレータ という記法で対象の関数 hoge に適用できます。この記法で行われている変換は以下とほぼ等価になります。

hoge = my_decorator(hoge)

デコレータを使うことにより、関数呼び出しの前後で処理を追加できるようになります。ではどのように処理を追加できるのでしょうか?より具体的な例を見ていきましょう。

最もシンプルなデコレータ

最もシンプルな例として、関数の位置引数を numpy.float16 に変換するデコレータ auto_fp16 を作成してみます。本当は kwargs に対しても処理を書いたり、入力の型を確認したりしないといけないのですが、今回は簡単のために省略しています。

def auto_fp16(func):
    def wrapper(*args, **kwargs):
        new_args = [np.float16(a) for a in args]    # 追加処理
        result = func(*new_args, **kwargs)          # 対象関数の処理
        return result
    return wrapper

最初に言ったように、デコレータは関数をラップして別の関数を返します。したがって、auto_fp16 の内部で適用対象の関数 func をラップした関数 wrapper が定義されており、これがデコレータ auto_fp16 の返り値となっています。関数 wrapper の内部では追加処理が記述されており、ここで位置引数が numpy.float16 に変換されて元々の関数 func の引数に渡されています。

ちなみに wrapper(*args, **kwargs) の時点でなんじゃこりゃ?となってしまった方はこちらを参照してください。簡単に言ってしまえば、こう書くことで任意の引数を受け取ることができるようになります。

では実際の出力を確認してみましょう。

import numpy as np

@auto_fp16
def add(x, y):
    return x + y
    
ret = add(np.float32(1.0), np.float32(2.0))
print(f'output: {ret} ({ret.dtype})')

# output: 3.0 (float16)

関数の入力として numpy.float32 型で渡されていた2変数が numpy.float16 に変換されて計算されていることが確認できました。最初にも説明した通り、デコレータ auto_fp16 は以下の変換とほぼ等価です。

add = auto_fp16(add)

引数付きのデコレータ

さっきのデコレータに引数を追加するだけではだめ?

次に、デコレータに引数を渡すことでその挙動を制御してみます。今回は、出力を numpy.float32 に戻すかどうかの論理値 out_fp32 をオプションに設定します。イメージとしては以下の変換を行うことになります。

add = auto_fp16(out_fp32=True)(add)

さて、デコレータはどのように定義できるでしょうか?先ほどまでの例にキーワード引数がついただけじゃないかと思ってこう考えた方もいたかもしれません。しかし、これはうまく動作しません。

def auto_fp16(func, out_fp32=False):
    def wrapper(*args, **kwargs):
        new_args = [np.float16(a) for a in args]
        result = func(*new_args, **kwargs)
        return result
    return wrapper

実際に動かしてみると func が渡されていないよと怒られてしまいます。

@auto_fp16(out_fp32=True)
def add(x, y):
    return x + y

TypeError: auto_fp16() missing 1 required positional argument: 'func'

上記のデコレータの定義だと、以下のように定義すれば期待通りに動作します(冷静に考えれば普通ですよね)。ただ、デコレータとして @ を用いた記法が使えなくなるのは致命的ですね。

add = auto_fp16(add, out_fp32=True)

デコレータの記法を使うためには、def auto_fp16(func) の部分では追加の変数を渡さない、つまりは関数オブジェクトのみを渡すようにする必要があります。したがって、デコレータのコア部分はそのままにして、外側から追加の変数を渡してあげる必要があるわけです。こんなわざとらしい言い方をしたのでお気づきの方もいるかと思いますが、デコレータ関数をクロージャ化すればOKです。クロージャって何?と思った方のために次で少し補足します。

クロージャを活用して引数を渡す

少し話は逸れますが、クロージャについて説明するために以下のような関数を考えます。

def func():
    x = 2
    def double(y):
        return y * x
    return double
    
f = func()
f(3)

# 6

結論から言うと、この double がクロージャです。関数の中の関数であること(グローバルスコープ以外で定義されていること)によって、定義時の自身を含むスコープの情報を記憶することができます。すなわち、func を通った後でも変数 x の情報を保持したまま関数 double の処理を実行することができます。

では、話を戻してデコレータをクロージャ化してみます。

def auto_fp16(out_fp32=False):
    def auto_fp16_wrapper(func):
        def wrapper(*args, **kwargs):
            new_args = [np.float16(a) for a in args]
            result = func(*new_args, **kwargs)
            if out_fp32:
                result = np.float32(result)
            return result
        return wrapper
    return auto_fp16_wrapper

上記のようにデコレータの外側で auto_fp16 を定義して追加変数を渡します。これにより auto_fp16_wrapperauto_fp16 のスコープの情報を記憶するクロージャにできました。auto_fp16_wrapper はクロージャであるため、auto_fp16 のスコープを抜けた後も引数 out_fp32 の情報を保持することができます。最終的に、この関数 auto_fp16 が実行しているメインの処理部分は auto_fp16_wrapper であり、これは引数なしの auto_fp16 と同じになります。複雑化したようでやっていることは変わらないんですね。では実際の挙動を見てみましょう。

@auto_fp16(out_fp32=True)
def add(x, y):
    return x + y
    
ret = add(np.float32(1.0), np.float32(2.0))
print(f'output: {ret} ({ret.dtype})')

# output: 3.0 (float32)

out_fp32 のオプションが有効化されたことにより、add 関数の出力が numpy.float32 に変換されていることが確認できます。

functools.wraps によるデコレータの定義

ここまでで関数デコレータの基本的な説明が終わりました。しかし、実はこのままでは問題があります。add 関数の名前を出力してみましょう。

print(add.__name__)

# wrapper

忘れてしまいそうですが、 add 関数は wrapper 関数で置き換えられてしまっており、正しい情報を参照することができません。ここでの置き換えとは、add 関数のメタデータが wrapper 関数のメタデータに置き換えられていることを指します。そのため例えば以下のような問題が存在します。

  1. 組み込み関数 help が機能しない
help(add)

# Help on function wrapper in module __main__:
#
# wrapper(*args, **kwargs)
  1. オブジェクトシリアライザーは元の関数の位置を決定できずにエラーになる
import pickle
pickle.dumps(add)

# Traceback ...
# AttributeError: Can't pickle local object 'auto_fp16.<locals>.auto_fp16_wrapper.<locals>.wrapper'

元の関数を正しく参照できないことによる問題は、特にデバッガ機能では致命的です。このような問題を引き起こさないためには、デコレータを適用する関数 (add) のメタデータを、デコレータが返却する関数 (wrapper) にコピーする必要があります。これ自分でやるの?と思った方もいらっしゃるかと思いますが、Python では wraps という便利なヘルパー関数が準備されています。使い方は以下のようになります。

from functools import wraps

def auto_fp16(out_fp32=False):
    def auto_fp16_wrapper(func):
        @wraps(func)    # 追加
        def wrapper(*args, **kwargs):
            new_args = [np.float16(a) for a in args]
            result = func(*new_args, **kwargs)
            if out_fp32:
                result = np.float32(result)
            return result
        return wrapper
    return auto_fp16_wrapper

たった1行追加するだけで、add 関数のメタデータが wrapper 関数にもコピーされるようになります。さきほどの問題も以下のように解決されます(もう一度 add 関数を定義しなおす必要がありますがここでは省略します)。以下のように期待通りの動作になりました。

help(add)

# Help on function add in module __main__:
#
# add(x, y)
import pickle
pickle.dumps(add)

# b'\x80\x03c__main__\nadd\nq\x00.'

OSS におけるデコレータの実装を覗いてみる

最後に、ここまで説明してきた知識を活かして、OSSにおけるデコレータ実装を見てみます。ここでは物体検出フレームワークである MMDetection のデコレータを見てみます(正確には MMCV で定義されています)。

1. auto_fp16

このデコレータは物体検出モデルの forward メソッド等に適用することで float16 での学習を可能にします(この部分は理解できなくてもOKです)。使うときはこんな感じです。

class MyModule(nn.Module)

    # Convert x and y to fp16
    @auto_fp16()
    def forward(self, x, y):
        pass

では、このデコレータのソースコードを見てみましょう。以下では説明のために一部のみを抜粋しています。

fp16_utils.py
def auto_fp16(apply_to=None, out_fp32=False):

    def auto_fp16_wrapper(old_func):

        @functools.wraps(old_func)
        def new_func(*args, **kwargs):

	    # 適用対象が float16 変換に対応しているか確認
	    # float16 に変換する変数名のリスト args_to_cast を取得
	    # ...(省略)...
	    
            # 位置引数 args を必要に応じて float16 に変換
            new_args = []
            if args:
                arg_names = args_info.args[:len(args)]
                for i, arg_name in enumerate(arg_names):
                    if arg_name in args_to_cast:
                        new_args.append(
                            cast_tensor_type(args[i], torch.float, torch.half))
                    else:
                        new_args.append(args[i])
	    
	    # キーワード引数 kwargs を必要に応じて float16 に変換
	    # ...(省略)...

	    # 変換された引数を元々の関数に適用
	    #(Pytorch version >= 1.6.0 では torch.cuda.amp.autocast を利用)
            if (TORCH_VERSION != 'parrots' and
                    digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
                with autocast(enabled=True):
                    output = old_func(*new_args, **new_kwargs)
            else:
                output = old_func(*new_args, **new_kwargs)

	    # out_fp32 の値に応じて出力を float32 に変換
            if out_fp32:
                output = cast_tensor_type(output, torch.half, torch.float)
            return output

        return new_func

    return auto_fp16_wrapper

https://github.com/open-mmlab/mmcv/blob/fb486b96fd9932637a16f23a5dc60904c12bea7d/mmcv/runner/fp16_utils.py

おや?と思われた方多いんじゃないでしょうか。なんだか今回の記事で作ってきたデコレータにそっくりですね。実は、OSS の実例を交えた説明にしたいという思いから、MMDetection のデコレータを基に例を作成していました。細かい処理の分岐などに違いはありますが、ベースは標準的な関数デコレータとして実装されていることがわかると思います。複雑そうに見えても理解してしまえば作りは至って単純ですね。

2. register_module

任意のクラスオブジェクトを MMDetection のレジストリに登録する際にこのデコレータを使います。例えば、ユーザーが作成した新しいモデルや損失関数などを既存レジストリに追加登録することが可能です。この機能により、MMDetection 内の既存モジュールと同様にカスタムモジュールも簡単に管理できます。まずは使用例を見てみましょう。

MODELS = Registry('models')

@MODELS.register_module()
Class MyNet:
    pass
    
mynet = MODELS.build(dict(type='MyNet'))

こんな感じでカスタムモデルをモデルレジストリに登録し、共通のAPIからモデルインスタンスを作成できるようになりました。ここで何か違和感があるのではないでしょうか?そうです、上の例では初めてデコレータがクラスに対して適用されています。一番最初に「実際にはクラスに対しても同様の概念があります」と書いた伏線がここで回収されることになります。ただ考え方としてはこれまでと同じで、以下のようにクラスをラップしているイメージになります。

MyNet = MODELS.register_module(MyNet)

実際のコードはやや複雑なので、デコレータ部分が実装されている Registry クラスの簡略化版を以下に示します。

registry.py
class Registry:

    def __init__(self, name, ...):
        self._name = name
	self._module_dict = dict()

    def register_module(self):	
	def _register(cls):
	    name = cls.__name__
	    self._module_dict[name] = cls
	    return cls
	return _register
	
    # ...(省略)...

https://github.com/open-mmlab/mmcv/blob/fb486b96fd9932637a16f23a5dc60904c12bea7d/mmcv/utils/registry.py

register_module メソッドがクラスに対するデコレータになっています。やっていることはいたってシンプルで、対象レジストリクラスの _module_dict プロパティに追加したいクラスを登録しているだけです。デコレータはこのようにクラスや関数の登録にも利用することができ、モジュールとしての汎用的なインターフェースを整える助けにもなります。

せっかくなので MMDetection のレジストリがどのように実装されているのかもう少し見てみます。MMDetection では、レジストリからのインスタンス作成向けに Registry.build メソッドを準備しており、ユーザーが独自の build メソッドを定義しない場合には build_from_cfg 関数が呼び出されます。この関数が基本的な Registry.build メソッドにあたるので、この関数を例に実装を見てみましょう。以下にレジストリからのインスタンス作成に関する簡略化したコードを示しています。

registry.py
class Registry:
    # ...(省略)...
    
    def get(self, key):
        return self._module_dict[key]
    
    def build(self, *args, **kwargs):
        return self.build_func(*args, **kwargs, registry=self)

def build_from_cfg(cfg, registry, default_args=None):

    # cfg, default_args により渡されたパラメータを辞書型変数 args に格納
    # ...(省略)...
    
    obj_type = args.pop('type')
    obj_cls = registry.get(obj_type)
    return obj_cls(**args)

build_from_cfg では、レジストリに保存されている _module_dict を参照してキー名 (obj_type)からクラスオブジェクト (obj_cls) を呼び出し、必要な引数を与えてインスタンスを返却しています。これはデコレータを用いたレジストリ実装のひな型にできそうなので今後使ってみたいですね。本題のデコレータからは少し話が逸れましたが、デコレータを用いたレジストリ機能により、汎用的な実装を実現していることを確認できました。

まとめ

  • デコレータは関数をラップすることで別の関数を返却する関数
  • デコレータの内部にラップ関数が配置される形で定義できる
  • デコレータをクロージャ化することで引数を渡す
  • 元々の関数と引数を同時に渡すことで @ 記法が使いにくくなるためクロージャ化する
  • 元々の関数の情報を保持するために functools.wraps を使ってデコレータを定義する
  • OSS におけるデコレータは入力の型変換や関数・クラスの登録等に使われている

所感

調べる前までは正直どこで使いどころがあるの?と疑問に思っていましたが、OSS の例などを通して、実際の使いどころのイメージがわいてきました。他にも関数の処理速度のベンチマークを取るために使われることも多いみたいですね。具体的な利用例については引き続き調査を継続しつつ、今後1年以内にどこかで使ってみたいです。

参考文献

Discussion