🤖

AIにおける学習とは何なのかを高校生にもわかるように説明する

2024/11/15に公開

学習とはいったい何なんだろう?

昨今身近に使われるAIとしてChatGPTがあります。
ユーザーである私たちがChatGPTをテキストに入力すると、そのテキストに対してChatGPTが返答をしてくれます。このChatGPTは、AIの一種である自然言語処理というものを使って、私たちのテキストを理解し、適切な返答をしてくれます。ChatGPTが作られる背景には、ウェキペディアのようなプラットフォームから大量のテキストデータを拾ってきて、AIが学習をしています。
しかし、言葉だけで「学習」といわれても、実際にどんなことをしているのか想像がつきません。
また、ChatGPT以外でも、テキストを入力すると、それに基づいた画像を生成してくれるAIや、音声を入力すると、それに基づいたテキストを生成してくれるAIなどがあります。
とにかく、AIを使うと「何かを入力すると、それに基づいた何かを出力してくれる」ということができるようになります。
しかし、何故そのような多種多様なことができるかはよくわかりません。

そこで今回は、AIとは何かを高校生にもわかるように少しだけ粒度を落として説明していきます。

AIの実態

簡単な例として、手書きで数字が書いてある画像を入力すると、その数字が3であるか、否かを判別するAIを考えます。
実際に判別する際には以下の手順で行います。

  1. 画像をピクセルごとに分解して、それぞれのピクセルの色を0から255の数値で表現します。例えば、28x28ピクセルの画像であれば、784個の数値が得られます。それぞれの数値をx_1, x_2, ..., x_{784}とします。
  2. これらの数値を入力として、関数F(\cdot)に代入し、F(x_1, x_2, ..., x_{784})を計算します。計算した値をzとします。
  3. zを確率の形式に整えるため、zを0から1の範囲に収めるための関数\sigma(x)=\frac 1 {1+e^{-x}}(後程解説します)に代入し、\sigma(z)を計算します。計算した値を\hat{y}とします。
  4. \hat{y}が0.5以上1以下であれば、画像に書かれている数字は3であると判定します。0以上0.5未満の場合は、3でないと判定します。

手書き数字の判別

AIが学習する部分は、関数F(\cdot)の中身を決めるパラメータを表す数値です。

学習とは

AIが学習するとは、関数F(\cdot)の中身を決めるパラメータを、大量のデータを使って決定することです。
単純な例として、以下の問題を考えます。

問題: 入力される数字をxとして、3x5を超えるか否かを判定するAIを作成するように、関数F(x)=wx+bのパラメータw,bを決定する。

まず、この問題を学習せずにw,bを決め打ちで決定することを考えます。
AIの出力では、あくまで確率を出力として出す必要があるため、そのための関数としてシグモイド関数\sigma(z)を使います。
この関数は、zが大きいほど1に近づき、zが小さいほど0に近づく関数です。また、zが0以上の時、\sigma(z)は0.5以上1以下の値を取り、zが0未満の時、\sigma(z)は0以上0.5未満の値を取ります。
つまり、入力されるすべての実数に対する写像が存在して、常に0から1の範囲の値を取ることができます。

\sigma(z) = \frac 1 {1+e^{-z}}

シグモイド関数

また、シグモイド関数には以下の性質があります。

\begin{align*} 0.5 &\leq \hat{y} = \sigma(z) = \frac 1 {1+e^{-z}} \\ \Leftrightarrow e^{-z} &\leq 1 \\ \Leftrightarrow z &\geq 0 \end{align*}

入力変数zが0以上の時、\hat{y}は0.5以上1以下の値を取ることがわかります。同様にして、入力変数zが0未満の時、\hat{y}は0以上0.5未満の値を取ることがわかります。
つまり、3x5を超える時、F(x)を0以上に、そうでない時、F(x)を0未満にするように、w,bを決定すれば良いことがわかります。
よって、例えばw=3, b=-5とすることで、3x5を超える時、F(x)は0以上に、そうでない時、F(x)は0未満になります。

しかし、このようにw,bを決め打ちで決定することは、あまり汎用的とは言えません。そこで、データを用いてw,bを決定することを考えます。

データセットとして考えるものとして、以下のようなものが考えられます。
(x,y) = \{(1,0), (2,0), (3,1), (4,1), (5,1)\}
以下、(x_i,y_i)i番目のデータを表します。
例えば、(1,0)x=1の時、3x5を超えないため、正解の確率として0を表します。
(3,1)x=3の時、3x5を超えるため、正解の確率として1を表します。

