🐨

Unity Sentisを使ってUnityランタイム上でAIモデルを使えるようにする

2024/11/03に公開

Unity Sentisとは

Unityランタイム上でAIモデル実行を可能とするサービスです。
2023年6月に登場しました。

https://unity.com/ja/products/sentis

クラウドにアクセスする必要がないため、ネットワークが通ってない or 弱い箇所でもAIモデルを実行することができて、コストも無料です。
また対応するプラットフォームも、Unityがサポートしているもの全てで動かすことができます。

使えるモデルはONNX規格のモデルであれば使用することができます。

つくるもの

9ヶ月前に作ったアプリをUnity Sentisで再現してみます。

https://zenn.dev/headwaters/articles/2dd879438294c1

↑の記事では以下のサービスを使いました。

  • GPT-4...Azure OpenAI Service
  • STT...Azure AI Speech Service
  • TTS...Azure AI Speech Service

音声テキスト変換するのも、回答を生成するのも、音声合成するのも全てクラウドを通してました。
Unity Sentisを使って全てローカル実行に変えてみます。

環境構築

前提

使用するAIモデルはHugging Faceにアップされているものを使います。
Models画面でunity-sentisで絞り込まれたモデルを活用することができます。

https://huggingface.co/models?library=unity-sentis


1. 環境構築

バージョンは2023.2.18f1を使用します。(バージョン2023であれば何でも大丈夫かと。)


「Assets」フォルダ直下に「Scripts」フォルダと「StreamingAssets」フォルダを作成。


2. パッケージ導入

「Windows」タブ→「Package Manager」を開く。
左上の「+」ボタンから「Install package by name...」をクリックする。


Nameには「com.unity.sentis」を入力
versionには「1.5.0-pre.2」を入力してインストール


もう一つ、「com.unity.nuget.newtonsoft-json」を入力してインストールしてください。

音声テキスト変換の組み込み

1. モデルを導入

音声テキスト変換には「sentis-whisper-tiny」モデルを使用します。

https://huggingface.co/unity/sentis-whisper-tiny


「Files and versions」に移動して、以下のファイルをダウンロードします。

  • AudioDecoder_Tiny.sentis
  • AudioEncoder_Tiny.sentis
  • LogMelSepctro.sentis
  • vocab.json

「StreamingAssets」フォルダ直下に「Whisper」フォルダを作成して、ダウンロードしたファイルを格納します。


ロジックは以下の記事が非常に参考になりましたので、使わせてもらいます。

https://note.com/ugee/n/n383d583cef52

「Scripts」フォルダに↑の記事で実装されている3つのファイルを格納。


2. ロジック修正

WhisperEX.cs

StreamingAssetsフォルダ内にサブフォルダを作っているので、変数を修正します。

WhisperEX.cs
public class WhisperEx
{
    // Assets/StreamingAssetsにあるSentisモデルを取得
    public const string LogMelSpectroModelName = "Whisper/LogMelSepctro.sentis"; // Log Melスペクトログラムモデル名
    public const string EncoderModelName = "Whisper/AudioEncoder_Tiny.sentis"; // オーディオエンコーダーモデル名
    public const string DecoderModelName = "Whisper/AudioDecoder_Tiny.sentis"; // オーディオデコーダーモデル名
    public const string VocabName = "Whisper/vocab.json"; // 語彙ファイル名

    public const int MaxTokens = 100; // 最大トークン数
    public const int EndOfText = 50257; // テキストの終了を示すトークン
    public const int StartOfTranscript = 50258; // テキストの開始を示すトークン
    public const int Transcribe = 50359; // 特定の言語で音声をテキストに変換するトークン
    public const int Translate = 50358; // 英語に翻訳するトークン
    public const int NoTimeStamps = 50363; // タイムスタンプを削除するトークン
    public const int StartTime = 50364; // タイムスタンプの開始を示すトークン

