iTranslated by AI
Thinking about Optimal Transport (1) — Starting from Scratch
Purpose
I don't understand optimal transport. I don't understand the optimal transport distance. I don't understand Earth Mover's Distance. I don't understand the Wasserstein distance. When I don't understand something, I feel like it's best to explore while experimenting to find something that seems right, so I'll try some trial and error.
Note that, strictly speaking, there seem to be subtle distinctions between these terms, but this time I will use them interchangeably without making any particular distinction.
The Beginning (A Musing)
It's not for any particular reason, but I occasionally hear about optimal transport distance. So, I did a quick search and found Introduction to Optimal Transport. I read it about three times, but I didn't understand anything... My question is simple:
- How do you map points on one distribution to points on another? It doesn't seem like it's given as prior knowledge, but...
That's what it boils down to. As I flipped through the pages, matrices appeared, optimization problems appeared, and entropy appeared, so I was overwhelmed 🤯😵💫😵💫
arXiv:1506.05439 Learning with a Wasserstein Loss was listed as a reference, so I looked at it. Eq. (2) seems to be the one:
This equation is what is called the optimal transport distance, and it seems that:
-
is a given cost functionc: \mathcal{K} \times \mathcal{K} \to \mathbb{R} -
and\mu_1 are probability measures on\mu_2 \mathcal{K} -
is a set of joint probability distributions on\prod (\mu_1, \mu_2) that have\mathcal{K} \times \mathcal{K} and\mu_1 as marginal probability distributions\mu_2
Apparently, it's important when
Continuing my search, I found Yossi Rubner et al.'s The Earth Mover's Distance as a Metric for Image Retrieval. The content on p.8 probably corresponds to it, but this time inequality constraints appeared instead of marginal distributions, making it even more confusing.
Next, I found Impression Estimation of Images Using Earth Mover's Distance by Yuko Sakuta et al. Since it's in Japanese, it feels more accessible. "Figure 3 Calculation of distance between color histograms by EMD" looks like a relevant image. It seems to sort out the source distribution and bring it into the target distribution. But was there any paper that explicitly wrote about that map? That's where I got stuck.
So, my heart was completely broken, and I thought I'd just forget about it, but since today is Sunday, I decided to try something with NumPy.
The Simplest Case
Looking back at p.13 of Introduction to Optimal Transport, there is a diagram showing bar graphs of the same height being translated. It seems to suggest that shorter movements result in smaller optimal transport distances. And in the case where this bar is a delta function concentrated at one point, it seems that the distance between those two support points is the optimal transport distance.
Therefore, I first want to consider a case where

In conclusion, as a result of running a certain optimization, I found a calculation that achieves "optimal transport," and the approximately determined optimal transport distance was 2.764. The marginal distribution reconstructed from the formula achieving this distance is as follows, and it maintains the original

The "transition" that realizes the optimal transport is approximately as follows:
Summing the elements of each row for each column of
In other words, it means that the probability 1, which was at column index 0, was distributed to row indices 0, 1, and 3 with values of 0.034, 0.009, and 0.914, respectively. Others are errors. Therefore, in this case, the optimal transport distance is
Note that the true transition is:
and the true optimal transport distance is 3.
...This is what I finally understood. This is a conclusion obtained after confirming the feeling with experimental results in a slightly more general setting, but because of this clear result, it gave me confidence that the experiments were likely correct.
Experiment
I would like to describe the flow of the experiment that led to the above results below.
First, import the necessary modules:
import numpy as np
import pprint
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, MultipleLocator
The naive idea was as follows:
np.random.seed(seed=42)
N = 5 # Width when considering space K discretely
joint_prob = np.random.rand(N * N).reshape(N, N) # Something like a 2D list
Here,
-
np.sum(joint_prob, axis=0)\approx mu1 -
np.sum(joint_prob, axis=1)\approx mu2
I felt it would be good if these were close, but there are
Determining the Source and Target Distributions for Optimized Transport
Set up the problem by defining some appropriate data.
N = 5
X = np.arange(N, dtype=int)
np.random.seed(seed=42)
# Randomly determine distributions mu1 and mu2. The seed is fixed for reproducibility.
mu1 = np.random.rand(N)
mu1 = mu1 / np.sum(mu1)
mu2 = np.random.rand(N)
mu2 = mu2 / np.sum(mu2)
# Visualization
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].stairs(mu1, fill=True)
axs[0].set_title(r"$\mu_1$")
axs[0].set_xlabel("x")
axs[0].set_ylabel(r"$\mu_1$")
axs[0].xaxis.set_major_locator(MaxNLocator(integer=True))
axs[0].yaxis.set_major_locator(MultipleLocator(0.1))
axs[0].set_ylim([0, 1])
axs[0].grid()
axs[1].stairs(mu2, fill=True)
axs[1].set_title(r"$\mu_2$")
axs[1].set_xlabel("x")
axs[1].set_ylabel(r"$\mu_2$")
axs[1].xaxis.set_major_locator(MaxNLocator(integer=True))
axs[1].yaxis.set_major_locator(MultipleLocator(0.1))
axs[1].set_ylim([0, 1])
axs[1].grid()
plt.tight_layout()
plt.show()

