🤗

Swallow v0.1とMLXでローカルチャットbotを動かしてみた

2024/05/06に公開

はじめに

Swallow instruct-v0.1シリーズ(7B, 13B, 70B)をMLXを使ってローカルのMacBook Proで動かして遊んでみました。

今回も、ローカルで動作するチャットbotをクイックに作ってみました。

環境

  • Apple M3 MAX (128GB)
    • 推論中のpythonプロセスのUnified Memory消費量はざっくり最大で以下のとおりでした
      • 70B 4bit : 41GB
      • 70B 8bit : 73GB
  • Python 3.10
    • 前回の記事のとおり、3.9以前では動作しないと思われます

ライブラリ

以下を使っています。現状(2024/05/05)、特にバージョン指定しないpipインストールで問題なく動作します。

mlx_lm

https://github.com/ml-explore/mlx
https://huggingface.co/docs/hub/mlx

MLXはAppleが提供する機械学習(特にDeep Learning)用のフレームワークです。このフレームワーク上でモデルを動作させることで、Unified MemoryとGPUを活用し高速に学習・推論できます。mlx_lmはMLXを用いてhugging faceのLLMを動かしてくれます。

gradio

https://www.gradio.app
クイックにチャットbotを構築できます。

使用モデル

Swallow v0.1シリーズ

https://huggingface.co/tokyotech-llm
https://x.com/chokkanorg/status/1783799621841760734
Llama 2ベースのSwallowモデルについて、指示追従性能を高めたモデルが先月公開されました。

リリースされたばかりということもあり、まだMLX-communityにはモデルが登録されていなかったため、私で一通り量子化を行い、登録しておきましたので、お使いください。
https://huggingface.co/mlx-community/Swallow-7b-instruct-v0.1-4bit
https://huggingface.co/mlx-community/Swallow-7b-instruct-v0.1-8bit
https://huggingface.co/mlx-community/Swallow-13b-instruct-v0.1-4bit
https://huggingface.co/mlx-community/Swallow-13b-instruct-v0.1-8bit
https://huggingface.co/mlx-community/Swallow-70b-instruct-v0.1-4bit
https://huggingface.co/mlx-community/Swallow-70b-instruct-v0.1-8bit

動作チェック

とりあえず、7B-8bitモデルで遊んでみましょう。mlx_lmをインストールした仮想環境をactivateし、以下のコードを実行するととりあえず動きます。(プロンプト生成部分はSwallowのサンプルコードを参照し、要約文はMLXの概要説明を入力しました)

simple_run.py
from mlx_lm import load, generate

model, tokenizer = load("mlx-community/Swallow-7b-instruct-v0.1-8bit")

PROMPT_DICT = {
    "prompt_input": (
        "以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。"
        "リクエストを適切に完了するための回答を記述してください。\n\n"
        "### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### 応答:"

    ),
    "prompt_no_input": (
        "以下に、あるタスクを説明する指示があります。"
        "リクエストを適切に完了するための回答を記述してください。\n\n"
        "### 指示:\n{instruction}\n\n### 応答:"
    ),
}

def create_prompt(instruction, input=None):
    """
    Generates a prompt based on the given instruction and an optional input.
    If input is provided, it uses the 'prompt_input' template from PROMPT_DICT.
    If no input is provided, it uses the 'prompt_no_input' template.

    Args:
        instruction (str): The instruction describing the task.
        input (str, optional): Additional input providing context for the task. Default is None.

    Returns:
        str: The generated prompt.
    """
    if input:
        # Use the 'prompt_input' template when additional input is provided
        return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input)
    else:
        # Use the 'prompt_no_input' template when no additional input is provided
        return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)

# Example usage
instruction = "あなたは日本語を話す優秀なアシスタントです。以下のインプットの文章を日本語の箇条書きで要約してください。"
input = """
MLX is a NumPy-like array framework designed for efficient and flexible machine learning on Apple silicon, brought to you by Apple machine learning research.

The Python API closely follows NumPy with a few exceptions. MLX also has a fully featured C++ API which closely follows the Python API.

The main differences between MLX and NumPy are:

Composable function transformations: MLX has composable function transformations for automatic differentiation, automatic vectorization, and computation graph optimization.
Lazy computation: Computations in MLX are lazy. Arrays are only materialized when needed.
Multi-device: Operations can run on any of the supported devices (CPU, GPU, …)
The design of MLX is inspired by frameworks like PyTorch, Jax, and ArrayFire. A notable difference from these frameworks and MLX is the unified memory model. Arrays in MLX live in shared memory. Operations on MLX arrays can be performed on any of the supported device types without performing data copies. Currently supported device types are the CPU and GPU."""

prompt = create_prompt(instruction, input)

response = generate(
    model, tokenizer, prompt,
    max_tokens=500,
    verbose=True
    )

以下の通り、なぜか英語で回答してきました。

出力
### 応答:
MLX is a NumPy-like array framework designed for efficient and flexible machine learning on Apple silicon. The Python API closely follows NumPy with a few exceptions, and a fully featured C++ API is also available.

