👻

PythonのEnumが他の型と不意に比較されたことをmypyで検知したい

2021/12/15に公開

例えば

from enum import Enum

class Color(str, Enum):
    RED = "red"
    BLUE = "blue"

こんなのがあったとして

if __name__ == "__main__":
    print(Color.RED == Color.BLUE)
    print(Color.RED == Color.RED)
    print(Color.RED == "red") # こういうのを型エラーで弾きたい
    print(Color.RED is "red") # こういうのも型エラーで弾きたい

こういう話。mypy入れても何も言ってくれないんですよね。

なので

from enum import Enum
from typing import TypeVar

class Color(str, Enum):
    RED = "red"
    BLUE = "blue"

class Color2(str, Enum):
    RED = "red"
    BLUE = "blue"

T = TypeVar("T")

class Comparer(Generic[T]):
    def compare(self, a: T, b: T) -> bool:
        return a is b
    
if __name__ == "__main__":
    comparer = Comparer[Color]()
    print(comparer.compare(Color.RED, Color.BLUE))
    print(comparer.compare(Color.RED, Color.RED))
    print(comparer.compare(Color.RED, "red"))
    print(comparer.compare(Color.RED, Color2.RED2))

みたいにすれば、最後2行はmypyで弾ける。

error: Argument 2 to "compare" of "Comparer" has incompatible type "str"; expected "Color"
error: Argument 2 to "compare" of "Comparer" has incompatible type "Color2"; expected "Color"

一瞬、 __eq__ をオーバーライドすれば良い気もしたのですが、 __eq__ は性質上何でも受け取らないといけないし、そもそも上位でobject型が振ってあったので、とりあえずこんな感じで回避するしかなさげにみえました。

あと

T = TypeVar("T", bound=Enum)
def compare(a: T, b: T) -> bool:
    return a is b

とかしてもenumが入らないときは検知できますが、Enumを継承したクラスが入ってくると途端に検知できなくなる感じのようでした。(上記の例で言えば、"red"の混入は防げるが、Color2は防げない)

Discussion