GraphCastをgooglecolabで使ってみた(2/2)
GraphCastとは
GraphCastとはGraph newral network(グラフニューラルネットワーク)を用いてforecast(天気予報)をするものです。従来の数値気象予測(NWP)に対して危害学習(ML)ベースの新しい手法です。
世界的な中期天気予報の精度がこの技術により向上しました。GraphCastは、再解析データから直接学習し、10日先までの気象変数を予測します。その予測精度は、現在最も精度の高い運用システム(ECMWFのHRES)を多くの評価対象で上回ります。
前回からの続き
前回の記事では、GraphCastをgooglecolabで使う際のデータセットとモデルの設定までを解説しました。https://zenn.dev/wataru923/articles/99653f75ed8aae
予測とモデルの学習
今回は、準備したデータセットとモデルを用いて、予測する方法とトレーニングする方法について見ていきたいと思います。
オートレグレッシブ予測
以下に、オートレグレッシブ予測を行うためのコードを説明します。このコードでは、JAXを使用してモデルを実行し、予測結果を生成します。
前提条件の確認
モデルの解像度がデータの解像度と一致しているかどうかを確認します。
assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
"Model resolution doesn't match the data resolution. You likely want to "
"re-filter the dataset list, and download the correct data.")
データの次元情報を表示
入力データ、ターゲットデータ、および強制データの次元情報を表示します。
print("Inputs: ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)
オートレグレッシブ予測の実行
オートレグレッシブ予測を行うために、rollout.chunked_prediction
関数を使用します。この関数は、Pythonループを使用して予測を行います。オートレグレッシブ予測については、以下で簡単にまとめています。
predictions = rollout.chunked_prediction(
run_forward_jitted,
rng=jax.random.PRNGKey(0),
inputs=eval_inputs,
targets_template=eval_targets * np.nan,
forcings=eval_forcings)
-
run_forward_jitted
: JITコンパイルされたモデルの順伝播関数。 -
rng
: 乱数生成器。予測の再現性を確保するために使用します。 -
inputs
: 初期入力データ。 -
targets_template
: ターゲットテンプレートデータ。NaNで初期化されています。 -
forcings
: 強制データ。
予測結果の表示
生成された予測結果を表示します。
predictions
完全なコード
以下に、全体のコードを示します。
# 前提条件の確認
assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
"Model resolution doesn't match the data resolution. You likely want to "
"re-filter the dataset list, and download the correct data.")
# データの次元情報を表示
print("Inputs: ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)
# オートレグレッシブ予測の実行
predictions = rollout.chunked_prediction(
run_forward_jitted,
rng=jax.random.PRNGKey(0),
inputs=eval_inputs,
targets_template=eval_targets * np.nan,
forcings=eval_forcings)
# 予測結果の表示
predictions
説明
-
前提条件の確認:
- モデルの解像度がデータの解像度と一致していることを確認します。
-
データの次元情報を表示:
- 入力データ、ターゲットデータ、および強制データの次元情報を表示します。これにより、データが正しく読み込まれていることを確認します。
-
オートレグレッシブ予測の実行:
-
rollout.chunked_prediction
関数を使用して、JITコンパイルされたモデルの順伝播関数を用いて予測を実行します。 -
rng
パラメータを使用して乱数生成器を設定し、予測の再現性を確保します。 -
inputs
、targets_template
、およびforcings
パラメータを使用して、初期入力データ、ターゲットテンプレートデータ、および強制データを指定します。
-
-
予測結果の表示:
- 生成された予測結果を表示します。予測結果は、
predictions
変数に格納されます。
- 生成された予測結果を表示します。予測結果は、
このコードを実行することで、モデルのオートレグレッシブ予測が実行され、指定されたステップ数だけの予測結果が生成されます。
予測点の選択 (Choose predictions to plot)
今までと同じようなスタイルで、予測する要素と、圧力レベルを設定する。
予測点のプロット (Plot predictions)
予測されたデータ、実際のターゲットデータ、およびその差分をプロットします。
モデルの学習
以下の操作は大量のメモリを必要とし、使用するアクセラレータによっては、解像度の低いデータでは非常に小さな「ランダム」モデルしか適合しません。上記で選択したトレーニングステップ数を使用します。
損失の計算
以下は、オートレグレッシブ予測における損失(Loss)を計算するためのコードです。このコードは、事前にJITコンパイルされたloss_fn_jitted
関数を使用して、トレーニングデータに対する損失を計算します。
事前にコンパイルされたloss_fn_jitted
関数を使用して、損失と診断情報を計算します。
loss, diagnostics = loss_fn_jitted(
rng=jax.random.PRNGKey(0),
inputs=train_inputs,
targets=train_targets,
forcings=train_forcings)
-
rng
: 乱数生成器。予測の再現性を確保するために使用します。 -
inputs
: トレーニングの入力データ。 -
targets
: トレーニングのターゲットデータ。 -
forcings
: トレーニングの強制データ。
損失の表示
計算された損失を表示します。
print("Loss:", float(loss))
完全なコード
以下に、損失の計算と表示を行う完全なコードを示します。
# 損失の計算
loss, diagnostics = loss_fn_jitted(
rng=jax.random.PRNGKey(0),
inputs=train_inputs,
targets=train_targets,
forcings=train_forcings)
# 損失の表示
print("Loss:", float(loss))
説明
-
損失の計算:
-
loss_fn_jitted
関数を呼び出し、トレーニングデータに対する損失と診断情報を計算します。 -
rng
パラメータを使用して乱数生成器を設定し、再現性を確保します。 -
inputs
、targets
、およびforcings
パラメータを使用して、トレーニングの入力データ、ターゲットデータ、および強制データを指定します。
-
-
損失の表示:
- 計算された損失を表示します。損失は浮動小数点数として表示されます。
このコードを実行することで、トレーニングデータに対するオートレグレッシブ予測の損失が計算され、その値が表示されます。損失の値は、モデルの性能を評価するために使用されます。損失が低いほど、モデルの予測精度が高いことを示します。
勾配の計算
以下は、オートレグレッシブ予測における勾配(Gradient)を計算するためのコードです。このコードは、事前にJITコンパイルされたgrads_fn_jitted
関数を使用して、トレーニングデータに対する損失の勾配を計算し、その結果を表示します。
事前にコンパイルされたgrads_fn_jitted
関数を使用して、損失、診断情報、次の状態、および勾配を計算します。
loss, diagnostics, next_state, grads = grads_fn_jitted(
inputs=train_inputs,
targets=train_targets,
forcings=train_forcings)
-
inputs
: トレーニングの入力データ。 -
targets
: トレーニングのターゲットデータ。 -
forcings
: トレーニングの強制データ。
勾配の平均値を計算
勾配の絶対値の平均を計算します。
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
-
jax.tree_util.tree_map
: 勾配の各要素に対して絶対値の平均を計算する関数を適用します。 -
jax.tree_util.tree_flatten
: 勾配ツリーをフラット化します。 -
np.mean
: フラット化された勾配の絶対値の平均を計算します。
損失と勾配の表示
計算された損失と勾配の平均値を表示します。
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")
その結果 Loss: 0.763336181640625
が得られました。この値は、使うデーセットによって大きく変わってくると思います。
完全なコード
以下に、勾配の計算と表示を行う完全なコードを示します。
# 勾配の計算
loss, diagnostics, next_state, grads = grads_fn_jitted(
inputs=train_inputs,
targets=train_targets,
forcings=train_forcings)
# 勾配の平均値を計算
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
# 損失と勾配の表示
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")
説明
-
勾配の計算:
-
grads_fn_jitted
関数を呼び出し、トレーニングデータに対する損失、診断情報、次の状態、および勾配を計算します。 -
inputs
、targets
、およびforcings
パラメータを使用して、トレーニングの入力データ、ターゲットデータ、および強制データを指定します。
-
-
勾配の平均値を計算:
-
jax.tree_util.tree_map
を使用して、勾配ツリーの各要素に対して絶対値の平均を計算する関数を適用します。 -
jax.tree_util.tree_flatten
を使用して、勾配ツリーをフラット化します。 -
np.mean
を使用して、フラット化された勾配の絶対値の平均を計算します。
-
-
損失と勾配の表示:
- 計算された損失と勾配の平均値を表示します。損失は小数点以下4桁まで表示され、勾配の平均値は小数点以下6桁まで表示されます。
このコードを実行することで、トレーニングデータに対するオートレグレッシブ予測の勾配が計算され、その結果が表示されます。勾配の情報は、モデルの学習を行う際に重要です。勾配が大きい場合、モデルは急速に学習する可能性がありますが、学習が不安定になることもあります。
実装結果
goolecolabで実装しましたが、RAMの容量を超えてクラッシュしてしまいました。
今度は、ローカルかリソースを増やして改めてやってみようと思います。
JAX内でオートレグレッシブ予測のループ
以下は、オートレグレッシブ予測をJAX内でループを保持しながら実行するためのコードです。このコードは、事前にJITコンパイルされたrun_forward_jitted
関数を使用して予測を行います。
予測の実行
事前にJITコンパイルされたrun_forward_jitted
関数を使用して、予測を実行します。
predictions = run_forward_jitted(
rng=jax.random.PRNGKey(0),
inputs=train_inputs,
targets_template=train_targets * np.nan,
forcings=train_forcings)
-
rng
: 乱数生成器。予測の再現性を確保するために使用します。 -
inputs
: トレーニングの入力データ。 -
targets_template
: ターゲットデータのテンプレート。NaNで初期化されます。 -
forcings
: トレーニングの強制データ。
予測結果の表示
生成された予測結果を表示します。
predictions
説明
-
データの次元情報を表示:
- トレーニングデータの入力、ターゲット、および強制データの次元情報を表示します。これにより、データが正しく読み込まれていることを確認できます。
-
予測の実行:
-
run_forward_jitted
関数を呼び出して、トレーニングデータを使用して予測を実行します。 -
rng
パラメータを使用して乱数生成器を設定し、予測の再現性を確保します。 -
inputs
、targets_template
、およびforcings
パラメータを使用して、トレーニングの入力データ、ターゲットデータのテンプレート、および強制データを指定します。
-
-
予測結果の表示:
- 生成された予測結果を表示します。予測結果は、
predictions
変数に格納されます。
- 生成された予測結果を表示します。予測結果は、
このコードを実行することで、トレーニングデータを使用したオートレグレッシブ予測がJAX内で実行され、予測結果が表示されます。
結果
最後のgrads_fn_jitted
関数を用いて、損失と勾配を求めるところでは、メモリ不足でクラッシュしてしまいました。今後は、ローカル環境やリソースを増強して再度試みる予定です。
Discussion