🌳

TPEの簡単な説明と実装

2023/05/24に公開

初投稿です。

この記事ではTPEの原理と実装について簡単に説明していきます。深い原理については説明せず、必要最小限の実装を行うのでご了承ください。

TPEとは?

  • ベイズ最適化の一種です。
  • ガウス過程回帰を用いるベイズ最適化と比べると計算量が非常に軽いです。
  • OptunaやHyperoptなどのパラメータ最適化ライブラリの基幹アルゴリズムになっています。

簡単な原理

ここではブラックボックスな目的関数f(\bm{x})最小化 することを考えます。
「最大化」を問題にしていないので注意してください。

獲得関数EI

ベイズ最適化で良く用いられる獲得関数EI(Expected Immrovement)を考えます。EIは以下のように表されます。

\text{EI}=\int^{\infty}_{- \infty}\max(0, y^*-y)p(y|x)dy

ここで、y^*は閾値(threshold)、yは次の候補点の予測値、xは候補点の予測に用いるデータのベクトル、p(y|x)はそのデータxが与えられた時の予測値yの条件付き確率です。
予測値yy^*より大きいとき、被積分関数は\max(0, y^*-y)により0になるので本質的には以下の式を考えることになります。
\text{EI}=\int^{y^{*}}_{-\infty}(y^*-y)p(y|x)dy

これからこの式を変形しTPEが目標とする形に導きます。

ベイズの定理

まず、ベイズの定理よりp(y|x)は以下のようになります。

p(y|x)=\frac{p(x|y)p(y)}{p(x)}

これを用いてEIの式を変形します。

カーネル密度推定

ここでp(x|y)に注目します。y \ge y^*のとき、ある分布g(x)p(x|y)を代替し、y \lt y^*のとき、ある分布l(x)p(x|y)を代替することを考えます。つまり、以下のようなことを考えます。

p(x|y) = \begin{cases} l(x) &\text{if } y \lt y^* \\ g(x) &\text{if } y \ge y^* \end{cases}

ここで注意してほしいのがl(x)が上位の分布になることです。なぜならyを最小化する問題を考えているからです。l(x)が上位の分布、g(x)が下位の分布であることを頭に入れておいてください。l(x), g(x)の分布についてはカーネル密度推定によってノンパラメトリックに推定することになります。

y^*の決定

ところで、ここまであまり触れてきませんでしたが、y^*の決め方について説明します。とはいうものの自由に決めても大丈夫ではあります。例えば観測したyの中から上から5番目(5番目に小さいもの)をy^*としてもよいですし、上位10%に入るものの中で最大のものをy^*としてもよいです。ここでは観測したyの中から上位\gamma%に入るものを上位、それ以外を下位として考えます。つまり、これまでに取得したデータの個数をm、上位のデータの個数をn_lとすると

n_{l}=\gamma m

となります。OptunaやHyperoptでは独自の選択法を採用していますが、今回はあえて上位10%のものを上位組とします。

EIの最終的な形

以上で登場した事実をもとにEIを変形すると以下のような形になります。

\text{EI}=\frac{\gamma y^{*}l(x)-l(x)\int^{y^*}_{-\infty}p(y)dy}{\gamma l(x) + (1-\gamma )g(x)}\propto \left(\gamma+\frac{g(x)}{l(x)}(1-\gamma)\right)^{-1}

この式からEIは上位層の分布l(x)と下位層の分布g(x)にのみ依存することが分かります。よってTPEが提起する候補点はl(x),g(x)の比によって決まるのでこの値をもとに探索を執り行えばよいです。ただし、最適化するパラメータの個数をdd個のパラメータをx_i(i=1,2,..,d)とすると
l(x)=l_1(x_1)l_2(x_2)...l_d(x_d) \\ g(x)=g_1(x_1)g_2(x_2)...g_d(x_d)

といった形になっていることに注意してください。とはいうものの、
\frac{l(x)}{g(x)}=\frac{l_1(x_1)}{g_1(x_1)}\cdot \frac{l_2(x_2)}{g_2(x_2)}\cdots\frac{l_d(x_d)}{g_d(x_d)}

という形になっているのでEIを最大化しようと思うならば、それぞれのパラメータx_iに対して独立にカーネル密度推定を行って分布l_i(x),g_i(x)を作成するという操作だけで問題ないです。独立にパラメータを選択できる点は計算量を落とすのに一役買っていますが、パラメータ間の相関を無視するというデメリットを抱えてしまいます。

ガウス過程回帰との比較

nを取得したデータの個数、dをパラメータ数、mを次の予測候補点の個数とすると以下のようなテーブルの関係になる。