    public enum Language
    {
        English = 50259, // 英語
        Korean = 50264, // 韓国語
        Japanese = 50266 // 日本語
    }
}


WhisperPresenter.cs

  1. WhisperModelとPhi15コンポーネントを読み込む
  2. Buttonをクリックしてレコーディングの開始・停止するためのPublic関数を用意
  3. Update関数は不要になるので、コメントアウト
  4. ProcessVoiceInputで文字起こしされたテキストを引数に渡して、Phi15の関数を実行
WhisperPresenter.cs
using UnityEngine;

public class WhisperPresenter : MonoBehaviour
{
    [SerializeField] private WhisperModel whisperModel;
    [SerializeField] private RunPhi15 runPhi15;

    private void OnValidate()
    {
#if UNITY_EDITOR
        if (!whisperModel) whisperModel = GetComponent<WhisperModel>();
        if (!runPhi15) runPhi15 = GetComponent<RunPhi15>();
#endif
    }

    public void startRecording() {
        whisperModel.StartRecording();
    }

    public void stopRecording() {
        ProcessVoiceInput();
    }
    // private void Update()
    // {
    //     if (Input.GetKeyDown(KeyCode.LeftCommand))
    //     {
    //         whisperModel.StartRecording(); // 左Ctrlキーが押されたときに録音を開始
    //     }
    //     else if (Input.GetKeyUp(KeyCode.LeftCommand))
    //     {
    //         ProcessVoiceInput(); // 左Ctrlキーが離されたときに音声入力を処理
    //     }
    // }

    private async void ProcessVoiceInput()
    {
        var result = await whisperModel.StopRecording(); // 録音を停止し、文字起こしを開始
        Debug.Log($"{result}");
        runPhi15.StartInference(result); // 文字起こし結果を入力として、生成テキストを生成
    }
}

生成AIの組み込み

1. モデルを導入

生成AIには「Phi-1.5」モデルを使用します。

https://huggingface.co/unity/sentis-phi-1_5


「Files and versions」に移動して、以下のファイルをダウンロードします。

  • phi15.sentis
  • merges.txt
  • RunPhi15.cs
  • vocab.json

「StreamingAssets」フォルダ直下に「Phi15」フォルダを作成して、RunPhi15.csファイル以外を格納します。
RunPhi15.csファイルは「Scripts」フォルダに格納。


2. ロジック修正

RunPhi15.cs

  1. TMProをインポート
  2. RunJetsコンポーネントを読み込む
  3. OutputStringに入ってる文字を削除
  4. ファイルを読み込むパスを修正
  5. WhisperPresenter.csから実行される関数「StartInference」を定義
  6. 「RunInference」関数内で文字生成が終わったタイミングでRunJetsのTTS関数を実行
  7. 文字列を結合するところで、UI上に反映させるためにtextMeshPro変数に文字列を結合
RunPhi15.cs
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.Sentis;
using System.IO;
using System.Text;
using FF = Unity.Sentis.Functional;
using TMPro;

/*
 *              Phi1.5 Inference Code
 *              ===========================
 *  
 *  Put this script on the Main Camera
 *  
 *  In Assets/StreamingAssets put:
 *  
 *  phi15.sentis (or put in asset folder)
 *  vocab.json
 *  merges.txt
 * 
 *  Install package com.unity.nuget.newtonsoft-json from packagemanger
 *  Install package com.unity.sentis
 * 
 */


public class RunPhi15: MonoBehaviour
{
    [SerializeField] private RunJets runJets;
    public TextMeshProUGUI textMeshPro;

    //Drop the tinystories.sentis or onnx file on here if using an asset:
    //public ModelAsset asset;
    const BackendType backend = BackendType.GPUCompute;

    string outputString = "";

    // This is how many tokens you want. It can be adjusted.
    const int maxTokens = 100;

    //Make this smaller for more randomness
    const float predictability = 5f;

    //Special tokens
    const int END_OF_TEXT = 50256;

