😸

StreamDiffusionとQuest3で視界をリアルタイム変換してみる

に公開

日本では発売されていないですが Ray-Ban Meta や今後発売される Meta Orion など、これから MRグラス + AIの組み合わせが盛り上がってきそうです。
そんな世界がどういったものかを考えてみたくて 最近リリースされた Meta Quest 3 パススルーAPI + 画像生成AIの組み合わせで視界をリアルタイムで変換する、みたいなことをやってみました。

概要

MetaからQuest3のパススルーAPIがリリースされました。Quest3では画像生成AIの動作は厳しいので、Quest3のパススルー映像をPCに送信、PC上で画像生成、それをまたQuest3に返信して表示する。
というのをやってみます。
StableDiffusionは生成に時間がかかるので早いと噂のStreamDiffusionを使います。

Quest3のパススルー映像の中心から512ピクセル×512ピクセルを切り出してStreamDiffusionに送信、返信も512ピクセル×512ピクセルで、これを目の前に表示します。
SD1.5のモデルに対応しています。

構成

表示結果

この記事では、Apache License 2.0 のもとで公開されている StreamDiffusionのコードを一部改変してテストを行った内容を紹介します。
元コードは以下のライセンスに基づいて提供されています:
Apache License, Version 2.0
http://www.apache.org/licenses/LICENSE-2.0

環境

サーバー

  • PC
    • CPU:Intel Core i7-8700 CPU @ 3.20GHz
    • メモリ:32G
    • GPU:NVIDIA Gfource RTX 4070Ti (メモリ12G)
    • OS:Windows11
  • Cuda compilation tools, release 12.1, V12.1.66
  • cudnn-windows-x86_64-8.9.7.29_cuda12-archive
  • Python 3.10.11
    ※Cuda 12.1でないとダメそう

クライアント

  • Unity 2022.3.52f1
  • Meta Quest 3
    • OS:Meta Horizon OS v76.1024

手順

サーバー

StreamDiffusion のインストール

ソースDL、仮想環境作成

git clone https://github.com/cumulo-autumn/StreamDiffusion.git
cd StreamDiffusion
python -m venv venv
venv\Scripts\activate

Stream Diffusion セットアップ

python -m pip install --upgrade pip
pip3 install torch==2.1.0 torchvision==0.16.0 xformers --index-url https://download.pytorch.org/whl/cu121
pip install git+https://github.com/cumulo-autumn/StreamDiffusion.git@main#egg=streamdiffusion[tensorrt]
pip install "numpy<2.0"
pip install --upgrade diffusers transformers huggingface-hub
python -m streamdiffusion.tools.install-tensorrt
pip install pywin32
pip install peft

コード

StreamDiffusion\examples\img2img\single.py を改造しています。

single
# Licensed under the Apache License, Version 2.0
#   https://www.apache.org/licenses/LICENSE-2.0
# Modified by Toshimasa Shiraki

import os
import sys
from typing import Literal, Dict, Optional

##ネットワーク関連
import socket
import struct
from io import BytesIO
from PIL import Image

import fire

# グローバル変数
stream = None

sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))

from utils.wrapper import StreamDiffusionWrapper

CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))


def main(
    input: str = os.path.join(CURRENT_DIR, "..", "..", "images", "inputs", "input.png"),
    output: str = os.path.join(CURRENT_DIR, "..", "..", "images", "outputs", "output.png"),
    model_id_or_path: str = "KBlueLeaf/kohaku-v2.1",
    lora_dict: Optional[Dict[str, float]] = None,
    prompt: str = "1girl with brown dog hair, thick glasses, smiling",
    negative_prompt: str = "low quality, bad quality, blurry, low resolution",
    width: int = 512,
    height: int = 512,
    acceleration: Literal["none", "xformers", "tensorrt"] = "xformers",
    use_denoising_batch: bool = True,
    guidance_scale: float = 1.2,
    cfg_type: Literal["none", "full", "self", "initialize"] = "self",
    seed: int = 2,
    delta: float = 0.5,
):
    """
    Initializes the StreamDiffusionWrapper.

    Parameters
    ----------
    input : str, optional
        The input image file to load images from.
    output : str, optional
        The output image file to save images to.
    model_id_or_path : str
        The model id or path to load.
    lora_dict : Optional[Dict[str, float]], optional
        The lora_dict to load, by default None.
        Keys are the LoRA names and values are the LoRA scales.
        Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...}
    prompt : str
        The prompt to generate images from.
    negative_prompt : str, optional
        The negative prompt to use.
    width : int, optional
        The width of the image, by default 512.
    height : int, optional
        The height of the image, by default 512.
    acceleration : Literal["none", "xformers", "tensorrt"], optional
        The acceleration method, by default "tensorrt".
    use_denoising_batch : bool, optional
        Whether to use denoising batch or not, by default True.
    guidance_scale : float, optional
        The CFG scale, by default 1.2.
    cfg_type : Literal["none", "full", "self", "initialize"],
    optional
        The cfg_type for img2img mode, by default "self".
        You cannot use anything other than "none" for txt2img mode.
    seed : int, optional
        The seed, by default 2. if -1, use random seed.
    delta : float, optional
        The delta multiplier of virtual residual noise,
        by default 1.0.
    """

    if guidance_scale <= 1.0:
        cfg_type = "none"

    global stream
    stream = StreamDiffusionWrapper(
        model_id_or_path=model_id_or_path,
        lora_dict=lora_dict,
        t_index_list=[22, 32, 45],
        frame_buffer_size=1,
        width=width,
        height=height,
        warmup=10,
        acceleration=acceleration,
        mode="img2img",
        use_denoising_batch=use_denoising_batch,
        cfg_type=cfg_type,
        seed=seed,
    )

    stream.prepare(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=50,
        guidance_scale=guidance_scale,
        delta=delta,
    )

    ##通信開始
    print("通信開始")
    start_server()