ガウス過程回帰 TPE
前計算の計算量
O(n^3)
O(dn\log{n})
予測の計算量
O(mn)
O(mn)
メリット (TPEと比べて)パラメータ間の相関を考慮できる。 (ガウス過程と比べて)計算量が軽い。カテゴリカル変数も扱える。
デメリット 計算量が重い。カテゴリカル変数が扱いにくい。 パラメータ間の相関を無視する。

実装の流れ

python3の擬似コードでTPEアルゴリズムの流れを表してみると以下のようになります。大きな流れは__main__スコープに書いてあります。suggest_by_TPE関数によって候補点を決定していきます。下記はあくまで簡易的に書いたコードですので後の実装で修正が加わります。また、今回は連続値のみを扱うので、離散的な整数値やカテゴリカル変数、対数スケールの変数については扱いません。

def objective(x):
    """
    目的関数
    xを受け取ってyを返す
    """
    pass
  
def generate_candidates(m, l, r):
    """
    [l,r]の間でm個の候補を生成する関数
    """
    pass
	
def select(candidates, l_data, g_data):
    """
    上位データ(l_data)と下位データ(g_data)を使ってcandidatesから最適な候補を返す関数
    ここでカーネル密度推定などを行う
    """
    pass

def suggest_by_TPE(data):
    data.sort()
    l_data, g_data = separate(data) # 上位層と下位層に分割
	
    # >>>> パラメータ1
    # 範囲[1, 4]で候補をランダムにm個生成
    candidates1 = generate_candidates(m, 1, 4) 
    #候補点の選定
    candidate1 = select(candidates1, l_data, g_data)
    
    # >>>> パラメータ2
    # 範囲[-20, 20]で候補をランダムにm個生成
    candidates2 = generate_candidates(m, -20, 20) 
    # 候補点の選定
    candidate2 = select(candidates2, l_data, g_data)
    
    return [candidate1, candidate2]
		

if __name__ == "__main__":
    data = []
    
    # 最適化ループ
    for i in range(total_search):
	x = suggest_by_TPE(data) # 候補の提案
	y = objective(x)                 # 候補の評価
	data.append((x, y))      # データの追加


簡単な実装

実行環境

  • Python3.8.13
  • numpy 1.21.5
  • plotly 5.6.0
  • scipy 1.7.3

目的関数の定義

ここでは2変数x_1,x_2の関数を-8\le x_1\le8, -8\le x_2\le8の範囲で最適化を行ってみます。以下のような式で表される目的関数を考えます。

f(x_1,x_2)=\left(\frac{x^2_1}{100}-\frac{x^2_2}{50} + \frac{x_1x_2}{10}\right)\sin(x_1-x_2)+10

plotlyで可視化すると以下のようになります。図ではわかりにくいのですがx_1=6,x_2=-8あたりに最小値を持っています。

def objective(x1: float, x2: float) -> float:
  return np.sin(x1 - x2) * (x1**2 / 100 - x2**2 / 50 + x1 * x2 / 10) + 10

x1_args = np.arange(-8, 8, 0.1)
x2_args = np.arange(-8, 8, 0.1)
y = [[objective(x1, x2) for x1 in x1_args] for x2 in x2_args]
fig = go.Figure(data=[go.Surface(z=y, x=x1_args, y=x2_args)])


# グラフのレイアウト設定
fig.update_layout(
    title='目的関数',
    scene=dict(
        xaxis_title='x1',
        yaxis_title='x2',
        zaxis_title='y'
    )
    
    autosize=False,
    width=500,
    height=500,
)

# グラフの表示
fig.show()

実装の全体

先に実装の全体を紹介します。コードとしては100行以内で書くことができました。

from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
import scipy.stats

class Config:
    GAMMA = 0.2
    total_search = 100    # 探索回数
    num_candidate = 30    # 候補点の個数
    num_initial_data = 10 # 初期点の個数

@dataclass
class Data:
    x1: float
    x2: float
    y: float

    def __gt__(self, other):
        return self.y > other.y

    def __lt__(self, other):
        return self.y < other.y

def objective(x1: float, x2: float) -> float:
    return np.sin(x1 - x2) * (x1**2 / 100 - x2**2 / 50 + x1 * x2 / 10) + 10

def gen_init_data(num: int) -> List[Data]:
    x1_list = np.random.random(num) * 16.0 - 8.0 # -8 ~ 8
    x2_list = np.random.random(num) * 16.0 - 8.0 # -8 ~ 8
    y_list = [objective(x1, x2) for (x1, x2) in zip(x1_list, x2_list)]
    ret = [Data(x1, x2, y) for (x1, x2, y) in zip(x1_list, x2_list, y_list)]
    return ret


def seperate(data: List[Data]) -> Tuple[List[Data], List[Data]]:
    num_best = int(np.floor(len(data) * Config.GAMMA))
    return data[:num_best], data[num_best:]


def generate_candidate(num_candidate: int, l: float, r: float) -> List[float]:
    return np.random.random(num_candidate) * (r - l) + l


