iTranslated by AI

The content below is an AI-generated translation. This is an experimental feature, and may contain errors. View original article

Playing with JAX (1) — Linear Regression

に公開

Purpose

A record of installing JAX and playing with it for a bit. This is a memorandum of performing standard linear regression, with content mostly similar to what is described in Linear Regression with JAX.

Since the article would end too quickly with just that, I've padded it with examples using direct calculations from statistics and scikit-learn.

Recap

Deep learning encompasses various genres such as image classification, image generation, image recognition, object detection, and natural language processing. However, it fundamentally boils down to the problem of finding a differentiable mapping f from an input data space \mathcal{X} = \R^n to an output data space \mathcal{Y} = \R^m. Examples include:

  • \mathcal{Y} = \R^1 is the rent of a property, and \mathcal{X} = \R^3 consists of information such as distance from the station, building age, and whether the room is a corner unit.
  • \mathcal{Y} = \R^{64\times 64} is a photograph of a person's face, and \mathcal{X} = \R^{128} consists of random noise following a normal distribution.
  • \mathcal{Y} = \R^{3\times 128} is a set of three Japanese words, where each word is encoded as a 128-dimensional vector. \mathcal{X} = \R^{3\times 128} is a set of three English words, where each word is encoded as a 128-dimensional vector.

Among various problems, what is specifically called supervised learning is the problem of finding such a differentiable mapping given examples \{(x_i, y_i)\} \subset \mathcal{X} \times \mathcal{Y}:

  1. For known data \hat{y_i} = f(x_i), \hat{y}_i \approx y_i holds.
  2. For unknown data x \not\in \{y_i\}, f(x) \in \mathcal{Y} is plausible in some sense.

What are we doing this time?

Avoiding anything grandiose, we will set \mathcal{X} = \R^1 and \mathcal{Y} = \R^1 and have JAX find the conversion formula between Celsius and Fahrenheit. For the training data, we assume "Celsius-Fahrenheit pair data obtained by individuals given thermometers for both scales and asked to take measurements at certain timings." We assume some degree of fluctuation due to visual measurement.

Data Creation

import numpy as np
import random
import matplotlib.pyplot as plt

xs = np.arange(-5, 15, 0.05)
ys = np.array([x*9/5+32 + random.gauss(0,3) for x in xs])

ys_ideal = np.array([x*9/5+32 for x in xs])
plt.scatter(xs,ys)
plt.plot(xs,ys_ideal, color='red')
plt.xlabel('Celsius')
plt.ylabel('Fahrenheit')
plt.show()

This will be our training data. As is well-known, the Fahrenheit temperature F corresponding to Celsius temperature C is given by:

F = \frac{9}{5} C + 32

The red line in the figure above corresponds to this linear relationship.

Solving with Statistics

Let the Celsius-Fahrenheit dataset be \{x_i, y_i\}_{1 \leq i \leq N}. We want to find the coefficients \alpha and \beta through linear regression \hat{y}_i = \alpha x_i + \beta such that:

\begin{align*} \argmax\limits_{\alpha,\beta} \sum_{j=1}^N (\hat{y}_j - y_j)^2 \tag{1} \end{align*}

By letting the sample means of the dataset be \bar{x} = \frac{1}{N}\sum x_i and \bar{y} = \frac{1}{N}\sum y_i, and setting the partial derivatives of Equation (1) with respect to \alpha and \beta to 0, we obtain the following solution:

\begin{align*} \alpha &= \frac{\sum (x_i -\bar{x})(y_i - \bar{y})}{\sum(x_i - \bar{x})^2} \\ \beta &= \bar{y} - \alpha \bar{x} \end{align*}

Just to confirm, solving this in Python yields:

xs_mean = np.mean(xs)
ys_mean = np.mean(ys)
alpha = np.sum((xs - xs_mean)*(ys - ys_mean))/np.sum((xs - xs_mean)**2)
beta = ys_mean - alpha * xs_mean

