Chta VectorでLlama-3.2-VisionにLlama-3.1をマージして日本語対応させる
要約
- ChatVectorを用いて、Llama-3.1-Swallow-8B-v0.1の日本語能力をLlama-3.2-11B-Visionに付加した
今回作成したモデルはこちら
目的
meta-llama/Llama-3.2-11B-Vision-Instructのモデルカードに
Llama 3.2-Vision is built on top of Llama 3.1 text-only model, which is an auto-regressive language model that uses an optimized transformer architecture.
Llama 3.2-Visionは、Llama 3.1のテキスト専用モデルの上に構築されており、最適化された変換器アーキテクチャを使用する自動回帰型言語モデルです。
と記載されているように、Llama-3.2-VisionはLlama-3.1をベースに作られているようです。
そこで、以前拝見したこちらの記事を思い出しました。
Llama-3.2-VisionとLlama-3.1がLlavaとLlamaの関係性に近いのであれば、同様のことができるのではないかと考えました。
実施内容
使用するモデル
VLM: meta-llama/Llama-3.2-11B-Vision-Instruct
英語モデル: meta-llama/Llama-3.1-8B
日本語モデル: tokyotech-llm/Llama-3.1-Swallow-8B-v0.1
Llama-3.2-Visionについて
Llama-3.2-Visionでは、self attention層の間にcross attention層というものが導入されています。
11Bモデルはvision_model側にself_attnが32層、
cross_attnが8層の計40層で構成されています。
8Bモデルは32層のモデルなので、おそらくはこのself_attnがLlama-3.1-8Bなのではないかと考えました。
実際に2つのモデルで対応する層の重みを見てみるの、かなり近しい値になっているのが分かりました。
マージ
似たようなことを考えている人がいるだろうと同様の事例がないか探していると、次のモデルを見つけました。
こちらのモデルはLlama-3.2-90BにHermesの重みを加算することで得られたモデルのようで、作成に使ったコードを公開してくださっていました。
今回はこちらのコードを少しだけ手を加えてマージを行いました。
主な変更点
- ロードするモデルを3つにする
multimodal_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"
text_ja_model_path = "tokyotech-llm/Llama-3-Swallow-8B-v0.1"
text_en_model_path = "meta-llama/Meta-Llama-3-8B"
save_path = "models/merged_model"
- モデルの重みが加算されている部分を、chat vectorの加算に変更
state_dict_multimodal[key_multimodal][:state_dict_text_ja[key_text].shape[0], :].copy_(state_dict_multimodal[key_multimodal][:state_dict_text_ja[key_text].shape[0], :] + state_dict_text_ja[key_text] - state_dict_text_en[key_text])
結果
完成したのがこちらのモデルになります。
日本にまつわる画像でこのモデルの出力を確認してみます。
プロンプト: 作品名とシーンの説明をしてください。
モデル | 出力 |
---|---|
meta-llama/Llama-3.2-11B-Vision-Instruct | この作品は、漫画家の宮崎吾朗による日本の漫画作品です。 この作品は、世界を荒廃させた巨大な生物を巡る戦いを描いた、SF漫画です。 この作品のあらすじは以下の通りです。 地球が荒廃した世界で、人間は生物の群れに追い詰められていました。生物たちは人間を食料として捕まえ、食べていました。人間は生物を追い詰めるために、生物の群れを倒すために戦いました。人間は生物の群れ |
Kendamarron/Llama-3.2-11B-Vision-Instruct-Swallow-8B-Merge | 『千と千尋の神隠し』は、2001年に公開された宮崎駿監督によるアニメーション映画です。物語は、両親と一緒に引っ越した神隠しの町で、奇妙な体験をする少女千尋の姿を描いたものです。千尋は、町の神様に化け物が住む「油屋」という湯屋を経営する家に引き取られます。千尋は、神様に化け物を洗濯する役割を担いますが、化け物の正体 |
プロンプト: これはなんですか?
モデル | 出力 |
---|---|
meta-llama/Llama-3.2-11B-Vision-Instruct | この画像は、日本の奈良県の奈良市にある大仏殿(だいぶつでん)であると推測されます。画像から大きな仏像が見えるため、仏教の施設であると推測されます。画像の背景には、奈良公園(ならこうえん)と奈良市内を走る国道308号が見えます。国道308号は、奈良市内を走る国道24号の支線で、奈良公園の南東に位置しています。国道308号は、奈良 |
Kendamarron/Llama-3.2-11B-Vision-Instruct-Swallow-8B-Merge | これは奈良県奈良市にある東大寺の鐘楼です。鐘楼は鐘を吊るすために建てられた塔で、東大寺の鐘楼は高さ約50メートル、直径約5メートルあります。鐘楼の鐘は国の重要文化財に指定されています。鐘楼の鐘は、東大寺の僧侶が毎日打ち鳴らすことで、仏教の教えを広める役割を果たしています。鐘楼は東大寺のシンボル的な建物として知 |
鐘楼というのは間違っていそうですが、元のモデルと比べると日本の文化にまつわる質問にも回答できるようになっていますね!
まとめ
今回は、Chat Vectorを用いて日本語理解のできるLlama-3.2-Visionモデルを作ってみました。
Llama-3.1の継続事前学習モデルをLlama-3.2のマルチモーダルモデルに活かせるのは可能性が広がりそうですね
Discussion
pythonコードを共有してもらえないでしょうか?