OpenAI Whisper モデルを NxD Inference で動かす
対象読者: AWS Trainium/Inferentia で音声認識 (STT) をサービングしたい方
前提知識: AWS Neuron 関連の基礎知識、Transformers ライブラリの基本的な使い方
📚 音声認識 (Speech Recognition) について学ぶ
音声認識の基礎から実装まで体系的に学びたい方は、Hugging Face が提供する無料の学習コース Audio Course - Chapter 5: Automatic Speech Recognition をおすすめします。Whisper を含む最新の音声認識技術を実践的に学べます。私も音声系は無知を極めているため学び始めました。
はじめに
NxD Inference とは
NxD Inference (NeuronX Distributed Inference) は、AWS が開発した公式の分散推論ライブラリです。AWS Trainium が搭載されたインスタンスなどで、大規模な Transformer モデルを効率的に実行するために設計されています。
なぜ Whisper で NxD Inference を使うのか
OpenAI の Whisper は、音声認識 (STT: Speech-to-Text) において高い精度を誇る Encoder-Decoder Transformer モデルです。NxD Inference を使うことで手動でのコンパイル、Tensor Parallelism、KV-Cache、等の対応を実装することなく大規模な分散推論の枠組みに乗っかることができます。なのでモデルが対応されているなら積極的に乗っかっておきたいです。どうやら NxD Inference 0.7.0 から Whisper モデルが利用可能のようなので実際に動かしてみます。(対応されてないモデルは自分で既存実装を参考にコントリビューションしてみたいです)
環境要件
| コンポーネント | バージョン | 備考 |
|---|---|---|
| Neuron SDK | 2.27+ | 2.28+ 推奨 |
| neuronxcc | 2.22+ | |
| NxD Inference | 0.7.0 | 公式リリース最新: v0.6.10598 |
| PyTorch | 2.5+ | 2.8.0+ 推奨 |
| transformers | 4.40+ | Whisper モデルサポート |
| openai-whisper | 20250625 | 音声処理ユーティリティ |
アーキテクチャ概要
NeuronApplicationWhisper の構造
NxD Inference の Whisper 実装は、OpenAI の公式 Whisper モデルを継承しつつ、Neuron 特有の最適化を加えています。
参考実装: whisper_nxd_model.py
NeuronApplicationWhisper これが NxD Inference 向けに実装された Whisper モデル用のクラスです。
class NeuronApplicationWhisper(Whisper):
"""
AWS Neuron 向けに最適化された Whisper モデル。
"""
def __init__(self, model_path, config, *args, **kwargs):
super().__init__(config.dims)
# Encoder と Decoder を分離
self.encoder = NeuronApplicationWhisperEncoder(...)
self.decoder = NeuronApplicationWhisperDecoder(...)
設定オブジェクトの作成
NxD Inference 特有の処理: NeuronConfig の作成
from neuronx_distributed_inference.models.config import NeuronConfig
# AWS Neuron 向けの設定
neuron_config = NeuronConfig(
batch_size=1,
torch_dtype=torch.float16,
tp_degree=8,
)
NxD 特有の処理: WhisperInferenceConfig の作成
from neuronx_distributed_inference.models.whisper.modeling_whisper import WhisperInferenceConfig
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config
# NeuronConfig と Whisper モデル設定を統合
inference_config = WhisperInferenceConfig(
neuron_config,
load_config=load_pretrained_config("openai/whisper-large-v3"),
)
WhisperInferenceConfig の役割は、上述した NeuronConfig と Whisper の model_config を統合することです。
モデルの初期化とコンパイル
NxD 特有の処理: NeuronApplicationWhisper の初期化
from neuronx_distributed_inference.models.whisper.modeling_whisper import NeuronApplicationWhisper
MODEL_PATH = "openai/whisper-large-v3"
COMPILED_PATH = "./whisper_large_v3_compiled"
# モデルのインスタンス化
neuron_model = NeuronApplicationWhisper(
MODEL_PATH,
config=inference_config
)
クラス側の実装
このコンストラクタは Whisper.__init__(config.dims) で標準アーキテクチャを構築後、Encoder/Decoder を Neuron 用に置き換えます。
NxD 特有の処理: コンパイル
AWS Neuron では現時点では事前にコンパイルが必要なため compile() メソッドが用意されており、compiled_model_path にコンパイル結果を保存し、load() 時にはすでにコンパイル済みモデルがあればそのキャッシュを利用する。設定によっては変更すると再度コンパイルが必要なものもあります。
コンパイル済みモデルのロード
NxD 特有の処理: モデルのロード
load() メソッドはコンパイル済みの NEFF ファイルを NeuronCore にロードします。
音声の前処理 (Neuron に依存しない一般的な処理)
import librosa
import numpy as np
import torch
# 音声ファイルを 16kHz モノラルで読み込み
audio_path = "audio-sample.mp3"
audio_data, sr = librosa.load(audio_path, sr=16000, mono=True)
# メルスペクトログラムへ変換
from whisper.audio import log_mel_spectrogram, pad_or_trim
mel = log_mel_spectrogram(audio_data) # (80, n_frames)
mel = pad_or_trim(mel, 3000) # (80, 3000) に固定
mel = torch.from_numpy(mel).unsqueeze(0) # (1, 80, 3000)
これは NxD に依存しない 標準的な音声処理で、OpenAI の whisper パッケージに含まれるユーティリティを使用しています。ちなみにそんなに音声技術わかってないので学び中です。
sr=16000 で Whisper 必須の 16kHz サンプリングレートを指定し、mono=True でステレオをモノラルに変換します。pad_or_trim(mel, 3000) は 30 秒 (16000 Hz × 30 = 480000 samples → 3000 frames) に固定する処理で、短い音声は 0 パディング、長い音声は切り捨て(チャンク分割が必要)となります。
推論の実行
NxD 特有の処理: transcribe メソッド
上記で解説した NeuronApplicationWhisper を利用したサンプルが存在しており、それを元に動作確認を実施しました。
再現手順
このセクションでは、inf2.xlarge インスタンス上で Kotoba Whisper v2.2 (日本語特化モデル) を NxD Inference で動作させる手順を示します。全ての手順は heredoc 形式で記述されているため、コピー&ペーストで実行できます。
この記事を見れば Inf2 や Trn2 のインスタンスを立ち上げて利用する検証環境を簡単に構築できます。
推奨環境
- インスタンスタイプ: inf2.xlarge、trn1.2xlarge、trn2.3xlarge
- Neuron SDK: 2.28+
- neuronxcc: 2.22+
- NxD Inference: 0.7.0+ (開発版)
- モデル: kotoba-tech/kotoba-whisper-v2.2 (1550M パラメータ、日本語特化)
Step 0: 事前準備
source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/pip install gTTS scipy soundfile
mkdir -p ~/whisper-kotoba-test && cd ~/whisper-kotoba-test
Step 1: 環境確認
Neuron SDK と NxD Inference が正しくインストールされているか確認します。
python3 << 'EOF'
import sys
print("=" * 80)
print("環境確認")
print("=" * 80)
# PyTorch & Neuron
import torch
import torch_neuronx
print(f"✓ PyTorch: {torch.__version__}")
print(f"✓ torch_neuronx: {torch_neuronx.__version__}")
# NxD Inference
sys.path.insert(0, '/tmp/neuronx-distributed-inference/src')
import neuronx_distributed_inference as nxd
print(f"✓ NxD Inference: {nxd.__version__}")
# Whisper モジュール
from neuronx_distributed_inference.models.whisper.modeling_whisper import (
WhisperInferenceConfig,
NeuronApplicationWhisper,
)
print(f"✓ Whisper module: Available")
print("=" * 80)
print("環境確認完了")
print("=" * 80)
EOF
出力結果: details を展開してください
出力結果
================================================================================
環境確認
================================================================================
✓ PyTorch: 2.9.0+cu128
✓ torch_neuronx: 2.9.0.2.11.19912+e48cd891
✓ NxD Inference: 0.7.0
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/parallel_layers/layers.py:16: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
from .mappings import (
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/parallel_layers/layers.py:16: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
from .mappings import (
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/parallel_layers/layers.py:16: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
from .mappings import (
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/modules/moe/blockwise.py:74: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
component, error = import_nki(config)
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/modules/moe/blockwise.py:74: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
component, error = import_nki(config)
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/modules/moe/blockwise.py:74: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
component, error = import_nki(config)
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/modules/moe/blockwise.py:74: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
component, error = import_nki(config)
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronxcc/nki/_pre_prod_kernels/bwmm_mxfp4.py:564: SyntaxWarning: assertion is always true, perhaps remove parentheses?
assert(token_indices_2D.shape==(128, 1), f"Expect token_indices_2D to have shape (128, 1), got {token_indices_2D.shape}")
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/modules/moe/blockwise.py:74: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
component, error = import_nki(config)
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/modules/moe/blockwise.py:74: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
component, error = import_nki(config)
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/modules/moe/blockwise.py:76: UserWarning: Warning: Failed to import blockwise_mm_baseline_shard_n_k1_while_2loops: No module named 'neuronxcc.nki._private.blockwise_matmul_while'
warnings.warn(f"Warning: {error}")
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/modules/moe/moe_fused_tkg.py:49: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
component, error = import_nki(config)
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/modules/moe/moe_fused_tkg.py:49: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
component, error = import_nki(config)
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/modules/moe/moe_fused_tkg.py:49: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
component, error = import_nki(config)
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed_inference/modules/attention/utils.py:13: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
from neuronx_distributed_inference.modules.custom_calls import neuron_cumsum
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed_inference/modules/lora_serving/lora_checkpoint.py:9: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
from neuronx_distributed_inference.modules.attention.gqa import replicate_kv
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed_inference/modules/lora_serving/lora_checkpoint.py:9: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
from neuronx_distributed_inference.modules.attention.gqa import replicate_kv
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed_inference/modules/lora_serving/lora_checkpoint.py:9: DeprecationWarning: torch_neuronx.nki_jit is deprecated, use nki.jit instead.
from neuronx_distributed_inference.modules.attention.gqa import replicate_kv
✓ Whisper module: Available
================================================================================
環境確認完了
================================================================================
Step 2: モデルダウンロード
Kotoba Whisper v2.2 モデル(1550M パラメータ、日本語特化)をダウンロードします。
mkdir -p models
python3 << 'EOF'
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
print("=" * 80)
print("kotoba-tech/kotoba-whisper-v2.2 ダウンロード")
print("=" * 80)
model_id = "kotoba-tech/kotoba-whisper-v2.2"
save_dir = "models/kotoba-whisper-v2.2"
print(f"モデル: {model_id}")
print(f"保存先: {save_dir}\n")
print("モデルをダウンロード中...")
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, low_cpu_mem_usage=True)
model.save_pretrained(save_dir)
print("✓ モデル保存完了")
print("プロセッサをダウンロード中...")
processor = AutoProcessor.from_pretrained(model_id)
processor.save_pretrained(save_dir)
print("✓ プロセッサ保存完了")
print(f"\nダウンロード完了")
print(f" モデル: {model_id}")
print(f" 機能: 日本語特化音声認識")
print("=" * 80)
EOF
出力結果: details を展開してください
出力結果
================================================================================
kotoba-tech/kotoba-whisper-v2.2 ダウンロード
================================================================================
モデル: kotoba-tech/kotoba-whisper-v2.2
保存先: models/kotoba-whisper-v2.2
モデルをダウンロード中...
✓ モデル保存完了
プロセッサをダウンロード中...
✓ プロセッサ保存完了
ダウンロード完了
モデル: kotoba-tech/kotoba-whisper-v2.2
機能: 日本語特化音声認識
================================================================================
Step 3: サンプル実行スクリプト作成
cat > whisper_nxd.py << 'PYTHON_EOF'
"""Whisper with NxD Inference"""
import torch
import soundfile as sf
import numpy as np
import scipy.signal
from pathlib import Path
from transformers import AutoProcessor
from neuronx_distributed_inference.models.config import NeuronConfig
from neuronx_distributed_inference.models.whisper.modeling_whisper import (
WhisperInferenceConfig,
NeuronApplicationWhisper,
)
from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config
class WhisperNxD:
def __init__(self, model_path, compiled_path, tp_degree=2, language="ja"):
self.model_path = Path(model_path)
self.compiled_path = Path(compiled_path)
self.tp_degree = tp_degree
self.language = language
self.model = None
self.processor = None
def compile(self):
if self.compiled_path.exists():
print("Model already compiled")
return
print(f"Compiling Kotoba Whisper v2.2 (TP={self.tp_degree})...")
config = WhisperInferenceConfig(
NeuronConfig(batch_size=1, torch_dtype=torch.float16, tp_degree=self.tp_degree),
load_config=load_pretrained_config(str(self.model_path)),
)
self.model = NeuronApplicationWhisper(str(self.model_path), config=config)
self.compiled_path.mkdir(parents=True, exist_ok=True)
self.model.compile(str(self.compiled_path))
print(f"Compilation complete")
def load(self):
print("Loading model...")
self.processor = AutoProcessor.from_pretrained(str(self.model_path))
config = WhisperInferenceConfig(
NeuronConfig(batch_size=1, torch_dtype=torch.float16, tp_degree=self.tp_degree),
load_config=load_pretrained_config(str(self.model_path)),
)
self.model = NeuronApplicationWhisper(str(self.compiled_path), config=config)
self.model.load(str(self.compiled_path))
print("Model loaded")
def transcribe(self, audio_path):
# Load audio file with soundfile (avoids FP16 warning)
audio_data, sr = sf.read(audio_path)
# Resample to 16kHz if needed (using scipy.signal)
if sr != 16000:
audio_data = scipy.signal.resample_poly(
audio_data, 16000, sr
).astype(np.float32)
# Convert to mono if stereo
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1)
# Ensure float32 dtype (required by NxD transcribe)
audio_data = audio_data.astype(np.float32)
audio_duration = len(audio_data) / 16000
# Pass numpy array to transcribe (avoids openai-whisper audio processing)
result = self.model.transcribe(
audio_data,
language=self.language,
verbose=False
)
return {'text': result['text'], 'duration': audio_duration}
PYTHON_EOF
echo "whisper_nxd.py 作成完了"
出力結果: echo が出るだけ
Step 4: モデルコンパイル
python3 << 'EOF'
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path.cwd()))
from whisper_nxd import WhisperNxD
print("="*80)
print("Step 4: Kotoba Whisper v2.2 コンパイル")
print("="*80)
model = WhisperNxD(
model_path="models/kotoba-whisper-v2.2",
compiled_path="models/kotoba-whisper-v2.2-compiled-tp2",
tp_degree=2,
language="ja" # 日本語
)
start = time.time()
model.compile()
print(f"コンパイル時間: {time.time()-start: .1f}秒")
EOF
出力結果(初回): details を展開してください
出力結果
================================================================================
Step 4: Kotoba Whisper v2.2 コンパイル
================================================================================
Compiling Kotoba Whisper v2.2 (TP=2)...
INFO:Neuron:Saving the neuron_config to models/kotoba-whisper-v2.2-compiled-tp2/encoder/
INFO:Neuron:Generating HLOs for the following models: ['Encoder']
[2026-02-10 14:54:15.635: I neuronx_distributed/parallel_layers/parallel_state.py:630] > initializing tensor model parallel with size 2
[2026-02-10 14:54:15.635: I neuronx_distributed/parallel_layers/parallel_state.py:631] > initializing pipeline model parallel with size 1
[2026-02-10 14:54:15.635: I neuronx_distributed/parallel_layers/parallel_state.py:632] > initializing context model parallel with size 1
[2026-02-10 14:54:15.635: I neuronx_distributed/parallel_layers/parallel_state.py:633] > initializing data parallel with size 1
[2026-02-10 14:54:15.635: I neuronx_distributed/parallel_layers/parallel_state.py:634] > initializing world size to 2
[2026-02-10 14:54:15.635: I neuronx_distributed/parallel_layers/parallel_state.py:379] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=<PG_Group_Logic.LOGIC1: (<function ascending_ring_PG_group at 0x7f64cde03240>, 'Ascending Ring PG Group')>
[2026-02-10 14:54:15.636: I neuronx_distributed/parallel_layers/parallel_state.py:658] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1]]
[2026-02-10 14:54:15.636: I neuronx_distributed/parallel_layers/parallel_state.py:659] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1]]
[2026-02-10 14:54:15.636: I neuronx_distributed/parallel_layers/parallel_state.py:660] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1]]
[2026-02-10 14:54:15.636: I neuronx_distributed/parallel_layers/parallel_state.py:661] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1]]
[2026-02-10 14:54:15.636: I neuronx_distributed/parallel_layers/parallel_state.py:662] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1]]
[2026-02-10 14:54:15.636: I neuronx_distributed/parallel_layers/parallel_state.py:663] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1]]
INFO:Neuron:Generating 1 hlos for key: Encoder
INFO:Neuron:Minimal metadata will be added to HLO
INFO:Neuron:Started loading module Encoder
INFO:Neuron:Finished loading module Encoder in 0.08839869499206543 seconds
INFO:Neuron:generating HLO: Encoder, input example shape = torch.Size([1, 128, 3000])
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/parallel_layers/layers.py:532: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=False):
INFO:Neuron:Finished generating HLO for Encoder in 0.9561893939971924 seconds, input example shape = torch.Size([1, 128, 3000])
INFO:Neuron:Generated all HLOs in 1.0960869789123535 seconds
INFO:Neuron:Can't find a priority model, skip marking weights
INFO:Neuron:Can't find a priority model, skip optimizing weight layout for other HLOs
INFO:Neuron:Starting compilation for all HLOs
INFO:Neuron:Neuron compiler flags: --model-type=transformer --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2' --internal-hlo2tensorizer-options='--verify-hlo=true' --auto-cast=none -O1 --verbose=35 --logfile=/tmp/nxd_model/Encoder/_tp0_bk0/log-neuron-cc.txt
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/libneuronxla/neuron_cc_wrapper.py:246: SyntaxWarning: str format compiler_flags is discouraged as its handling involves repeated joining and splitting, which can easily make mistakes if something is quoted or escaped. Use list[str] instead. Refer to documentation of the Python subprocess module for details.
warnings.warn(SyntaxWarning(
.Roundtrip constructed a transpose sequence [rounds: 1; efficiency: 77]:
dve_j_optimized : Fix prefix (0, 1) and permute (2,) with (3, 4) / latency=13,460; shape=(10, 128, 128, 1, 3); dtype_size=2
Roundtrip constructed a transpose sequence [rounds: 1; efficiency: 60]:
dve_j_optimized : Fix prefix (0, 1, 2) and permute (3,) with (4, 5) / latency=134,603; shape=(10, 128, 10, 128, 1, 3); dtype_size=2
Completed run_backend_driver.
Compiler status PASS
2026-02-10 14:54:27.000992: 149811 [INFO]: Compilation Successfully Completed for model.MODULE_0ba55c089b6a22a674e2+c9c9fa2d.hlo_module.pb
INFO:Neuron:Finished Compilation for all HLOs in 11.266722202301025 seconds
INFO:Neuron:Can't find a priority model, falling back to the existing weight layout
INFO:Neuron:Finished building model in 12.598812341690063 seconds
INFO:Neuron:SKIPPING pre-sharding the checkpoints. The checkpoints will be sharded during load time.
INFO:Neuron:Saving the neuron_config to models/kotoba-whisper-v2.2-compiled-tp2/decoder/
INFO:Neuron:Generating HLOs for the following models: ['DecoderPrefill', 'DecoderDecode']
[2026-02-10 14:54:28.324: I neuronx_distributed/parallel_layers/parallel_state.py:630] > initializing tensor model parallel with size 2
[2026-02-10 14:54:28.324: I neuronx_distributed/parallel_layers/parallel_state.py:631] > initializing pipeline model parallel with size 1
[2026-02-10 14:54:28.324: I neuronx_distributed/parallel_layers/parallel_state.py:632] > initializing context model parallel with size 1
[2026-02-10 14:54:28.324: I neuronx_distributed/parallel_layers/parallel_state.py:633] > initializing data parallel with size 1
[2026-02-10 14:54:28.324: I neuronx_distributed/parallel_layers/parallel_state.py:634] > initializing world size to 2
[2026-02-10 14:54:28.324: I neuronx_distributed/parallel_layers/parallel_state.py:379] [rank_0_pp-1_tp-1_dp-1_cp-1] Chosen Logic for replica groups ret_logic=<PG_Group_Logic.LOGIC1: (<function ascending_ring_PG_group at 0x7f64cde03240>, 'Ascending Ring PG Group')>
[2026-02-10 14:54:28.324: I neuronx_distributed/parallel_layers/parallel_state.py:658] [rank_0_pp-1_tp-1_dp-1_cp-1] tp_groups: replica_groups.tp_groups=[[0, 1]]
[2026-02-10 14:54:28.325: I neuronx_distributed/parallel_layers/parallel_state.py:659] [rank_0_pp-1_tp-1_dp-1_cp-1] dp_groups: replica_groups.dp_groups=[[0], [1]]
[2026-02-10 14:54:28.325: I neuronx_distributed/parallel_layers/parallel_state.py:660] [rank_0_pp-1_tp-1_dp-1_cp-1] pp_groups: replica_groups.pp_groups=[[0], [1]]
[2026-02-10 14:54:28.325: I neuronx_distributed/parallel_layers/parallel_state.py:661] [rank_0_pp-1_tp-1_dp-1_cp-1] cp_groups: replica_groups.cp_groups=[[0], [1]]
[2026-02-10 14:54:28.325: I neuronx_distributed/parallel_layers/parallel_state.py:662] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_model_groups: replica_groups.ep_model_groups=[[0], [1]]
[2026-02-10 14:54:28.325: I neuronx_distributed/parallel_layers/parallel_state.py:663] [rank_0_pp-1_tp-1_dp-1_cp-1] ep_data_groups: replica_groups.ep_data_groups=[[0], [1]]
INFO:Neuron:Generating 1 hlos for key: DecoderPrefill
INFO:Neuron:Minimal metadata will be added to HLO
INFO:Neuron:Started loading module DecoderPrefill
INFO:Neuron:Finished loading module DecoderPrefill in 0.3752787113189697 seconds
INFO:Neuron:generating HLO: DecoderPrefill, input example shape = torch.Size([1, 448])
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/neuronx_distributed/parallel_layers/layers.py:532: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with torch.cuda.amp.autocast(enabled=False):
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:470: UserWarning: Received an input tensor that was unused or used in a non-static way when traced so the tensor will be ignored. (index=2, shape=torch.Size([1]), dtype=torch.int32). The non-static usage could happen when the traced function expects the input tensor's shape to change (i.e., using the shape to do index slicing), which is not allowed by inference trace expecting static input shapes.
warnings.warn(
INFO:Neuron:Finished generating HLO for DecoderPrefill in 0.12747669219970703 seconds, input example shape = torch.Size([1, 448])
INFO:Neuron:Generating 1 hlos for key: DecoderDecode
INFO:Neuron:Minimal metadata will be added to HLO
INFO:Neuron:Started loading module DecoderDecode
INFO:Neuron:Finished loading module DecoderDecode in 0.3645610809326172 seconds
INFO:Neuron:generating HLO: DecoderDecode, input example shape = torch.Size([1, 1])
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/torch_neuronx/xla_impl/hlo_conversion.py:470: UserWarning: Received an input tensor that was unused or used in a non-static way when traced so the tensor will be ignored. (index=1, shape=torch.Size([1, 1500, 1280]), dtype=torch.float16). The non-static usage could happen when the traced function expects the input tensor's shape to change (i.e., using the shape to do index slicing), which is not allowed by inference trace expecting static input shapes.
warnings.warn(
INFO:Neuron:Finished generating HLO for DecoderDecode in 0.10927104949951172 seconds, input example shape = torch.Size([1, 1])
INFO:Neuron:Generated all HLOs in 1.0258264541625977 seconds
INFO:Neuron:Can't find a priority model, skip marking weights
INFO:Neuron:Can't find a priority model, skip optimizing weight layout for other HLOs
INFO:Neuron:Starting compilation for all HLOs
INFO:Neuron:Neuron compiler flags: --model-type=transformer --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2' --internal-hlo2tensorizer-options='--verify-hlo=true' --auto-cast=none -O1 --verbose=35 --logfile=/tmp/nxd_model/DecoderPrefill/_tp0_bk0/log-neuron-cc.txt
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/libneuronxla/neuron_cc_wrapper.py:246: SyntaxWarning: str format compiler_flags is discouraged as its handling involves repeated joining and splitting, which can easily make mistakes if something is quoted or escaped. Use list[str] instead. Refer to documentation of the Python subprocess module for details.
warnings.warn(SyntaxWarning(
INFO:Neuron:Neuron compiler flags: --model-type=transformer --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2' --internal-hlo2tensorizer-options='--verify-hlo=true' --auto-cast=none -O1 --verbose=35 --logfile=/tmp/nxd_model/DecoderDecode/_tp0_bk0/log-neuron-cc.txt
..Completed run_backend_driver.
Compiler status PASS
2026-02-10 14:54:41.000741: 149811 [INFO]: Compilation Successfully Completed for model.MODULE_769453b82c87e72264dc+7ef8efc5.hlo_module.pb
Completed run_backend_driver.
Compiler status PASS
2026-02-10 14:54:42.000215: 149811 [INFO]: Compilation Successfully Completed for model.MODULE_22b66b0fb7daa9d253e2+f69ccc06.hlo_module.pb
INFO:Neuron:Finished Compilation for all HLOs in 12.869001388549805 seconds
INFO:Neuron:Can't find a priority model, falling back to the existing weight layout
INFO:Neuron:Finished building model in 15.251092433929443 seconds
INFO:Neuron:SKIPPING pre-sharding the checkpoints. The checkpoints will be sharded during load time.
Compilation complete
コンパイル時間: 32.1秒
Step 5: テスト音声生成
日本語のテスト音声を生成します。
mkdir -p test_audio
python3 << 'EOF'
from gtts import gTTS
japanese_text = "こんにちは。これは音声認識のテストです。今日は良い天気ですね。"
tts = gTTS(text=japanese_text, lang='ja')
tts.save('test_audio/japanese.mp3')
print(f"日本語音声生成: {japanese_text}")
EOF
出力結果
日本語音声生成: こんにちは。これは音声認識のテストです。今日は良い天気ですね。
Step 6: 推論実行
python3 << 'EOF'
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path.cwd()))
from whisper_nxd import WhisperNxD
print("="*80)
print("Step 6: 日本語音声認識")
print("="*80)
model = WhisperNxD(
model_path="models/kotoba-whisper-v2.2",
compiled_path="models/kotoba-whisper-v2.2-compiled-tp2",
tp_degree=2,
language="ja"
)
model.load()
print("音声認識実行中(日本語)...")
start = time.time()
result = model.transcribe("test_audio/japanese.mp3")
elapsed = time.time() - start
print(f"\n推論完了")
print(f" 処理時間: {elapsed: .3f}秒")
print(f" 音声長: {result['duration']: .2f}秒")
print(f" RTF: {elapsed/result['duration']: .3f}x")
print(f"\n認識結果: {result['text']}")
print("="*80)
EOF
出力結果
================================================================================
Step 6: 日本語音声認識
================================================================================
...
Model loaded
音声認識実行中(日本語)...
/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/lib/python3.12/site-packages/whisper/transcribe.py:132: UserWarning: FP16 is not supported on CPU; using FP32 instead
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 662/662 [00:00<00:00, 1134.37frames/s]
推論完了
処理時間: 0.669秒
音声長: 6.62秒
RTF: 0.101x
認識結果: こんにちはこれは音声認識のテストです今日は良い天気ですね
================================================================================
いい感じでなんと 0.6 秒で音声生成が完了していますね!爆速です。ワーニングが出ているのはいずれ直します。
まとめ
OpenAI の Whisper を AWS のカスタムチップ搭載のインスタンス上で NxD Inference を用いて推論処理をやってみました。価格もお安めなので技術好きな方はゴリゴリにチューニングして GPU よりコストも性能も良く推論サービングすることを目指してはどうでしょうか。にゃーん。
Discussion