👋

Rayシリーズ:Ray Coreを利用したバッチ予測例の検証

に公開

今回はRay Coreの例として提供されているバッチ予測のサンプルを通して、バッチ予測の実装方法をみていきたいと思います。Rayに関するシリーズは以下でまとめていますのでぜひご覧ください。

https://zenn.dev/akasan/scraps/73a90764c065d1

早速例を試してみる

今回は以下の例を試してみます。この例では、バッチで取得したデータを対象として、どのように推論を行うかを試す例となっております。

https://docs.ray.io/en/latest/ray-core/examples/batch_prediction.html

環境構築

uvを利用して以下のように環境を構築します。

uv init batch_prediction -p 3.12
cd batch_prediction
uv add ray numpy pandas pyarrow

コードの実装

まずはコードの全体像を確認してみましょう。

batch_prediction.py
import pandas as pd
import numpy as np
import pyarrow.parquet as pq
import ray


def load_model():
    # A dummy model.
    def model(batch: pd.DataFrame) -> pd.DataFrame:
        # Dummy payload so copying the model will actually copy some data
        # across nodes.
        model.payload = np.zeros(100_000_000)
        return pd.DataFrame({"score": batch["passenger_count"] % 2 == 0})
    
    return model


@ray.remote
def make_prediction(model, shard_path):
    df = pq.read_table(shard_path).to_pandas()
    result = model(df)

    # Write out the prediction result.
    # NOTE: unless the driver will have to further process the
    # result (other than simply writing out to storage system),
    # writing out at remote task is recommended, as it can avoid
    # congesting or overloading the driver.
    # ...

    # Here we just return the size about the result in this example.
    return len(result)


# 12 files, one for each remote task.
input_files = [
        f"s3://anonymous@air-example-data/ursa-labs-taxi-data/downsampled_2009_full_year_data.parquet"
        f"/fe41422b01c04169af2a65a83b753e0f_{i:06d}.parquet"
        for i in range(12)
]

# ray.put() the model just once to local object store, and then pass the
# reference to the remote tasks.
model = load_model()
model_ref = ray.put(model)

result_refs = []

# Launch all prediction tasks.
for file in input_files:
    # Launch a prediction task by passing model reference and shard file to it.
    # NOTE: it would be highly inefficient if you are passing the model itself
    # like make_prediction.remote(model, file), which in order to pass the model
    # to remote node will ray.put(model) for each task, potentially overwhelming
    # the local object store and causing out-of-disk error.
    result_refs.append(make_prediction.remote(model_ref, file))

results = ray.get(result_refs)

# Let's check prediction output size.
for r in results:
    print("Prediction output size:", r)

それでは順を追ってコードを確認しましょう。

まずはモデルの実装になります。今回は機械学習モデルを実装したと言うより、ルールベースで結果のデータフレームを作成することで擬似的にモデルを再現しています。

def load_model():
    # A dummy model.
    def model(batch: pd.DataFrame) -> pd.DataFrame:
        # Dummy payload so copying the model will actually copy some data
        # across nodes.
        model.payload = np.zeros(100_000_000)
        return pd.DataFrame({"score": batch["passenger_count"] % 2 == 0})
    
    return model

次にRayのリモート関数を定義しています。リモート関数ではモデル定義と処理対象のファイル名を受け取り、その結果を返すようになっています(この例では結果そのものではなく、リモート関数で何件のデータを処理したか返しています)。

@ray.remote
def make_prediction(model, shard_path):
    df = pq.read_table(shard_path).to_pandas()
    result = model(df)

    # Write out the prediction result.
    # NOTE: unless the driver will have to further process the
    # result (other than simply writing out to storage system),
    # writing out at remote task is recommended, as it can avoid
    # congesting or overloading the driver.
    # ...

    # Here we just return the size about the result in this example.
    return len(result)

リモート関数が定義された後は、以下のようにして実際にリモート関数に引数を与えて結果への参照を取得しています。今回利用するデータはS3上で一般向けに公開されているデータセットを用意し、ファイル数は12ファイル用意されています。

# 12 files, one for each remote task.
input_files = [
        f"s3://anonymous@air-example-data/ursa-labs-taxi-data/downsampled_2009_full_year_data.parquet"
        f"/fe41422b01c04169af2a65a83b753e0f_{i:06d}.parquet"
        for i in range(12)
]

# ray.put() the model just once to local object store, and then pass the
# reference to the remote tasks.
model = load_model()
model_ref = ray.put(model)

result_refs = []

# Launch all prediction tasks.
for file in input_files:
    # Launch a prediction task by passing model reference and shard file to it.
    # NOTE: it would be highly inefficient if you are passing the model itself
    # like make_prediction.remote(model, file), which in order to pass the model
    # to remote node will ray.put(model) for each task, potentially overwhelming
    # the local object store and causing out-of-disk error.
    result_refs.append(make_prediction.remote(model_ref, file))

推論結果を取得するために、全ての参照から結果を取得し、それぞれのタスクで何件データを処理したかを出力します。

results = ray.get(result_refs)

# Let's check prediction output size.
for r in results:
    print("Prediction output size:", r)

実行してみる

それでは先ほどのコードを実行してみましょう。結果を見ると、12行の出力がされており、それぞれ約14万件のデータを処理したと言うことになります。

uv run batch_prediction.py

# 結果
Prediction output size: 141062
Prediction output size: 133932
Prediction output size: 144014
Prediction output size: 143087
Prediction output size: 148108
Prediction output size: 141981
Prediction output size: 136394
Prediction output size: 136999
Prediction output size: 139985
Prediction output size: 156198
Prediction output size: 142893
Prediction output size: 145976

OOMへの対策

バッチ予測をするときは特にデータ量が多くなる傾向があります。仮にロードされるデータのメモリ量が事前に想定できる場合は以下のようにしてoptionsを指定して実行することにより、適切に並列性を制御できるとのことです。

make_prediction.options(memory=100*1023*1025).remote(model_ref, file)

また、今回の例ではモデルをload_modelから取得していますが、この結果を毎回make_predictionにそのまま受け渡すと全ての結果がオブジェクトストレージに格納されてしまうため、ray.put()を利用して参照にした上でそれぞれのmake_predictionに受け渡しています。

まとめ

今回はバッチ予測のサンプルをみてみました。今回の例では稼働させるモデルやデータは公開されているモデルをベースとしていますが、次回は実際に開発してみたモデルを使って試してみようと思います。

Discussion