🎧

Command-R (35B)をMac(MLX)でローカルで動かしてAPIサーバにする

2024/04/07に公開

Command-R

Cohere Command-RというLLMが話題です。
35Bモデルと104Bモデル(Command-R+)があるようです。
128kのコンテキストウインドウで、ツールやRAGでの使用を想定したモードもあるようです。

実験用であればローカルで使えそうです。

https://huggingface.co/CohereForAI/c4ai-command-r-v01

これをmacOS(AppleSilicon)でローカルで使ってみます。

MLXで使う

MLXというAppleSilicon用のMLフレームワークで使えるモデルがあります。
今回は4bit量子化モデルを使ってみます。

Command-R 4bit
https://huggingface.co/mlx-community/c4ai-command-r-v01-4bit
Command-R 2bit
https://huggingface.co/mlx-community/c4ai-command-r-v01-2bit

Command-R+ 4bit
https://huggingface.co/mlx-community/c4ai-command-r-plus-4bit

使い方はシンプルです。

pip install mlx-lm

次にサンプルを少し修正したsample.pyを作成

import argparse
from mlx_lm import load, generate

def main():
    parser = argparse.ArgumentParser(description='Generate text using Command-R(MLX LM)')
    parser.add_argument('prompt', type=str, help='Input prompt for text generation')
    args = parser.parse_args()

    model, tokenizer = load("mlx-community/c4ai-command-r-v01-4bit")

    prompt = []
    prompt.append({'role': 'user', 'content': args.prompt})

    inputs = tokenizer.apply_chat_template(prompt,
                                           tokenize=False,
                                           add_generation_prompt=True)
    
    response = generate(
        model, 
        tokenizer,
        prompt=inputs,
        verbose=True,
        temp=0.2,
        max_tokens=12800,
    )

if __name__ == "__main__":
    main()

実行してみます。

python sample.py "こんにちは!元気ですか?"

もうこれだけでダウンロードされて動きます。

python sample.py "こんにちは!元気ですか?"
Fetching 13 files: 100%|████████████████████| 13/13 [00:00<00:00, 136999.88it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
==========
Prompt: <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>こんにちは!元気ですか?<|END_OF_TURN_TOKEN|>
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
はい、元気です!おっしゃる通りですね。日本語でお話しできるなんて素敵ですね。今日はどんなことがおきましたか?お仕事や勉強で忙しい日でしたか?
==========
Prompt: 26.384 tokens-per-sec
Generation: 13.637 tokens-per-sec

実用的な速度(M1 Max)ですが、長時間回すととても熱くなります…。

簡単なAPIサーバにしてみる

これだけでは芸がないので、実験用にFlaskでローカルで動くかんたんなAPIサーバにしてみました。
(ChatGPTが7割書いてくれました…)

https://github.com/romot-co/command-r-mlx-simple-api

生成(1回): /generate
チャット: /chat
ツール: /tool
RAG: /rag

これを元にまずは実験用に使用していきたいです。

Discussion