##ネットワーク関連
def handle_client(conn):
    while True:
        try:
            # 1. Unityからの画像サイズ取得
            size_data = conn.recv(4)
            if not size_data:
                break
            size = struct.unpack("<I", size_data)[0]

            # 2. Unityからの画像データ受信
            img_bytes = b''
            while len(img_bytes) < size:
                img_bytes += conn.recv(size - len(img_bytes))

            # 3. Pillowで画像処理
            input_img = Image.open(BytesIO(img_bytes))
            print("画像受信:", input_img.size)

            # 4. i2i処理(ここに生成AIを追加)
            #output_img = input_img.transpose(Image.FLIP_LEFT_RIGHT)  # 仮の変換処理
            image_tensor = stream.preprocess_image(input_img)
            output_image = stream(image=image_tensor)
            print("画像生成:", output_image.size)

            # 5. 結果をJPEGで送信
            buf = BytesIO()
            output_image.save(buf, format="JPEG")
            result_bytes = buf.getvalue()

            conn.send(struct.pack("<I", len(result_bytes)))
            conn.send(result_bytes)
            print("画像送信:")


        except Exception as e:
            print("通信エラー:", e)
            break

def start_server(host="0.0.0.0", port=5001):
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind((host, port))
    s.listen(1)
    print(f"待機中... ポート {port}")
    conn, addr = s.accept()
    print(f"接続: {addr}")
    handle_client(conn)
    conn.close()

##メイン
if __name__ == "__main__":
    fire.Fire(main)

実行

python single.py --model_id_or_path "./kohaku-v2.1.safetensors" --prompt "best quality,ultra detailed,3d model,realistic" --negative_prompt "worst quality,out of focus,JPEG artifacts,low resolution,error" --width 512 --height 512 --acceleration xformers --cfg_type initialize

※モデルはローカルを指定しないと動かないようです。パラメーターでプロンプト、ネガティブプロンプトも指定可能です。
実行すると通信待機になります。

クライアント

Unityのアプリ

サンプルプロジェクトの「CameraViewer」サンプルを流用します
https://github.com/oculus-samples/Unity-PassthroughCameraApiSamples

コード

以下を作成します。

ImageSendAndReceive
using System.Collections;
using UnityEngine;
using UnityEngine.UI;
using System.Net.Sockets;
using System.Threading;
using System.IO;

public class ImageSendAndReceive : MonoBehaviour
{
    public RawImage displayImage;  // 入力画像
    public RawImage display;       // 出力画像(結果表示)
    public string serverIP = "192.168.0.5";
    public int port = 5001;

    TcpClient client;
    NetworkStream stream;
    private UnityMainThreadDispatcher dispatcher;
    private Texture2D reusableTex;

    void Start()
    {
        dispatcher = UnityMainThreadDispatcher.Instance();
        if (dispatcher == null)
        {
            UnityEngine.Debug.LogError("UnityMainThreadDispatcher がシーンに存在しません!");
            return;
        }

        client = new TcpClient(serverIP, port);
        stream = client.GetStream();

        reusableTex = new Texture2D(512, 512, TextureFormat.RGB24, false);
        StartCoroutine(SendAndReceiveCoroutine());

        // displayのサイズを固定(512x512)
        if (display != null)
        {
            display.rectTransform.sizeDelta = new Vector2(512, 512);
        }
    }

