👻

mlxでStable Diffusionを動かす

2023/12/14に公開

こんにちは、noppeです。
先日Apple シリコンに最適化されたMLフレームワーク、MLXが発表されました。

https://github.com/ml-explore/mlx

自分は門外漢なので、MLが何が何だか分かりませんがとりあえず触ってみましょう。


A fox jumping on the sea.

これは、今回MLXのStable Diffusionを使って生成した画像です。
通常のStable Diffusion同様に画像が生成されているように見えます。
余談ですが、Stable Diffusionの狐はこの毛色で生成される事が多い気がします。

mlxにはいくつかのサンプルが用意されているので、試すだけであれば特に難しいことはありませんでした。

gh repo clone ml-explore/mlx-examples
cd mlx-examples/stable_diffusion
pip install -r requirements.txt
python txt2image.py "A fox jumping on the sea." --n_images 1 --n_rows 1

特にハマることもなく動作しました。
モデルは初回の実行時に取得されます。

mlxのREADMEには

MLX has a Python API that closely follows NumPy.

と書かれていて、NumPyのような演算系の処理もフレームワークに含まれているようです。
実際、今回触れたDiffusionのサンプルコードでもsimplifyやconcatenateといった演算がmlxを使って行われていました。

for x_t in tqdm(latents, total=args.steps):
    mx.simplify(x_t)
    mx.simplify(x_t)
    mx.eval(x_t)

# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
    decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
    mx.eval(decoded[-1])

# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])

Discussion