def select(candidates: List[float], l_data: List[float], g_data: List[float]) -> float:
    kde_l = scipy.stats.gaussian_kde(l_data, 0.2) # band 幅 要変更
    kde_g = scipy.stats.gaussian_kde(g_data, 0.2) 
    g = kde_g(candidates)
    l = kde_l(candidates)

    EI = l * np.reciprocal(g)
    best_idx = np.argmax(EI)
    return candidates[best_idx]


def suggest_by_TPE(data: List[Data], num_candidate: int) -> List[float]:
    data.sort()
    l_data, g_data = seperate(data)

    # パラメータ1
    candidates1 = generate_candidate(Config.num_candidate, l=-8, r=8)
    l1_data = [d.x1 for d in l_data]
    g1_data = [d.x1 for d in g_data]
    candidate1 = select(candidates1, l_data=l1_data, g_data=g1_data)


    # パラメータ2
    candidates2 = generate_candidate(Config.num_candidate, l=-8, r=8)
    l2_data = [d.x2 for d in l_data]
    g2_data = [d.x2 for d in g_data]
    candidate2 = select(candidates2, l_data=l2_data, g_data=g2_data)

    return candidate1, candidate2


if __name__ == "__main__":
    # 初期データの作成
    data = gen_init_data(Config.num_initial_data)
    best_idx = np.argmin([d.y for d in data])
    best_param = (data[best_idx].x1, data[best_idx].x2)
    best_score = data[best_idx].y

    for i in range(Config.total_search):
        x1, x2 = suggest_by_TPE(data, Config.num_candidate)
        y = objective(x1, x2)

        if y < best_score:
            best_score = y
            best_param = (x1, x2)

        print(f"trial{i+1:04d}: param: {(x1,x2)}, score: {y}, best score: {best_score}, best_param: {best_param}")

        new_data = Data(x1, x2, y)
        data.append(new_data)

データ

パラメータ(x_1, x_2)とその評価値yについては以下のようにdataclassデコレ―タを用いて実装しています。yの値でソートできるように特殊メソッド(__lt__, __gt__)も実装します。

@dataclass
class Data:
    x1: float
    x2: float
    y: float

    def __gt__(self, other):
        return self.y > other.y

    def __lt__(self, other):
        return self.y < other.y

初期点データの作成

初期点データの作成は必ずしも必要ではありませんが、初期点の取り方で探索の方向性が変わることがあるため、考慮に入れる必要もある場面もあると思います。今回はgen_init_data関数で初期点を作りますがランダムに作成しています。初期点作成方法に関しては他にラテンハイパーキューブサンプリングなどがありますが今回は割愛します。

def gen_init_data(num: int) -> List[Data]:
    x1_list = np.random.random(num) * 16.0 - 8.0 # -8 ~ 8
    x2_list = np.random.random(num) * 16.0 - 8.0 # -8 ~ 8
    y_list = [objective(x1, x2) for (x1, x2) in zip(x1_list, x2_list)]
    ret = [Data(x1, x2, y) for (x1, x2, y) in zip(x1_list, x2_list, y_list)]
    return ret

上位と下位の分割

分割はConfigのGAMMAによって上位と下位に分けます。今回は下記のようにナイーブに上位\gamma%とそれ以外のグループに分けることとします。

def seperate(data: List[Data]) -> Tuple[List[Data], List[Data]]:
    num_best = int(np.floor(len(data) * Config.GAMMA))
    return data[:num_best], data[num_best:]

候補の探索

generate_candidate関数は候補点の作成を、select関数は候補の集合から一つに候補を絞る関数となります。今回はgenerate_candidate関数ではナイーブに乱数を生成して候補点を作成していますが、工夫の余地は残されています。例えば今までに観測したデータの近傍を候補点とすれば「探索」よりも「活用」に重きを置いた候補点を作成することもできます。
select関数ではカーネル密度推定によって候補を一つ見つけています。上位の分布l(x)と下位の分布g(x)をscipy.stats.gaussian_kdeによって作成しています。ここではバンド幅は0.2としています。Optunaではバンド幅やカーネルの重みを独自に決め、マジッククリップなどの操作をしていますが、今回はそのような機能は実装しません。

def generate_candidate(num_candidate: int, l: float, r: float) -> List[float]:
    return np.random.random(num_candidate) * (r - l) + l


def select(candidates: List[float], l_data: List[float], g_data: List[float]) -> float:
    kde_l = scipy.stats.gaussian_kde(l_data, 0.2) # band 幅
    kde_g = scipy.stats.gaussian_kde(g_data, 0.2) 
    g = kde_g(candidates)
    l = kde_l(candidates)

    EI = l * np.reciprocal(g) # l(x) / g(x)
    best_idx = np.argmax(EI)
    return candidates[best_idx]