    IEnumerator SendAndReceiveCoroutine()
    {
        Stopwatch stopwatch = new Stopwatch();

        while (true)
        {
            if (displayImage.texture == null)
            {
                yield return new WaitForSeconds(0.03f);
                continue;
            }

            stopwatch.Restart();

            int width = displayImage.texture.width;
            int height = displayImage.texture.height;

            // 中央512×512の座標を計算
            int x = Mathf.Max((width - 512) / 2, 0);
            int y = Mathf.Max((height - 512) / 2, 0);

            // RenderTextureを使って切り出し
            RenderTexture currentRT = RenderTexture.active;
            RenderTexture tmpRT = RenderTexture.GetTemporary(width, height);

            Graphics.Blit(displayImage.texture, tmpRT);
            RenderTexture.active = tmpRT;

            reusableTex.ReadPixels(new Rect(x, y, 512, 512), 0, 0);
            reusableTex.Apply();

            RenderTexture.active = currentRT;
            RenderTexture.ReleaseTemporary(tmpRT);

            byte[] jpg = reusableTex.EncodeToJPG();

            try
            {
                byte[] size = System.BitConverter.GetBytes(jpg.Length);
                stream.Write(size, 0, 4);
                stream.Write(jpg, 0, jpg.Length);

                byte[] recvSizeBytes = new byte[4];
                stream.Read(recvSizeBytes, 0, 4);
                int recvSize = System.BitConverter.ToInt32(recvSizeBytes, 0);

                byte[] recvImg = new byte[recvSize];
                int total = 0;
                while (total < recvSize)
                    total += stream.Read(recvImg, total, recvSize - total);

                dispatcher.Enqueue(() =>
                {
                    Texture2D rtex = new Texture2D(512, 512);
                    rtex.LoadImage(recvImg);
                    display.texture = rtex;
                });
            }
            catch (IOException e)
            {
                UnityEngine.Debug.LogError("通信エラー: " + e.Message);
                break;
            }

            stopwatch.Stop();
            float elapsed = stopwatch.ElapsedMilliseconds / 1000f;
            float fps = 1f / elapsed;
            UnityEngine.Debug.Log($"送受信FPS: {fps:F2} ({elapsed:F3}秒)");

            yield return null;
        }
    }

    void OnDestroy()
    {
        stream?.Close();
        client?.Close();
    }
}
UnityMainThreadDispatcher
using System;
using System.Collections.Generic;
using UnityEngine;

public class UnityMainThreadDispatcher : MonoBehaviour
{
    private static UnityMainThreadDispatcher _instance;
    private static readonly Queue<Action> _executionQueue = new Queue<Action>();

    public static UnityMainThreadDispatcher Instance()
    {
        if (_instance == null)
        {
            _instance = FindObjectOfType<UnityMainThreadDispatcher>();
            if (_instance == null)
            {
                Debug.LogError("UnityMainThreadDispatcher is not present in the scene. Please add it manually.");
            }
        }
        return _instance;
    }

    void Update()
    {
        lock (_executionQueue)
        {
            while (_executionQueue.Count > 0)
            {
                var action = _executionQueue.Dequeue();
                action?.Invoke();
            }
        }
    }

    public void Enqueue(Action action)
    {
        if (action == null) return;

        lock (_executionQueue)
        {
            _executionQueue.Enqueue(action);
        }
    }
}

サンプルのカメラを表示するRawImageを複製して、受信用のRawImageを作成します


受信用のRawImageは邪魔なので見えない位置に適当に動かします(真後ろとか)。


サイズを512ピクセル×512ピクセルにします。


空のGameObjectを作成して(ここではImageSendAndReceive)ImageSendAndReceive.csとUnityMainThreadDispatcher.csをアタッチします。


インスペクタのImageSendAndReceiveに以下を設定します

  • Display Image:カメラを表示するRawImage
  • Display:受信用のRawImage
  • Server IP:サーバーのIP
  • Port:サーバーのポート番号
    ※PC側で設定したポートを使用できるようにしておく必要があります

ビルドしてQuest3にインストールしてください。

実行

先にサーバーを実行して、待機中になった後、Quest3側のアプリを起動します。

結果

こんな感じ


この表示のサーバー側コマンド

python single.py --model_id_or_path "./ghostmix_v20Bakedvae.safetensors" --prompt "score_9, score_8_up, score_7_up, score_6_up, high resolution, minimalistic style, psychedelic mushrooms" --negative_prompt "score_6, score_5, score_4, censored, 3d, monochrome, watermark, muscular female, bad hands" --width 512 --height 512 --acceleration xformers --cfg_type initialize

自分の環境では7fpsくらい出ます
RTX2070だと2fpsくらいでした

ホロラボのテックブログ

Discussion