iTranslated by AI

The content below is an AI-generated translation. This is an experimental feature, and may contain errors. View original article
📈

Thinking about Optimal Transport (1) — Starting from Scratch

に公開

Purpose

I don't understand optimal transport. I don't understand the optimal transport distance. I don't understand Earth Mover's Distance. I don't understand the Wasserstein distance. When I don't understand something, I feel like it's best to explore while experimenting to find something that seems right, so I'll try some trial and error.

Note that, strictly speaking, there seem to be subtle distinctions between these terms, but this time I will use them interchangeably without making any particular distinction.

The Beginning (A Musing)

It's not for any particular reason, but I occasionally hear about optimal transport distance. So, I did a quick search and found Introduction to Optimal Transport. I read it about three times, but I didn't understand anything... My question is simple:

  • How do you map points on one distribution to points on another? It doesn't seem like it's given as prior knowledge, but...

That's what it boils down to. As I flipped through the pages, matrices appeared, optimization problems appeared, and entropy appeared, so I was overwhelmed 🤯😵‍💫😵💫

arXiv:1506.05439 Learning with a Wasserstein Loss was listed as a reference, so I looked at it. Eq. (2) seems to be the one:

\begin{align*} W_c (\mu_1, \mu_2) = \inf_{\gamma \in \prod (\mu_1, \mu_2)} \int_{\mathcal{K} \times \mathcal{K}} c(\kappa_1, \kappa_2) \gamma (d \kappa_1, d \kappa_2) \end{align*}

This equation is what is called the optimal transport distance, and it seems that:

  • c: \mathcal{K} \times \mathcal{K} \to \mathbb{R} is a given cost function
  • \mu_1 and \mu_2 are probability measures on \mathcal{K}
  • \prod (\mu_1, \mu_2) is a set of joint probability distributions on \mathcal{K} \times \mathcal{K} that have \mu_1 and \mu_2 as marginal probability distributions

Apparently, it's important when c is a distance d_{\mathcal{K}}(\cdot, \cdot) on \mathcal{K}. It doesn't click at all. Furthermore, Cédric Villani's "Optimal transport, old and new" is cited as a reference, but it's so thick that the number of pages alone discouraged me.

Continuing my search, I found Yossi Rubner et al.'s The Earth Mover's Distance as a Metric for Image Retrieval. The content on p.8 probably corresponds to it, but this time inequality constraints appeared instead of marginal distributions, making it even more confusing.

Next, I found Impression Estimation of Images Using Earth Mover's Distance by Yuko Sakuta et al. Since it's in Japanese, it feels more accessible. "Figure 3 Calculation of distance between color histograms by EMD" looks like a relevant image. It seems to sort out the source distribution and bring it into the target distribution. But was there any paper that explicitly wrote about that map? That's where I got stuck.

So, my heart was completely broken, and I thought I'd just forget about it, but since today is Sunday, I decided to try something with NumPy.

The Simplest Case

Looking back at p.13 of Introduction to Optimal Transport, there is a diagram showing bar graphs of the same height being translated. It seems to suggest that shorter movements result in smaller optimal transport distances. And in the case where this bar is a delta function concentrated at one point, it seems that the distance between those two support points is the optimal transport distance.

Therefore, I first want to consider a case where \mathcal{K} = \{ 0, 1, 2, 3, 4 \}, and prepare pseudo \delta_0 (x) and \delta_3 (x) as follows. Presumably, the optimal transport distance would be 3, which is the movement from 0 to 3.

In conclusion, as a result of running a certain optimization, I found a calculation that achieves "optimal transport," and the approximately determined optimal transport distance was 2.764. The marginal distribution reconstructed from the formula achieving this distance is as follows, and it maintains the original \mu_1 and \mu_2 reasonably well.

The "transition" that realizes the optimal transport is approximately as follows:

\begin{align*} \hat{T} = \begin{bmatrix} 0.034 & 0 & 0 & 0 & 0 \\ 0.009 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 \\ 0.914 & 0 & 0.007 & 0.032 & 0.007 \\ 0 & 0 & 0 & 0 & 0 \end{bmatrix} \end{align*}

Summing the elements of each row for each column of \hat{T} yields roughly [1 \quad 0 \quad 0 \quad 0 \quad 0], which corresponds to \mu_1. Next, summing the elements of each column for each row yields roughly [0 \quad 0 \quad 0 \quad 1 \quad 0], which corresponds to \mu_2.

