🐈

Act 16. ロジスティック回帰について学ぶ

2024/11/15に公開

はじめに

Act 01. AIで外国為替を自動売買するまでの道のりをベースに学習を進めて行く。

前回は線形回帰について学習した。
今回からロジスティック回帰について学習する。

今後学習する内容は以下の通り。
学習はChatGTPさん考案の以下の順序で進めて行く。

学習順序

教師あり学習

  1. 線形回帰: 基本的な回帰問題を学ぶ。 2024/11/13 学習済み
  2. ロジスティック回帰: 分類問題の基本を理解する。
  3. 決定木: シンプルで直感的なモデル。
  4. ランダムフォレスト: 決定木のアンサンブル学習。

教師なし学習

  1. k平均法: クラスタリングの基本。
  2. 主成分分析(PCA): 次元削減の手法。

強化学習

  1. 基本的な概念: エージェント、環境、報酬の関係を理解する。
  2. Q学習: 簡単な強化学習アルゴリズム。
    ※強化学習は重要だと判断したため、概念から再度学習する。

ロジスティック回帰

ChatGPTさんに聞いてみた。

とりあえずいつもの通りChatGPTさんにロジスティック回帰について聞いてみる。
以下の通り回答を貰った。

ロジスティック回帰(Logistic Regression)は、主に「二値分類」を行うための統計的手法であり、「ロジット変換」や「オッズ」を用いることで、データの特徴からその対象がどちらのクラスに属するかを確率的に推定します。これにより、ある事象が発生する「確率」を求めることができ、応用範囲の広い手法です。

オッズとロジット変換

まず、ロジスティック回帰の理解に重要な「オッズ」と「ロジット変換」について説明します。

  • オッズ(Odds):
    ある事象が発生する確率 p を用いて、「オッズ」とは「事象が発生する確率」と「発生しない確率」の比で定義されます。数式で表すと次の通りです。

    \text{オッズ} = \frac{p}{1 - p}

    たとえば、ある商品の購入確率が70%の場合、オッズは 0.7 / (1 - 0.7) = 2.33 となり、これは「その商品が購入される可能性が購入されない可能性の約2.33倍である」ことを意味します。

  • ロジット変換(Logit Transformation):
    ロジスティック回帰では、この「オッズ」をさらに「ロジット変換」することで、確率を「無限の範囲(負の無限大から正の無限大まで)」に変換し、線形モデルで扱いやすくします。ロジット変換は、オッズの対数をとることで次のように定義されます。

    \text{ロジット}(p) = \ln\left(\frac{p}{1 - p}\right)

    こうして得られたロジット値を、線形関数でモデル化できるため、ロジスティック回帰においてはこの変換が重要な役割を果たします。

ロジスティック回帰の仕組み

ロジスティック回帰は、次の2つのステップで構成されます。

1. 線形モデルの構築

まず、ロジット変換されたオッズを予測するための線形モデルを構築します。線形回帰と同様に、データの特徴量に重みを掛け合わせて次のような線形式を作ります。

z = w_1 x_1 + w_2 x_2 + \dots + w_n x_n + b

ここで、x_i は各特徴量(説明変数)、w_i はその重み、b はバイアス項です。線形モデルがオッズのロジット値を予測する形になります。

2. シグモイド関数で確率に変換

予測したロジット値 z をそのままでは確率として扱えないため、シグモイド関数(またはロジスティック関数)を用いて、0から1の間に収まるように変換します。シグモイド関数は次のように定義されます。

p = \sigma(z) = \frac{1}{1 + e^{-z}}

この関数は、 z の値が大きくなると p が1に近づき、値が小さくなると p が0に近づくという性質を持ちます。この p が「クラス1に属する確率」と解釈されます。

予測の決定

シグモイド関数を通して得られた確率が0.5以上なら「クラス1(例:スパム)」、0.5未満なら「クラス0(例:スパムでない)」と判定するのが一般的です。

学習(パラメータの最適化)

ロジスティック回帰の重みとバイアスは、データに対して誤差(ロス)が最小になるように調整されます。この誤差を測るために、交差エントロピー損失関数(ロジスティック損失関数) がよく使われます。損失関数を最小化するために、勾配降下法 などの最適化アルゴリズムを使ってパラメータを最適化します。

ロジスティック回帰の利点と欠点

  • 利点:

    • モデルがシンプルで解釈しやすく、計算量が少ない。
    • 結果が確率として出力されるため、予測結果の「確かさ」を評価しやすい。
  • 欠点:

    • 分類が線形で分けられる場合にのみ有効。
    • 複雑なデータ(非線形なデータ)には適用しにくい。

