📚

Pythonで関数型プログラミングできるライブラリ作った

2021/09/30に公開

はじめに

関数チェーンを実現したかったので kette(chainのドイツ語[1])という名前のライブラリを作りました。

https://github.com/wsuzume/kette

2021/10/31 にアップデートを行い、関数の部分適用が可能になりました。これにより、以下の関数やライブラリとよく似た機能が、より関数型言語ライクな書き方で可能になります。

  • functools.partial: 関数の部分適用を可能にするメソッドです。
  • toolz の Functoolz: 関数の部分適用や関数合成を提供するパッケージです。
  • python-chain: 関数合成の演算子を提供するパッケージです。

以下のようなコードを

def add(x, y):
    return x + y

def mul(x, y):
    return x * y

ys = list(map(lambda x: mul(add(x, 1), 2), [1, 2, 3, 4])

こう書けます。

from kette import chain
from kette.curried import c_list

@chain
def add(x, y):
    return x + y

@chain
def mul(x, y):
    return x * y

# 例
ys = c_list * map | (mul(2) * add(1), [1, 2, 3, 4])

デモコードは以下です。

https://colab.research.google.com/drive/1RCvQvuKrOXXNTsb4O8zzGhp80kQ9CUnk?usp=sharing

Install

$ pip install kette

Usage

モジュールの使い方、つまり何ができるかを最初に説明します。実装はあとで載せておきます。

1. 関数の部分適用と関数合成

関数チェーンを可能にするには以下のように kette から chain デコレータをインポートし、関数定義に @chain をつければよいです。

from kette import chain

@chain
def add(x, y):
    return x + y

@chain
def mul(x, y):
    return x * y

@chain
def sub(x, y):
    return x - y

@chain デコレータにより直下で定義された関数が Chain というクラスにラップされることで部分適用や関数結合が実現されます。以下は関数の部分適用を行うコードです。

# 関数の部分適用
# 以下のコードは lambda y: add(2, y) と同じような機能を持ちます
add(2)
print(add(2)(3))

# キーワード引数で部分適用
# 以下のコードは lambda x: sub(x, 1) と同じような機能を持ちます
sub(y=1)
print(sub(y=1)(2))

# キーワード引数名は引き継がれます
# よって以下のような指定が可能です
mul(y=2)(x=3)
print(mul(y=2)(x=3))
output
5
1
6

また、以下の文法で関数合成をサポートしています。

# 関数合成(add(1) を適用してから mul(2) を適用する)
f = add(1) >> mul(2)
y = f(3)
print(y)

# 直にこう書いてもいい
y = (add(1) >> mul(2))(3)
print(y)

# 関数合成(add(1) を適用してから mul(2) を適用する)
f = mul(2) * add(1)
y = f(3)
print(y)

# 直にこう書いてもいい
y = (mul(2) * add(1))(3)
print(y)

上記のいずれについても結果は同じになります。

output
8
8
8
8

関数をチェーンさせた場合もキーワード引数名は引き継がれ、もっとも入り口に近い関数の引数はキーワード引数で指定可能です。ただし、一度部分適用した引数を上書き指定することはできません。

# add 関数がもっとも入り口に近いため、その引数は
# キーワード引数で指定することが可能です。
f = add(1) >> mul(2)
y = f(y=3)
print(y)
output
8

関数はデフォルト値が指定されていない引数がすべて適用された段階で実行されます。関数適用時にデフォルト引数を上書きするのに十分な個数の引数を与えれば、デフォルト引数は上書きされます。また、キーワード引数を用いてデフォルト引数を上書きすることも可能です。

@chain
def fun(x, y, z, a=5):
    return x, y, z, a

print(fun(1, 2, 3))
print(fun(1, 2, 3, 4))
print(fun(1, 2, 3, a=4))
print(fun(x=1, a=4)(2, 3))
print(fun(y=2, a=4)(1, 3))
output
(1, 2, 3, 5)
(1, 2, 3, 4)
(1, 2, 3, 4)
(1, 2, 3, 4)
(1, 2, 3, 4)

まだ適用されていない引数の情報は ._params で取得することができます。

print(fun._params)
print(fun(1, 2)._params)
print(fun(1, 2, a=4)._params)
{'x': <Parameter "x">, 'y': <Parameter "y">, 'z': <Parameter "z">, 'a': <Parameter "a=5">}
{'z': <Parameter "z">, 'a': <Parameter "a=5">}
{'z': <Parameter "z">}

2. 通常の関数とのチェーン

最左項が Chain クラスのインスタンスであれば通常の関数や無名関数もチェーンさせることが可能です。

g = add(1) >> sub(y=5) >> (lambda x: x * 3)
y = g(3)
print(y)
output
-3

3. 関数適用

もっとも誤解が生じない関数適用は、通常の関数適用 () を用いる方法です。

y = (mul(2) * add(1))(5)
print(y)
output
12

kette ではさらに関数適用に &| を用意しています。& は通常の () と同様の振る舞いをし、右辺に左辺の関数を適用します。

# 以下のコードは y = (mul(2) * add(1))(5) と等価です
y = mul(2) * add(1) & 5
print(y)
output
12

| は右辺が Mapping なら **kwargs に、そうでない場合でも Iterable なら *args に展開した上で左辺の関数を適用します。どちらでもない場合はそのまま適用します。

y = mul(2) * add(1) | (5, )
print(y)

y = mul(2) * add(1) | { 'y': 5 }
print(y)
output
12
12

また、単なる定数 xx : () -> a の関数ともみなせるため、左から関数結合によって関数を適用することができます。

y = 5 >> add(1) >> mul(2)
print(y)
output
12

add(1) >> mul(2) >> 5 という関数合成を考えることも可能ですが、仮に成立するとしても 5 が無引数のため、その手前のチェーンは無意味です。したがって kette ではこのようなチェーンはエラーとみなします。

add(1) >> mul(2) >> 5
output
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-25-29d1b871a5f4> in <module>()
----> 1 add1 >> mul2 >> 5

<ipython-input-2-ab8b2a1fbd55> in __rshift__(self, other)
     65         if isinstance(other, Callable):
     66             return Chain(self.function + [other])
---> 67         raise ValueError("uncallable object can't be chained.")
     68 
     69     def __or__(self, other):

ValueError: uncallable object can't be chained.

|& は左辺と右辺を入れ替えても関数適用の意味を持ちます。したがって、左辺に引数、右辺に適用したい関数を与えても、右辺に引数、左辺に適用したい関数を与えたときと同じ挙動になります。

x = mul(2) * add(1) & 5
y = 5 & mul(2) * add(1)
print(x == y)

x = (5, ) | mul(2) * add(1)
y = mul(2) * add(1) | (5, )
print(x == y)

x = { 'y': 5 } | mul(2) * add(1)
y = mul2 * add1 | { 'y': 5 }
print(x == y)
output
True
True
True

4. 演算子の優先順

ここまで紹介した演算子は以下のような優先順になります。

* > >> > & > |

つまり

f >> g * h | x
x & f >> g * h

は、それぞれ

((f >> (g * h)) | x)
(x & (f >> (g * h)))

と等価です。

5. 戻り値が複数あるときの挙動

チェーンの中の関数が複数戻り値を持つ場合、すなわち戻り値がタプルのとき、次のチェーンに入る前に *args 相当の展開が行われます。したがってタプル (1, 2, 3) を次のチェーンに渡したいときはリストで [1, 2, 3] と渡すか、少々面倒ですが ((1, 2, 3), ) のように渡す必要があります(タプルなので一段階展開されて (1, 2, 3) に戻ります)。

また、チェーンの中の関数の戻り値が (tuple, dict) の二つ組になっているとき、次のチェーンに入るときに tuple*argsdict**kwargs 相当の展開が行われます。

少しクセのあるルールですが、複数引数を持つ関数をチェーンさせるときに自然な書き方ができます。

@chain
def f(x, y):
    # タプルで返すと次の関数に入る前に展開されるので、
    # 2引数関数にチェーンさせることができます
    return x + y, x - y

@chain
def g(v, w):
    # タプルと辞書の組で返すと、タプルは *args に、
    # 辞書は **kwargs に展開されます
    return (v * w, ), {'a': v + w, 'b': v - w}

@chain
def h(z, a, b):
    return z + a + b

fun = f >> g >> h
print(fun(2, 3))
output
5

6. Chainクラス

ここまでに紹介した機能は、すべて関数を Chain クラスでラップすることで提供される機能です。通常の関数に @chain デコレータをつけることで Chain クラスにするのがもっとも簡単ですが、Chain クラスを明示的に生成することも可能です。

from kette import Chain

Chain クラスは関数[2]または関数のリスト[3]を引数に取ります。関数が与えられた場合は単にその関数をラップし、関数のリストが与えられた場合はリストの並び順に >> で結合したのと同じ扱いになります。

f = Chain(lambda x: x * 2)
g = Chain([lambda x: x * 2, lambda x: x + 5])

print(f(1))
print(g(1))
output
2
7

Chain クラス自体も Callable なので、Chain クラスのインスタンスを Chain クラスのコンストラクタの引数に与えても問題はありません(実はこの仕組みを使って関数の部分適用を実現していたりします)。

7. ビルトイン関数の扱い

kette の部分適用は inspect モジュールの signature 関数によって関数のシグネチャ情報を読み取ることで実現されています。しかしビルトイン関数(listmapなど)はシグネチャの情報を持っておらず読み取ることができないため、Chain クラスにラップされる段階(コンストラクタに代入されるときや関数結合のとき)でビルトイン関数だったら Chain クラスで扱える関数にすり替わるようになっています。たとえば list であれば

_c_list = lambda iterable=(): list(iterable)

と定義される関数 _c_list に置き換えられています。これによって若干に違いが現れる可能性があることはご注意ください。たとえば map 関数は map(function, iterable, *iterables) という定義を持っており、iterable, *iterablezip しながら function に適用していくような機能を持っているのですが、Chain クラスは可変長引数を取り扱えないので

_c_map = lambda function, iterable: map(function, iterable)

のような定義になっています。

おおよそすべてのビルトイン関数は先頭に c_ をつけることで Chain クラスにラップされたものを使用可能です。これらの関数群は curried モジュールに定義されています。たとえば list であれば

c_list = Chain(_c_list)

のようになっています。

実装

Python の演算子オーバーロードとデコレータを悪用した黒魔術活用したハックです。短いので ver 0.1.4 のソースコードを載せておきます。

import warnings
from collections import Callable, Iterable, Mapping

def _check_callable(x):
    # None なら恒等関数にするのでよし
    # あとは Callable か、Callable のリストでなければエラー
    if x is None:
        return
    if isinstance(x, Callable):
        return
    elif isinstance(x, list):
        for f in x:
            if not isinstance(f, Callable):
                raise ValueError(f"uncallable object '{f}' in the list, which must be 'Callable' or 'List[Callable]'.")
    else:
        raise ValueError("function must be 'Callable' or 'List[Callable]'.")

class Chain:
    def __init__(self, function=None):
        _check_callable(function)

        # チェーンさせる関数をリストで保持しておく
        if function is None:
            self.function = [lambda x: x]
        elif isinstance(function, list):
            self.function = function
        else:
            self.function = [function]
    
    def __call__(self, *args, **kwargs):
        f_idx = 0
        
	# 途中で None になったらそこでチェーンが切れているので警告
        def check_arg(x):
            if x is None:
                warnings.warn(f"Chain broken with 'None' value between '{self.function[f_idx-1].__name__}' and '{self.function[f_idx].__name__}'", RuntimeWarning)
        
        chain_args = self.function[0](*args, **kwargs)
        for f_idx, f in enumerate(self.function[1:], 1):
            check_arg(chain_args)
            if isinstance(chain_args, tuple):
	        # 戻り値がタプルのときは展開
                chain_args = f(*chain_args)
            else:
                chain_args = f(chain_args)
        return chain_args
    
    # f * g
    def __mul__(self, other):
        if isinstance(other, Chain):
            return Chain(other.function + self.function)
        if isinstance(other, Callable):
            return Chain([other] + self.function)
        raise ValueError("uncallable object can't be chained.")
    
    # f >> g
    def __rshift__(self, other):
        if isinstance(other, Chain):
            return Chain(self.function + other.function)
        if isinstance(other, Callable):
            return Chain(self.function + [other])
        raise ValueError("uncallable object can't be chained.")
    
    # f | x
    def __or__(self, other):
        return self(other)

    # f & x
    def __and__(self, other):
        if isinstance(other, Mapping):
            return self(**other)
        if isinstance(other, Iterable):
            return self(*other)
        return self(other)
    
    # 左辺が Chain でないときの f >> g
    def __rrshift__(self, other):
        return self(other)

    # 左辺が Chain でないときの x | f
    def __ror__(self, other):
        return self(other)

    # 左辺が Chain でないときの x & f
    def __rand__(self, other):
        if isinstance(other, Mapping):
            return self(**other)
        if isinstance(other, Iterable):
            return self(*other)
        return self(other)

# デコレータ
def chain(function):
    return Chain(function)
脚注
  1. 有名どころだとケッテンクラート(kettenkrad)の ketten(鎖、履帯のこと)と同じ意味。 ↩︎

  2. 正確にはCallableオブジェクト。 ↩︎

  3. 正確にはCallableオブジェクトのリスト。 ↩︎

Discussion