In other words, it means that the probability 1, which was at column index 0, was distributed to row indices 0, 1, and 3 with values of 0.034, 0.009, and 0.914, respectively. Others are errors. Therefore, in this case, the optimal transport distance is 0.034 \times |0 - 0| + 0.009 \times |1 - 0| + 0.914 \times |3 - 0| = 2.751, and if all errors are included, it becomes 2.764.

Note that the true transition is:

\begin{align*} T = \begin{bmatrix} 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 \\ 1 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 \end{bmatrix} \end{align*}

and the true optimal transport distance is 3.

...This is what I finally understood. This is a conclusion obtained after confirming the feeling with experimental results in a slightly more general setting, but because of this clear result, it gave me confidence that the experiments were likely correct.

Experiment

I would like to describe the flow of the experiment that led to the above results below.

First, import the necessary modules:

import numpy as np
import pprint
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, MultipleLocator

The naive idea was as follows:

np.random.seed(seed=42)

N = 5  # Width when considering space K discretely
joint_prob = np.random.rand(N * N).reshape(N, N)  # Something like a 2D list

Here,

  • np.sum(joint_prob, axis=0) \approx mu1
  • np.sum(joint_prob, axis=1) \approx mu2

I felt it would be good if these were close, but there are N^2 unknowns for 2N equations, so the solution cannot be determined. Even so, if I want them to be determined nicely somehow... that's when I began to understand the meaning of "optimization problems."

Determining the Source and Target Distributions for Optimized Transport

Set up the problem by defining some appropriate data.

N = 5
X = np.arange(N, dtype=int)

np.random.seed(seed=42)

# Randomly determine distributions mu1 and mu2. The seed is fixed for reproducibility.
mu1 = np.random.rand(N)
mu1 = mu1 / np.sum(mu1)

mu2 = np.random.rand(N)
mu2 = mu2 / np.sum(mu2)

# Visualization
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].stairs(mu1, fill=True)
axs[0].set_title(r"$\mu_1$")
axs[0].set_xlabel("x")
axs[0].set_ylabel(r"$\mu_1$")
axs[0].xaxis.set_major_locator(MaxNLocator(integer=True))
axs[0].yaxis.set_major_locator(MultipleLocator(0.1)) 
axs[0].set_ylim([0, 1])
axs[0].grid()

axs[1].stairs(mu2, fill=True)
axs[1].set_title(r"$\mu_2$")
axs[1].set_xlabel("x")
axs[1].set_ylabel(r"$\mu_2$")
axs[1].xaxis.set_major_locator(MaxNLocator(integer=True))
axs[1].yaxis.set_major_locator(MultipleLocator(0.1)) 
axs[1].set_ylim([0, 1])
axs[1].grid()

plt.tight_layout()
plt.show()

At this point, I'm not quite sure, but I want to find a matrix joint_prob that represents something like the optimal transport of these two distributions.

Putting it into PyTorch

I want to manipulate joint_prob, but I'm uneasy about determining the values through optimization methods that don't use gradients... or rather, in this case, I'm not sure, so I'll prioritize an approach of gradual improvement and use gradient-based optimization. Therefore, I decided to put it into PyTorch.

Importing additional modules to use PyTorch:

import torch
from torch import nn
from torch import optim

Below, I've listed the constraints with the mindset of a certain Quadratic Unconstrained Binary Optimization (QUBO). I set the constraints to be a joint probability distribution and for the marginal distributions to be \mu_1 and \mu_2.

First, I wrote the code to find the list joint_prob that satisfies these constraints through training:

# Start from a random 2D list
joint_prob = torch.rand(N * N).reshape(N, N)
# Convert mu1 to a tensor
mu1_tensor = torch.tensor(mu1)
# Convert mu2 to a tensor
mu2_tensor = torch.tensor(mu2)

# Train joint_prob
params = nn.parameter.Parameter(joint_prob, requires_grad=True)
optimizer = optim.Adam([params])

for epoch in range(5000):
    optimizer.zero_grad()

    # Since joint_prob itself is a probability distribution, the sum of all elements is 1.
    prob_constraint = (torch.sum(params) - 1) ** 2
    # The marginal distribution obtained by summing over rows becomes mu1.
    mu1_constraint = torch.sum((torch.sum(params, dim=0) - mu1_tensor) ** 2) / N
    # The marginal distribution obtained by summing over columns becomes mu2.
    mu2_constraint = torch.sum((torch.sum(params, dim=1) - mu2_tensor) ** 2) / N
    constraint = prob_constraint + mu1_constraint + mu2_constraint
    loss = constraint
    loss.backward()
    optimizer.step()

    # Since it is a probability distribution, the value range is [0, 1], and outliers are clipped.
    with torch.no_grad():
        params.clamp_(0, 1)

