TPEの簡単な説明と実装
初投稿です。
この記事ではTPEの原理と実装について簡単に説明していきます。深い原理については説明せず、必要最小限の実装を行うのでご了承ください。
TPEとは?
- ベイズ最適化の一種です。
- ガウス過程回帰を用いるベイズ最適化と比べると計算量が非常に軽いです。
- OptunaやHyperoptなどのパラメータ最適化ライブラリの基幹アルゴリズムになっています。
簡単な原理
ここではブラックボックスな目的関数
「最大化」を問題にしていないので注意してください。
獲得関数EI
ベイズ最適化で良く用いられる獲得関数EI(Expected Immrovement)を考えます。EIは以下のように表されます。
ここで、
予測値
これからこの式を変形しTPEが目標とする形に導きます。
ベイズの定理
まず、ベイズの定理より
これを用いてEIの式を変形します。
カーネル密度推定
ここで
ここで注意してほしいのが
y^* の決定
ところで、ここまであまり触れてきませんでしたが、
となります。OptunaやHyperoptでは独自の選択法を採用していますが、今回はあえて上位10%のものを上位組とします。
EIの最終的な形
以上で登場した事実をもとにEIを変形すると以下のような形になります。
この式からEIは上位層の分布
といった形になっていることに注意してください。とはいうものの、
という形になっているのでEIを最大化しようと思うならば、それぞれのパラメータ
ガウス過程回帰との比較
ガウス過程回帰 | TPE | |
---|---|---|
前計算の計算量 | ||
予測の計算量 | ||
メリット | (TPEと比べて)パラメータ間の相関を考慮できる。 | (ガウス過程と比べて)計算量が軽い。カテゴリカル変数も扱える。 |
デメリット | 計算量が重い。カテゴリカル変数が扱いにくい。 | パラメータ間の相関を無視する。 |
実装の流れ
python3の擬似コードでTPEアルゴリズムの流れを表してみると以下のようになります。大きな流れは__main__スコープに書いてあります。suggest_by_TPE関数によって候補点を決定していきます。下記はあくまで簡易的に書いたコードですので後の実装で修正が加わります。また、今回は連続値のみを扱うので、離散的な整数値やカテゴリカル変数、対数スケールの変数については扱いません。
def objective(x):
"""
目的関数
xを受け取ってyを返す
"""
pass
def generate_candidates(m, l, r):
"""
[l,r]の間でm個の候補を生成する関数
"""
pass
def select(candidates, l_data, g_data):
"""
上位データ(l_data)と下位データ(g_data)を使ってcandidatesから最適な候補を返す関数
ここでカーネル密度推定などを行う
"""
pass
def suggest_by_TPE(data):
data.sort()
l_data, g_data = separate(data) # 上位層と下位層に分割
# >>>> パラメータ1
# 範囲[1, 4]で候補をランダムにm個生成
candidates1 = generate_candidates(m, 1, 4)
#候補点の選定
candidate1 = select(candidates1, l_data, g_data)
# >>>> パラメータ2
# 範囲[-20, 20]で候補をランダムにm個生成
candidates2 = generate_candidates(m, -20, 20)
# 候補点の選定
candidate2 = select(candidates2, l_data, g_data)
return [candidate1, candidate2]
if __name__ == "__main__":
data = []
# 最適化ループ
for i in range(total_search):
x = suggest_by_TPE(data) # 候補の提案
y = objective(x) # 候補の評価
data.append((x, y)) # データの追加
簡単な実装
実行環境
- Python3.8.13
- numpy 1.21.5
- plotly 5.6.0
- scipy 1.7.3
目的関数の定義
ここでは2変数
plotlyで可視化すると以下のようになります。図ではわかりにくいのですが
def objective(x1: float, x2: float) -> float:
return np.sin(x1 - x2) * (x1**2 / 100 - x2**2 / 50 + x1 * x2 / 10) + 10
x1_args = np.arange(-8, 8, 0.1)
x2_args = np.arange(-8, 8, 0.1)
y = [[objective(x1, x2) for x1 in x1_args] for x2 in x2_args]
fig = go.Figure(data=[go.Surface(z=y, x=x1_args, y=x2_args)])
# グラフのレイアウト設定
fig.update_layout(
title='目的関数',
scene=dict(
xaxis_title='x1',
yaxis_title='x2',
zaxis_title='y'
)
autosize=False,
width=500,
height=500,
)
# グラフの表示
fig.show()
実装の全体
先に実装の全体を紹介します。コードとしては100行以内で書くことができました。
from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
import scipy.stats
class Config:
GAMMA = 0.2
total_search = 100 # 探索回数
num_candidate = 30 # 候補点の個数
num_initial_data = 10 # 初期点の個数
@dataclass
class Data:
x1: float
x2: float
y: float
def __gt__(self, other):
return self.y > other.y
def __lt__(self, other):
return self.y < other.y
def objective(x1: float, x2: float) -> float:
return np.sin(x1 - x2) * (x1**2 / 100 - x2**2 / 50 + x1 * x2 / 10) + 10
def gen_init_data(num: int) -> List[Data]:
x1_list = np.random.random(num) * 16.0 - 8.0 # -8 ~ 8
x2_list = np.random.random(num) * 16.0 - 8.0 # -8 ~ 8
y_list = [objective(x1, x2) for (x1, x2) in zip(x1_list, x2_list)]
ret = [Data(x1, x2, y) for (x1, x2, y) in zip(x1_list, x2_list, y_list)]
return ret
def seperate(data: List[Data]) -> Tuple[List[Data], List[Data]]:
num_best = int(np.floor(len(data) * Config.GAMMA))
return data[:num_best], data[num_best:]
def generate_candidate(num_candidate: int, l: float, r: float) -> List[float]:
return np.random.random(num_candidate) * (r - l) + l
def select(candidates: List[float], l_data: List[float], g_data: List[float]) -> float:
kde_l = scipy.stats.gaussian_kde(l_data, 0.2) # band 幅 要変更
kde_g = scipy.stats.gaussian_kde(g_data, 0.2)
g = kde_g(candidates)
l = kde_l(candidates)
EI = l * np.reciprocal(g)
best_idx = np.argmax(EI)
return candidates[best_idx]
def suggest_by_TPE(data: List[Data], num_candidate: int) -> List[float]:
data.sort()
l_data, g_data = seperate(data)
# パラメータ1
candidates1 = generate_candidate(Config.num_candidate, l=-8, r=8)
l1_data = [d.x1 for d in l_data]
g1_data = [d.x1 for d in g_data]
candidate1 = select(candidates1, l_data=l1_data, g_data=g1_data)
# パラメータ2
candidates2 = generate_candidate(Config.num_candidate, l=-8, r=8)
l2_data = [d.x2 for d in l_data]
g2_data = [d.x2 for d in g_data]
candidate2 = select(candidates2, l_data=l2_data, g_data=g2_data)
return candidate1, candidate2
if __name__ == "__main__":
# 初期データの作成
data = gen_init_data(Config.num_initial_data)
best_idx = np.argmin([d.y for d in data])
best_param = (data[best_idx].x1, data[best_idx].x2)
best_score = data[best_idx].y
for i in range(Config.total_search):
x1, x2 = suggest_by_TPE(data, Config.num_candidate)
y = objective(x1, x2)
if y < best_score:
best_score = y
best_param = (x1, x2)
print(f"trial{i+1:04d}: param: {(x1,x2)}, score: {y}, best score: {best_score}, best_param: {best_param}")
new_data = Data(x1, x2, y)
data.append(new_data)
データ
パラメータ(
@dataclass
class Data:
x1: float
x2: float
y: float
def __gt__(self, other):
return self.y > other.y
def __lt__(self, other):
return self.y < other.y
初期点データの作成
初期点データの作成は必ずしも必要ではありませんが、初期点の取り方で探索の方向性が変わることがあるため、考慮に入れる必要もある場面もあると思います。今回はgen_init_data関数で初期点を作りますがランダムに作成しています。初期点作成方法に関しては他にラテンハイパーキューブサンプリングなどがありますが今回は割愛します。
def gen_init_data(num: int) -> List[Data]:
x1_list = np.random.random(num) * 16.0 - 8.0 # -8 ~ 8
x2_list = np.random.random(num) * 16.0 - 8.0 # -8 ~ 8
y_list = [objective(x1, x2) for (x1, x2) in zip(x1_list, x2_list)]
ret = [Data(x1, x2, y) for (x1, x2, y) in zip(x1_list, x2_list, y_list)]
return ret
上位と下位の分割
分割はConfigのGAMMAによって上位と下位に分けます。今回は下記のようにナイーブに上位
def seperate(data: List[Data]) -> Tuple[List[Data], List[Data]]:
num_best = int(np.floor(len(data) * Config.GAMMA))
return data[:num_best], data[num_best:]
候補の探索
generate_candidate関数は候補点の作成を、select関数は候補の集合から一つに候補を絞る関数となります。今回はgenerate_candidate関数ではナイーブに乱数を生成して候補点を作成していますが、工夫の余地は残されています。例えば今までに観測したデータの近傍を候補点とすれば「探索」よりも「活用」に重きを置いた候補点を作成することもできます。
select関数ではカーネル密度推定によって候補を一つ見つけています。上位の分布
def generate_candidate(num_candidate: int, l: float, r: float) -> List[float]:
return np.random.random(num_candidate) * (r - l) + l
def select(candidates: List[float], l_data: List[float], g_data: List[float]) -> float:
kde_l = scipy.stats.gaussian_kde(l_data, 0.2) # band 幅
kde_g = scipy.stats.gaussian_kde(g_data, 0.2)
g = kde_g(candidates)
l = kde_l(candidates)
EI = l * np.reciprocal(g) # l(x) / g(x)
best_idx = np.argmax(EI)
return candidates[best_idx]
出力と考察
上で紹介した簡易なコードを実行すると以下のような出力になります。100回の探索で
この原因については、TPEが変数間の相関を無視しているためだと考えられます。今回設定した目的関数には
trial0001: param: (4.920359713652854, -7.943892431426473), score: 8.553450716256133, best score: 8.478849124539922, best_param: (4.970165901714612, -7.90952135221618)
trial0002: param: (5.016915630169173, 5.527603435276287), score: 8.820209404104943, best score: 8.478849124539922, best_param: (4.970165901714612, -7.90952135221618)
trial0003: param: (5.048957000284613, 5.4503562189494), score: 9.05736007708248, best score: 8.478849124539922, best_param: (4.970165901714612, -7.90952135221618)
trial0004: param: (4.730764402754124, 2.5180827774313777), score: 11.031833287359266, best score: 8.478849124539922, best_param: (4.970165901714612, -7.90952135221618)
trial0005: param: (-6.799638825332387, -5.285936013697743), score: 6.50792517985481, best score: 6.50792517985481, best_param: (-6.799638825332387, -5.285936013697743)
trial0006: param: (6.606123555933879, -7.871231363746391), score: 4.341444682042416, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0007: param: (7.9301374890694944, -7.90823013630671), score: 10.896366442979584, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0008: param: (-5.036477670790301, -4.972207149875496), score: 9.854627326524614, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0009: param: (7.0285318972930355, -7.879359974690598), score: 5.490590002183352, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0010: param: (7.324742669387627, -5.323892579881772), score: 9.677066782667556, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0011: param: (5.391425368390795, -7.591487964410833), score: 7.9952684600691315, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0012: param: (7.2247020530230035, -7.622800330500123), score: 5.33937194586759, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0013: param: (6.996322265509203, -7.835759625422414), score: 5.221870039888324, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0014: param: (6.736463279725834, -7.578726688338465), score: 4.2913460289424785, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0015: param: (7.1726735592113595, -7.546634473520719), score: 4.9569392310110505, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0016: param: (6.6479876306942, -7.395559217814455), score: 4.4558981538121865, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0017: param: (6.748949969842959, -7.745972656610368), score: 4.405894909137963, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0018: param: (6.841431525563536, -7.491318839475801), score: 4.330709141045963, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0019: param: (7.021836121250068, -7.442363620542677), score: 4.46893524397544, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0020: param: (6.685085918773094, -7.1492847019651755), score: 4.888904706183518, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0021: param: (6.770061092867438, -7.442921232877941), score: 4.3278334595263495, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0022: param: (6.564036472843911, -7.770332968363066), score: 4.236737626183018, best score: 4.236737626183018, best_param: (6.564036472843911, -7.770332968363066)
trial0023: param: (6.607814272756059, -6.302410921859751), score: 8.475451687475875, best score: 4.236737626183018, best_param: (6.564036472843911, -7.770332968363066)
trial0024: param: (6.610748096480984, -7.890726385871609), score: 4.370779198072773, best score: 4.236737626183018, best_param: (6.564036472843911, -7.770332968363066)
trial0025: param: (6.875498317620059, -7.055547392284943), score: 4.739824069569588, best score: 4.236737626183018, best_param: (6.564036472843911, -7.770332968363066)
...
trial0097: param: (6.414543421383989, -7.6651734615048355), score: 4.328878870363756, best score: 4.181403777942899, best_param: (6.357733388117499, -7.997094259329399)
trial0098: param: (6.3986740295756555, -7.441465133564476), score: 4.779445016665804, best score: 4.181403777942899, best_param: (6.357733388117499, -7.997094259329399)
trial0099: param: (6.350398653275951, -7.321735584684967), score: 5.24631071383829, best score: 4.181403777942899, best_param: (6.357733388117499, -7.997094259329399)
trial0100: param: (7.460420504900847, -7.367001764492722), score: 5.354246294579943, best score: 4.181403777942899, best_param: (6.357733388117499, -7.997094259329399)
Discussion