🚀

GraphCastをgooglecolabで使ってみた(2/2)

2024/06/29に公開

GraphCastとは

GraphCastとはGraph newral network(グラフニューラルネットワーク)を用いてforecast(天気予報)をするものです。従来の数値気象予測(NWP)に対して危害学習(ML)ベースの新しい手法です。
世界的な中期天気予報の精度がこの技術により向上しました。GraphCastは、再解析データから直接学習し、10日先までの気象変数を予測します。その予測精度は、現在最も精度の高い運用システム(ECMWFのHRES)を多くの評価対象で上回ります。

前回からの続き

前回の記事では、GraphCastをgooglecolabで使う際のデータセットとモデルの設定までを解説しました。https://zenn.dev/wataru923/articles/99653f75ed8aae

予測とモデルの学習

今回は、準備したデータセットとモデルを用いて、予測する方法とトレーニングする方法について見ていきたいと思います。

オートレグレッシブ予測

以下に、オートレグレッシブ予測を行うためのコードを説明します。このコードでは、JAXを使用してモデルを実行し、予測結果を生成します。

前提条件の確認

モデルの解像度がデータの解像度と一致しているかどうかを確認します。

assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
  "Model resolution doesn't match the data resolution. You likely want to "
  "re-filter the dataset list, and download the correct data.")

データの次元情報を表示

入力データ、ターゲットデータ、および強制データの次元情報を表示します。

print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

オートレグレッシブ予測の実行

オートレグレッシブ予測を行うために、rollout.chunked_prediction 関数を使用します。この関数は、Pythonループを使用して予測を行います。オートレグレッシブ予測については、以下で簡単にまとめています。
https://zenn.dev/wataru923/articles/1c730a8dd44615

predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)
  • run_forward_jitted: JITコンパイルされたモデルの順伝播関数。
  • rng: 乱数生成器。予測の再現性を確保するために使用します。
  • inputs: 初期入力データ。
  • targets_template: ターゲットテンプレートデータ。NaNで初期化されています。
  • forcings: 強制データ。

予測結果の表示

生成された予測結果を表示します。

predictions

完全なコード

以下に、全体のコードを示します。

# 前提条件の確認
assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
  "Model resolution doesn't match the data resolution. You likely want to "
  "re-filter the dataset list, and download the correct data.")

# データの次元情報を表示
print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

# オートレグレッシブ予測の実行
predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)

# 予測結果の表示
predictions

説明

  1. 前提条件の確認:

    • モデルの解像度がデータの解像度と一致していることを確認します。
  2. データの次元情報を表示:

    • 入力データ、ターゲットデータ、および強制データの次元情報を表示します。これにより、データが正しく読み込まれていることを確認します。
  3. オートレグレッシブ予測の実行:

    • rollout.chunked_prediction 関数を使用して、JITコンパイルされたモデルの順伝播関数を用いて予測を実行します。
    • rng パラメータを使用して乱数生成器を設定し、予測の再現性を確保します。
    • inputstargets_template、および forcings パラメータを使用して、初期入力データ、ターゲットテンプレートデータ、および強制データを指定します。
  4. 予測結果の表示:

    • 生成された予測結果を表示します。予測結果は、predictions 変数に格納されます。

このコードを実行することで、モデルのオートレグレッシブ予測が実行され、指定されたステップ数だけの予測結果が生成されます。

予測点の選択 (Choose predictions to plot)

今までと同じようなスタイルで、予測する要素と、圧力レベルを設定する。

予測点のプロット (Plot predictions)

予測されたデータ、実際のターゲットデータ、およびその差分をプロットします。

モデルの学習

以下の操作は大量のメモリを必要とし、使用するアクセラレータによっては、解像度の低いデータでは非常に小さな「ランダム」モデルしか適合しません。上記で選択したトレーニングステップ数を使用します。

損失の計算

以下は、オートレグレッシブ予測における損失(Loss)を計算するためのコードです。このコードは、事前にJITコンパイルされたloss_fn_jitted関数を使用して、トレーニングデータに対する損失を計算します。
事前にコンパイルされたloss_fn_jitted関数を使用して、損失と診断情報を計算します。

loss, diagnostics = loss_fn_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
  • rng: 乱数生成器。予測の再現性を確保するために使用します。
  • inputs: トレーニングの入力データ。
  • targets: トレーニングのターゲットデータ。
  • forcings: トレーニングの強制データ。

損失の表示

計算された損失を表示します。

print("Loss:", float(loss))

完全なコード

以下に、損失の計算と表示を行う完全なコードを示します。

# 損失の計算
loss, diagnostics = loss_fn_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)

# 損失の表示
print("Loss:", float(loss))

説明

  1. 損失の計算:

    • loss_fn_jitted関数を呼び出し、トレーニングデータに対する損失と診断情報を計算します。
    • rngパラメータを使用して乱数生成器を設定し、再現性を確保します。
    • inputstargets、およびforcingsパラメータを使用して、トレーニングの入力データ、ターゲットデータ、および強制データを指定します。
  2. 損失の表示:

    • 計算された損失を表示します。損失は浮動小数点数として表示されます。

