大きなマルチモーダルLLMを分解して保存する話:Qwen3-Omniを例として
はじめに
こんにちは、株式会社エクサウィザーズWANDの大西です。
エクサウィザーズのAdvent Calendar 3日目です!
皆様はローカル環境でマルチモーダルLLMを活用したいと考えたことはありませんか?私はあります。特に、日本語の音声理解や音声合成がセットになったLLMを使いたいと日々感じています。
9月にAlibabaからQwenシリーズの最新マルチモーダルモデルであるQwen3-Omniが公開されました。発音にはまだ改善の余地がありますが、日本語も喋ることができる魅力的なモデルです。
しかし、Qwen3-Omniはパラメータサイズ30Bという巨大なモデルです。7Bや13Bのような小型モデルもリリースされていないため、個人のPCや小規模なサーバーで動かすのは難しいでしょう。(量子化モデルがあるため低リソースではそちらが選択肢になります)
このような状況だと、以下のように考えることもあるでしょう。
- Qwen3-Omniの画像理解などの一部分は無くても良いので少しでもモデルを軽量化したい
- 音声合成や音声エンコーダ部分だけ切り取って別モデルに転用したい
この記事ではそんな方々のために、Qwen3-Omniを例として既存のモデルの重みの一部分を分解して保存する方法について解説します。
Qwen3-Omniとは?
様々なサイトやブログで紹介されているので詳細は割愛しますが、Qwen3-Omniについて一言で言うと、画像・音声・動画・テキストを理解し音声合成まで可能な大規模言語モデル(LLM) の一つです。
このような動画も扱えるオープンなマルチモーダルLLMはまだ数が少なく、Qwen3-Omniはその中でも非常に優れた性能を持っています。
日本語能力を体感する上では、以下の公式動画や使ってみた方のXポストがわかりやすいです。また、実際にデモサイトを使って、自分で実験するとモデルの能力をより実感できるので是非試してみてください。
参考)Qwen3-Omni関連リンク
- コード:https://github.com/QwenLM/Qwen3-Omni
- HF上の実体コード:https://github.com/huggingface/transformers/tree/v4.57-release/src/transformers/models/qwen3_omni_moe
- 論文:https://arxiv.org/abs/2509.17765
- モデル一覧:https://huggingface.co/collections/Qwen/qwen3-omni
- デモサイト※:https://huggingface.co/spaces/Qwen/Qwen3-Omni-Demo
※デモではOSSとして配布されているQwen3-Omni-30B-A3B-Instructモデル等とは異なり、論文中で言及されている非公開のチューニングバージョンであるQwen3-Omni-Flashというモデルが使われているのでご注意ください(モデル指定がされている該当箇所)。
Qwen3-Omniを構成するモジュール

