👿

【Python】PrefectでSLURMのジョブを管理する〜HPC環境での機械学習ワークフロー構築に向けて〜

2024/11/11に公開

はじめに

こんにちは。わいけい(@yk_llm_gpt)です。
今回はMLOpsの話題です。
より正確に言うと、オンプレ環境のMLにおける学習部分の管理がテーマです。

私はもともとWeb系のバックエンドエンジニアをしていますが、最近、MLOps周りのタスクを業務内で行う機会が増えてきました。
例えば、SSHで自社サーバー(NVIDIAのDGXサーバーです)内に入り、その中で機械学習ワークフロー構築をする、といった業務です。

しかし、以下のような事情により、初めはとっつきづらかったです。

  • そもそも私はスーパーコンピューターを始めとするHPC環境での開発経験があまりなかった
  • ML領域でも最近はクラウドを使うことが増えており、オンプレ向けのオープンな情報が少ない
  • 特に弊社のようなHPC(High-Performance Computing)環境での学習ワークフロー管理に関する日本語の情報が少ない

そこで今回は、私が苦労したポイントの1つである「SLURMジョブをモダンな機械学習ワークフローオーケストレーションツール(Prefect)で管理する」というTipsについて、概要と手順をまとめておきたいと思います。

PrefectでSLURMジョブを管理することで

  • Prefectのワークフロー内のタスクを、他のチームメンバーの作業を極力妨害しないように実行出来る
  • SLURM側からの視点でみれば、ジョブの失敗時のリトライや通知が簡単に行える上、ジョブの進行情報や統計情報などをブラウザなどで確認できるようになる
  • 以上から結果的に開発生産性が向上する

といったメリットがあります。

なお、今回の記事作成にあたり、下記で行われている米国のNERSC(国立エネルギー研究科学計算センター)チームの議論を参考にしました。

https://github.com/PrefectHQ/prefect/issues/10136

この記事の対象者

この記事は、自分と同じくHPC環境での機械学習初心者に向けて書いています。例えば、大規模なモデル(LLMなど)を学習させるべく計算環境を手に入れたものの、その使い方が今一つ肌になじんでいないような人が対象です。

逆に、既にゴリゴリとスパコン環境での計算を回し続けています、みたいな人には少し物足りないかもしれません。

ここでいうHPC環境とは

まず前提として、今回の記事で言っている「HPC環境」が何を指すのかを明確化しておきたいと思います。

HPC環境をものすごくざっくり言うと、通常の計算機では不可能な規模での計算を可能にするために用意された超ハイパワーな環境ということになります。
こういった環境は複数の人もしくはチームがそれぞれの用途で使いまわしたいことが多いです。
そのため、そういったニーズに合わせて通常の計算機とは少し構成が異なっています。

クラスターとジョブスケジューラー

例えば、(弊社もそうですが)HPC環境は、いくつかの計算ノードと管理ノード(各ノードは計算機、すなわちサーバーです)に分割されています。
そして、それらの要素から構成されるクラスターをそれぞれのユーザーがリソースをうまく分担しながら使っていきます。
こうすることで計算リソースに無駄が出ないようにしています。

例えば弊社の場合、(詳細はぼかしますが)NVIDIAのDGX A100を複数積んだ計算ノードをいくつか保有しています。そして、それらをまとめて管理するサーバーが1台存在しています。

複数の計算ノードでうまく計算を行うために、ジョブスケジューラー(その名の通り、ジョブをスケジューリングするやつです)を介してそれぞれのユーザーがやりたい計算を実行していきます。

ジョブスケジューラーを介することで、ユーザーはそれぞれのジョブに対してGPU数や最大実行時間などの必要リソースを必要な分だけ割り当てることが可能です。
結果として、使用者全体としてリソースを効率的に分け合えるわけですね。

ちなみに我々の場合はジョブスケジューラーとして次に紹介するSLURMを使っています。

SLURMとは

SLURM(Simple Linux Utility for Resource Management)は、ジョブスケジューラーの一つです。

https://slurm.schedmd.com/overview.html

(私もそこまで詳しくないのですが、アカデミアの領域では結構使っている人が多いのだとか)

