🤖

polarsでLabelEncoderを使おうとしたら詰まった話

2022/12/06に公開約2,100字

TL;DR

pyarrowを入れよう

polarsって?

polarsとは、PythonとRustのDataFrameライブラリです。DataFrameといえばpandasですが、なんとpolarsはpandasより高速に動作します!![1]

環境

python = "^3.10"
polars = "^0.15.2"
scikit-learn = "^1.1.3"

本題

タイタニックデータセットを利用します。

以下はSexをラベルエンコーディングして、csvとして出力するコードです。

import polars as pl
from sklearn import preprocessing


# データセットの読み込み
dataset_path = "./titanic"
train = pl.read_csv(f"{dataset_path}/train.csv")


# Sexを数値に変換
le = preprocessing.LabelEncoder()
sex = train["Sex"]
labels = le.fit_transform(sex)
train = train.with_column(pl.Series(labels).alias("Sex"))


# csvファイルとして出力
train.write_csv("out.csv")

さて、このコードを実行して、出力されたcsvファイルを見てみましょう。

なんということでしょう。
Sexがすべて0になってしまいました。

ちょっと寄り道

ここで、fit_transformに渡すSeriesをndarrayに変換して、ついでに出力して値を確認してみます。

import polars as pl
from sklearn import preprocessing


# データセットの読み込み
dataset_path = "./titanic"
train = pl.read_csv(f"{dataset_path}/train.csv")


# Sexを数値に変換
le = preprocessing.LabelEncoder()
- sex = train["Sex"]
+ sex = train["Sex"].to_numpy()
+ print(sex[:10])
labels = le.fit_transform(sex)
train = train.with_column(pl.Series(labels).alias("Sex"))


# csvファイルとして出力
train.write_csv("out.csv")

すると、以下のようにndarrayの要素がnanになっていることが確認できます。

[nan nan nan nan nan nan nan nan nan nan]

どうやらこのあたりに原因があるようです。

解決へ

ことの原因は見えてきました。
こういうときは公式ドキュメントを読むのが一番でしょう。

Series.to_numpyのAPIリファレンスのNoteに気になることが書かれています。

If you are attempting to convert Utf8 to an array you’ll need to install pyarrow.

どうやらpyarrowというライブラリが別途必要なようです。

pyarrowは以下のコマンドでインストールできます。

pip install pyarrow

インストールした後に、はじめのコードを実行してみると……

今度は無事にSexをラベルエンコーディングすることができました。

さいごに

kaggleに入門しようと思ったら、思わぬ落とし穴にハマってしまいました……

これは余談なのですが、実際に詰まったときにはfit_transformにSeriesを入れるとエラーがでました。正直エラー無しでは解決に相当時間が掛かっていたと思うので、改めてエラーの有り難みを実感しました。

脚注
  1. https://www.pola.rs/benchmarks.html ↩︎

Discussion

ログインするとコメントできます