🧩

【AWS Trainium 50本ノック #1】いきなりTrainiumに触れてみよう編

に公開

第 1 章 いきなり Trainium に触れてみよう編

本章では以下を仮定します。

  • AWS のアカウントを所有している
  • シェルの基本的な操作の理解
  • PyTorch の基本的な理解

問題 (1〜7)

まずは、Trainium が何であるかを理解するために、実際に触ってみましょう。

  1. AWS EC2 から trn1.2xlarge のインスタンスを立ち上げてください。ただし、以下の AMI(マシンイメージ)を選択してください。

    • Deep Learning AMI Neuron (Ubuntu 22.04)

      • AMI の名前に Neuron と入っているものを検索すると、上記 Ubuntu 22.04 版とは別に Amazon Linux 2023 版のものも出てきます。どちらを選ぶかで、以下作業に細かな差異が生じる可能性があります(以下では Ubuntu22.04 版を前提とします)。

    名前の先頭に trn とついているインスタンスには、Trainiumチップが搭載されています。trn1.2xlarge インスタンスは、その中では最も安価なインスタンスで、$1.34/h で借りられます(us-east-2でオンデマンド利用の場合・2025-05-09調べ)。利用可能なリージョンは限られています(2025-05-09時点では日本リージョンにはありません)。

  2. 立ち上げたインスタンスにsshログインしてください。

    • ユーザー名は ubuntu です(ec2-user だと入れません)。
  3. このマシンに確かにTrainiumチップが搭載されていることを確認するために、neuron-top コマンドを実行してください。

    成功するとこんな画面が出てきます。この画面で、Trainiumチップ(を構成するコアたち)の稼働状況をリアルタイムで閲覧できます。このマシンに搭載されている、NC0NC1 という2つのコアの稼働状況が表示されています。q キーを押すと終了できます。

  4. 以下の手順で、Python仮想環境をアクティベートしてください(以下で2通り紹介していますが、どちらでも構いません)。

    Deep Learning AMI Neuron には最新の Neuron ドライバ、仮想環境がプリインストールされているので、フレームワーク、ワークロードに応じた仮想環境を利用可能です。

    • 学習向け環境 (PyTorch 2.7 + NxD Training ライブラリ)

      • 下記コマンドで、仮想環境をアクティベートし、残りのセットアップとして setup_nxdt.sh を実行して下さい。

        source /opt/aws_neuronx_venv_pytorch_2_7_nxd_training/bin/activate
        setup_nxdt.sh
        
    • 推論向け環境 (PyTorch 2.7 + NxD Inference ライブラリ)

      source /opt/aws_neuronx_venv_pytorch_2_7_nxd_inference/bin/activate
      

    Deep Learning AMI Neuron は Neuron SDK の最新版に対応すべく随時更新されています。最新版 Neuron SDKリリース直後は、AMI のリリース日を確認頂き、既に更新済みかどうかを確認して下さい。

    インストール済みのドライバ、ライブラリが最新のものかどうかを確認するには、Neuron ドキュメントの内容 (https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/releasecontent.html) と以下のコマンドで確認できるインストール済みのドライバ、ライブラリのバージョンとを比較します。

    dpkg -l | grep neuron
    pip list | grep -e neuron -e torch
    
  5. Python仮想環境に入っていることを確認した後、python コマンドで Python の対話コンソールを起動して、以下スクリプトを実行してください。

    まずは CPU で計算してみましょう。

    >>> import torch
    >>> x1 = torch.arange(6).reshape(2, 3).to(dtype=torch.bfloat16)
    >>> y1 = (x1 @ x1.T).flatten()
    >>> z1 = y1[1:]
    >>> z1
    

    tensor([14., 14., 50.], dtype=torch.bfloat16) と表示されるはずです。次に、同じ計算を Trainium 上でやってみましょう。上に続けて、以下を実行してください。

    >>> import torch_xla
    >>> x2 = x1.to("xla")
    >>> y2 = (x2 @ x2.T).flatten()
    >>> z2 = y2[1:]
    >>> z2
    
    • もし以下の警告が出た場合は、無視してOKです。(REF: 公式Docs
      警告の内容を表示
      2025-08-21 07:25:18.566356: W neuron/nrt_adaptor.cc:53] nrt_tensor_write_hugepage() is not available, will fall back to nrt_tensor_write().
      2025-08-21 07:25:18.566388: W neuron/nrt_adaptor.cc:62] nrt_tensor_read_hugepage() is not available, will fall back to nrt_tensor_read().
      2025-Aug-21 07:25:18.0568 3598:4732 [1] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):213 CCOM WARN NET/OFI Failed to initialize sendrecv protocol
      2025-Aug-21 07:25:18.0573 3598:4732 [1] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):354 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
      2025-Aug-21 07:25:18.0577 3598:4732 [1] ncclResult_t nccl_net_ofi_init_no_atexit_fini_v6(ncclDebugLogger_t):183 CCOM WARN NET/OFI Initializing plugin failed
      2025-Aug-21 07:25:18.0582 3598:4732 [1] net_plugin.cc:97 CCOM WARN OFI plugin initNet() failed is EFA enabled?
      

    .to("xla") により、x2 はTrainium上のテンソルとなります。y2, z2 についても同様にTrainium上のテンソルとなります。以下のように、コンパイラメッセージが表示された後、数秒の待ち時間の末に、最終的に z2 の値が表示されます。(初回実行時だけ、数秒ではなく数分の待ち時間が生じる場合があります)

    2025-05-15 07:32:34.000963:  42146  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --framework=XLA /tmp/ubuntu/neuroncc_compile_workdir/d98fe98d-019a-4d6d-b127-0efdd43b88f2/model.MODULE_6654042255826386093+e30acd3a.hlo_module.pb --output /tmp/ubuntu/neuroncc_compile_workdir/d98fe98d-019a-4d6d-b127-0efdd43b88f2/model.MODULE_6654042255826386093+e30acd3a.neff --target=trn1 --verbose=35
    .Completed run_backend_driver.
    
    Compiler status PASS
    tensor([14., 14., 50.], device='xla:0', dtype=torch.bfloat16)
    

    Trn上での計算は「遅延評価」されます。すなわち、実際の計算は直ちには行われず、値が本当に必要になったタイミングで、計算グラフがコンパイルされ、その後、実際の計算が走ります。
    上の例では、y2 = (x2 @ x2.T).flatten()z2 = y2[1:] の実行時点では y2z2 の値は計算されていません。ただし、後で実際の値を計算するために必要となる、計算の履歴(計算グラフ)の情報だけが保持されています。最後の行( >>> z2 )で、z2 の値を画面に表示する必要が生じたため、計算グラフのコンパイル(=計算内容をTrnチップ上で実際に実行される命令列に変換する)が走り、それに成功(Compiler status PASS)したのちに、実際の計算が走りました。

  6. 「どのようなタイミングでコンパイル・遅延評価が走る/走らないか」を理解することは、モデルをTrnに移植する際に非常に重要となってきます。以下を試し、どのタイミングでコンパイルメッセージが表示されるかに注目してください。

    >>> import torch
    >>> import torch_xla
    >>> a = torch.randn(4, dtype=torch.bfloat16).to("xla")
    >>> b = a * a * a
    >>> b
    >>> c = torch.randn(4, dtype=torch.bfloat16).to("xla")
    >>> d = c * c * c
    >>> d
    >>> e = torch.randn(4, dtype=torch.bfloat16).to("xla")
    >>> f = e * e
    >>> f
    >>> g = torch.randn(5, dtype=torch.bfloat16).to("xla")
    >>> h = g * g
    >>> h
    

    b, f, h の表示時にコンパイルが走り、d の表示時にはコンパイルが走らないことを確認できます。この例は、以下のことを意味しています:

    • d は c を 3 回掛けることで計算されますが、この計算は「a から b を計算するときの計算グラフ」と同一で、すでにコンパイル済みです。そのため、d の表示時にはコンパイルは走りません。
    • f は e を 2 回掛けることで計算されますが、この計算は「a から b を計算するときの計算グラフ」とは異なり、まだコンパイルされていません。そのため、f の表示時にはコンパイルが走ります。
    • h は g を 2 回掛けることで計算されます。この計算は「e から f を計算するときの計算グラフ」と一見同じに見えますが、e の shape と g の shape は異なります。 計算に登場するテンソルの shape が異なると、それらは異なる計算グラフとして扱われます。 そのためコンパイルが走ります。

    【補足】遅延評価 (lazy mode) について: lazy mode と eager mode の違いについてはこちらを参照ください:

    以上のような計算のクセがあるにも関わらず、なぜわざわざTrainiumを利用するのでしょうか?それは、並列計算のコストパフォーマンスが非常に良い ためです。

  7. trn1.2xlarge のインスタンスを停止・削除(終了)してください。

    これを忘れると、$1.34/h(※ us-east-2・2025-05-15時点)がいつまでも課金され続けることになりますので、必ず実施してください。

脚注
  1. 本50本ノックの内容を監修くださった AWS の常世様に、感謝を申し上げます。 ↩︎

KARAKURI Techblog

Discussion