🦜
【python】one-hot encodingの実装(4パターン)
- 機械学習モデルにカテゴリ変数を特徴量として入力するための手法として、one-hot encodingがある
- これの実装方法をいくつかまとめてみた
ライブラリのバージョン
python : 3.13.5
- numpy : 2.3.3
- pandas : 2.3.2
- Scikit-learn : 1.7.2
1. numpyを使う場合
numpyは処理が早く、使用頻度が高いため、大体はこの形式でよい
プログラム
import numpy as np
data = np.array(["red", "blue", "green", "red", "blue"])
# ユニーク値と各データのインデックスを取得
unique_values, inverse = np.unique(data, return_inverse=True)
print("unique values:", unique_values)
print("inverse:", inverse)
# one-hot行列を生成
one_hot = np.eye(len(unique_values))[inverse]
print("One-hot encoding:\n", one_hot)
出力結果

numpy.uniqueについて
2. pandasを使う場合
pandasもよく使うため、手っ取り早くできます
ちなみにget_dummiesはcategory型を使わずに使用すると、入力データの順番によって、位置が変わる可能性があるため注意
(これで、train, validation別々にかけると特徴量の列番号が変わってめちゃくちゃになることがある)
プログラム例
import pandas as pd
# データに "yellow" は出てこない
df = pd.DataFrame({"color": ["red", "blue", "red", "blue", "green"]})
print(df)
# category 型に "yellow" も含めて指定
df["color"] = pd.Categorical(df["color"], categories=["red", "blue", "green", "yellow"])
# One-hot encoding
one_hot = pd.get_dummies(df, columns=["color"], dtype="int8")
print(one_hot)
出力結果

pd.get_dummiesについて
3. scikit-learnを使う場合
多機能だけど、使い方にクセがあるので個人的にはあまり使わない
プログラム例
from sklearn.preprocessing import OneHotEncoder
import numpy as np
data = np.array([["red"], ["blue"], ["green"], ["red"], ["blue"]])
# encoder = OneHotEncoder(sparse_output=False)
encoder = OneHotEncoder(sparse_output=False, categories=[["red", "green", "blue", "yellow"]]) # 順番やdataにない列を入れる場合
one_hot = encoder.fit_transform(data)
print("Categories:", encoder.categories_)
print("One-hot encoding:\n", one_hot)
出力結果

OneHotEncoderについて
4. ライブラリ不使用
サードパーティ製を使えない状況の場合は、こんな感じ
プログラム
data = ["red", "blue", "green", "red", "blue"]
# ユニークな値を取得
# unique_vals = list(set(data)) # 自動取得の場合
unique_vals = ["red", "blue", "green", "yellow"] # 順番やdataにない要素を指定する場合
# One-hot encoding
one_hot = []
for val in data:
row = [1 if val == uv else 0 for uv in unique_vals]
one_hot.append(row)
print("Unique values:", unique_vals)
print("One-hot encoding:")
for row in one_hot:
print(row)
出力結果

まとめ
- 1-hot encodingの実装方法をまとめた
- numpy, pandas, scikit-learn, ライブラリなしの4パターンを紹介した
Discussion