🙌

WebGPUで動くstable-diffusionのGUIをtauriで作る

2023/12/24に公開

この記事はRust Advent Calendar 2023の25日目の記事です。

はじめに

今年は生成AIが盛り上がりまくりましたが、
生成AIの盛り上がりの一つの理由にChatGPTのようにWeb上のAPI経由で汎用的なAIモデルを実行できるようにしたことで、様々なプラットフォームでAIアプリの開発が可能になったことが挙げられるのではないかと思っています。
AIアプリを様々なプラットフォームで実現する上で、WebAPI化というのも一つの方法なのですが、ChatGPTのようなインターネットを必要とするAIを使いたくない場合には、ローカルで動かせて、かつ特定のハードウェアデバイスに依らない、マルチプラットフォームなフレームワークを用いることも選択肢に入ってくるかと思います。

WebGPUとは?

https://www.w3.org/TR/webgpu/

深層学習でGPUといえば、NVidiaのCUDAがよく使われています。よく使用される深層学習フレームワークなどはNVidia以外のAMDやM1などのGPUも使用できますが、それぞれのGPUでのSDKが異なるため、それに対応したバックエンドに入れ替えが必要であったり、一部の機能が未対応ということがあったりします。
WebGPUはもともとブラウザ上でグラフィクス処理のために使われていたWebGLの後継となるAPIで、さまざまなGPUを共通のAPIで扱うことができます。
WebGPUによってGPU毎の異なるSDKに対応するといった機能の開発を減らすことができるようになります。

tauriとは?

https://tauri.app/

tauriはrustでデスクトップアプリケーションを作成するためのフレームワークです。
マルチプラットフォームに対応したアプリケーションを開発することができ、現在アルファ版ではありますが、iOS/Androidといったモバイルに対応したアプリケーションも開発可能です。

WebGPUに対応したRustの深層学習ライブラリ「burn」

https://github.com/tracel-ai/burn

burnはWebGPUに対応したRust製の深層学習ライブラリです。
WebGPUの他にもtorchのrust版のtch-rsやCPUで動かしたい場合はNdArrayをバックエンドに選択できます。
ここ最近になって、stable-diffusionを含め有名モデルをburnに移植したものがいくつかgithub上に挙げられており(以下)、個人的に今かなりアツいフレームワークです。

まずはWebGPUでstable-diffusionを動かす

まずは上記のstable-diffusion-burnをWebGPUで動かしたいと思います。
上記のリポジトリは最新のburnに対応できていないのと、tchがデフォルトのバックエンドになっているので、はじめからWebGPUで動くようにフォークして修正しました。

https://github.com/neka-nat/stable-diffusion-burn-wgpu

これを取ってきて、デモを動かします。

git clone https://github.com/neka-nat/stable-diffusion-burn-wgpu.git
cd stable-diffusion-burn-wgpu
cargo run --release --bin sample 7.5 20 "An ancient mossy stone." img

私の環境だとGPUのRAMが1000MiBほど使用されており、数十秒から1分くらいで苔の生えた石の画像が生成されました。

tauriとstable-diffusionを繋げる

それではtauriとstable-diffusionを繋げていきます。
今回はシンプルに、プロンプトの入力部分と生成ボタンと画像表示部分からなるUIを生成していきます。
tauriのアプリケーションを作成します。

cargo install create-tauri-app --locked
cargo create-tauri-app

今回はフロントエンドはReactで作成しました。
ほぼ最初のテンプレートのままで、画像の出力とstable-diffusion周りの初期化を追加した形になります。
画像はRGB画像をpngに変換して、base64にエンコードしてフロントに送っています。以下が該当コードです。

struct SDState {
    sd: StableDiffusion<Wgpu<AutoGraphicsApi, f32, i32>>,
    tokenizer: stablediffusion_wgpu::tokenizer::SimpleTokenizer,
}

fn convert_rgb_to_png(data: Vec<u8>, width: u32, height: u32) -> Vec<u8> {
    let img: RgbImage =
        ImageBuffer::from_raw(width, height, data).expect("Failed to create image buffer");

    let mut png_data = Vec::new();
    let mut cursor = Cursor::new(&mut png_data);
    img.write_to(&mut cursor, image::ImageOutputFormat::Png)
        .expect("Failed to write PNG buffer");

    cursor.into_inner().to_vec()
}

#[tauri::command]
fn generate(prompt: &str, state: State<SDState>) -> Result<String, ()> {
    let image = generate_image::generate_image(&state.sd, &state.tokenizer, 2.5, 20, prompt);
    let image = image[0].clone();
    let png_data = convert_rgb_to_png(image, 512, 512);
    Ok(base64::encode(png_data))
}

フロントエンド側は以下のようになります。

import { useState } from "react";
import { invoke } from "@tauri-apps/api/tauri";
import "./App.css";

function App() {
  const [imageData, setImageData] = useState("");
  const [prompt, setPrompt] = useState("");

  const handleGenerateImage = async () => {
    try {
      const data = await invoke('generate', { prompt });
      setImageData(`data:image/png;base64,${data}`);
    } catch (error) {
      console.error('Error generating image:', error);
    }
  };

  return (
    <div className="container">
      <h1>Stable diffusion WebGPU</h1>

      <form
        className="row"
        onSubmit={(e) => {
          e.preventDefault();
          handleGenerateImage();
        }}
      >
        <input
          id="generate-input"
          onChange={(e) => setPrompt(e.currentTarget.value)}
          placeholder="Enter a prompt..."
        />
        <button type="submit">Generate</button>
      </form>

      {imageData && <img src={imageData} alt="Generated" />}
    </div>
  );
}

export default App;

立ち上げて、実際に画像生成してみます。

yarn tauri dev

テキストを入力してGenerateボタンを押して数十秒待つと画像が表示されます。

表示画像に合わせたUIの調整など、細かな修正は色々ありそうですが、一通り動きました。
最終的なコードは以下に挙げています。
https://github.com/neka-nat/stable-diffusion-tauri-ui

まとめ

今回はWebGPUでstable-diffusionを動かしつつ、tauriを使ってアプリケーション化してみました。
WebGPUもtauriもマルチプラットフォームで使用できる技術であり、モバイルへの適用も可能というのもあって、今後も注目していきつつ、いろいろ触ってみたいと思います。

Discussion