🌊

重箱の隅をつつくクラスのあれこれ

2023/07/25に公開

はじめに

本記事は筆者がPythonの学習をするうえで得た知識の備忘録です。明示的にも暗黙的にも全てが正しいとは保証できないのでご容赦ください(間違いの指摘は大歓迎です)。今回はクラスとその周りについて触れます。クラスの基礎は分かっている前提で、知っていると少しお得かもって程度の内容です。

基本的な特殊メソッド

自分の定義したクラスに足し算や引き算等を定義できます。アンダースコア二つで名前が囲まれています(ex: __add__, __sub__, ...)。以下は複素数を簡易的に実装したものです。実部と虚部をインスタンス変数に持ちます。

from typing import Union, TYPE_CHECKING

class Complex:

    __slots__ = (
        "real",
        "img"
    )

    if TYPE_CHECKING:
        real: Union[int, float]
        img: Union[int, float]

    def __init__(
        self,
        real: Union[int, float],
        img: Union[int, float, None] = 0
    ) -> None:
        self.real = real
        self.img = img


    def __add__(self, other: Union[int, float, Complex]) -> Complex:
        """足し算を定義 (self + otherのときに呼び出される)"""
        if isinstance(other, (int, float))
            real = other
            img = 0
        else:
            real = other.real
            img = other.img

        return Complex(self.real + real, self.img + img)

何気なく使っていた加算演算子ですが、x + yというのはtype(x).__add__(x, y)のショートカットだったということが分かります。その他にイコールや絶対値、strインスタンス化したときに挙動も定義できます。

import math

...

    def __eq__(self, other) -> bool:
        """x == y の定義"""
        return (
            isinstance(other, Complex)
            and self.real == other.real
            and self.img == other.img
        )

    def __abs__(self) -> float:
        """abs(x)が呼ばれたときの挙動"""
        return math.sqrt(self.real**2 + self.img**2)

    def __str__(self) -> str:
        """str(x)が呼ばれた時の挙動""""
        return "{}{+:}*j".format(self.real, self.img)

応用

特殊メソッドは使うと結構便利なのは言うまでもありませんが、デコレータと合わせるとキャッシュとか実装できて面白いです。わざわざ自分で実装する機会は多くはないと思いますが、あくまで余談です。

__get__メソッドは特殊メソッドのなかでもかなり低級の機能を提供します。これはオーナークラス or インスタンスクラスの属性が呼び出されたときに呼ばれるメソッドです。以下はお馴染みpropertyクラスを実装したものです。settergetter等は割愛します。

from __future__ import annotations
from typing import (
    Any,
    Callable,
    TypeVar,
    Generic,
    overload,
    Type,
    Optional
)

T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)

class _property(Generic[T, T_co]):

    def __init__(self, func: Callable[[T], T_co]) -> None:
        self.func = func

    @overload
    def __get__(self, instance: T, owner = Type[T]) -> T_co:
        ...
    @overload
    def __get__(self, instance: None, owner = Type[T]) -> _property[T, T_co]:
        ...
    def __get__(self, instance: Optional[T], owner: Type[T]) -> Any:
        if instance is None:
            return self
        return self.func(instance)

class Sample:

    def __init__(self, name) -> None:
        self._name = name

    @_property
    def name(self) -> str:
        return self._name

s = Sample('sample')
print(s.name) # sample
print(type(Sample.name)) # <class '__main__._property'>

開発しやすいように型ヒントをごちゃごちゃ書いていますが、細かいことは気にしないで結構です。重要なのは

    def __get__(self, instance: Optional[T], owner: Type[T]) -> Any:
        if instance is None:
            return self
        return self.func(instance)

の部分です。propertyデコレータで定義したメソッドの引数がselfのみなのは、propertyが呼ばれたときに予めデコレータを付けられた関数の引数としてインスタンス変数を渡しているからだということが分かります。

これを少しいじればキャッシュの実装もできます。要は初回は関数の処理をして結果を保存し、二度目以降は保存した値を返せばOKなので以下のようなコードになるかと思います。

class CachedProperty(Generic[T, T_co]):

    def __init__(self, func: Callable[[T], T_co]) -> None:
        self.func = func

    @overload
    def __get__(self, instance: T, owner = Type[T]) -> T_co:
        ...
    @overload
    def __get__(self, instance = None, owner = Type[T]) -> CachedProperty[T, T_co]:
        ...
    def __get__(self, instance: Optional[T], owner: Type[T]) -> Any:
        if instance is None:
            return self
        value = self.func(instance)
        print('called')
        setattr(instance, self.func.__name__, value)
        return value


class Sample2:

    def __init__(self) -> None:
        pass

    @CachedProperty
    def test(self) -> list[int]:
        return list(range(5000000))

if __name__ == '__main__':
    import time
    s = Sample2()
    start = time.time()
    s.test
    print(time.time() - start)
    start2 = time.time()
    s.test
    print(time.time() - start2)

以下は実行結果例です。分かりやすいように途中でprintをはさんでいます。

called
0.06731486320495605
0.0

キャッシュを使わない場合と比較してみましょう。

class Sample2:

    def __init__(self) -> None:
        pass

    def test(self) -> list[int]:
        return list(range(5000000))

if __name__ == '__main__':
    import time
    s = Sample2()
    start = time.time()
    s.test()
    print(time.time() - start)
    start2 = time.time()
    s.test()
    print(time.time() - start2)

結果

0.10685992240905762
0.10961365699768066

キャッシュを使った方が明らかに速いですね。ぶっちゃけfunctoolsでキャッシュがサポートされているので自分で実装するのはこのくらいしかありませんが笑

GitHubで編集を提案

Discussion