    //Store the vocabulary
    string[] tokens;

    IWorker engine;

    int currentToken = 0;
    int[] outputTokens = new int[maxTokens];

    // Used for special character decoding
    int[] whiteSpaceCharacters = new int[256];
    int[] encodedCharacters = new int[256];

    bool runInference = false;


    //stop after this many tokens
    const int stopAfter = 100;

    int totalTokens = 0;

    string[] merges;
    Dictionary<string, int> vocab;

    void Start()
    {
        SetupWhiteSpaceShifts();

        LoadVocabulary();

        var model1 = ModelLoader.Load(Path.Join(Application.streamingAssetsPath , "Phi15/phi15.sentis"));

        int outputIndex = model1.outputs.Count - 1;
        //var model1 = ModelLoader.Load(asset);
        //Create a new model to select the random token:
        var model2 = FF.Compile(
            (input, currentToken) =>
            {
                var row = FF.Select(model1.Forward(input)[outputIndex], 1, currentToken);
                return FF.Multinomial(predictability * row, 1);
            },
            (model1.inputs[0], InputDef.Int(new TensorShape()))
        );

        engine = WorkerFactory.CreateWorker(backend, model2);
    }

    public void StartInference(string generatedText)
    {
        string prompt = "Answer the following questions briefly:" + generatedText;
        DecodePrompt(generatedText);
        runInference = true;
    }

    // Update is called once per frame
    void Update()
    {
        if (runInference)
        {
            RunInference();
        }
    }

    void RunInference()
    {
        using var tokensSoFar = new TensorInt(new TensorShape(1, maxTokens), outputTokens);
        using var index = new TensorInt(currentToken);

        engine.Execute(new Dictionary<string, Tensor> { {"input_0", tokensSoFar },  { "input_1", index }});

        var probs = engine.PeekOutput() as TensorInt;
        //Debug.Log(probs.shape);

        probs.CompleteOperationsAndDownload();

        int ID = probs[0];

        //shift window down if got to the end
        if (currentToken >= maxTokens - 1)
        {
            for (int i = 0; i < maxTokens - 1; i++) outputTokens[i] = outputTokens[i + 1];
            currentToken--;
        }

        outputTokens[++currentToken] = ID;
        totalTokens++;

        if (ID == END_OF_TEXT || totalTokens >= stopAfter)
        {
            runInference = false;
            Debug.Log(outputString);
            runJets.TextToSpeech(outputString);

        }
        // else if (ID < 0 || ID >= tokens.Length)
        // {
        //     // Really we should use the added_tokens.json for this
        //     outputString += " ";
        // }
        else {
            string newWord = GetUnicodeText(tokens[ID]);
            outputString += newWord;
            textMeshPro.text += newWord;
        }
    }

    void DecodePrompt(string text)
    {
        var inputTokens = GetTokens(text);

        for(int i = 0; i < inputTokens.Count; i++)
        {
            outputTokens[i] = inputTokens[i];
        }
        currentToken = inputTokens.Count - 1;
    }
   
    void LoadVocabulary()
    {
        var jsonText = File.ReadAllText(Path.Join(Application.streamingAssetsPath , "Phi15/vocab.json"));
        vocab = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, int>>(jsonText);
        tokens = new string[vocab.Count];
        foreach (var item in vocab)
        {
            tokens[item.Value] = item.Key;
        }

        merges = File.ReadAllLines(Path.Join(Application.streamingAssetsPath , "Phi15/merges.txt"));
    }

    // Translates encoded special characters to Unicode
    string GetUnicodeText(string text)
    {
        var bytes = Encoding.GetEncoding("ISO-8859-1").GetBytes(ShiftCharacterDown(text));
        return Encoding.UTF8.GetString(bytes);
    }
    string GetASCIIText(string newText)
    {
        var bytes = Encoding.UTF8.GetBytes(newText);
        return ShiftCharacterUp(Encoding.GetEncoding("ISO-8859-1").GetString(bytes));
    }