出力と考察

上で紹介した簡易なコードを実行すると以下のような出力になります。100回の探索でx_1=6.358, x_2=-7.997がベストパラメータとなりました。一見すると良さそうな結果を与えていますが乱数次第では最小値が全く見つからないケースも私のPCの実験で確認されています。
この原因については、TPEが変数間の相関を無視しているためだと考えられます。今回設定した目的関数には\sin(x_1-x_2)が入っていますが、これは意図的に入れています。\sin(x_1-x_2)g(x_1)h(x_2)のように独立した二つの関数g(x), h(x)の積で表すことはできません。このため、変数の独立性を仮定してカーネル密度推定を行っている段階で候補としてふさわしくない候補点が出現することがあり得ます。このような事実があるためTPEでは探索がうまくいかないケースもあるのではないかと考えられます。

trial0001: param: (4.920359713652854, -7.943892431426473), score: 8.553450716256133, best score: 8.478849124539922, best_param: (4.970165901714612, -7.90952135221618)
trial0002: param: (5.016915630169173, 5.527603435276287), score: 8.820209404104943, best score: 8.478849124539922, best_param: (4.970165901714612, -7.90952135221618)
trial0003: param: (5.048957000284613, 5.4503562189494), score: 9.05736007708248, best score: 8.478849124539922, best_param: (4.970165901714612, -7.90952135221618)
trial0004: param: (4.730764402754124, 2.5180827774313777), score: 11.031833287359266, best score: 8.478849124539922, best_param: (4.970165901714612, -7.90952135221618)
trial0005: param: (-6.799638825332387, -5.285936013697743), score: 6.50792517985481, best score: 6.50792517985481, best_param: (-6.799638825332387, -5.285936013697743)
trial0006: param: (6.606123555933879, -7.871231363746391), score: 4.341444682042416, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0007: param: (7.9301374890694944, -7.90823013630671), score: 10.896366442979584, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0008: param: (-5.036477670790301, -4.972207149875496), score: 9.854627326524614, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0009: param: (7.0285318972930355, -7.879359974690598), score: 5.490590002183352, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0010: param: (7.324742669387627, -5.323892579881772), score: 9.677066782667556, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0011: param: (5.391425368390795, -7.591487964410833), score: 7.9952684600691315, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0012: param: (7.2247020530230035, -7.622800330500123), score: 5.33937194586759, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0013: param: (6.996322265509203, -7.835759625422414), score: 5.221870039888324, best score: 4.341444682042416, best_param: (6.606123555933879, -7.871231363746391)
trial0014: param: (6.736463279725834, -7.578726688338465), score: 4.2913460289424785, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0015: param: (7.1726735592113595, -7.546634473520719), score: 4.9569392310110505, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0016: param: (6.6479876306942, -7.395559217814455), score: 4.4558981538121865, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0017: param: (6.748949969842959, -7.745972656610368), score: 4.405894909137963, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0018: param: (6.841431525563536, -7.491318839475801), score: 4.330709141045963, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0019: param: (7.021836121250068, -7.442363620542677), score: 4.46893524397544, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0020: param: (6.685085918773094, -7.1492847019651755), score: 4.888904706183518, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0021: param: (6.770061092867438, -7.442921232877941), score: 4.3278334595263495, best score: 4.2913460289424785, best_param: (6.736463279725834, -7.578726688338465)
trial0022: param: (6.564036472843911, -7.770332968363066), score: 4.236737626183018, best score: 4.236737626183018, best_param: (6.564036472843911, -7.770332968363066)
trial0023: param: (6.607814272756059, -6.302410921859751), score: 8.475451687475875, best score: 4.236737626183018, best_param: (6.564036472843911, -7.770332968363066)
trial0024: param: (6.610748096480984, -7.890726385871609), score: 4.370779198072773, best score: 4.236737626183018, best_param: (6.564036472843911, -7.770332968363066)
trial0025: param: (6.875498317620059, -7.055547392284943), score: 4.739824069569588, best score: 4.236737626183018, best_param: (6.564036472843911, -7.770332968363066)
...
trial0097: param: (6.414543421383989, -7.6651734615048355), score: 4.328878870363756, best score: 4.181403777942899, best_param: (6.357733388117499, -7.997094259329399)
trial0098: param: (6.3986740295756555, -7.441465133564476), score: 4.779445016665804, best score: 4.181403777942899, best_param: (6.357733388117499, -7.997094259329399)
trial0099: param: (6.350398653275951, -7.321735584684967), score: 5.24631071383829, best score: 4.181403777942899, best_param: (6.357733388117499, -7.997094259329399)
trial0100: param: (7.460420504900847, -7.367001764492722), score: 5.354246294579943, best score: 4.181403777942899, best_param: (6.357733388117499, -7.997094259329399)

Discussion