応用例

ロジスティック回帰は医療診断やマーケティング、スパムフィルタ、信用リスクの評価など、二値分類が必要な多くの分野で利用されています。

感想

うーん、わかる個所と分からない箇所がある。
線形モデルを作るところとか、前回学習した線形回帰とほぼ同じ公式だもんね。

ただ、シグモイド関数とか交差エントロピー損失関数なんてのは完全に初めまして。
線形回帰の予習も含めて一つずつ学習していこうかな。

自分なりに解説

ロジスティック回帰ってつまり何をどうやってるの?を見てみたが分かりやすかった。
何となく何をしているのかイメージが付いたので、実際に学んでいこうと思う。

動画と被っている個所が多いが記事に残しておく。
今回は睡眠時間をもとに平均寿命より長生きする(1)か長生きしない(0)かの関係についてロジスティック回帰で分析する際の流れを書いていこうと思う。

データの紹介

例えばこんなデータがあったとする。
これは1時間~9時間の睡眠時間で平均寿命より長く生きたか否かを散布図として表したもの。
各時間に10人いると思ってほしい。

オッズを求める

オッズを求めるため、これを以下のように置き換える。

表にすると以下のような感じ。

睡眠時間 平均寿命を超えた人数 平均寿命を超えなかった人数 平均寿命を超える確率
1時間 1人 9人 10%
2時間 2人 8人 20%
3時間 3人 7人 30%
4時間 4人 6人 4%
5時間 4.5人 5.5人 45%
6時間 6人 4人 60%
7時間 7人 3人 70%
8時間 8人 2人 80%
9時間 9人 1人 90%

オッズの公式は以下の通りだった。
p には確率が入る。

\text{オッズ} = \frac{p}{1 - p}

例えば80%の場合は以下の通り。

\text{オッズ} = \frac{p}{1 - p} = \frac{0.8}{1 - 0.8} = \frac{0.8}{0.2} = 4

表にオッズを追加すると以下の通りになる。

睡眠時間 平均寿命を超えた人数 平均寿命を超えなかった人数 平均寿命を超える確率 オッズ
1時間 1人 9人 10% 0.11
2時間 2人 8人 20% 0.25
3時間 3人 7人 30% 0.43
4時間 4人 6人 4% 0.66
5時間 4.5人 5.5人 45% 0.82
6時間 6人 4人 60% 1.5
7時間 7人 3人 70% 2.33
8時間 8人 2人 80% 4
9時間 9人 1人 90% 9

つまり、睡眠時間1時間の場合は「長生きする確率が長生きしない可能性の0.11倍である」と言え、睡眠時間が9時間の場合は「長生きする確率が長生きしない確率の9倍である」と言える。
オッズは確率とは異なる表現方法だが、確率と密接な関係にある。

ロジット変換を行う

オッズをさらにロジット変換することで、線形モデルとして扱いやすいデータにする。
ロジット変換は、確率を「無限の範囲(負の無限大から正の無限大まで)」にすること。

ロジット変換の公式は以下の通りだった。
p にはオッズの時と同様に確率が入る。

\text{ロジット}(p) = \ln\left(\frac{p}{1 - p}\right)

つまり、80%の場合は以下の通り。

\text{ロジット}(0.8) = \ln\left(4\right) = 1.39

なんで1.39になったん?と頭に"?"が浮かびまくったので同じ人がいると思い説明しておく。(自分がバカなだけじゃないよな…!?)

まずは \ln\left(4\right) の部分について、なぜいきなりこの形式になったかというと、確率が80%の時のオッズは4だから。

\ln\left(\right) の中の式とオッズを求める式がイコールだからここは分かると思う。
では次に、\ln\left(4\right)1.39になる理由について。

この \ln\left(\right) というのは自然対数と呼ばれるものらしい。
なんか聞いたことあるような無いような…。

自然対数の計算方法は、数 e \approx 2.718 を基にした対数で、\ln(x) は「e を何回掛ければ x になるか」を示すらしい。

つまり、\ln\left(4\right) の場合は、「e を何回掛ければ 4 になるか」を計算する。
答えは、2.718を1.39回かけると4になる。

実際に計算する場合は、電卓やプログラムの「ln」関数を使って自然対数を求めるのが一般的らしいから、ここではあまり深追いはしない。
iPhoneの場合は、電卓を開いて画面を横にすると \ln\left(\right) が使える。

とりあえず、ロジット変換をするにはオッズの自然対数をとるという風に覚えておく。
ロジット変換の結果は以下の通り。

