iTranslated by AI
Playing with JAX (1) — Linear Regression
Purpose
A record of installing JAX and playing with it for a bit. This is a memorandum of performing standard linear regression, with content mostly similar to what is described in Linear Regression with JAX.
Since the article would end too quickly with just that, I've padded it with examples using direct calculations from statistics and scikit-learn.
Recap
Deep learning encompasses various genres such as image classification, image generation, image recognition, object detection, and natural language processing. However, it fundamentally boils down to the problem of finding a differentiable mapping
-
is the rent of a property, and\mathcal{Y} = \R^1 consists of information such as distance from the station, building age, and whether the room is a corner unit.\mathcal{X} = \R^3 -
is a photograph of a person's face, and\mathcal{Y} = \R^{64\times 64} consists of random noise following a normal distribution.\mathcal{X} = \R^{128} -
is a set of three Japanese words, where each word is encoded as a 128-dimensional vector.\mathcal{Y} = \R^{3\times 128} is a set of three English words, where each word is encoded as a 128-dimensional vector.\mathcal{X} = \R^{3\times 128}
Among various problems, what is specifically called supervised learning is the problem of finding such a differentiable mapping given examples
- For known data
,\hat{y_i} = f(x_i) holds.\hat{y}_i \approx y_i - For unknown data
,x \not\in \{y_i\} is plausible in some sense.f(x) \in \mathcal{Y}
What are we doing this time?
Avoiding anything grandiose, we will set
Data Creation
import numpy as np
import random
import matplotlib.pyplot as plt
xs = np.arange(-5, 15, 0.05)
ys = np.array([x*9/5+32 + random.gauss(0,3) for x in xs])
ys_ideal = np.array([x*9/5+32 for x in xs])
plt.scatter(xs,ys)
plt.plot(xs,ys_ideal, color='red')
plt.xlabel('Celsius')
plt.ylabel('Fahrenheit')
plt.show()

This will be our training data. As is well-known, the Fahrenheit temperature
The red line in the figure above corresponds to this linear relationship.
Solving with Statistics
Let the Celsius-Fahrenheit dataset be
By letting the sample means of the dataset be
Just to confirm, solving this in Python yields:
xs_mean = np.mean(xs)
ys_mean = np.mean(ys)
alpha = np.sum((xs - xs_mean)*(ys - ys_mean))/np.sum((xs - xs_mean)**2)
beta = ys_mean - alpha * xs_mean
print('estimate:', alpha, beta)
print('ideal:', 9/5, 32)
estimate: 1.8103115195284565 32.08597124079442
ideal: 1.8 32
This was the result.
Solving with scikit-learn
To be honest, for a problem of this magnitude, using scikit-learn is likely the best choice. Let's try solving it for reference.
from sklearn.linear_model import LinearRegression
model_lr = LinearRegression()
model_lr.fit(xs.reshape(-1,1), ys.reshape(-1,1))
coef, intercept = model_lr.coef_[0][0], model_lr.intercept_[0]
print('estimate:', coef, intercept)
print('ideal:', 9/5, 32)
estimate: 1.8103115195284571 32.08597124079442
ideal: 1.8 32
It is simple and works perfectly fine.
Solving with JAX
Finally, we reach the main part. Since the result is already known, it's not particularly exciting, but it can't be helped.
By the way, frameworks like this often fail to produce good results unless the data is normalized, so I'll perform standard normalization.
Let
When the optimal coefficients
First, let's implement the data normalization in Python:
from jax import grad
import jax.numpy as jnp
xs_std = np.std(xs)
ys_std = np.std(ys)
xs_n = (xs - xs_mean) / xs_std
ys_n = (ys - ys_mean) / ys_std
Next, we implement the linear regression model:
def model(params, x):
W, b = params
return x * W + b
def loss(params, x, y):
preds = model(params, x)
return jnp.mean((preds - y)**2)
def update(params, x, y, lr=0.1):
return params - lr * grad(loss)(params, x, y)
Once we have this, all that remains is to run the training loop. Here, we will run 5,000 iterations without much further thought. After the iterations, we 'restore' the obtained 'optimal values' based on Equation (3):
params = jnp.array([0., 0.])
for _ in range(5000):
params = update(params, xs_n, ys_n)
a, b = params
a = a * ys_std / xs_std
b = ys_mean + b * ys_std - xs_mean * ys_std / xs_std
To check the results for now:
print('estimate:', a, b)
print('ideal:', 9/5, 32)
estimate: 1.8103114 31.748932
ideal: 1.8 32
It turned out like this. Let's also create a plot to commemorate:
plt.scatter(xs,ys)
params = jnp.array([a, b])
plt.plot(xs,model(params,xs), color='red')
plt.xlabel('Celsius')
plt.ylabel('Fahrenheit')
plt.show()

The expected results were successfully obtained. With this, the model has acquired the knowledge of Celsius to Fahrenheit conversion.
Summary
There isn't much to summarize, but I feel that I could write it more simply compared to TensorFlow or PyTorch. I was able to write casually like NumPy without being conscious of things like preparing data loaders or putting tensors on the GPU, and I could also run the training loop in a Pythonic way. It's also convenient that automatic differentiation is performed just by wrapping the loss with grad.
-
In statistics, the unbiased standard deviation is typically found using the square root of the unbiased variance. However, in the context of deep learning, this doesn't seem to be common practice, so I will simply use the standard deviation derived from the square root of the sample variance. ↩︎
Discussion