print('estimate:', alpha, beta)
print('ideal:', 9/5, 32)

estimate: 1.8103115195284565 32.08597124079442
ideal: 1.8 32

This was the result.

Solving with scikit-learn

To be honest, for a problem of this magnitude, using scikit-learn is likely the best choice. Let's try solving it for reference.

from sklearn.linear_model import LinearRegression

model_lr = LinearRegression()
model_lr.fit(xs.reshape(-1,1), ys.reshape(-1,1))
coef, intercept = model_lr.coef_[0][0], model_lr.intercept_[0]

print('estimate:', coef, intercept)
print('ideal:', 9/5, 32)

estimate: 1.8103115195284571 32.08597124079442
ideal: 1.8 32

It is simple and works perfectly fine.

Solving with JAX

Finally, we reach the main part. Since the result is already known, it's not particularly exciting, but it can't be helped.
By the way, frameworks like this often fail to produce good results unless the data is normalized, so I'll perform standard normalization.

Let \sigma_x and \sigma_y be the standard deviations of \{x_j\} and \{y_j\}, respectively[1]. We define the linear regression problem for the normalized data as:

\begin{align*} \frac{y_j - \bar{y}}{\sigma_y} = \tilde{\alpha} \frac{x_j - \bar{x}}{\sigma_x} + \tilde{\beta} \tag{2} \end{align*}

When the optimal coefficients \tilde{\alpha} and \tilde{\beta} are found, they correspond to the \alpha and \beta of problem (1) as follows:

\begin{align*} \alpha &= \tilde{\alpha} \frac{\sigma_y}{\sigma_x} \\ \beta &= \bar{y} + \tilde{\beta} \sigma_y - \bar{x} \frac{\sigma_y}{\sigma_x} \tag{3} \end{align*}

First, let's implement the data normalization in Python:

from jax import grad
import jax.numpy as jnp

xs_std = np.std(xs)
ys_std = np.std(ys)

xs_n = (xs - xs_mean) / xs_std
ys_n = (ys - ys_mean) / ys_std

Next, we implement the linear regression model:

def model(params, x):
    W, b = params
    return x * W + b

def loss(params, x, y):
    preds = model(params, x)
    return jnp.mean((preds - y)**2)

def update(params, x, y, lr=0.1):
    return params - lr * grad(loss)(params, x, y)

Once we have this, all that remains is to run the training loop. Here, we will run 5,000 iterations without much further thought. After the iterations, we 'restore' the obtained 'optimal values' based on Equation (3):

params = jnp.array([0., 0.])

for _ in range(5000):
    params = update(params, xs_n, ys_n)

a, b = params
a = a * ys_std / xs_std
b = ys_mean + b * ys_std - xs_mean * ys_std / xs_std

To check the results for now:

print('estimate:', a, b)
print('ideal:', 9/5, 32)

estimate: 1.8103114 31.748932
ideal: 1.8 32

It turned out like this. Let's also create a plot to commemorate:

plt.scatter(xs,ys)
params = jnp.array([a, b])
plt.plot(xs,model(params,xs), color='red')
plt.xlabel('Celsius')
plt.ylabel('Fahrenheit')
plt.show()

The expected results were successfully obtained. With this, the model has acquired the knowledge of Celsius to Fahrenheit conversion.

Summary

There isn't much to summarize, but I feel that I could write it more simply compared to TensorFlow or PyTorch. I was able to write casually like NumPy without being conscious of things like preparing data loaders or putting tensors on the GPU, and I could also run the training loop in a Pythonic way. It's also convenient that automatic differentiation is performed just by wrapping the loss with grad.

脚注
  1. In statistics, the unbiased standard deviation is typically found using the square root of the unbiased variance. However, in the context of deep learning, this doesn't seem to be common practice, so I will simply use the standard deviation derived from the square root of the sample variance. ↩︎

GitHubで編集を提案

Discussion