Qwen3-Omniモデル全体像, 引用: https://arxiv.org/abs/2509.17765
上記の図はテクニカルレポートから引用したQwen3-Omniのモデル全体像です。
Qwen3-Omniは、ThinkerとTalkerという2つの大きなモジュールで構成されています。さらに、それぞれのモジュールの中には複数の小モジュールが存在しています。
細分化した主要モジュールを簡単にまとめると以下のような表の構成になっています。それぞれ役割が異なるため、必要に応じて一部のモジュールだけを切り出して活用する価値があります。
簡単に活用しやすいものとしては、Thinkerモジュール内の音声エンコーダや画像エンコーダです。音声基盤モデルあるいは画像基盤モデルとして分類タスクに転用することが考えられます。
| モジュール名(コード上のクラス名を略記) | 役割 |
|---|---|
| AudioEncoder | 音声入力を処理し、音声特徴量を抽出する |
| VisionEncoder | 画像入力を処理し、画像特徴量を抽出する |
| ThinkerTextModel | テキスト入力を処理し、言語特徴を抽出するとともに音声特徴量・画像特徴量を統合し、返答するテキストを生成する |
| TalkerModel | Thinkerの情報を受け取り音声合成のための特徴を生成する |
| TalkerCodePredictorModel | TalkerModelの情報から音情報の埋め込みである音声コードを予測する |
| Code2Wav | 音声コードから音声を合成する |
事前準備
それでは早速Qwen3-Omniモデルの分解を始めていきましょう。
動作環境として、Qwen3-Omni-30B-A3B-Instructモデルを分解するには、ディスク容量は100GB程度の空きを確保しておいた方が良く、モデル全体をメモリ上に展開して試す場合はRAMも128GB程度は欲しいところです。
RAMが足りない場合でも、後述する方法で分解自体は可能ですが、動作確認などを行う際に少し不便になるので注意してください。
まずは必要なライブラリとモデルをダウンロードします。
※以降uvを用いる想定で記載します。
uv init qwen3_omni_decomposition
cd qwen3_omni_decomposition
uv python pin 3.13
uv add transformers==4.57.3 torch==2.9.1 accelerate==1.12.0 torchvision==0.24.1
uv run hf download Qwen/Qwen3-Omni-30B-A3B-Instruct --local-dir ./Qwen3-Omni-30B-A3B-Instruct
モデルのダウンロードが完了すれば、準備は整います。RAMに余裕のある方は、Pythonをインタラクティブモードで実行して、モデルがロードできるか確認すると良いでしょう。
uv run python
>>> from transformers import Qwen3OmniMoeForConditionalGeneration
>>> MODEL_PATH = "./Qwen3-Omni-30B-A3B-Instruct"
>>> model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
MODEL_PATH, device_map="cpu", low_cpu_mem_usage=True,
)
>>> model
Qwen3OmniMoeForConditionalGeneration(
(thinker): Qwen3OmniMoeThinkerForConditionalGeneration(
(audio_tower): Qwen3OmniMoeAudioEncoder(
...以降省略...
RAMが潤沢な環境向けの方法
インタラクティブモードでモジュールを確認する
モデルが正しくロードできることを確認できたら、モデルの構造を確認していきます。
model.named_children()メソッドを使うと、該当クラスの直下にあるモジュールを一覧で確認できます。先ほどの表示では情報量が多く全体像を把握しづらかったかもしれませんが、このように一つずつ階層を辿っていくことで、モデル全体の構成を把握しやすくなります。
>>> for name, module in model.named_children():
... print(f"Module name: {name}, Module type: {type(module)}")
Module name: thinker, Module type: <class 'transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerForConditionalGeneration'>
Module name: talker, Module type: <class 'transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeTalkerForConditionalGeneration'>
Module name: code2wav, Module type: <class 'transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeCode2Wav'>
thinker, talker, code2wavというモジュールがあることが分かったので、さらにthinkerモジュールの中身をnamed_children()で確認してみると、Audio Encoderや、Vision Encoderがあることが分かります。
>>> for name, module in model.thinker.named_children():
... print(f"Module name: {name}, Module type: {type(module)}")
Module name: audio_tower, Module type: <class 'transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeAudioEncoder'>
Module name: visual, Module type: <class 'transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeVisionEncoder'>
Module name: model, Module type: <class 'transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextModel'>
Module name: lm_head, Module type: <class 'torch.nn.modules.linear.Linear'>
今回はこの中からAudio Encoderモジュールに関連する部分を切り出して保存する例を紹介します。
分解したモジュールの保存①: Transformersのクラスの場合
Qwen3-OmniモデルはHugging Face TransformersライブラリとTorchで実装されており、各モジュールはクラスで定義されています。これらのクラスを利用して、必要なモジュールだけを保存することができます。
特にtransformersの組み込みクラスで定義されている場合は、save_pretrainedを呼ぶだけで簡単に保存できます。
下のコードは音声エンコーダ部分だけを切り出して保存する例です。
>>> model.thinker.audio_tower.save_pretrained("./Qwen3-Omni-AudioEncoder")
これで音声エンコーダ部分だけが保存されます。保存されたフォルダ内には、モデルの重みを保存したmodel.safetensorsファイルと、モデルの構成情報を保存したconfig.jsonファイルが含まれています。
モデルファイルのサイズは約1.3GB程度と、Qwen3-Omni全体の70GB程度と比べると大幅に小さくなっていることが分かります。
保存されたモデルは以下のようにロードして利用できます。
>>> AUDIO_MODEL_PATH = "./Qwen3-Omni-AudioEncoder"
>>> from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeAudioEncoder
>>> audio_encoder = Qwen3OmniMoeAudioEncoder.from_pretrained(AUDIO_MODEL_PATH)
>>> audio_encoder
Qwen3OmniMoeAudioEncoder(
(positional_embedding): SinusoidsPositionEmbedding()
...以降省略...
分解したモジュールの保存②: Torchクラスの場合
Torchのnn.Moduleクラスレベルで分解したい場合はState Dictを利用する必要があります。例としてAudio Encoder内にある一部の畳み込みモジュールだけを切り出して保存したい場合は以下のようにします。
>>> from safetensors.torch import save_file
>>> conv2d1 = model.thinker.audio_tower.conv2d1
>>> save_file(conv2d1.state_dict(), "./Qwen3-Omni-AudioConv2d1.safetensors")
これでconv2d1という畳み込みモジュールだけが保存されます。保存されたモジュールは同じネットワーク構成のクラスを用意することで、ロードすることができます。
>>> from safetensors.torch import load_file
>>> from torch import nn
>>> conv2d1_loaded = nn.Conv2d(1, 480, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
>>> state_dict = load_file("./Qwen3-Omni-AudioConv2d1.safetensors")
>>> conv2d1_loaded.load_state_dict(state_dict)
<All keys matched successfully>
分解したモジュールを活用する
保存したモジュールは、他のモデルに組み込んだり転移学習に利用したりすることができます。例えば、先ほど保存した音声エンコーダを用いて、10クラス音声分類モデルを構築する場合は以下のように活用できます。
もちろん、音声エンコーダ部分だけを微調整したり、全体を微調整したりすることも可能です。
import torch
import torch.nn as nn
from transformers import WhisperFeatureExtractor
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeAudioEncoder
class CustomAudioModel(nn.Module):
def __init__(self, preprocessor, audio_encoder):
super(CustomAudioModel, self).__init__()
self.preprocessor = preprocessor
self.audio_encoder = audio_encoder
self.classifier = nn.Linear(2048, 10) # 10クラス分類
def forward(self, x):
x = self.preprocessor(x, sampling_rate=16000, return_tensors="pt", padding=False)
x = x["input_features"][0].to(self.audio_encoder.dtype)
x_len = torch.tensor([x.shape[-1]], dtype=torch.long)
x = self.audio_encoder(x, x_len)
x = x.last_hidden_state.mean(dim=0)
x = self.classifier(x)
return x
audio_encoder = Qwen3OmniMoeAudioEncoder.from_pretrained("./Qwen3-Omni-AudioEncoder")
preprocessor = WhisperFeatureExtractor.from_pretrained("Qwen/Qwen3-Omni-30B-A3B-Instruct")
custom_model = CustomAudioModel(preprocessor, audio_encoder)
...以降省略(custom_modelを用いた学習コードを書く)...
RAMが潤沢でない環境向けの方法
RAM容量が限られている場合、モデル全体を一度にロードすることは困難です。そのため、コードやインデックスファイルから必要なモジュール構成を把握し、重みファイルから必要な部分だけを抽出してロードする必要があります。
ここでは、その方法について解説します。
モデル構成をコードやJSONから把握する
RAMが少ない状況や、きちんと理解したい場合は、実際にコードを辿って該当箇所を確認し、モデルの構成や全体像を把握するのが最善です。
また、重みに着目した全体像把握の方法として、モデルが格納されているフォルダにあるインデックスファイル(model.safetensors.index.json)内のweight_mapからも、どのようなモジュール構成になっているかを確認することが可能です。
このsafetensorsのインデックスファイルは、モデルの各重みがどのファイルに格納されているかを示すJSON形式のファイルです。weight_mapには、各重みの名前とそれが格納されているsafetensorsファイル名が対応付けられています。
インデックスファイルの一部を以下に示します。
{
"metadata": {
"total_parameters": 35259818545,
"total_size": 70519637090
},
"weight_map": {
"code2wav.code_embedding.weight": "model-00015-of-00015.safetensors",
"code2wav.decoder.0.conv.bias": "model-00015-of-00015.safetensors",
"code2wav.decoder.0.conv.weight": "model-00015-of-00015.safetensors",
"code2wav.decoder.1.block.0.alpha": "model-00015-of-00015.safetensors",
"code2wav.decoder.1.block.0.beta": "model-00015-of-00015.safetensors",
"code2wav.decoder.1.block.1.conv.bias": "model-00015-of-00015.safetensors",
"code2wav.decoder.1.block.1.conv.weight": "model-00015-of-00015.safetensors",
...以降省略...
必要なモジュールの重みのみを直接ロードして保存する
メモリを節約してモデルを読み込む方法として、必要な重みのみをsafetensorsからsafe_openを活用して直接ロードする方法があります。
以下のコードは、音声エンコーダの重みをロードして保存する例です。index内のthinker.audio_tower.で始まるキーを持つ重みのみを見つけファイルから抽出し、重みをまとめて新しいsafetensorsファイルとして保存しています。
この方法は、モデル全体をRAMにロードする必要がないため、RAMが少ない環境でも利用可能です。ただし、重みのみの保存となるため、Configファイル(config.json)は元のモデルから手動でコピーし、必要に応じて編集する必要があります。
import argparse
import os
import json
from safetensors import safe_open
from safetensors.torch import save_file
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeAudioEncoder
def extract_submodule(
index_path: str,
output_path: str,
prefix_to_extract: str = "thinker.audio_tower.",
):
# 1. インデックスファイルを読み込む
base_dir = os.path.dirname(index_path)
with open(index_path, "r") as f:
index_data = json.load(f)
weight_map = index_data.get("weight_map", {})
# 2. 必要なキーとそれが含まれるファイルを特定する
files_to_process = {}
print(f"Searching for keys starting with '{prefix_to_extract}'...")
target_keys_count = 0
for key, filename in weight_map.items():
if key.startswith(prefix_to_extract):
if filename not in files_to_process:
files_to_process[filename] = []
files_to_process[filename].append(key)
target_keys_count += 1
if target_keys_count == 0:
print("対象のキーが見つかりませんでした。")
return
print(f"Found {target_keys_count} keys across {len(files_to_process)} files.")
# 3. 必要なデータだけを抽出して辞書に格納
extracted_state_dict = {}
for filename, keys in files_to_process.items():
file_path = os.path.join(base_dir, filename)
print(f"Processing {filename} ...")
with safe_open(file_path, framework="pt", device="cpu") as f:
for key in keys:
tensor = f.get_tensor(key)
new_key = key[len(prefix_to_extract):]
extracted_state_dict[new_key] = tensor
# 4. 新しい safetensors ファイルとして保存
print(f"Saving to {output_path} ...")
save_file(extracted_state_dict, output_path)
print("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extract submodule weights from a safetensors index file.")
parser.add_argument("--index_path", type=str, default="./Qwen3-Omni-30B-A3B-Instruct/model.safetensors.index.json", help="Path to the safetensors index JSON file.")
parser.add_argument("--output_path", type=str, default="./Qwen3-Omni-AudioEncoder/model.safetensors", help="Path to save the extracted safetensors file.")
parser.add_argument("--prefix_to_extract", type=str, default="thinker.audio_tower.", help="Prefix of keys to extract.")
args = parser.parse_args()
extract_submodule(
index_path=args.index_path,
output_path=args.output_path,
prefix_to_extract=args.prefix_to_extract,
)
おわりに
この記事では、Qwen3-Omniモデルを例に、既存の大規模マルチモーダルLLMから特定のモジュールを抽出して保存する方法について解説しました。これにより、必要な機能だけを取り出して効率的に利用することが可能になります。
ぜひ皆様もQwen3-Omniや他の大規模モデルを分解して、自分のプロジェクトに活用してみてください!
Discussion