This went reasonably well, and the marginal distributions obtained from params.cpu().detach().numpy() roughly match mu1 and mu2.

Finding the Optimal Transport Distance through Optimization

When the optimization was complete, I thought, "So, what is the optimal transport distance?" and realized that I needed the original loss function in addition to the constraints mentioned above. Recalling once more:

\begin{align*} W_c (\mu_1, \mu_2) = \inf_{\gamma \in \prod (\mu_1, \mu_2)} \int_{\mathcal{K} \times \mathcal{K}} c(\kappa_1, \kappa_2) \gamma (d \kappa_1, d \kappa_2) \end{align*}

I realized that I had implemented \gamma as joint_prob and \inf_{\gamma \in \prod (\mu_1, \mu_2)} in the form of the optimization loop above. Therefore, what remains is \int_{\mathcal{K} \times \mathcal{K}} c(\kappa_1, \kappa_2) \gamma (d \kappa_1, d \kappa_2). In discrete terms, this part would likely look like the following. Since I had to determine the "given cost function" c myself, I appropriately chose the absolute difference in position, i.e., the L^1 distance. While implementing this, the meaning of

it's important when c is a distance d_{\mathcal{K}}(\cdot, \cdot) on \mathcal{K}

became clear to me.

# The optimal value of this loss function is the optimal transport distance
EM_loss = 0
for col in range(N):
    for row in range(N):
        amount = params[row, col]  # γ(dκ₁, dκ₂)
        dist = abs(row - col)  # c(κ₁, κ₂); L1 distance
        EM_loss += dist * amount

I combined this with the previous constraints to create an objective function and completed the training loop as follows. As optimization progressed, at first it satisfied the constraints (i.e., remaining a probability distribution and matching the specified marginal distributions), but eventually, it started to lower the overall value of the objective function by reducing the loss function even if it meant violating the constraints. For this reason, I implemented an early stopping approach to terminate the training.

joint_prob = torch.rand(N * N).reshape(N, NC)
mu1_tensor = torch.tensor(mu1)
mu2_tensor = torch.tensor(mu2)

params = nn.parameter.Parameter(joint_prob, requires_grad=True)
optimizer = optim.Adam([params])

losses = []
EM_losses = []
constraints = []

best_constraint = 10000
best_constraint_epoch = -1
patience = 20  # Early stopping logic
patience_cnt = 0

for epoch in range(12000):
    optimizer.zero_grad()

    EM_loss = 0
    for col in range(N):
        for row in range(N):
            amount = params[row, col]
            dist = abs(row - col)
            EM_loss += amount * dist
    EM_loss = EM_loss / (N * N)

    prob_constraint = (torch.sum(params) - 1) ** 2
    mu1_constraint = torch.sum((torch.sum(params, dim=0) - mu1_tensor) ** 2) / N
    mu2_constraint = torch.sum((torch.sum(params, dim=1) - mu2_tensor) ** 2) / N
    constraint = prob_constraint + mu1_constraint + mu2_constraint
    loss = EM_loss + 4 * constraint  # Weight and sum the loss and constraints appropriately
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        params.clamp_(0, 1)

    # Terminate training if constraints start to be violated, using an early stopping approach
    if constraint.item() < best_constraint:
        best_constraint = constraint.item()
        best_constraint_epoch = epoch
        patience_cnt = 0
    else:
        patience_cnt += 1
    if patience_cnt >= patience:
        break

    losses.append(loss.item())
    EM_losses.append(EM_loss.item())
    constraints.append(constraint.item())

This training completes in about 20 seconds.

Confirming Experimental Results

First, visualize the transition of the loss function values. Since the values become quite small, I took the logarithm of the vertical axis.

fig, axs = plt.subplots(1, 3, figsize=(18, 4))
axs[0].set_yscale("log")
axs[0].plot(losses)
axs[0].set_title("losses")
axs[0].set_xlabel("epoch")
axs[0].set_ylabel("loss")
axs[0].grid()

axs[1].set_yscale("log")
axs[1].plot(EM_losses)
axs[1].set_title("EM_losses")
axs[1].set_xlabel("epoch")
axs[1].set_ylabel("EM_losse")
axs[1].grid()