    string ShiftCharacterDown(string text)
    {
        string outText = "";
        foreach (char letter in text)
        {
            outText += ((int)letter <= 256) ? letter :
                (char)whiteSpaceCharacters[(int)(letter - 256)];
        }
        return outText;
    }

    string ShiftCharacterUp(string text)
    {
        string outText = "";
        foreach (char letter in text)
        {
            outText += (char)encodedCharacters[(int)letter];
        }
        return outText;
    }

    void SetupWhiteSpaceShifts()
    {
        for (int i = 0, n = 0; i < 256; i++)
        {
            encodedCharacters[i] = i;
            if (IsWhiteSpace(i))
            {
                encodedCharacters[i] = n + 256;
                whiteSpaceCharacters[n++] = i;
            }
        }
    }

    bool IsWhiteSpace(int i)
    {
        //returns true if it is a whitespace character
        return i <= 32 || (i >= 127 && i <= 160) || i == 173;
    }

    List<int> GetTokens(string text)
    {
        text = GetASCIIText(text);

        // Start with a list of single characters
        var inputTokens = new List<string>();
        foreach(var letter in text)
        {
            inputTokens.Add(letter.ToString());
        }

        ApplyMerges(inputTokens);

        //Find the ids of the words in the vocab
        var ids = new List<int>();
        foreach(var token in inputTokens)
        {
            if (vocab.TryGetValue(token, out int id))
            {
                ids.Add(id);
            }
        }

        return ids;
    }

    void ApplyMerges(List<string> inputTokens)
    {
        foreach(var merge in merges)
        {
            string[] pair = merge.Split(' ');
            int n = 0;
            while (n >= 0)
            {
                n = inputTokens.IndexOf(pair[0], n);
                if (n != -1 && n < inputTokens.Count - 1 && inputTokens[n + 1] == pair[1])
                {
                    inputTokens[n] += inputTokens[n + 1];
                    inputTokens.RemoveAt(n + 1);
                }
                if (n != -1) n++;
            }
        }
    }

    private void OnDestroy()
    {
        engine?.Dispose();
    }
    
}

音声合成の組み込み

1. モデルを導入

音声合成には「sentis-jets-text-to-speech」モデルを使用します。

https://huggingface.co/unity/sentis-jets-text-to-speech


「FIles and versions」に移動して、以下のファイルをダウンロードします。

  • jets-text-to-speech.sentis
  • phoneme_dict.txt
  • RunJets.cs

「StreamingAssets」フォルダ直下に「Jets」フォルダを作成して、RunJets.csファイル以外を格納します。
RunJets.csファイルは「Scripts」フォルダに格納。


2. ロジック修正

RunJets.cs

  1. Start関数とUpdate関数内の「TextToSpeech」関数を削除
  2. ファイルを読み込むパスを修正
  3. TextToSpeech関数が別の箇所から実行できるようにPublicにする
  4. TextToSpeech関数は引数を受け取るようにして、その文字列を音声合成するようにする
RunJets.cs
using System.Collections.Generic;
using UnityEngine;
using Unity.Sentis;
using System.IO;

//                      Jets Text-To-Speech Inference
//                      =============================
//
// This file implements the Jets Text-to-speech model in Unity Sentis
// The model uses phenomes instead of raw text so you have to convert it first.
// Place this file on the Main Camera
// Add an audio source
// Change the inputText
// When running you can press space bar to play it again

public class RunJets : MonoBehaviour
{
    public string inputText = "Once upon a time, there lived a girl called Alice. She lived in a house in the woods.";
    //string inputText = "The quick brown fox jumped over the lazy dog";
    //string inputText = "There are many uses of the things she uses!";

    //Set to true if we have put the phoneme_dict.txt in the Assets/StreamingAssets folder
    bool hasPhenomeDictionary = true;