The main differences between MLX and NumPy are composable function transformations, lazy computation, multi-device support, and a unified memory model. MLX has composable function transformations for automatic differentiation, automatic vectorization, and computation graph optimization. MLX also has lazy computation, which means arrays are only materialized when needed.

MLX supports multi-device operations, and operations can run on any of the supported devices (CPU, GPU, ...). The design of MLX is inspired by frameworks like PyTorch, Jax, and ArrayFire, but with a unified memory model that allows operations to be performed on any of the supported device types without performing data copies.

Currently, MLX supports the CPU and GPU as supported device types.
==========
Prompt: 679.914 tokens-per-sec
Generation: 41.425 tokens-per-sec

同じプロンプトで、13B-8bitだと以下のとおりで、日本語で出力してくれはいるものの、訳しているだけにみえ、要約してくれてません

出力
### 応答:
MLX は Apple の機械学習研究チームによって開発された NumPy のような配列フレームワークです。MLX は Apple silicon で効率的で柔軟な機械学習を実現するために設計されています。

Python API は NumPy に似ていますが、いくつかの例外があります。MLX には完全に機能が搭載された C++ API もあり、Python API と同じように動作します。

MLX と NumPy の主な違いは次の通りです。

* 関数の可換的な変換: MLX には自動微分、自動ベクトル化、計算グラフ最適化などの可換的な関数変換があります。
* 遅延計算: MLX の計算は遅延されます。配列は必要なときにのみマテリアライズされます。
* マルチデバイス: MLX の操作は CPU、GPU、... などのサポートされているデバイスのいずれでも実行できます。

MLX の設計は PyTorch、Jax、ArrayFire などのフレームワークに影響を受けています。MLX とこれらのフレームワークとの主な違いは共有メモリモデルです。MLX の配列は共有メモリに存在します。MLX 配列の操作はサポートされているデバイスのいずれでも実行でき、データコピーを行う必要はありません。現在サポートされているデバイスタイプは CPU と GPU です。

同じプロンプトで、70B-8bitだと以下のとおりで、要望どおりに妥当に要約して回答してくれているようにみえます

出力
* MLX は NumPy に似た Apple silicon 用の効率的で柔軟な機械学習用配列フレームワークです。
* Python API は NumPy に非常に近いですが、いくつかの例外があります。MLX には完全に機能する C++ API もあり、Python API に非常に近いです。
* MLX と NumPy の主な違いは、自動微分、自動ベクトル化、計算グラフ最適化などのコンポーザブル関数変換、遅延計算、マルチデバイスです。
* MLX のデザインは PyTorch、Jax、ArrayFire などのフレームワークにインスパイアされています。MLX との主な違いは、共有メモリモデルです。MLX 配列は共有メモリに存在し、MLX 配列に対する操作はサポートされているデバイスタイプのいずれかで実行できます。現在サポートされているデバイスタイプは CPU と GPU です。

チャットbotをクイックに

前の記事と同様、クイックにチャットbotを作ってみました。

チャットbotのコードは以下のとおりです。

app.py
import gradio as gr
from mlx_lm import load, generate

# 定数まわり
model_name = "mlx-community/Swallow-70b-instruct-v0.1-8bit"
# Instructionの定義
instruction = "あなたは日本語を話す優秀なアシスタントです。以下のインプットの文章を日本語の箇条書きで要約してください。"


def create_prompt(instruction, input=None):
    """
    Generates a prompt based on the given instruction and an optional input.
    If input is provided, it uses the 'prompt_input' template from PROMPT_DICT.
    If no input is provided, it uses the 'prompt_no_input' template.

    Args:
        instruction (str): The instruction describing the task.
        input (str, optional): Additional input providing context for the task. Default is None.

    Returns:
        str: The generated prompt.
    """

    PROMPT_DICT = {
        "prompt_input": (
            "以下に、あるタスクを説明する指示があり、それに付随する入力が更なる文脈を提供しています。"
            "リクエストを適切に完了するための回答を記述してください。\n\n"
            "### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### 応答:"

        ),
        "prompt_no_input": (
            "以下に、あるタスクを説明する指示があります。"
            "リクエストを適切に完了するための回答を記述してください。\n\n"
            "### 指示:\n{instruction}\n\n### 応答:"
        ),
    }

    if input:
        # Use the 'prompt_input' template when additional input is provided
        return PROMPT_DICT["prompt_input"].format(instruction=instruction, input=input)
    else:
        # Use the 'prompt_no_input' template when no additional input is provided
        return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)

model, tokenizer = load(model_name)

def generate_response(input_text):
    prompt = create_prompt(instruction, input_text)
    response = generate(
        model, tokenizer, prompt,
        max_tokens=512,
        verbose=True
    )
    return response

gr.Interface(fn=generate_response, inputs="text", outputs="text").launch()

上記を実行すると、以下のようにチャットbotが起動し、プロンプトに応答してくれます。

【参考】量子化のやりかた

以下のコマンドで、量子化を行いつつ、communityにモデルをアップロードできます(mlx-communityのトップページにスニペットがあります)

terminal入力
python -m mlx_lm.convert --hf-path tokyotech-llm/Swallow-70b-instruct-v0.1 --q-bits 8 -q --upload-repo mlx-community/Swallow-70b-instruct-v0.1-8bit

Discussion