このコードを実行することで、トレーニングデータに対するオートレグレッシブ予測の損失が計算され、その値が表示されます。損失の値は、モデルの性能を評価するために使用されます。損失が低いほど、モデルの予測精度が高いことを示します。

勾配の計算

以下は、オートレグレッシブ予測における勾配(Gradient)を計算するためのコードです。このコードは、事前にJITコンパイルされたgrads_fn_jitted関数を使用して、トレーニングデータに対する損失の勾配を計算し、その結果を表示します。

事前にコンパイルされたgrads_fn_jitted関数を使用して、損失、診断情報、次の状態、および勾配を計算します。

loss, diagnostics, next_state, grads = grads_fn_jitted(
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
  • inputs: トレーニングの入力データ。
  • targets: トレーニングのターゲットデータ。
  • forcings: トレーニングの強制データ。

勾配の平均値を計算

勾配の絶対値の平均を計算します。

mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
  • jax.tree_util.tree_map: 勾配の各要素に対して絶対値の平均を計算する関数を適用します。
  • jax.tree_util.tree_flatten: 勾配ツリーをフラット化します。
  • np.mean: フラット化された勾配の絶対値の平均を計算します。

損失と勾配の表示

計算された損失と勾配の平均値を表示します。

print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

その結果 Loss: 0.763336181640625が得られました。この値は、使うデーセットによって大きく変わってくると思います。

完全なコード

以下に、勾配の計算と表示を行う完全なコードを示します。

# 勾配の計算
loss, diagnostics, next_state, grads = grads_fn_jitted(
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)

# 勾配の平均値を計算
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])

# 損失と勾配の表示
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

説明

  1. 勾配の計算:

    • grads_fn_jitted関数を呼び出し、トレーニングデータに対する損失、診断情報、次の状態、および勾配を計算します。
    • inputstargets、およびforcingsパラメータを使用して、トレーニングの入力データ、ターゲットデータ、および強制データを指定します。
  2. 勾配の平均値を計算:

    • jax.tree_util.tree_mapを使用して、勾配ツリーの各要素に対して絶対値の平均を計算する関数を適用します。
    • jax.tree_util.tree_flattenを使用して、勾配ツリーをフラット化します。
    • np.meanを使用して、フラット化された勾配の絶対値の平均を計算します。
  3. 損失と勾配の表示:

    • 計算された損失と勾配の平均値を表示します。損失は小数点以下4桁まで表示され、勾配の平均値は小数点以下6桁まで表示されます。

このコードを実行することで、トレーニングデータに対するオートレグレッシブ予測の勾配が計算され、その結果が表示されます。勾配の情報は、モデルの学習を行う際に重要です。勾配が大きい場合、モデルは急速に学習する可能性がありますが、学習が不安定になることもあります。

実装結果

goolecolabで実装しましたが、RAMの容量を超えてクラッシュしてしまいました。
今度は、ローカルかリソースを増やして改めてやってみようと思います。

JAX内でオートレグレッシブ予測のループ

以下は、オートレグレッシブ予測をJAX内でループを保持しながら実行するためのコードです。このコードは、事前にJITコンパイルされたrun_forward_jitted関数を使用して予測を行います。

予測の実行

事前にJITコンパイルされたrun_forward_jitted関数を使用して、予測を実行します。

predictions = run_forward_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets_template=train_targets * np.nan,
    forcings=train_forcings)
  • rng: 乱数生成器。予測の再現性を確保するために使用します。
  • inputs: トレーニングの入力データ。
  • targets_template: ターゲットデータのテンプレート。NaNで初期化されます。
  • forcings: トレーニングの強制データ。

予測結果の表示

生成された予測結果を表示します。

predictions

説明

  1. データの次元情報を表示:

    • トレーニングデータの入力、ターゲット、および強制データの次元情報を表示します。これにより、データが正しく読み込まれていることを確認できます。
  2. 予測の実行:

    • run_forward_jitted関数を呼び出して、トレーニングデータを使用して予測を実行します。
    • rngパラメータを使用して乱数生成器を設定し、予測の再現性を確保します。
    • inputstargets_template、およびforcingsパラメータを使用して、トレーニングの入力データ、ターゲットデータのテンプレート、および強制データを指定します。
  3. 予測結果の表示:

    • 生成された予測結果を表示します。予測結果は、predictions変数に格納されます。

このコードを実行することで、トレーニングデータを使用したオートレグレッシブ予測がJAX内で実行され、予測結果が表示されます。

結果

最後のgrads_fn_jitted関数を用いて、損失と勾配を求めるところでは、メモリ不足でクラッシュしてしまいました。今後は、ローカル環境やリソースを増強して再度試みる予定です。

Discussion