    readonly string[] phonemes = new string[] { 
        "<blank>", "<unk>", "AH0", "N", "T", "D", "S", "R", "L", "DH", "K", "Z", "IH1", 
        "IH0", "M", "EH1", "W", "P", "AE1", "AH1", "V", "ER0", "F", ",", "AA1", "B", 
        "HH", "IY1", "UW1", "IY0", "AO1", "EY1", "AY1", ".", "OW1", "SH", "NG", "G", 
        "ER1", "CH", "JH", "Y", "AW1", "TH", "UH1", "EH2", "OW0", "EY2", "AO0", "IH2", 
        "AE2", "AY2", "AA2", "UW0", "EH0", "OY1", "EY0", "AO2", "ZH", "OW2", "AE0", "UW2", 
        "AH2", "AY0", "IY2", "AW2", "AA0", "\"", "ER2", "UH2", "?", "OY2", "!", "AW0", 
        "UH0", "OY0", "..", "<sos/eos>" };

    readonly string[] alphabet = "AE1 B K D EH1 F G HH IH1 JH K L M N AA1 P K R S T AH1 V W K Y Z".Split(' ');

    //Can change pitch and speed with this for a slightly different voice:
    const int samplerate = 22050;

    Dictionary<string, string> dict = new ();

    IWorker engine;

    AudioClip clip;

    void Start()
    {
        LoadModel();
        ReadDictionary();
        // TextToSpeech();
    }

    void LoadModel()
    {
        var model = ModelLoader.Load(Path.Join(Application.streamingAssetsPath ,"Jets/jets-text-to-speech.sentis"));
        engine = WorkerFactory.CreateWorker(BackendType.GPUCompute, model);
    }

    public void TextToSpeech(string generatedText)
    {
        string ptext;
        if (hasPhenomeDictionary)
        {
            ptext = TextToPhonemes(generatedText);
            Debug.Log(ptext);
        }
        else
        {
            //If we have no phenome dictionary we can use one of these examples:
            ptext = "DH AH0 K W IH1 K B R AW1 N F AA1 K S JH AH1 M P S OW1 V ER0 DH AH0 L EY1 Z IY0 D AO1 G .";
            //ptext = "W AH1 N S AH0 P AA1 N AH0 T AY1 M , AH0 F R AA1 G M EH1 T AH0 P R IH1 N S EH0 S . DH AH0 F R AA1 G K IH1 S T DH AH0 P R IH1 N S EH0 S AH0 N D B IH0 K EY1 M AH0 P R IH1 N S .";
            //ptext = "D UW1 P L AH0 K EY2 T";
        }
        DoInference(ptext);
    }

    void ReadDictionary()
    {
        if (!hasPhenomeDictionary) return;
        string[] words = File.ReadAllLines(Path.Join(Application.streamingAssetsPath,"Jets/phoneme_dict.txt"));
        for (int i = 0; i < words.Length; i++)
        {
            string s = words[i];
            string[] parts = s.Split();
            if (parts[0] != ";;;") //ignore comments in file
            {
                string key = parts[0];
                dict.Add(key, s.Substring(key.Length + 2));
            }
        }
        // Add codes for punctuation to the dictionary
        dict.Add(",", ",");
        dict.Add(".", ".");
        dict.Add("!", "!");
        dict.Add("?", "?");
        dict.Add("\"", "\"");
        // You could add extra word pronounciations here e.g.
        //dict.Add("somenewword","[phonemes]");
    }

    public string ExpandNumbers(string text)
    {
        return text
            .Replace("0", " ZERO ")
            .Replace("1", " ONE ")
            .Replace("2", " TWO ")
            .Replace("3", " THREE ")
            .Replace("4", " FOUR ")
            .Replace("5", " FIVE ")
            .Replace("6", " SIX ")
            .Replace("7", " SEVEN ")
            .Replace("8", " EIGHT ")
            .Replace("9", " NINE ");
    }