このデータセットを使って、w,bを更新していくことで、AIが学習を行います。
具体的には、各データの正解yと比較して、予測値\hat{y}が最も「尤もらしく」(もっともらしく)なるようにw,bを更新していきます。
尤もらしさを表す関数として、尤度関数を定義します。

P(\hat{y}_i|w,b) = \hat{y}_i^{y_i} (1-\hat{y}_i)^{1-y_i}

この関数は値が大きいほど予測が尤もらしいことを示します。データのy_i1の時、出力される確率は高いほど尤もらしいです。逆に、データのy_i0の時、出力される確率は低いほど尤もらしいです。
この性質をうまく利用してP(\hat{y}_i|w,b)は定義されています。
この尤度関数を最大化するようなw,bを求めることが学習の目的です。
P(\hat{y}_i|w,b)を最大化するにあたって、微分を使って、停留点を求めるなどをして解析的に求めることもできますが、AIの複雑度が増したときに、全てのパラメータについて解析的に求めることが困難になるため、大量のデータを使ってw,bを更新していくことが一般的です。
このように、w,bを更新していくことを学習といいます。

パラメータの更新は、高校数学の分野で言う、ニュートン法の発想と感覚が近いです。

ニュートン法とは

ニュートン法とは、関数F(x)=0の解を解析的に求めることなく、
初期値x_0から出発して、F(x)=0の解に収束するようにxを更新していく方法です。
具体的には、F(x)の接線のx軸との交点を新しいxとして更新していきます。
更に詳しく知りたい人は、ニュートン法で検索してみてください。

ニュートン法では、F(x)=0の解を解析的に求めることが出来ないにもかかわらず、何度も更新を繰り返すことで、xF(x)=0となるxに収束することができます。
しかし、AIにおける学習は、ニュートン法とは異なり、最急降下法という方法によって行われます。
最急降下法は、関数F(x)の最小値を求めるために、関数の勾配(傾き)を使って、徐々に最小値に近づいていく方法です。
ある出発点x_0から、関数F(x)の勾配を計算し、その勾配の逆方向に少しずつxを更新していきます。この更新を繰り返すことで、最終的にF(x)の最小値に収束することができます。

最急降下法

AIの学習において最急降下法を用いる時、w,bを更新する際には、ある関数\mathcal{L}(w,b)を定義して、それを最小化するように更新する問題に帰着されます。
この関数\mathcal{L}(w,b)損失関数といいます。
この損失関数を、値が大きいほど、AIの予測が正解から外れていることを示すように設計したいです。

ここで、先ほど定義した尤度関数を使って、損失関数の設計を考えると、尤もらしいほど損失が小さく、尤もらしくないほど損失が大きいような関数を設計することができます。
よって、尤度にマイナスをかけたものを損失関数として設計することを考えて以下のように定義します。

\mathcal{L}(w,b) = - log P(\hat{y}_i|w,b) = -y_i log \hat{y}_i - (1-y_i) log(1-\hat{y}_i) (ここで、対数を用いるのは、扱いやすくするためです。logは単調増加関数であるため、大小関係は特に変わりません。)

この損失関数を最小化するようにw,bを更新していくことで、AIが学習を行います。

実際にパラメータを更新してみましょう

  1. パラメータの初期化: w=1, b=0

  2. w の更新: w = w - \alpha \frac{\partial \mathcal{L}(w,b)}{\partial w}

  3. b の更新: b = b - \alpha \frac{\partial \mathcal{L}(w,b)}{\partial b}

  4. 2,3を、パラメータが収束するまで繰り返す

ここで、\alphaは学習率と呼ばれるハイパーパラメータで、更新の大きさを調整するために使います。
また、\frac{\partial \mathcal{L}(w,b)}{\partial w}, \frac{\partial \mathcal{L}(w,b)}{\partial b}は、それぞれw,bに関する損失関数の勾配を表します。
(\partialは偏微分を表し、w以外のパラメータは一定として微分を行うことです。詳しくは偏微分で検索してみてください。)
これらの勾配は、微分を使って計算することができます。

勾配の計算

損失関数\mathcal{L}(w,b)wについて微分すると、以下のようになります。

まず事前準備として、以下の微分を計算します。

\begin{align*} \frac{\partial \mathcal{L}(w,b)}{\partial \hat{y}_i} &= -\frac{y_i}{\hat{y}_i} + \frac{1-y_i}{1-\hat{y}_i}\\ \frac{\partial \hat{y}_i}{\partial z} &= -\frac 1 {(1+e^{-z})^2} e^{-z} = \frac 1 {1+e^{-z}} (1-\frac 1 {1+e^{-z}}) = \hat{y}_i (1-\hat{y}_i)\\ \frac{\partial z}{\partial w} &= x_i\\ \frac{\partial z}{\partial b} &= 1 \end{align*}