ちなみに、SLURMはオープンソースのソフトウェアです。
https://github.com/SchedMD/slurm

SLURMを使う場合、主に下記のようなコマンドを駆使していくことになります。

コマンド 説明 使用例
sbatch ジョブをキューに投入 sbatch job_script.sbatch
srun コマンドを実行(インタラクティブまたはスクリプト内) srun my_program
squeue 現在のジョブキューを表示 squeue -u username
scancel ジョブをキャンセル scancel job_id
sinfo クラスタやリソースの状態を表示 sinfo
sacct ジョブの履歴情報を表示 sacct -j job_id
sstat 実行中のジョブの統計情報を取得 sstat -j job_id
scontrol SLURMの設定やジョブ情報を管理 scontrol show job job_id
sprio ジョブの優先順位を確認 sprio
salloc インタラクティブなリソース割り当て salloc -N 1 -n 4

中でもsbatchコマンドを使ってsbatchファイルに記載されたジョブを実行していく流れは頻繁に行われます。今回のテーマであるPrefectでのワークフロー管理では、このsbatchファイルを書く作業は基本的に発生しないのですが、一応sbatchファイルについてもごく簡単に触れておきます。

sbatchファイルの例

sbatchファイルは典型的には下記のような内容が記載されます。
見てわかる通り、シェルスクリプトにかなり近いです。
しかし、最初の方のコメントっぽい部分も意味を持っているので注意が必要です。

#!/bin/bash
#SBATCH --job-name=my_job_name       # ジョブ名
#SBATCH --output=my_job_output.out   # 標準出力ファイルの名前
#SBATCH --error=my_job_error.err     # エラーログファイルの名前
#SBATCH --ntasks=1                   # 使用するタスク数
#SBATCH --cpus-per-task=4            # 1タスクあたりのCPU数
#SBATCH --mem=8G                     # メモリ要求 (例: 8GB)
#SBATCH --time=02:00:00              # ジョブの最大実行時間 (hh:mm:ss)
#SBATCH --partition=general          # パーティション名(クラスター設定に応じて指定)

# モジュールの読み込み (必要な場合)
module load python/3.8  # 例としてPython 3.8を読み込む場合

# 実行するコマンド
python my_script.py    # 実行したいプログラムやスクリプト

下記のようなコマンドでsbatchファイルに記入されたジョブを投入することができます。

sbatch sample.sbatch

Prefectとは

さて次は、最近注目が集まっているPythonのワークフロー管理ツール、Prefectについて簡単に見ていきましょう。(すでにPrefectに親しんでいる方は読み飛ばして構いません)

Prefectは、データエンジニアリングやデータパイプラインの自動化に特化したワークフローオーケストレーションツールです。
特にETL (Extract, Transform, Load) プロセスの管理とモニタリングに優れています。
Prefectはエラー管理やリトライの制御、スケジューリング、依存関係の設定など、ワークフローを効率よく構築・運用するための機能が豊富に備わっています。

ちなみに、ワークフローオーケストレーションといえばAirflowが有名です。Airflow開発チームのメンバーがAirflowの欠点を補う形でPrefectを開発したという歴史的な経緯があります。

そのため、Prefectでは単純にPythonコードを書く要領でワークフローを記述することができます。
なので、私のような普段あまりMLに携わっていない人にとってもとっつきやすいと思います。

https://www.prefect.io/

Prefectでハマったところ

通常の環境であればPrefectは初心者に優しいツールであると個人的には思っています。

しかし、HPC環境への導入時には色々とハマるところがありました。
特に、SLURMのようなジョブ管理とPrefectでのワークフロー管理を両立させる方法がわからずに結構苦しみました。

すごく単純にやるのであれば、Prefectのタスクとしてsbatchコマンドを無理やり叩いてジョブを実行する方法などもありますが、それだとPrefectのタスク自体はsbatchコマンドを叩いた時点で成功とみなされるため、ジョブの成功・失敗などの情報をスマートに取得するのが難しい等の問題が残ります。

PrefectからSLURMジョブタスクを管理する

