Pythonで関数型プログラミングできるライブラリ作った
はじめに
関数チェーンを実現したかったので kette(chainのドイツ語[1])という名前のライブラリを作りました。
2021/10/31 にアップデートを行い、関数の部分適用が可能になりました。これにより、以下の関数やライブラリとよく似た機能が、より関数型言語ライクな書き方で可能になります。
- functools.partial: 関数の部分適用を可能にするメソッドです。
- toolz の Functoolz: 関数の部分適用や関数合成を提供するパッケージです。
- python-chain: 関数合成の演算子を提供するパッケージです。
以下のようなコードを
def add(x, y):
return x + y
def mul(x, y):
return x * y
ys = list(map(lambda x: mul(add(x, 1), 2), [1, 2, 3, 4])
こう書けます。
from kette import chain
from kette.curried import c_list
@chain
def add(x, y):
return x + y
@chain
def mul(x, y):
return x * y
# 例
ys = c_list * map | (mul(2) * add(1), [1, 2, 3, 4])
デモコードは以下です。
Install
$ pip install kette
Usage
モジュールの使い方、つまり何ができるかを最初に説明します。実装はあとで載せておきます。
1. 関数の部分適用と関数合成
関数チェーンを可能にするには以下のように kette
から chain
デコレータをインポートし、関数定義に @chain
をつければよいです。
from kette import chain
@chain
def add(x, y):
return x + y
@chain
def mul(x, y):
return x * y
@chain
def sub(x, y):
return x - y
@chain
デコレータにより直下で定義された関数が Chain
というクラスにラップされることで部分適用や関数結合が実現されます。以下は関数の部分適用を行うコードです。
# 関数の部分適用
# 以下のコードは lambda y: add(2, y) と同じような機能を持ちます
add(2)
print(add(2)(3))
# キーワード引数で部分適用
# 以下のコードは lambda x: sub(x, 1) と同じような機能を持ちます
sub(y=1)
print(sub(y=1)(2))
# キーワード引数名は引き継がれます
# よって以下のような指定が可能です
mul(y=2)(x=3)
print(mul(y=2)(x=3))
5
1
6
また、以下の文法で関数合成をサポートしています。
# 関数合成(add(1) を適用してから mul(2) を適用する)
f = add(1) >> mul(2)
y = f(3)
print(y)
# 直にこう書いてもいい
y = (add(1) >> mul(2))(3)
print(y)
# 関数合成(add(1) を適用してから mul(2) を適用する)
f = mul(2) * add(1)
y = f(3)
print(y)
# 直にこう書いてもいい
y = (mul(2) * add(1))(3)
print(y)
上記のいずれについても結果は同じになります。
8
8
8
8
関数をチェーンさせた場合もキーワード引数名は引き継がれ、もっとも入り口に近い関数の引数はキーワード引数で指定可能です。ただし、一度部分適用した引数を上書き指定することはできません。
# add 関数がもっとも入り口に近いため、その引数は
# キーワード引数で指定することが可能です。
f = add(1) >> mul(2)
y = f(y=3)
print(y)
8
関数はデフォルト値が指定されていない引数がすべて適用された段階で実行されます。関数適用時にデフォルト引数を上書きするのに十分な個数の引数を与えれば、デフォルト引数は上書きされます。また、キーワード引数を用いてデフォルト引数を上書きすることも可能です。
@chain
def fun(x, y, z, a=5):
return x, y, z, a
print(fun(1, 2, 3))
print(fun(1, 2, 3, 4))
print(fun(1, 2, 3, a=4))
print(fun(x=1, a=4)(2, 3))
print(fun(y=2, a=4)(1, 3))
(1, 2, 3, 5)
(1, 2, 3, 4)
(1, 2, 3, 4)
(1, 2, 3, 4)
(1, 2, 3, 4)
まだ適用されていない引数の情報は ._params
で取得することができます。
print(fun._params)
print(fun(1, 2)._params)
print(fun(1, 2, a=4)._params)
{'x': <Parameter "x">, 'y': <Parameter "y">, 'z': <Parameter "z">, 'a': <Parameter "a=5">}
{'z': <Parameter "z">, 'a': <Parameter "a=5">}
{'z': <Parameter "z">}
2. 通常の関数とのチェーン
最左項が Chain
クラスのインスタンスであれば通常の関数や無名関数もチェーンさせることが可能です。
g = add(1) >> sub(y=5) >> (lambda x: x * 3)
y = g(3)
print(y)
-3
3. 関数適用
もっとも誤解が生じない関数適用は、通常の関数適用 ()
を用いる方法です。
y = (mul(2) * add(1))(5)
print(y)
12
kette
ではさらに関数適用に &
と |
を用意しています。&
は通常の ()
と同様の振る舞いをし、右辺に左辺の関数を適用します。
# 以下のコードは y = (mul(2) * add(1))(5) と等価です
y = mul(2) * add(1) & 5
print(y)
12
|
は右辺が Mapping
なら **kwargs
に、そうでない場合でも Iterable
なら *args
に展開した上で左辺の関数を適用します。どちらでもない場合はそのまま適用します。
y = mul(2) * add(1) | (5, )
print(y)
y = mul(2) * add(1) | { 'y': 5 }
print(y)
12
12
また、単なる定数 x
は x : () -> a
の関数ともみなせるため、左から関数結合によって関数を適用することができます。
y = 5 >> add(1) >> mul(2)
print(y)
12
add(1) >> mul(2) >> 5
という関数合成を考えることも可能ですが、仮に成立するとしても 5
が無引数のため、その手前のチェーンは無意味です。したがって kette
ではこのようなチェーンはエラーとみなします。
add(1) >> mul(2) >> 5
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-25-29d1b871a5f4> in <module>()
----> 1 add1 >> mul2 >> 5
<ipython-input-2-ab8b2a1fbd55> in __rshift__(self, other)
65 if isinstance(other, Callable):
66 return Chain(self.function + [other])
---> 67 raise ValueError("uncallable object can't be chained.")
68
69 def __or__(self, other):
ValueError: uncallable object can't be chained.
|
と &
は左辺と右辺を入れ替えても関数適用の意味を持ちます。したがって、左辺に引数、右辺に適用したい関数を与えても、右辺に引数、左辺に適用したい関数を与えたときと同じ挙動になります。
x = mul(2) * add(1) & 5
y = 5 & mul(2) * add(1)
print(x == y)
x = (5, ) | mul(2) * add(1)
y = mul(2) * add(1) | (5, )
print(x == y)
x = { 'y': 5 } | mul(2) * add(1)
y = mul2 * add1 | { 'y': 5 }
print(x == y)
True
True
True
4. 演算子の優先順
ここまで紹介した演算子は以下のような優先順になります。
*
> >>
> &
> |
つまり
f >> g * h | x
x & f >> g * h
は、それぞれ
((f >> (g * h)) | x)
(x & (f >> (g * h)))
と等価です。
5. 戻り値が複数あるときの挙動
チェーンの中の関数が複数戻り値を持つ場合、すなわち戻り値がタプルのとき、次のチェーンに入る前に *args
相当の展開が行われます。したがってタプル (1, 2, 3)
を次のチェーンに渡したいときはリストで [1, 2, 3]
と渡すか、少々面倒ですが ((1, 2, 3), )
のように渡す必要があります(タプルなので一段階展開されて (1, 2, 3)
に戻ります)。
また、チェーンの中の関数の戻り値が (tuple, dict)
の二つ組になっているとき、次のチェーンに入るときに tuple
は *args
、dict
は **kwargs
相当の展開が行われます。
少しクセのあるルールですが、複数引数を持つ関数をチェーンさせるときに自然な書き方ができます。
@chain
def f(x, y):
# タプルで返すと次の関数に入る前に展開されるので、
# 2引数関数にチェーンさせることができます
return x + y, x - y
@chain
def g(v, w):
# タプルと辞書の組で返すと、タプルは *args に、
# 辞書は **kwargs に展開されます
return (v * w, ), {'a': v + w, 'b': v - w}
@chain
def h(z, a, b):
return z + a + b
fun = f >> g >> h
print(fun(2, 3))
5
6. Chainクラス
ここまでに紹介した機能は、すべて関数を Chain
クラスでラップすることで提供される機能です。通常の関数に @chain
デコレータをつけることで Chain
クラスにするのがもっとも簡単ですが、Chain
クラスを明示的に生成することも可能です。
from kette import Chain
Chain
クラスは関数[2]または関数のリスト[3]を引数に取ります。関数が与えられた場合は単にその関数をラップし、関数のリストが与えられた場合はリストの並び順に >>
で結合したのと同じ扱いになります。
f = Chain(lambda x: x * 2)
g = Chain([lambda x: x * 2, lambda x: x + 5])
print(f(1))
print(g(1))
2
7
Chain クラス自体も Callable なので、Chain クラスのインスタンスを Chain クラスのコンストラクタの引数に与えても問題はありません(実はこの仕組みを使って関数の部分適用を実現していたりします)。
7. ビルトイン関数の扱い
kette
の部分適用は inspect
モジュールの signature
関数によって関数のシグネチャ情報を読み取ることで実現されています。しかしビルトイン関数(list
やmap
など)はシグネチャの情報を持っておらず読み取ることができないため、Chain
クラスにラップされる段階(コンストラクタに代入されるときや関数結合のとき)でビルトイン関数だったら Chain
クラスで扱える関数にすり替わるようになっています。たとえば list
であれば
_c_list = lambda iterable=(): list(iterable)
と定義される関数 _c_list
に置き換えられています。これによって若干に違いが現れる可能性があることはご注意ください。たとえば map
関数は map(function, iterable, *iterables)
という定義を持っており、iterable, *iterable
を zip
しながら function
に適用していくような機能を持っているのですが、Chain
クラスは可変長引数を取り扱えないので
_c_map = lambda function, iterable: map(function, iterable)
のような定義になっています。
おおよそすべてのビルトイン関数は先頭に c_
をつけることで Chain
クラスにラップされたものを使用可能です。これらの関数群は curried
モジュールに定義されています。たとえば list
であれば
c_list = Chain(_c_list)
のようになっています。
実装
Python の演算子オーバーロードとデコレータを悪用した黒魔術活用したハックです。短いので ver 0.1.4
のソースコードを載せておきます。
import warnings
from collections import Callable, Iterable, Mapping
def _check_callable(x):
# None なら恒等関数にするのでよし
# あとは Callable か、Callable のリストでなければエラー
if x is None:
return
if isinstance(x, Callable):
return
elif isinstance(x, list):
for f in x:
if not isinstance(f, Callable):
raise ValueError(f"uncallable object '{f}' in the list, which must be 'Callable' or 'List[Callable]'.")
else:
raise ValueError("function must be 'Callable' or 'List[Callable]'.")
class Chain:
def __init__(self, function=None):
_check_callable(function)
# チェーンさせる関数をリストで保持しておく
if function is None:
self.function = [lambda x: x]
elif isinstance(function, list):
self.function = function
else:
self.function = [function]
def __call__(self, *args, **kwargs):
f_idx = 0
# 途中で None になったらそこでチェーンが切れているので警告
def check_arg(x):
if x is None:
warnings.warn(f"Chain broken with 'None' value between '{self.function[f_idx-1].__name__}' and '{self.function[f_idx].__name__}'", RuntimeWarning)
chain_args = self.function[0](*args, **kwargs)
for f_idx, f in enumerate(self.function[1:], 1):
check_arg(chain_args)
if isinstance(chain_args, tuple):
# 戻り値がタプルのときは展開
chain_args = f(*chain_args)
else:
chain_args = f(chain_args)
return chain_args
# f * g
def __mul__(self, other):
if isinstance(other, Chain):
return Chain(other.function + self.function)
if isinstance(other, Callable):
return Chain([other] + self.function)
raise ValueError("uncallable object can't be chained.")
# f >> g
def __rshift__(self, other):
if isinstance(other, Chain):
return Chain(self.function + other.function)
if isinstance(other, Callable):
return Chain(self.function + [other])
raise ValueError("uncallable object can't be chained.")
# f | x
def __or__(self, other):
return self(other)
# f & x
def __and__(self, other):
if isinstance(other, Mapping):
return self(**other)
if isinstance(other, Iterable):
return self(*other)
return self(other)
# 左辺が Chain でないときの f >> g
def __rrshift__(self, other):
return self(other)
# 左辺が Chain でないときの x | f
def __ror__(self, other):
return self(other)
# 左辺が Chain でないときの x & f
def __rand__(self, other):
if isinstance(other, Mapping):
return self(**other)
if isinstance(other, Iterable):
return self(*other)
return self(other)
# デコレータ
def chain(function):
return Chain(function)
Discussion