axs[2].set_yscale("log")
axs[2].plot(constraints)
axs[2].set_title("constraints")
axs[2].set_xlabel("epoch")
axs[2].set_ylabel("constraint")
axs[2].grid()

plt.tight_layout()
plt.show()

The overall values are decreasing, but the constraint term starts to converge slightly earlier. After this, the values began to rise, so I terminated it with early stopping.

Checking Marginal Distributions

Does the joint_prob after optimization represent the marginal distributions \mu_1 and \mu_2 well?

final_joint_prob = params.cpu().detach().numpy()

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
mu1_ = np.sum(final_joint_prob, axis=0)
axs[0].stairs(mu1_, fill=True)
axs[0].set_title(r"$\mu_1$")
axs[0].set_xlabel("x")
axs[0].set_ylabel(r"$\mu_1$")
axs[0].xaxis.set_major_locator(MaxNLocator(integer=True))
axs[0].yaxis.set_major_locator(MultipleLocator(0.1)) 
axs[0].set_ylim([0, 1])
axs[0].grid()

mu2_ = np.sum(final_joint_prob, axis=1)
axs[1].stairs(mu2_, fill=True)
axs[1].set_title(r"$\mu_2$")
axs[1].set_xlabel("x")
axs[1].set_ylabel(r"$\mu_2$")
axs[1].xaxis.set_major_locator(MaxNLocator(integer=True))
axs[1].yaxis.set_major_locator(MultipleLocator(0.1)) 
axs[1].set_ylim([0, 1])
axs[1].grid()

plt.tight_layout()
plt.show()

While there are some slight discrepancies in the details, they look generally similar.

[Reference] Probability distributions set initially:

Being a Joint Probability Distribution

The marginal distributions are fine, but can we consider it a joint probability distribution in the first place?

print(f"{np.sum(final_joint_prob)=}")

np.sum(final_joint_prob)=0.9981522

Since the sum of all elements is 1, it seems it can be properly regarded as a probability distribution.

What's Inside the Joint Probability Distribution?

I could just do pprint.pprint(final_joint_prob), but let's format it a bit:

\begin{align*} T = \begin{bmatrix} 0.092 & 0 & 0 & 0 & 0 \\ 0 & 0.055 & 0 & 0 & 0 \\ 0.018 & 0.157 & 0.152 & 0.026 & 0 \\ 0 & 0.109 & 0.032 & 0.078 & 0.011 \\ 0 & 0 & 0.080 & 0.121 & 0.066 \end{bmatrix} \end{align*}

It seems that for the quantity of about 0.11 at column index 0, 0.092 was distributed to row index 0, and 0.018 was distributed to row index 2. I was surprised that it moved a bit far away. However, basically, the values are concentrated around the diagonal components, so it seems there are many cases where the values are transported to the neighborhood for adjustment. Focusing on column 3, we can see that about half of the amount, 0.121 out of about 0.225, was moved to row 4. While the right end of \mu_2 is quite high, the right end of \mu_1 is low, so it means the difference was received from the neighbor. For the part that became insufficient for row 3, it seems to have been mostly compensated by transport from column 1 to row 3.

While looking at this transition matrix, I began to understand the intuition behind "Figure 3 Calculation of distance between color histograms by EMD" in Impression Estimation of Images Using Earth Mover's Distance.

Optimal Transport Distance

What is the crucial optimal transport distance?

print(f"EM distance={EM_losses[-1] * (N * N)}")

That's the result. This value itself doesn't mean much this time.

EM distance=0.7671486120671034

In The Earth Mover's Distance as a Metric for Image Retrieval or arXiv:1701.07875 Wasserstein GAN, neural networks would be optimized in a way that further optimizes this metric. However, as seen in this experiment, naively finding the optimal transport distance seems to involve a fairly high computational cost.

Note that the result of running this experimental code with:

mu1 = np.zeros(N)
mu1[0] = 1

mu2 = np.zeros(N)
mu2[3] = 1

is the "Simplest Case" at the beginning.

Summary

Starting from knowing nothing at all about optimal transport and optimal transport distance, I think I was able to grasp a rough outline while implementing it based on the general feel. Under the given "distance" function c, assuming that the "amount" is moved according to \gamma (d \kappa_1, d \kappa_2) to match the distributions \mu_1 and \mu_2, the problem asks what the optimal way to do that is.

Rewriting the situation slightly, if the force required to move amount 1 is 1, then the minimum "force x distance = work" is the optimal transport "work," which is the minimum effort required for carrying luggage.

GitHubで編集を提案

Discussion