これらを使って、wについての損失関数の微分を計算します。

\begin{align*} \frac{\partial \mathcal{L}(w,b)}{\partial w} &= \frac{\partial \mathcal{L}(w,b)}{\partial \hat{y}_i} \frac{\partial \hat{y}_i}{\partial z} \frac{\partial z}{\partial w} \\ &= (-\frac{y_i}{\hat{y}_i} + \frac{1-y_i}{1-\hat{y}_i}) \hat{y}_i (1-\hat{y}_i) x_i\\ &= (\hat{y}_i - y_i) x_i \end{align*}

同様に、bについての損失関数の微分を計算します。

\begin{align*} \frac{\partial \mathcal{L}(w,b)}{\partial b} &= \frac{\partial \mathcal{L}(w,b)}{\partial \hat{y}_i} \frac{\partial \hat{y}_i}{\partial z} \frac{\partial z}{\partial b} \\ &= (-\frac{y_i}{\hat{y}_i} + \frac{1-y_i}{1-\hat{y}_i}) \hat{y}_i (1-\hat{y}_i)\\ &= \hat{y}_i - y_i \end{align*}

よって、更新式は以下のようになります。

w \leftarrow w - \alpha (\hat{y}_i - y_i) x_i \\ b \leftarrow b - \alpha (\hat{y}_i - y_i)

このように、データを使うことによって、そのデータに適合するようにパラメータが更新されていくことで学習されることが分かります。

ここで次節の事前準備として、このモデル全体をグラフィカルに表現してみましょう。

単純なネットワーク

実際に用いられるAI

実際に用いられるAIは、このような単純な問題ではなく、画像認識、音声認識、自然言語処理、予測分析など、多様な問題に対応するために、複雑なモデルが使われます。
最初の例で述べた画像に書かれた数字を判別する物も入力変数も784個の数値で入力されるため、それに応じたパラメータ数が必要になります。
また、関数F(\cdot)の中身も、先ほど扱ったwx+bのように線形な関数ではなく、より複雑にするように設計される必要があります。

複雑なネットワーク

この図にあるfは、活性化関数と呼ばれ、非線形性を持たせるための関数です。

非線形性を持たせるとは

線形性とは、入力変数xに対して、y=ax+bのように、xに比例してyが変化する関数のことです。一方、非線形性とは、y=ax^2+bのように、xに比例していない関数のことです。
AIのモデルには、非線形性を持たせることで、複雑な問題に対応することができます。
活性化関数がない場合、ネットワーク全体を表すF(以降、モデルとする)としては、単なる線形関数としてしか表現することができません。
具体的には、

\begin{align*} F(x) &= w_3(w_2(w_1x+b_1)+b_2)+b_3\\ &= w_3w_2w_1x + w_3w_2b_1 + w_3b_2 + b_3 &= w'x + b' \end{align*}

のように、どれだけ層を重ねても、線形関数でしか表現することができません。そのため、活性化関数を計算の途中に挟み、非線形性を持たせることで、複雑な問題に対応することができます。

また汎用的に使われる活性化関数の例として、ReLU関数があります。
ReLU関数は以下のように定義されています。

f(x) = \begin{cases} x & (x > 0)\\ 0 & (x \leq 0) \end{cases}

ReLU関数は、入力が0以上の時、そのまま出力し、0未満の時、0を出力する関数です。
この関数が汎用的に使われているのは、微分が簡単であり、計算が高速であるためです。

この図に示すように、ユニットを多層に重ねたり、非線形な関数を使ったりすることで、複雑なタスクに対応できるようになります。我々が便利に使えるようにするためには、更に複雑なモデルが必要になります。
ChatGPTなどの自然言語処理のAIは、RNNと呼ばれる出力されたものを次の入力に使ったり、画像認識のためのAIは、CNNと呼ばれる、画像の部分的な特徴を抽出して、それを次の層に渡すなど様々な工夫によってモデルをより複雑にすることで、複雑な問題に対応できるようにします。

まとめ

AIにおける学習とは、与えられたデータに適合するように、モデルのパラメータを更新していくことです。
この学習は、最急降下法を使って、損失関数を最小化するようにパラメータを更新していきます。
我々が日常的に使えるようにするためには、非常に自由度の高い複雑なネットワークを構成して、より柔軟に様々なタスクに対応できるようにします。
しかしながら、AIの学習の観点からは、損失関数をうまく設計することができさえすれば、どのようなモデルを使っても、データに適合するようにパラメータが更新されるため、学習を行うことができます。

GitHubで編集を提案

Discussion