Zenn
📑

リストモナドをPythonで再現したい

2025/03/26に公開

Haskell にはリストモナドというものがあるらしいです。

instance Monad [] where
    return x = [x]
    xs >>= f = concat (map f xs)

リストの各要素に計算を施しながらその結果をどんどん結合していく感じです。

組み合わせを列挙したりする際に便利らしいです。

例えば、二つのリストから要素を取り出してペアを作っていく関数などに使えます。

pairs :: [Int] -> [Int] -> [(Int, Int)]
pairs xs ys = do
    x <- xs
    y <- ys
    return (x, y)

これはdo記法というものを使っており、>>= をそのまま使うとこうなります(見づらい)。

pairs xs ys = xs >>= \x -> ys >>= \y -> return (x, y)

こいつを実行する場合、例えばこうなります。

pairs [1, 2] [3, 4]

実行結果はこうです。

[(1, 3), (1, 4), (2, 3), (2, 4)]

これをPythonで実装して「お気持ち」を掴みたいと思います。

Python で実装したい

全体像

とりあえず表面的な実装を真似しました。

def foldr(f, z, xxs):
    if xxs == []:
        return z

    x = xxs[0]
    xs = xxs[1:]
    return f(x, foldr(f, z, xs))


def concat(xxs):
    plus = lambda x, y: x + y
    return foldr(plus, [], xxs)

def ret(a):
    """
    return :: a -> m a の代わり
    """
    return [a]


def bind(ma, a_to_mb):
    """
    (>>=) :: m a -> ( a -> m b) -> m b の代わり
    """
    return concat(list(map(a_to_mb, ma)))

pairs 関数も一応再現できています。

def pairs(xs, ys):
    return bind(
        xs, lambda x: bind(
            ys, lambda y: ret((x, y))
        )
    )

if __name__ == "__main__":
    bar = pairs([1,2], [3,4])
    print(bar)
[(1, 3), (1, 4), (2, 3), (2, 4)]

説明

モナドは returnbind>>=)を実装する必要があります。

return の実装(?)

まず、returnret という関数で再現することにします。

return x = [x]

つまりこうです。

def ret(a):
    return [a]

こうすると実際に値を返すとき return ret(x) みたいになって冗長ですが、return 自体を上書きするような表現が思いつかなかったので、一旦こんな感じにします。

bind の実装

次に bind ですが、Haskell側の実装を見ると、先に concat を実装しないといけません。

xs >>= f = concat (map f xs)

concat の定義はこうです。

concat :: [[a]] -> [a]
concat = foldr (++) []

foldr が出てきました。

foldr :: (a -> b -> b) -> b -> [a] -> b
foldr _ acc []     = acc
foldr f acc (x:xs) = f x (foldr f acc xs)

これも Python で実装しないといけません。
順番にやっていきます。

foldr の実装

foldr です。

foldr :: (a -> b -> b) -> b -> [a] -> b
foldr _ acc []     = acc
foldr f acc (x:xs) = f x (foldr f acc xs)

例えばこう使います。

Prelude> foldr (+) 1 [2, 3]
6
Prelude> foldr (+) 1 [2, 3, 4]
10
Prelude> foldr (*) 2 [2, 3]
12
Prelude> foldr (*) 2 [2, 3, 4]
48

第三引数のリストの先頭を取り出し、それを第二引数と一緒に第一引数の関数に渡します。
最初の例では 1 + 2 + 3 を計算しており、3つ目の例では 2 * 2 * 3 を計算しています。

Pythonでそのまま再帰関数を使って再現します。

def foldr(f, z, xxs):
    if xxs == []:
        return z

    x = xxs[0]
    xs = xxs[1:]
    return f(x, foldr(f, z, xs))

使ってみます。

>>> plus = lambda x, y: x + y
>>> foldr(plus, 1, [2, 3])
6
>>> mult = lambda x, y: x * y
>>> foldr(mult, 2, [2, 3, 4])
48

大丈夫みたいです。

concat の実装

次は concat です。

