🐍

Pythonの型ヒントと共に進化するコード(#17: Callable と ParamSpec)

に公開
これまでの連載記事


前回は 型変数(TypeVar)を用いて Generics を導入し、入力と同じ型を返す関係を型レベルで表現する手法を紹介しました。

前回の最後で予告した通り、今回は Generics の基礎を踏まえてデコレータを型安全に書く課題に取り組みます。

アプリケーションが成長するにつれて、複数の箇所で必要になる共通の処理(例えばロギングやキャッシュ、性能計測など)が現れます。こうした横断的な関心事を Python ではデコレータを使って実装することがよくあります。

今回は型ヒントを用いて堅牢なデコレータを実装します。

今回の課題:型情報を破壊するデコレータ

API クライアントの各メソッドがどれくらいの時間をかけて実行されたかを知りたいという要求が来たとします。デコレータを作るのにうってつけのシナリオです。

素朴なデコレータはこのように書けるかもしれません。

時間計測デコレータの素朴な実装
import time
from typing import Any
import requests
from models import Headers  # type Headers = dict[str, str]

def measure_time(func):
    """時間計測デコレータ"""
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        end = time.perf_counter()
        print(f"'{func.__name__}' took {end - start:.4f} seconds")
        return result
    return wrapper

# http_client.py に適用してみる
class RequestsHttpClient:
    @measure_time
    def post(self, url: str, json: dict[str, Any], headers: Headers | None) -> requests.Response:
        return requests.post(url, json=json, headers=headers)

このデコレータはうまく動きます。しかし静的型付けの世界ではこのコードは深刻な問題を引き起こします。

型チェッカーの視点から見ると @measure_time でデコレートされた後の post メソッドはもはや元の post メソッドではありません。wrapper 関数に置き換えられています。


赤枠内、 post メソッドの型に注目
デコレート後の型が Callable[..., Any] に潰れているのが問題

そして型注釈のないデコレータは「なんでも受け取って、なんでも返す関数」として扱われます。つまりデコレート後の関数の型は Callable[..., Any] になってしまい、元のシグネチャ情報は完全に失われます(Callable については後ほど説明します)。

結果として何が起きるでしょうか。

型情報が失われる

RequestsHttpClientpost メソッドの本来のシグネチャが完全に失われます。

IDE の補完が効かなくなる

client.post( とタイプしても引数 urljson のヒントは表示されません。

Protocol の整合性が壊れる

現在のコードでは HttpClient という Protocol を定義していました。

http_client.py
class HttpClient(Protocol):
    def post(self, url: str, json: dict[str, Any], headers: Headers | None) -> requests.Response:
        ...

この Protocol に対して RequestsHttpClient は実装クラスとして機能していました。

デコレータを適用する前
class RequestsHttpClient:
    def post(self, url: str, json: dict[str, Any], headers: Headers | None) -> requests.Response:
        return requests.post(url, json=json, headers=headers)

# この時点では RequestsHttpClient は HttpClient Protocol に適合している
client: HttpClient = RequestsHttpClient()  # OK

ところが、型注釈のないデコレータを適用すると状況が変わります。

型注釈のないデコレータを適用
# 型注釈のないデコレータ
def measure_time(func):
    def wrapper(*args, **kwargs):
        # ...
        return func(*args, **kwargs)
    return wrapper

# デコレータを適用
class RequestsHttpClient:
    @measure_time  # 型注釈がない
    def post(self, url: str, json: dict[str, Any], headers: Headers | None) -> requests.Response:
        return requests.post(url, json=json, headers=headers)

# post メソッドのシグネチャが Callable[..., Any] になってしまう
# Protocol との不一致が検出されなくなる
client: HttpClient = RequestsHttpClient() # 型チェッカーによっては検出されない

RequestsHttpClient.post のシグネチャが Callable[..., Any] になってしまったため、型チェッカーから見ると任意の引数を受け取り、Any を返す関数となります。

こうなると Protocol との不一致が検出されなくなり、型チェックが実質無効化されます。

型安全性の観点から見ると、これはかなりよろしくありません。

client = RequestsHttpClient()
response = client.post("http://example.com", {}, None)
# responseの型は Any になり、型チェックが無効化される
reveal_type(response)  # Revealed type is "Any"

戻り値が Any になることで、以降のコードで response に対してどんな操作をしても型チェッカーは何も検証してくれなくなります。

素朴なデコレータはこれまで積み上げてきた型安全性を破壊してしまいます...

処方箋:ParamSpec によるシグネチャの透過

この問題を解決するには、デコレータが元の関数の引数の型情報をそのまま保持し透過させる必要があります。この仕組みを 3 つのステップで実装していきます。

ステップ 1:関数の型を表す Callable

まず基本となるのが typing.Callable です。これは関数そのものの型を表現します。

Callable[[引数の型...], 戻り値の型] という形で記述します。

from typing import Any, Callable

# strとdictを引数に取り、requests.Responseを返す関数の型
def my_func(url: str, json: dict[str, Any]) -> requests.Response: ...

# my_funcの型は Callable[[str, dict[str, Any]], requests.Response]
f: Callable[[str, dict[str, Any]], requests.Response] = my_func

上記の例だと、my_func の型は Callable[[str, dict[str, Any]], requests.Response] として表現されます。

しかしこれだけでは「あらゆる関数」を受け取れる汎用的なデコレータは書けません。引数の数や型が固定されてしまうからです。

ステップ 2:TypeVar で戻り値の型を保持する

前回(16 日目)で紹介した TypeVar を使えばデコレータの戻り値の型を保持できます。

# Rは任意の戻り値の型を表す
def decorator[R](func: Callable[..., R]) -> Callable[..., R]:
    ...

しかし引数部分が ... のままでは、まだ引数の型情報は失われてしまいます。

ステップ 3:パラメータの仕様をキャプチャする ParamSpec

最後のピースを埋めるのが typing.ParamSpec です。

ParamSpecCallable の引数仕様(パラメータの仕様)全体をキャプチャし、それを別の Callable に転送するための Generics の一種です。Python 3.12 以降の新しい構文では **P のように ** を付けて宣言します。

これら全てを組み合わせると型安全なデコレータが完成します。

import time
from functools import wraps
from typing import Callable

# [**P, R] の部分が Generics の宣言
# **P が任意の引数リストを、R が任意の戻り値を表す
def measure_time[**P, R](func: Callable[P, R]) -> Callable[P, R]:
    """
    funcは引数P、戻り値Rを持つ任意の関数。
    このデコレータが返す関数も全く同じ引数Pと戻り値Rを持つ。
    """
    @wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        # P.argsとP.kwargsでキャプチャした引数を展開して渡す
        start = time.perf_counter()
        result = func(*args, **kwargs)
        end = time.perf_counter()
        print(f"'{func.__name__}' took {end - start:.4f} seconds")
        return result
    return wrapper

この構文でデコレータが元の関数のシグネチャを完全に受け継ぐことを保証します。

  • **P が引数全体の型を捉え
  • R が戻り値の型を捉え

最終的に Callable[P, R] として寸分違わず再現してくれるのです。

コードの進化:型安全なデコレータの導入

ParamSpec デコレータの適用

decorators.py を新規作成し、型安全な measure_time デコレータを実装します。そして http_client.pyRequestsHttpClient.post メソッドにこのデコレータを適用します。

decorators.py(新規作成)

# decorators.py
from __future__ import annotations

import time
from functools import wraps
from typing import Callable

# Python 3.12+ のGenerics構文を使用
def measure_time[**P, R](func: Callable[P, R]) -> Callable[P, R]:
    """関数の実行時間を計測し、標準出力に表示するデコレータ"""
    @wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        print(f"Finished '{func.__name__}' in {end_time - start_time:.4f} secs")
        return result
    return wrapper

http_client.py(変更後)

# http_client.py
from __future__ import annotations

import requests
from typing import Protocol, Any

from models import Headers
from decorators import measure_time


class HttpClient(Protocol):
    def post(self, url: str, json: dict[str, Any], headers: Headers | None) -> requests.Response:
        ...


class RequestsHttpClient:
    @measure_time
    def post(self, url: str, json: dict[str, Any], headers: Headers | None) -> requests.Response:
        return requests.post(url, json=json, headers=headers)

この変更で RequestsHttpClient.post メソッドは時間計測の機能を追加しつつも、型シグネチャは完全に維持されます。HttpClient プロトコルとの整合性も保たれ、型チェッカーはエラーを報告しません。

得られたもの

今回のリファクタリングで型安全性を損なうことなく横断的な関心事を分離する高度な実装ができました。

ParamSpec を使いこなすことでビジネスロジックを汚すことなく、ロギング・キャッシュ・性能計測といった機能を再利用可能なデコレータとして安全に利用することができます。

次回予告

実は、7 日目の記事内でしれっと以下の cast を紛れ込ませていました。

main.py の一部
def fetch_and_format_address(...):
    ...
    response = http_client.post(api_url, json={"zipcode": zipcode}, headers=headers)
    ...
    payload = cast(Mapping[str, Any], response.json())  # 👈 これ
    address = Address.unmarshal_payload(payload)
    ...

これ、実は危険です。

response.json() の戻り値は Any なので、実際に何が返ってくるかは実行時までわかりません。検証なしで cast してしまうと外部由来の壊れたデータがそのまま内部ロジックに侵入するリスクがあります。

次回は Type Narrowing(型の絞り込み)を扱います。境界で実行時に型を検証し、その結果を型チェッカーにも伝える仕組みです。isinstanceTypeGuardTypeIs を使って cast よりも安全な方法を取り入れます。

処方後のコードはこちら

decorators.py(新規作成)

# decorators.py
from __future__ import annotations

import time
from functools import wraps
from typing import Callable

# 👉 Python 3.12+ のGenerics構文を使った型安全なデコレータ
def measure_time[**P, R](func: Callable[P, R]) -> Callable[P, R]:
    """関数の実行時間を計測し、標準出力に表示するデコレータ"""
    @wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        print(f"Finished '{func.__name__}' in {end_time - start_time:.4f} secs")
        return result
    return wrapper

typings.py

# typings.py
from __future__ import annotations

from collections.abc import Mapping, Sequence


def first[T](items: Sequence[T]) -> T | None:
    """シーケンスの最初の要素を返す。空なら None を返す。"""
    return items[0] if items else None


def get_or[K, V](d: Mapping[K, V], key: K, default: V) -> V:
    """マッピングから値を取得する。キーがなければデフォルト値を返す。"""
    return d.get(key, default)

models.py

# models.py
from __future__ import annotations

from collections.abc import Mapping
from dataclasses import dataclass, replace
from typing import Any, ClassVar, Final, NewType, ReadOnly, Self, TypedDict

ZipCode = NewType("ZipCode", str)
type Headers = dict[str, str]

@dataclass(frozen=True, slots=True)
class Address:
    API_PATH: ClassVar[Final[str]] = "/v1/address"

    zipcode: str
    prefecture: str
    prefecture_kana: str
    city: str
    city_kana: str
    town: str
    town_kana: str

    def full_address(self) -> str:
        """都道府県・市区町村・町域を結合したフル住所を返す"""
        return self.prefecture + self.city + self.town

    def full_address_kana(self) -> str:
        """フル住所のカナ表記を返す"""
        return self.prefecture_kana + self.city_kana + self.town_kana

    @classmethod
    def unmarshal_payload(cls, payload: Mapping[str, Any]) -> Address:
        """APIレスポンスからAddressオブジェクトを生成する"""
        return cls(
            zipcode=str(payload["zipcode"]),
            prefecture=str(payload["prefecture"]),
            prefecture_kana=str(payload["prefecture_kana"]),
            city=str(payload["city"]),
            city_kana=str(payload["city_kana"]),
            town=str(payload["town"]),
            town_kana=str(payload["town_kana"]),
        )

class FormattedAddressDict(TypedDict):
    zipcode: ReadOnly[str]
    full_address: ReadOnly[str]
    prefecture: ReadOnly[str]
    city: ReadOnly[str]
    town: ReadOnly[str]

class FormattedAddressWithKanaDict(FormattedAddressDict):
    full_address_kana: ReadOnly[str]

@dataclass(frozen=True, slots=True)
class AddressFormatter:
    _address: Address | None = None
    _include_kana: bool = False

    def with_address(self, address: Address) -> Self:
        return replace(self, _address=address)

    def with_kana(self, include: bool = True) -> Self:
        return replace(self, _include_kana=include)

    def build(self) -> FormattedAddressDict | FormattedAddressWithKanaDict:
        if self._address is None:
            raise ValueError("Address must be set before building.")

        base: FormattedAddressDict = {
            "zipcode": self._address.zipcode,
            "full_address": self._address.full_address(),
            "prefecture": self._address.prefecture,
            "city": self._address.city,
            "town": self._address.town,
        }

        if self._include_kana:
            with_kana: FormattedAddressWithKanaDict = {
                **base,
                "full_address_kana": self._address.full_address_kana(),
            }
            return with_kana

        return base

http_client.py

# http_client.py
from __future__ import annotations

import requests as requests_lib
from typing import Protocol

from models import Headers
# 👉 新しく作ったデコレータをインポート
from decorators import measure_time

type JsonObject = dict[str, object]


class HttpResponse(Protocol):
    @property
    def status_code(self) -> int: ...
    def json(self) -> object: ...


class HttpClient(Protocol):
    def post(self, url: str, json: JsonObject, headers: Headers | None = None) -> HttpResponse: ...


class RequestsResponse:
    def __init__(self, response: requests_lib.Response) -> None:
        self._response = response

    @property
    def status_code(self) -> int:
        return self._response.status_code

    def json(self) -> object:
        return self._response.json()


class RequestsHttpClient:
    def __init__(self) -> None:
        self._session = requests_lib.Session()

    # 👉 型安全なデコレータを適用
    @measure_time
    def post(self, url: str, json: JsonObject, headers: Headers | None = None) -> RequestsResponse:
        response = self._session.post(url, json=json, headers=headers)
        return RequestsResponse(response)

main.py

# main.py
from __future__ import annotations

import json
from collections.abc import Mapping
from typing import Any, Final, cast

from models import (
    ZipCode,
    Headers,
    Address,
    AddressFormatter,
)
from http_client import HttpClient, RequestsHttpClient

# 定数
BASE_URL: Final[str] = "https://api.zipcode-jp.example"
HTTP_OK: Final[int] = 200

def fetch_and_format_address(
    zipcode: ZipCode,
    include_kana: bool,
    http_client: HttpClient,
    headers: Headers | None = None,
) -> str | None:
    """郵便番号から住所を取得し、整形して返す"""

    api_url = f"{BASE_URL}{Address.API_PATH}"

    try:
        response = http_client.post(api_url, json={"zipcode": zipcode}, headers=headers)
        if response.status_code != HTTP_OK:
            print(f"Error: Failed to fetch address. Status: {response.status_code}")
            return None

        payload = cast(Mapping[str, Any], response.json())
        address = Address.unmarshal_payload(payload)

        formatter = AddressFormatter()
        result = formatter.with_address(address).with_kana(include_kana).build()

        return json.dumps(result, indent=2, ensure_ascii=False)

    except Exception as e:
        print(f"An error occurred: {e}")
        return None

# 実行例
if __name__ == "__main__":
    http_client = RequestsHttpClient()
    zipcode = ZipCode("1000001")
    result = fetch_and_format_address(
        zipcode, include_kana=True, http_client=http_client
    )
    if result is not None:
        print(result)

Discussion