At this point, I'm not quite sure, but I want to find a matrix joint_prob that represents something like the optimal transport of these two distributions.
Putting it into PyTorch
I want to manipulate joint_prob, but I'm uneasy about determining the values through optimization methods that don't use gradients... or rather, in this case, I'm not sure, so I'll prioritize an approach of gradual improvement and use gradient-based optimization. Therefore, I decided to put it into PyTorch.
Importing additional modules to use PyTorch:
import torch
from torch import nn
from torch import optim
Below, I've listed the constraints with the mindset of a certain Quadratic Unconstrained Binary Optimization (QUBO). I set the constraints to be a joint probability distribution and for the marginal distributions to be
First, I wrote the code to find the list joint_prob that satisfies these constraints through training:
# Start from a random 2D list
joint_prob = torch.rand(N * N).reshape(N, N)
# Convert mu1 to a tensor
mu1_tensor = torch.tensor(mu1)
# Convert mu2 to a tensor
mu2_tensor = torch.tensor(mu2)
# Train joint_prob
params = nn.parameter.Parameter(joint_prob, requires_grad=True)
optimizer = optim.Adam([params])
for epoch in range(5000):
optimizer.zero_grad()
# Since joint_prob itself is a probability distribution, the sum of all elements is 1.
prob_constraint = (torch.sum(params) - 1) ** 2
# The marginal distribution obtained by summing over rows becomes mu1.
mu1_constraint = torch.sum((torch.sum(params, dim=0) - mu1_tensor) ** 2) / N
# The marginal distribution obtained by summing over columns becomes mu2.
mu2_constraint = torch.sum((torch.sum(params, dim=1) - mu2_tensor) ** 2) / N
constraint = prob_constraint + mu1_constraint + mu2_constraint
loss = constraint
loss.backward()
optimizer.step()
# Since it is a probability distribution, the value range is [0, 1], and outliers are clipped.
with torch.no_grad():
params.clamp_(0, 1)
This went reasonably well, and the marginal distributions obtained from params.cpu().detach().numpy() roughly match mu1 and mu2.
Finding the Optimal Transport Distance through Optimization
When the optimization was complete, I thought, "So, what is the optimal transport distance?" and realized that I needed the original loss function in addition to the constraints mentioned above. Recalling once more:
I realized that I had implemented joint_prob and
it's important when
is a distance c on d_{\mathcal{K}}(\cdot, \cdot) \mathcal{K}
became clear to me.
# The optimal value of this loss function is the optimal transport distance
EM_loss = 0
for col in range(N):
for row in range(N):
amount = params[row, col] # γ(dκ₁, dκ₂)
dist = abs(row - col) # c(κ₁, κ₂); L1 distance
EM_loss += dist * amount
I combined this with the previous constraints to create an objective function and completed the training loop as follows. As optimization progressed, at first it satisfied the constraints (i.e., remaining a probability distribution and matching the specified marginal distributions), but eventually, it started to lower the overall value of the objective function by reducing the loss function even if it meant violating the constraints. For this reason, I implemented an early stopping approach to terminate the training.
joint_prob = torch.rand(N * N).reshape(N, NC)
mu1_tensor = torch.tensor(mu1)
mu2_tensor = torch.tensor(mu2)
params = nn.parameter.Parameter(joint_prob, requires_grad=True)
optimizer = optim.Adam([params])
losses = []
EM_losses = []
constraints = []
best_constraint = 10000
best_constraint_epoch = -1
patience = 20 # Early stopping logic
patience_cnt = 0
for epoch in range(12000):
optimizer.zero_grad()
EM_loss = 0
for col in range(N):
for row in range(N):
amount = params[row, col]
dist = abs(row - col)
EM_loss += amount * dist
EM_loss = EM_loss / (N * N)
prob_constraint = (torch.sum(params) - 1) ** 2
mu1_constraint = torch.sum((torch.sum(params, dim=0) - mu1_tensor) ** 2) / N
mu2_constraint = torch.sum((torch.sum(params, dim=1) - mu2_tensor) ** 2) / N
constraint = prob_constraint + mu1_constraint + mu2_constraint
loss = EM_loss + 4 * constraint # Weight and sum the loss and constraints appropriately
loss.backward()
optimizer.step()
with torch.no_grad():
params.clamp_(0, 1)
# Terminate training if constraints start to be violated, using an early stopping approach
if constraint.item() < best_constraint:
best_constraint = constraint.item()
best_constraint_epoch = epoch
patience_cnt = 0
else:
patience_cnt += 1
if patience_cnt >= patience:
break
losses.append(loss.item())
EM_losses.append(EM_loss.item())
constraints.append(constraint.item())
This training completes in about 20 seconds.
Confirming Experimental Results
First, visualize the transition of the loss function values. Since the values become quite small, I took the logarithm of the vertical axis.
fig, axs = plt.subplots(1, 3, figsize=(18, 4))
axs[0].set_yscale("log")
axs[0].plot(losses)
axs[0].set_title("losses")
axs[0].set_xlabel("epoch")
axs[0].set_ylabel("loss")
axs[0].grid()
axs[1].set_yscale("log")
axs[1].plot(EM_losses)
axs[1].set_title("EM_losses")
axs[1].set_xlabel("epoch")
axs[1].set_ylabel("EM_losse")
axs[1].grid()
axs[2].set_yscale("log")
axs[2].plot(constraints)
axs[2].set_title("constraints")
axs[2].set_xlabel("epoch")
axs[2].set_ylabel("constraint")
axs[2].grid()
plt.tight_layout()
plt.show()