    public string TextToPhonemes(string text)
    {
        string output = "";
        text = ExpandNumbers(text).ToUpper();

        string[] words = text.Split();
        for (int i = 0; i < words.Length; i++)
        {
            output += DecodeWord(words[i]);
        }
        return output;
    }

    //Decode the word into phenomes by looking for the longest word in the dictionary that matches
    //the first part of the word and so on. 
    //This works fairly well but could be improved. The original paper had a model that
    //dealt with guessing the phonemes of words
    public string DecodeWord(string word)
    {
        string output = "";
        int start = 0;
        for (int end = word.Length; end >= 0 && start < word.Length ; end--)
        { 
            if (end <= start) //no matches
            {
                start++;
                end = word.Length + 1;
                continue;
            }
            string subword = word.Substring(start, end - start);
            if (dict.TryGetValue(subword, out string value))
            {
                output += value + " ";
                start = end;
                end = word.Length + 1;
            }
        }
        return output;
    }
   
    int[] GetTokens(string ptext)
    {
        string[] p = ptext.Split();
        var tokens = new int[p.Length];
        for (int i = 0; i < tokens.Length; i++)
        {
            tokens[i] = Mathf.Max(0, System.Array.IndexOf(phonemes, p[i])); 
        }
        return tokens;
    }

    public void DoInference(string ptext)
    {      
        int[] tokens = GetTokens(ptext);

        using var input = new TensorInt(new TensorShape(tokens.Length), tokens);
        var result = engine.Execute(input);

        var output = result.PeekOutput("wav") as TensorFloat;
        output.CompleteOperationsAndDownload();
        var samples = output.ToReadOnlyArray();

        Debug.Log($"Audio size = {samples.Length / samplerate} seconds");

        clip = AudioClip.Create("voice audio", samples.Length, 1, samplerate, false);
        clip.SetData(samples, 0);

        Speak();
    }
    private void Speak()
    {
        AudioSource audioSource = GetComponent<AudioSource>();
        if (audioSource != null)
        {
            audioSource.clip = clip;
            audioSource.Play();
        }
        else
        {
            Debug.Log("There is no audio source");
        }
    }

    void Update()
    {
        if (Input.GetKeyDown(KeyCode.Space))
        {
            // TextToSpeech();
        }
    }

    private void OnDestroy()
    {
        engine?.Dispose();
    }
}

UI作成

Canvasコンポーネントを作成して、以下の三つのコンポーネントを中に作成します。


Text(TMP)
生成AIの回答結果を表示


Button
レコーディング開始に使用


Button
レコーディング停止に使用

UIとロジックを連携

MainCameraに以下の四つをコンポーネント追加します。
1. RunJets

2. RunPhi15
「RunJets」には↑で追加したRunJetsコンポーネントを設定
「TextMeshPro」にはCanvas内のText(TMP)コンポーネントを設定

3. WhisperModel
「Speaker Language」はEnglishのままにします。
音声合成が日本語に対応してないので...

4. WhisperPresenter
「WhisperModel」には↑で追加したWhisperModelコンポーネントを設定
「RunPhi15」にはRunPhi15コンポーネントを設定


Buttonの「OnClick」にはMainCamera内の「WhisperPresenter.startRecording」関数を設定


もう一つのButtonの「OnClick」にはMainCamera内の「WhisperPresenter.stopRecording」関数を設定

検証

ネットワークをオフにした状態で動作させることができました!
生成AIと音声合成はまだ実用には難しそうですが、音声テキスト変換(Whisper-Tiny)は精度・速度ともにかなりいい感じでした。

ちょっと気になるのが、HuggingFaceに公開されているSentisで使えるモデルがしばらく更新されてないことです...
ほとんどのモデルが最新のバージョン対応してないので、まだあまりメジャーになってないのかもです。

https://youtu.be/Ef_FlA9ROHU

ヘッドウォータース

Discussion