Everyone Should Learn Optimal Transport, Part 2

10 minute read

Published:

Continuing from the previous blog post, we will carry on with the core message of the two part series: optimal transport gives us calculus on the space of probability distributions. Just in case you were not already convinced by the first post — perhaps because Langevin diffusion is not everyone’s cup of tea — then neural networks might be the keyword that piques your interest.

We will see that many neural networks have a geometric property that is intricately linked with the structure of optimal transport. This is best demonstrated via a two layer mean field neural network, which we will briefly review next.

Mean Field Neural Networks

In a bit of a miracle year of 2018, beyond the discovery of the Neural Tangent Kernel [JGH18] and several related preprints on convergence results [ALS19,DLL+19] (which deserves its own blog post), there were four concurrent works on two layer mean-field neural networks [C18, MMN18, RV22, SS20]. One more piece of anecdote towards the multiple discovery hypothesis I guess.

Back to the main topic, let us consider input data ${x_i,y_i}_{i=1}^m \subset \mathbb{R}^{n_0} \times \mathbb{R}$, and a neural network defined by the weights $W \in \mathbb{R}^{n\times n_0}, a \in \mathbb{R}^{n}$, and activation function $\varphi:\mathbb{R} \to \mathbb{R}$

\[f(x; W, a) = \frac{1}{n} \sum_{i=1}^n a_i \varphi( \langle w_i, x \rangle ) \,.\]

The main idea of this approach is observe that instead writing the network as a sum, we can actually write it as an integral over the empirical measure $\rho^{(n)} = \frac{1}{n} \sum_{i=1}^n \delta_{w_i, a_i}$

\[f(x; \rho^{(n)}) = \int a \varphi( \langle w, x \rangle ) d\rho^{(n)} \,.\]

This approach and name came from the analysis of interacting particle systems and statistical physics. Here we can treat the weights corresponding to each neuron $(w_i, a_i)$ as a particle, and the population as $\rho^{(n)}$. Notice that we can only represent the particles using the empirical measure because the particles are indistinguishable (we will come back to this later).

The most important property arises in the limit as $n\to\infty$, where the random distribution of a single particle $\mathscr{L}(w_i, a_i)$ is identical to the population distribution $\rho^{(\infty)}$. Consequently, to completely characterize the (training) dynamics in the limit, we only need to characterize a single distribution. This remark deserves to be made even more explicit in a special environment.


The


Quotient Manifold

Without being precise, one way to think about quotients is from the perspective of symmetry. This is often defined using a group action, which helps us define the symmetry. It’s not super important to be an expert in abstract algebra for this blog post, as we will only consider very simple groups. For example, let us start by considering a construction of the torus $\mathbb{T}$.

  1. Original Manifold: Start with the real line $\mathbb{R}$.
  2. Action: Consider the group of integers (with respect to addition) $(\mathbb{Z}, +)$.
  3. Symmetry: Define the equivalence $x\sim y$ on $\mathbb{R}$ if there exists a $k \in \mathbb{Z}$ such that $x = y + k$.
  4. Quotient: Define $\mathbb{T} = \mathbb{R} / \sim$ as all equivalence classes the type $\pi(x) = \{ x+k | k \in \mathbb{Z} \}$.

We can think of $\pi:\mathbb{R} \to \mathbb{T}$ as a projection, and the quotient manifold is the interval $[0,1)$ that loops around itself, i.e. a circle that starts and ends at $0$. Most importantly, the target inherits the properties of the original manifold. In particular, this includes the differentiable structure.

For example, let us consider $F(x) = \cos( 2\pi x )$, which is a periodic function on $\mathbb{R}$ is that invariant to the symmetry group $\mathbb{Z}$. Next let us consider the gradient flow dynamics on $F$

\[\dot X_t = - \nabla F(X_t) \,, \quad X_0 = x_0 \in \mathbb{R} \,,\]

note $X_t$ is also invariant to the symmetry group. To be precise, suppose we consider another initial condition at $y_0 = x_0 + k$ for some $k \in \mathbb{Z}$, then the new dynamics $Y_t$ satisfies

\[Y_t = X_t + k \,, \quad \text{ for all } t \geq 0 \,.\]

The curious reader may want to ask: since the dynamics do not need the “extra information” due to symmetry, can we study the same dynamics on the quotient manifold $\mathbb{T}$ alone? Indeed, this is the key idea of this blog post. Let $\text{grad}_{\mathbb{T}}$ be the gradient operator on $\mathbb{T}$ (note: not always the same as the original manifold), and let $Z_t = \pi(X_t)$, we have that

\[\dot Z_t = - \text{grad}_{\mathbb{T}} F \circ \pi^{-1} ( Z_t ) \,,\]

where we note that $F \circ \pi^{-1}$ is a function on $\mathbb{T}$ and is well defined since $F$ is symmetric with respect to additive shifts by $\mathbb{Z}$. This equation is quite subtle, so let me emphasize the most important observation in a separate environment.


Remark If $X_t$ is invariant to the symmetry of $\mathbb{Z}$, then it is sufficient to study the quotient dynamics $Z_t$ on $\mathbb{T}$.


It is instructive to pause here and think about the general case. Without being completely precise (e.g. see Theorem 21.10 of [Lee12]), essentially if we want to study the gradient flow dynamics $X_t$ on a manifold $M$, which is invariant to some group action $G$ on $M$. Then it is sufficient to study the dynamics $Z_t$ on the quotient manifold $N = M/G$.

