📖

Multi-class logistic regression

2024/02/17に公開

Given a dataset \mathcal{D}=\{ (x_n, y_n) \}_{n=1}^N, where x_n \in \mathbb{R}^D and y_n \in \{1, 2, \dots, C \}, C is the number of classes, the logistic regression model is described as:

\begin{align*} p(y_n) &= \begin{pmatrix} p(y_n=1) \\ p(y_n=2) \\ \vdots \\ p(y_n=C) \end{pmatrix} \in \R^C ,\quad \sum_{c=1}^C p(y_n=c) = 1\\ p(y_n=c) &= \mathrm{softmax}(w_c^\top x_n) = \frac{\exp(w_c^\top x_n)}{\sum_{j=1}^C \exp(w_j^\top x_n)} \\ \end{align*}

where w = \set{w_1, w_2, \dots, w_C}, w_c \in \R^D represents independent weight vectors.


The log likelihood of the probablistic model is

\begin{align*} \mathcal{L}(w) &= \ln \prod_{n=1}^N \mathrm{Categorical}(p(y_n)) \\ &= \ln \prod_{n=1}^N \prod_{c=1}^C p(y_n=c)^{\delta_{y_n, c}} \\ &= \sum_{n=1}^N \sum_{c=1}^C \delta_{y_n, c} \ln p(y_n=c) \\ \end{align*}

where \delta_{y_n, c} = 1 if y_n = c otherwise 0.


Let's think about the gradient of \mathcal{L}(w) with respect to w_i.

\begin{align*} \frac{\partial \mathcal{L}(w)}{\partial w_i} &= \frac{\partial}{\partial w_i} \sum_{n=1}^N \sum_{c=1}^C \delta_{y_n, c} \ln p(y_n=c) \\ &= \sum_{n=1}^N \sum_{c=1}^C \delta_{y_n, c} \frac{1}{p(y_n=c)} \frac{\partial}{\partial w_i} p(y_n=c) \tag{1} \\ \end{align*}

When i = c,

\begin{align*} \frac{\partial}{\partial w_i} p(y_n=c) &= \frac{\partial}{\partial w_i} \frac{\exp(w_c^\top x_n)}{\sum_{j=1}^C \exp(w_j^\top x_n)} \\ &= \frac{(\exp(w_c^\top x_n))' \sum_{j=1}^C \exp(w_j^\top x_n) - \exp(w_c^\top x_n) (\sum_{j=1}^C \exp(w_j^\top x_n))'}{(\sum_{j=1}^C \exp(w_j^\top x_n))^2} \\ &= \frac{x_n \exp(w_c^\top x_n) \sum_{j=1}^C \exp(w_j^\top x_n) - \exp(w_c^\top x_n) x_n \exp(w_c^\top x_n)}{(\sum_{j=1}^C \exp(w_j^\top x_n))^2} \\ &= x_n \frac{\exp(w_c^\top x_n) (\sum_{j=1}^C \exp(w_j^\top x_n) - \exp(w_c^\top x_n))}{(\sum_{j=1}^C \exp(w_j^\top x_n))^2} \\ &= x_n \frac{\exp(w_c^\top x_n)}{\sum_{j=1}^C \exp(w_j^\top x_n)} \frac{\sum_{j=1}^C \exp(w_j^\top x_n) - \exp(w_c^\top x_n)}{\sum_{j=1}^C \exp(w_j^\top x_n)}\\ &= x_n p(y_n=c) (1 - p(y_n=c)) \\ \end{align*}

when i \neq c,

\begin{align*} \frac{\partial}{\partial w_i} p(y_n=c) &= \frac{\partial}{\partial w_i} \frac{\exp(w_c^\top x_n)}{\sum_{j=1}^C \exp(w_j^\top x_n)} \\ &= \frac{(\exp(w_c^\top x_n))' \sum_{j=1}^C \exp(w_j^\top x_n) - \exp(w_c^\top x_n) (\sum_{j=1}^C \exp(w_j^\top x_n))'}{(\sum_{j=1}^C \exp(w_j^\top x_n))^2} \\ &= \frac{- \exp(w_c^\top x_n) x_n\exp(w_i^\top x_n)}{(\sum_{j=1}^C \exp(w_j^\top x_n))^2} \\ &= - x_n \frac{\exp(w_c^\top x_n)}{\sum_{j=1}^C \exp(w_j^\top x_n)} \frac{\exp(w_i^\top x_n)}{\sum_{j=1}^C \exp(w_j^\top x_n)} \\ &= - x_n p(y_n=c) p(y_n=i) \\ &= x_n p(y_n=c) (0 - p(y_n=i)) \\ \end{align*}

We can concatenate both cases, i = c and i \neq c, by using \delta_{i,c},

\begin{align*} \frac{\partial}{\partial w_i} p(y_n=c) &= x_n p(y_n=c) (\delta_{i, c} - p(y_n=i)) \\ \end{align*}

\begin{align*} (1) &= \frac{\partial \mathcal{L}(w)}{\partial w_i} \\ &= \sum_{n=1}^N \sum_{c=1}^C \delta_{y_n, c} \cancel{\frac{1}{p(y_n=c)}} x_n \cancel{p(y_n=c)} (\delta_{i, c} - p(y_n=i)) \\ &= \sum_{n=1}^N x_n \sum_{c=1}^C \delta_{y_n, c} (\delta_{i, c} - p(y_n=i)) \\ &= \sum_{n=1}^N x_n (\delta_{i, y_n} - p(y_n=i)) \\ &= \underbrace{(x_1, x_2, \dots, x_N)}_{\R^{D \times N}} \underbrace{ \begin{pmatrix} (\delta_{i, y_1} - p(y_1=i)) \\ (\delta_{i, y_2} - p(y_2=i)) \\ \vdots \\ (\delta_{i, y_N} - p(y_N=i)) \\ \end{pmatrix} }_{\R^{N}} \in \R^{D} \end{align*}

\begin{align*} \frac{\partial \mathcal{L}(w)}{\partial w} &= \left( \frac{\partial \mathcal{L}(w)}{\partial w_1}, \frac{\partial \mathcal{L}(w)}{\partial w_2}, \dots \frac{\partial \mathcal{L}(w)}{\partial w_C} \right) \\ &= \underbrace{(x_1, x_2, \dots, x_N)}_{\R^{D \times N}} \underbrace{ \begin{pmatrix} (\delta_{1, y_1} - p(y_1=1)) & (\delta_{2, y_1} - p(y_1=2)) & \dots & (\delta_{C, y_1} - p(y_1=C)) \\ (\delta_{1, y_2} - p(y_2=1)) & (\delta_{2, y_2} - p(y_2=2)) & \dots & (\delta_{C, y_2} - p(y_2=C)) \\ \vdots \\ (\delta_{1, y_N} - p(y_N=1)) & (\delta_{2, y_N} - p(y_N=2)) & \dots & (\delta_{C, y_N} - p(y_N=C)) \\ \end{pmatrix} }_{\R^{N \times C}} \in \R^{D \times C} \end{align*}

A sample implementation,

Discussion