The overall values are decreasing, but the constraint term starts to converge slightly earlier. After this, the values began to rise, so I terminated it with early stopping.
Checking Marginal Distributions
Does the joint_prob after optimization represent the marginal distributions
final_joint_prob = params.cpu().detach().numpy()
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
mu1_ = np.sum(final_joint_prob, axis=0)
axs[0].stairs(mu1_, fill=True)
axs[0].set_title(r"$\mu_1$")
axs[0].set_xlabel("x")
axs[0].set_ylabel(r"$\mu_1$")
axs[0].xaxis.set_major_locator(MaxNLocator(integer=True))
axs[0].yaxis.set_major_locator(MultipleLocator(0.1))
axs[0].set_ylim([0, 1])
axs[0].grid()
mu2_ = np.sum(final_joint_prob, axis=1)
axs[1].stairs(mu2_, fill=True)
axs[1].set_title(r"$\mu_2$")
axs[1].set_xlabel("x")
axs[1].set_ylabel(r"$\mu_2$")
axs[1].xaxis.set_major_locator(MaxNLocator(integer=True))
axs[1].yaxis.set_major_locator(MultipleLocator(0.1))
axs[1].set_ylim([0, 1])
axs[1].grid()
plt.tight_layout()
plt.show()
While there are some slight discrepancies in the details, they look generally similar.

[Reference] Probability distributions set initially:

Being a Joint Probability Distribution
The marginal distributions are fine, but can we consider it a joint probability distribution in the first place?
print(f"{np.sum(final_joint_prob)=}")
np.sum(final_joint_prob)=0.9981522
Since the sum of all elements is 1, it seems it can be properly regarded as a probability distribution.
What's Inside the Joint Probability Distribution?
I could just do pprint.pprint(final_joint_prob), but let's format it a bit:
It seems that for the quantity of about 0.11 at column index 0, 0.092 was distributed to row index 0, and 0.018 was distributed to row index 2. I was surprised that it moved a bit far away. However, basically, the values are concentrated around the diagonal components, so it seems there are many cases where the values are transported to the neighborhood for adjustment. Focusing on column 3, we can see that about half of the amount, 0.121 out of about 0.225, was moved to row 4. While the right end of
While looking at this transition matrix, I began to understand the intuition behind "Figure 3 Calculation of distance between color histograms by EMD" in Impression Estimation of Images Using Earth Mover's Distance.
Optimal Transport Distance
What is the crucial optimal transport distance?
print(f"EM distance={EM_losses[-1] * (N * N)}")
That's the result. This value itself doesn't mean much this time.
EM distance=0.7671486120671034
In The Earth Mover's Distance as a Metric for Image Retrieval or arXiv:1701.07875 Wasserstein GAN, neural networks would be optimized in a way that further optimizes this metric. However, as seen in this experiment, naively finding the optimal transport distance seems to involve a fairly high computational cost.
Note that the result of running this experimental code with:
mu1 = np.zeros(N)
mu1[0] = 1
mu2 = np.zeros(N)
mu2[3] = 1
is the "Simplest Case" at the beginning.
Summary
Starting from knowing nothing at all about optimal transport and optimal transport distance, I think I was able to grasp a rough outline while implementing it based on the general feel. Under the given "distance" function
Rewriting the situation slightly, if the force required to move amount 1 is 1, then the minimum "force x distance = work" is the optimal transport "work," which is the minimum effort required for carrying luggage.
Discussion