concat :: [[a]] -> [a]
concat = foldr (++) []

++ は配列の結合です。

Prelude> "A" ++ "B"
"AB"
Prelude> [1] ++ [2]
[1,2]

配列の各要素を取り出して結合しています。

Prelude> concat [["1"], ["2"], ["3"]]
["1","2","3"]
Prelude> concat [[1], [2], [3]]
[1,2,3]
Prelude> concat ["1", "2", "3"]
"123"

これも、上で実装した foldr を使って concat を実装します。

def concat(xxs):
    plus = lambda x, y: x + y
    return foldr(plus, [], xxs)

ただし、これはリストの結合が + で書けることに全力で依存しています。
しかも、[[1], [2], [3]][1, 2, 3] の形式にすることしかできません。

>>> concat([["1"], ["2"], ["3"]])
["1", "2", "3"]
>>> concat([[1], [2], [3]])
[1, 2, 3]
>>> concat(["1", "2", "3"])
TypeError: can only concatenate str (not "list") to str

なぜかというと、Haskell では文字列含めいろんな階層の配列の初期値(?)が [] と書けるのに対し、Python はそうじゃない(文字列なら ""、リストなら [])からです。

例えば Haskell はこういうことができます。

Prelude> [] ++ "A"
"A"
Prelude> [] ++ ["A"]
["A"]
Prelude> [] ++ [["A"]]
[["A"]]

Python ではもちろん二つ目しかできません。

条件分けして云々…は面倒なのでやらないことにしました。

再び bind の実装

改めて、bind の演算はこうなります。

xs >>= f = concat (map f xs)

動きとしては、こうなります。

Prelude> [1, 2] >>= \x -> [x, x*3]
[1,3,2,6]

先ほどまでで concat が実装できたので、あとはそのまま Python で置き換えます。

def bind(ma, a_to_mb):
    return concat(list(map(a_to_mb, ma)))

変数名も対応させるとこうでしょうか。

def bind(xs, f):
    return concat(list(map(f, xs)))

実際に使えるかどうかみてみます。

>>> bind([1, 2], lambda x: [x, x * 3])
[1, 3, 2, 6]

大丈夫そうです。

用例も考える

使用例としてこんなものを考えます。冒頭で触れたやつそのままです。

pairs :: [Int] -> [Int] -> [(Int, Int)]
pairs xs ys = do
    x <- xs
    y <- ys
    return (x, y)

これも冒頭でも触れた通りdo記法を使わずに >>= を使うと、こうなります。

pairs xs ys = xs >>= \x -> ys >>= \y -> return (x, y)

実行結果はこうです。

Prelude> pairs [1, 2] [3, 4]
[(1, 3), (1, 4), (2, 3), (2, 4)]

これをそのまま Python で書きます。

def pairs(xs, ys):
    return bind(
        xs, lambda x: bind(
            ys, lambda y: ret((x, y))
        )
    )

そのままです。
実行結果も再現できています。

>>> pairs([1,2], [3,4])
[(1, 3), (1, 4), (2, 3), (2, 4)]

do記法もやりたいが…

do記法だと見た目がスッキリするので、これもPythonで表現できたらカッコ良さそうです。

pairs :: [Int] -> [Int] -> [(Int, Int)]
pairs xs ys = do
    x <- xs
    y <- ys
    return (x, y)

見た目だけ似せると、こんな感じでしょうか。

def pairs(xs, ys):
    x = hogehoge
    y = fugafuga
    return ret((x, y))

でも、欲しいのはリストなので、return ret((x, y)) のままではダメです。
ループして結合する必要があります。

なので、何かデコレータを使って外からループしてやる感じになるのでしょうか。

わかりません!!!!!!!!!!!!

モナド則を満たすのか?

モナドは以下を満たしていないといけません。

  • return a >>= ff a は等価
  • m >>= returnm は等価
  • (m >>= f) >>= gm >>= (\x -> f x >>= g) は等価

次回はこれを満たすことを確認したいと思います。

Discussion

ログインするとコメントできます