睡眠時間 長生きする確率 p オッズ \frac{p}{1 - p} ロジット変換の値 \ln\left(\frac{p}{1 - p}\right)
1時間 10% 0.11 -2.2
2時間 20% 0.25 -1.39
3時間 30% 0.43 -0.85
4時間 40% 0.66 -0.42
5時間 45% 0.82 -0.19
6時間 60% 1.5 0.41
7時間 70% 2.33 0.85
8時間 80% 4 1.39
9時間 90% 9 2.2

回帰分析

まずは長生きする確率の棒グラフだけで表示すると以下の通り。

長生きする確率を散布図で表示すると以下の通り。

長生きする確率をロジット変換後の散布図で表示すると以下の通り。

ロジット変換したものに対して回帰分析を行う。
これがロジスティック回帰分析になる。

なるほど。これで分析モデルが完成したって感じかな?
これで終わり!と思いきやそんなことはないと思う。

今回はロジスティック回帰で分析した結果をもとに、長生きするか否かを判別することが目的。
ということでもう少しお付き合い願う。

ロジスティック回帰モデルの予測プロセス

ロジスティック回帰分析で学習したモデルに対してデータを入力することで、そのデータのターゲット(分類結果)を求めることが出来る。

ターゲットの求め方は以下のプロセスで行う。

1. 特徴量の線形結合(ロジット値の計算)

学習済みのモデルには、各特徴量に対する重み(回帰係数)\beta_1, \beta_2, \dots, \beta_n と、切片 \beta_0 がある。

以前学習した線形回帰でも出てきたが、モデルのcoefficient(coef_)が各特徴量の回帰係数で、intercept(intercept_)が切片となる。
※()内はpythonで参照する場合のプロパティ名。

新しいデータ(特徴量のセット)(x_1, x_2, \dots, x_n) に対して、まず以下のように線形結合を行う。

z = \beta_0 + \beta_1 x_1 + \beta_2 x_2 + \dots + \beta_n x_n

線形回帰の場合は zy だっただけで、見慣れた公式だと思う。
ここでの z はロジット値と呼ばれ、確率のロジット変換(オッズの対数)に相当する。
つまり、ロジット変換して求めた線形回帰を元に、ロジット値を予測しているイメージ?

2. シグモイド関数による確率変換

次に求めたロジット値を使って、シグモイド関数による確率変換を行う。
でた、シグモイド関数…。

一旦公式を載せておく。

p = \frac{1}{1 + e^{-z}}

p とは確率のことで、この確率はターゲットが「1」(例えば「長生きする」や「合格する」などの肯定的なクラス)である確率を表す。
今回の場合だと長生きする確率だね。

では、睡眠時間が6.5時間の場合を求めてみよう。
特徴量は睡眠時間の1つのみになるため、以下のような計算式でロジット値を求める。

z = \beta_0 + 6.5x_1

実際に学習してモデルを作ったわけではないから \beta_0x_1 は不明。
モデルが完成していたら回帰係数と切片が求まっているので z (ロジット値)が求められる。

今回は図があるため、何となく出求めてみる。
恐らく z = 0.8 とか?

z を求めることが出来たため、シグモイド関数に入れて確率を求めていく。
つまり以下のようなこと。

p = \frac{1}{1 + e^{-0.8}}

1乗や2乗なら分かるけど、-0.8乗ってどうやって求めるんだ…?
なんか難しかったが、電卓などを使って計算することが多いらしい。

e^x というやつか exp(x) というものを使うらしい。
ちなみにiPhoneの場合は e^x だった。

e^{-0.8} を求めると 0.4493 となった。
0.4493 になる理由は以下の通り。
※-0.8乗でマイナスなのがポイント

負の指数の場合は逆数をとるので e^{-0.8} = \frac{1}{e^{0.8}} となる。
\frac{1}{e^{0.8}} = \frac{1}{2.2255} = 0.4493 になるってこと。

なので最終的にはこんな感じ。

p = \frac{1}{1 + 0.4493} = \frac{1}{1.4493} = 0.690

つまり69%の確率で平均寿命より長生きすることになり、シグモイド関数を通して得られた値が0.5以上なら「クラス1(長生きする)」、0.5未満なら「クラス0(長生きしない)」と判断するのが一般的であるため、長生きするという分析結果になる。

とまあ、こんな感じで分析を行っているらしい。

さいごに

線形回帰をしっかり学んだからか、ロジスティック回帰でそんなに躓くことなく理解できたと思う。
楽しすぎてたまらない!もっと学んで為替相場分析もやってみたい…。

焦らずじっくりと頑張ろう!
次回はpythonを使って実装するところをやってみようと思う。

ではまた

Discussion