At this point, the reader may be wondering: this is all great and all, but what does this have to with the optimal transport? Indeed, we will see in the next section that the space of probability measures can be interpreted as a quotient manifold.

Permutation Symmetry

Let us consider a vector $x = (x_1, x_2, \cdots, x_n) \in \mathbb{R}^n$, and the permutation group $G_n$ that simplify reorders the entries of the sequence, i.e.

\[g(x) = ( x_{\sigma(i)} )_{i=1}^n \,, \quad \text{ where } \sigma \text{ bijection on } [n] \,.\]

There is a very natural quotient representation defined by the empirical measure

\[\pi^{(n)}: x \mapsto \mu^{(n)}(dx) = \frac{1}{n} \sum_{i=1}^n \delta_{x_i}(dx) \,,\]

or we simply put a point mass of weight $\frac{1}{n}$ at each location of the $x_i$. Notice this representation is clear invariant to permutation, since addition is commutative. More importantly, observe that $\mu^{(n)}$ is a probability measure.

So we are starting to see how this relates back to optimal transport, but before we get there, we will need to equip the space with an appropriate Riemannian metric. In particular, we will consider the averaged Euclidean metric on $\mathbb{R}^n$

\[\langle x, y \rangle_n = \frac{1}{n} \sum_{i=1}^n x_i y_i \,.\]

The reason we choose this weighting, is so that it matches the normalization of the probability measure. For example, this allows us to write

\[\| x_n \|_n^2 = \int x^2 \, d\mu^{(n)} \,.\]

Here comes the interesting part. What is the 2-norm after taking the quotient with respect to the permutations $G_n$? This is a bit subtle, but the distance between two points should be taken at the pair that achieves the minimum.

Consider the $n=2$ case, where we want to compare the distance between the points $(x_1, x_2), (y_1, y_2)$ modulo the group permutation group $G_2$. Here, the axis of symmetry is the diagonal line of $x_2 = x_1$ plotted in dashes below. Now there are two possible ways to compute the squared distance: either $\|x-y\|_2^2 = (x_1-y_1)^2 + (x_2-y_2)^2$ or the permuted version $(x_2 - y_1)^2 + (x_1 - y_2)^2$. This is illustrated by the red and green lines below.

However, to make sense of distance in the quotient manifold, we should first project onto a representation, e.g. the half plane below the axis of symmetry, or $\{x_1 \leq x_2\}$. Observe that using this procedure, it’s equivalent to choosing the minimum distance. More precisely, the squared distance on the quotient space of $\mathbb{R}^n / G_n$ is as follows

\[\| x-y \|_{n,G_n}^2 = \min_{ g \in G_n } \| g(x) - y \|_n^2 = \min_{\sigma} \frac{1}{n} \sum_{i=1}^n ( x_{\sigma(i)} - y_i )^2 \,,\]

where once again $\sigma$ is a bijection on $[n]$.

The observant reader may already realized something incredible here: this is in fact the Monge distance on probability measures. More precisely, let $\mu^{(n)} = \frac{1}{n} \sum_{i=1}^n \delta_{x_i}, \nu^{(n)} = \frac{1}{n} \sum_{i=1}^n \delta_{y_i}$, then we have that

\[\| x-y \|_{n,G_n}^2 = \min_{T:\mathbb{R} \to \mathbb{R}} \int ( T(x) - y )^2 \, d\mu^{(n)}(x) d\nu^{(n)}(y) \,.\]

Now this is not quite the Wasserstein distance, which is a relaxation of the Monge formulation. However, we know by Brenier’s Theorem [Theorem 2.12, Vil03], these two formulations are equivalent whenever one of the two measures has a density.

References

  • [CB18] Chizat, L., & Bach, F. (2018). On the global convergence of gradient descent for over-parameterized models using optimal transport. Advances in neural information processing systems, 31.
  • [ALS19] Allen-Zhu, Z., Li, Y., & Song, Z. (2019, May). A convergence theory for deep learning via over-parameterization. In International conference on machine learning (pp. 242-252). PMLR.
  • [DLL+19] Du, S., Lee, J., Li, H., Wang, L., & Zhai, X. (2019, May). Gradient descent finds global minima of deep neural networks. In International conference on machine learning (pp. 1675-1685). PMLR.
  • [JGH18] Jacot, A., Gabriel, F., & Hongler, C. (2018). Neural tangent kernel: Convergence and generalization in neural networks. Advances in neural information processing systems, 31.
  • [Lee12] Lee, J. M. (2012). Quotient manifolds. Introduction to Smooth Manifolds, 540-563.
  • [MMN18] Mei, S., Montanari, A., & Nguyen, P. M. (2018). A mean field view of the landscape of two-layer neural networks. Proceedings of the National Academy of Sciences, 115(33), E7665-E7671.
  • [RV22] Rotskoff, G., & Vanden‐Eijnden, E. (2022). Trainability and accuracy of artificial neural networks: An interacting particle system approach. Communications on Pure and Applied Mathematics, 75(9), 1889-1935.
  • [SS20] Sirignano, J., & Spiliopoulos, K. (2020). Mean field analysis of neural networks: A law of large numbers. SIAM Journal on Applied Mathematics, 80(2), 725-752.