結論、PrefectからSLURMジョブを管理することは一応可能です。例えば、下記はSLURMジョブ投入と成否の監視を行うサンプルコードです。

from __future__ import annotations
from prefect import flow, task
import time
from dask_jobqueue import SLURMCluster
from prefect_dask.task_runners import DaskTaskRunner

def make_dask_runner(
    cluster_kwargs: dict,
    adapt_kwargs: dict[str, int | None] | None = None,
    client_kwargs: dict = None,
    temporary: bool = False,
):
    cluster_class = SLURMCluster

    # Make the one-time-use DaskTaskRunner
    if temporary:
        return DaskTaskRunner(
            cluster_class=cluster_class,
            cluster_kwargs=cluster_kwargs,
            adapt_kwargs=adapt_kwargs,
            client_kwargs=client_kwargs,
        )

    # Make the Dask cluster
    cluster = _make_dask_cluster(cluster_kwargs)

    # Set up adaptive scaling
    if adapt_kwargs and (adapt_kwargs["minimum"] or adapt_kwargs["maximum"]):
        cluster.adapt(minimum=adapt_kwargs["minimum"], maximum=adapt_kwargs["maximum"])

    # Return the DaskTaskRunner with the cluster address
    return DaskTaskRunner(address=cluster.scheduler_address)

def _make_dask_cluster(cluster_kwargs: dict = {}, verbose: bool = True):
    cluster = SLURMCluster(**cluster_kwargs)
    if verbose:
        print(
            f"Workers are submitted with the following job script:\n{cluster.job_script()}"
        )
        print(f"Scheduler is running at {cluster.scheduler.address}")
        print(f"Dashboard is located at {cluster.dashboard_link}")

    return cluster

if __name__ == "__main__":
    account_name = "your_name"
    n_slurm_jobs = 1  # Number of Slurm jobs to launch in parallel.
    n_nodes_per_calc = 1  # Number of nodes to reserve for each Slurm job.
    n_cores_per_node = 2  # Number of CPU cores per node.
    mem_per_node = "64 GB"  # Total memory per node.

    cluster_kwargs = {
        "n_workers": n_slurm_jobs,
        "cores": n_cores_per_node,
        "memory": mem_per_node,
        "shebang": "#!/bin/bash",
        "account": account_name,
        "walltime": "00:10:00",
        "memory": "1000M",
        "job_directives_skip": ["-n", "--cpus-per-task"],
        "job_extra_directives": [f"-N {n_nodes_per_calc}", "-q debug"],
        "python": "python",
    }

    runner = make_dask_runner(cluster_kwargs, temporary=False)

    @task
    def test():
        print("do something...")
        time.sleep(10)

    @task
    def test2():
        print("do something2...")
        time.sleep(10)

    @flow(task_runner=runner)
    def workflow(*args, **kwargs):
        test()
        test2()

    workflow()

Prefect Serverの起動

これを実行する前には、あらかじめPrefectサーバーを起動しておく必要があります。

prefect server start

もしSSHでログインしている場合はnohupコマンドなどを使用して、接続終了時にPrefectサーバーが落ちないようにしておくと良いでしょう。

nohup prefect server start &

もしくは引数でbackground指定することも可能です。

prefect server start --background

ワークフローの実行

さて、Prefectサーバーが起動したらワークフローを実行するには先ほどのPythonファイルを実行するだけです。

python3 sample.py

これでPrefectに備わっている、基本的なワークフロー構築、失敗時のリトライ、タスク終了時の通知(Slackなどに通知可)、タスク状況のGUI監視、タスクのスケジュール実行などがSLURMジョブに対しても行えるようになったと思います。

終わりに

オンプレでML周りを色々と頑張ろうとする場合、コードが正しくてもインフラに起因するジョブの失敗はどうしても起こってしまいます。

そういった時に、手作業でやり直すのはとても面倒くさいです。

Prefectを導入することでそういった負担を軽減することができたのは、個人的にはかなり嬉しいポイントでした。


今後もWeb開発やLLMに関する発信を行っていく予定なので今回の記事が少しでも役立ったという方は、私のSNSなどをフォローしていただけると大変喜びます。

SpiralAIテックブログ

Discussion