DagsHub使ってみた
今回はAIのデータやモデルについて管理することに重きを置いたDagsHubというサービスについて、その紹介と使い方をご紹介していこうと思います。
DagsHubとは?
一言で言うと、データサイエンスに関わるコードだけでなくデータについても合わせて管理するためのレポジトリ
を提供してくれるサービスです。コードの管理だけであればGitHubなどで問題ありませんが、いわゆるdvcを利用したデータ管理を行っているようなプロジェクトであったり、複数人でチームを組んで開発するデータサイエンスプロジェクト向けのレポジトリを探している場合は選択肢としてDagsHubも候補かと思います。特徴として、DAG(有向非巡回グラフ)を利用したリネージの可視化機能がデフォルトで提供されており、データの流れの可視化があるという点もメリットの一つです。
dvcについて
DagsHubを使う上ではdvcというものが欠かせません。dvcとはデータのバージョンの管理をするためのgitと思ってもらえれば大丈夫です。実際にdvcを利用する場合はgitと組み合わせて利用することになります。この記事では詳細には言及しませんが、要点だけかいつまんでみます!
- gitの感覚でデータを管理することができる
- 大規模なデータ(それこそGB単位やそれ以上)をgitで直接扱うのではなく、データレポジトリにデータの実態は保存しつつ、任意のデータをpull/pushして利用できるようにするコマンドです
- データサイエンスの再現性にも有用であり、ハイパーパラメータやモデルのアーティファクトなどの管理を一元化できます。
具体的な使い方はまた別の記事で紹介できればと思います。
実際にDagsHubを使ってみよう
まずはdvcでデータの準備
※この記事はDagsHubの存在を周知したい目的で作っていますが、結果としてdvcの使い方が大半となります
今回はデモなので大規模なデータを実際に用意してどうにかするということは避けようと思います。その代わり、以下のような想定のもと環境構築をしてみます。
- データの実態はGoogle Cloud Storageに保存されるものとする
- 対象データはタイタニックデータを利用する(Kaggleで有名なデータですね)
- モデルはscikit-learnから利用する。なおハイパーパラメータチューニングは一旦しないものとする
まずはgitを使ってプロジェクトを作ってみましょう。最初のプロセスは以下のように通常のgitを使う時と同じようにプロジェクトを作ります。
mkdir dagshub_test_repo
cd dagshub_test_repo
git init
次に、dvcはPythonで利用できるので、Pythonを利用してみます。今回は簡単のため仮想環境などは省略します。
pip install dvc[gs]
インストールが完了したらプロジェクトでdvc環境を初期化します。以下のコマンドによって初期化ができます。
dvc init
このコマンドを実行すると、以下のように3つの設定ファイルが生成されますのでgitで管理対象にしてください。
No commits yet
Changes to be committed:
(use "git rm --cached <file>..." to unstage)
new file: .dvc/.gitignore
new file: .dvc/config
new file: .dvcignore
次にデータを登録してみます。タイタニックのデータセットについて、今回は学習用データセットをdata/train.csv
として利用します。まずはdata/train.csv
を作成いただき、以下のコマンドによってdvcの管理対象にデータを追加します。
dvc add data/train.csv
すると以下のような表示がされますので、その指示に従ってコマンドを実行してください。なお、このコマンドが何をしているかというと、gitではデータの実態を追跡せず、代わりにdvcがデータを参照するためのファイルをgitで追跡するというものです。
To track the changes with git, run:
git add data/.gitignore data/train.csv.dvc
次に、データの実態をGCSにアップロードしてみます。まずはdvcのリモートリポジトリとしてGCSを指定する必要がありますが、実行前にgcloudで適切に認証が完了していることを確認してください(詳細はこちら)。以下のコマンドでその設定を反映させます。
dvc remote add -d <リモート名> gs://<mybucket>/<path>
登録ができたら、以下のようにすることでリモートパスにデータの実態をpushできます。
dvc push data/train.csv
なお、バケット内を見ると大量のハッシュ値がついたファイルがあるかと思います。dvcではファイルを分割して保存するようになっており、人の目では解読できない形態になっております。そのため、データの実態はデータを取得して確認してください。
上記の方法でリモートバケットを設定に追加したので、.dvc/config
の設定をgitでコミットしてください。
git add .dvc/conig && git commit -m "update: .dvc config for remote bucket"
なお、データの実体を削除してもリモートレポジトリにデータがあれば以下の方法でデータを復元できます。
dvc pull data/train.csv
学習の準備
データの準備はできましたので、次はモデルの学習の準備をしたいと思います。
今回学習や前処理で利用するファイルはsrc
フォルダにて作成しようと思います。そのためまずはsrcフォルダを作成しておきます。
mkdir src
具体的なファイルを作る前に、dvcではパイプラインの作成の方法がいくつかあるのですが、今回はdvc.yaml
ファイルを用いた管理を実装します。具体的な方法や設定内容は公式ページを参照いただくとして、今回は前処理と学習フローを以下のように実装します。
stages:
preprocess:
cmd: python src/preprocess.py
deps:
- data/train.csv
- src/preprocess.py
outs:
- data/preprocessed_train.pkl
train:
cmd: python src/train.py
deps:
- data/preprocessed_train.pkl
- src/train.py
outs:
- artifacts/model.pkl
metrics:
- artifacts/summary.csv:
cache: false
では次に前処理を実装します。今回は簡単のため、数値データだけを持つカラムだけを残すように前処理を実装します(検証のために極めて手抜きでやってますが見逃してください)。
import pickle
import pandas as pd
def preprocess():
train_df = pd.read_csv("data/train.csv")
train_df = train_df.drop(["PassengerId", "Name", "Sex", "Ticket", "Cabin", "Embarked"], axis=1)
train_df = train_df.dropna(axis=0)
pickle.dump(train_df, open("data/preprocessed_train.pkl", "wb"))
if __name__ == "__main__":
preprocess()
また、学習ファイルは以下のようにします。
import pickle
from sklearn.ensemble import RandomForestClassifier
def train():
df = pickle.load(open("data/preprocessed_train.pkl", "rb"))
y = df["Survived"]
X = df.drop(["Survived"], axis=1)
model = RandomForestClassifier(max_depth=5)
model.fit(X, y)
pickle.dump(model, open("artifacts/model.pkl", "wb"))
with open("artifacts/summary.csv", "w") as f:
f.write("accuracy\n1.0")
if __name__ == "__main__":
train()
これで実験準備は完了です。dvc.yamlファイルも用意しているので、あとは以下のコマンドを実行すると前処理と学習が順番に実行されます。
dvc repro
実行するとdataフォルダとartifactsフォルダに新しいファイルが生成されます。
DagsHubへpush
ではいよいよDagsHubへpushしてみようと思います。
まずはアカウント登録ですが、無料枠の範疇で今回は実施しようと思ってます。アカウント認証が終わったらレポジトリを作成してください。レポジトリが作り終わったら、以下のようにして通常のリモートブランチにpushする要領で対応します。
git remote add origin <DagsHubのリポジトリ>
git push
pushしてレポジトリを開くと以下のようになっているかと思います。今回は前処理をした後にそのデータを使ってモデルを学習し、アーティファクトを作成しましたが、その内容がGUIとしてDAGで表現sらえているのがわかります。このような表示があると、データ・モデルのリネージを理解することができ、問題が起こった時にどこのプロセスで失敗したかを評価しやすくなります。
まとめ
今回はDagsHubの紹介(と言いつつ大半がdvcの操作)をさせていただきました。dvcを利用するとリネージの可視化ができ、DagsHubと組み合わせることでデータサイエンティストにとって必要なデータの流れが追いやすくなるということで、ぜひチーム開発をする方はツールの候補として選択してみてください!
Discussion