PDE and Machine Learning (3): Variational Principles and Optimization
Chen Kai BOSS

What is the essence of neural network training? When we perform gradient descent in high-dimensional parameter space, does there exist a deeper continuous-time dynamics? As network width tends to infinity, does discrete parameter updating converge to some elegant partial differential equation? The answers to these questions lie at the intersection of calculus of variations, optimal transport theory, and partial differential equations.

Over the past decade, the success of deep learning has been built primarily on empirical insights and engineering practices. However, recent years have witnessed mathematicians discovering that viewing neural networks as particle systems on the space of probability measures and studying their evolution under Wasserstein geometry can reveal global properties of training dynamics, convergence guarantees, and the essence of phenomena like initialization and over-parameterization. The core tool of this perspective is variational principles— from the principle of least action in physics, to the JKO scheme in modern optimal transport theory, to the mean-field limit of neural networks.

This article systematically establishes this theoretical framework. We begin with classical calculus of variations, introducing fundamental tools such as functional derivatives and Euler-Lagrange equations. We then introduce Wasserstein metrics and gradient flow theory, demonstrating how the heat equation and Fokker-Planck equation can be unified as gradient flows of energy functionals. Finally, we focus on neural network training, deriving mean-field equations, proving global convergence, and validating theoretical predictions through numerical experiments.

Foundations of Calculus of Variations: From Functionals to Euler-Lagrange Equations

Functionals and First Variation

The core object of study in calculus of variations is the functional— a mapping that takes functions to real numbers. Unlike ordinary functions, the "input" to a functional is an entire function, while the "output" is a numerical value.

Definition (Functional): Letbe a function space (e.g.,). A functionalmaps each function$y to a real number.

Classical Examples:

  1. Arc length functional: The length of a curveon
  2. Surface area: Area of a surface of revolution
  3. Action functional (physics): Action of a particle trajectory whereis the Lagrangian function.

The fundamental problem in calculus of variations is: Among all functions satisfying boundary conditions, which one extremizes the functional?

Definition (Gateaux Derivative): The Gateaux derivative of a functionalatin the directionis defined asIf this limit exists and is linear in, we writeand callthe variational derivative or functional derivative ofwith respect to.

Theorem (Euler-Lagrange Equation): Consider the functionalIfis an extremum ofand satisfies boundary conditions, thensatisfies the Euler-Lagrange equation: Proof Sketch: Supposeis an extremum. For any functionwith, defineThe extremum condition requires. Computing givesIntegrating the second term by parts:Since, the boundary term vanishes, yieldingSinceis arbitrary, we obtain the Euler-Lagrange equation.

Classical Variational Problem: The Brachistochrone

Problem (Brachistochrone): In a gravitational field, along which smooth curve should a particle slide frictionlessly from pointto pointto minimize the descent time?

Set up coordinates withat the origin and the-axis pointing downward. The curve is, withat. By energy conservation, the velocity satisfiesThe arc length element is, the time element is, and the total time isHere. The Euler-Lagrange equation isComputing yieldsNote thatdoes not explicitly depend on, so there exists a first integral (Beltrami identity):Substituting givesSimplifying yieldswhere. Setting, we getParametrization: Let, thenThusIntegrating gives the cycloid equation:This is precisely the trajectory of a point on a circle of radiusrolling along a straight line.

Hamilton's Principle and Variational Methods in Physics

One of the most profound principles in physics is Hamilton's principle (the principle of least action): the actual motion trajectory extremizes the action functional.

Theorem (Hamilton's Principle): Let the Lagrangian of a particle be, and the action beThen the actual trajectorysatisfies the Euler-Lagrange equation Example (Free Particle): For Lagrangian, the Euler-Lagrange equation giveswhich is Newton's first law.

Example (Harmonic Oscillator): For, we obtain Legendre Transform and Hamiltonian: Define the generalized momentumThrough the Legendre transform, define the Hamiltonian:This yields Hamilton's canonical equations:This transforms the second-order Euler-Lagrange equation into a first-order system, revealing the symplectic geometric structure of energy conservation.

Functional Derivatives and Gradient Descent

In optimization theory, variational derivatives correspond to "gradients in infinite-dimensional space." Consider the functionalIts variational derivativeis defined asfor all test functions.

Computation Rules:

  1. Point-wise functionals: If, then

  2. Derivative functionals: If, then (obtained through integration by parts)

  3. Chain rule: If, where, then Example (Dirichlet Energy): For, the variational derivative isThe Euler-Lagrange equationyields Laplace's equation.

Gradient Flow: In function space, evolution along the negative gradient of a functional yields a PDE:For example, the gradient flow of Dirichlet energy is precisely the heat equation:This reveals a profound connection between PDEs and optimization: many important PDEs can be understood as gradient flows of energy functionals.

Gradient Flow Theory and Wasserstein Geometry

Gradient Flows in Euclidean Space

In finite-dimensional Euclidean space, consider a smooth function. The gradient flow is the solution trajectory of the differential equationThis is the continuous-time version of steepest descent.

Properties:

  1. Energy Dissipation: Along trajectories,decreases monotonically:

  2. Equilibrium Points: Trajectories converge to points satisfying.

  3. Lyapunov Stability: Ifis bounded below, trajectories are bounded; ifis strongly convex, convergence to the unique global minimum is guaranteed.

Example (Quadratic Function): Forwithpositive definite, the gradient flow iswith solution, exponentially converging to the origin.

Wasserstein Space and Optimal Transport

When studying the evolution of probability distributions, Euclidean geometry is no longer suitable. We need to introduce the Wasserstein metric, which measures the "optimal transport cost" between distributions.

Definition (Wasserstein-2 Distance): Letbe probability densities on(absolutely continuous with respect to Lebesgue measure). Thedistance is defined aswhereis the set of all joint distributions (couplings) with marginals.

Geometric Intuition:is the minimum "transportation cost" to reconfigure mass distributioninto, where cost is proportional to mass times squared distance.

Monge-Kantorovich Duality: Thedistance can also be expressed via Kantorovich duality: Theorem (Brenier): Ifare absolutely continuous, the optimal transport map exists, is unique, and has the form, whereis a convex function satisfying Example (Gaussian Distributions): Let,. ThenIn particular, if covariances are identical,.

Wasserstein Gradient Flows: The JKO Scheme

Core Idea: On the space of probability measures (probability measures with finite second moment), how do we define the "gradient flow" of an energy functional?

Definition (JKO Scheme): Given a functionaland time step, starting from, recursively defineMissing or unrecognized delimiter for \right\rho_{k+1} = \arg\min_\rho \left\{\mathcal{E}[\rho] + \frac{1}{2\tau} W_2^2(\rho, \rho_k) \right}As, the limit of the discrete trajectory (if it exists) is called the Wasserstein gradient flow of.

This scheme was proposed by Jordan, Kinderlehrer, and Otto in 1998, abbreviated as the JKO scheme. It generalizes the implicit Euler scheme to the space of probability measures: Formal Derivation: In Euclidean space, the gradient flowdiscretized by implicit Euler givesMissing or unrecognized delimiter for \right\frac{x_{k+1} - x_k}{\tau} = -\nabla f(x_{k+1}) \quad \Leftrightarrow \quad x_{k+1} = \arg\min_x \left\{f(x) + \frac{1}{2\tau}|x - x_k|^2 \right}The JKO scheme simply replaces the Euclidean distancewith the Wasserstein distance.

Heat Equation as Entropy Gradient Flow

Theorem (Otto): Consider the Boltzmann entropyIts Wasserstein gradient flow is precisely the heat equation (Fokker-Planck equation): Proof Sketch:

  1. JKO Scheme: For,Missing or unrecognized delimiter for \right\rho_{k+1} = \arg\min_\rho \left\{\mathcal{H}[\rho] + \frac{1}{2\tau} W_2^2(\rho, \rho_k) \right}

  2. Variational Condition: The first-order optimality condition iswhereis the optimal transport potential fromto.

  3. Entropy Variational Derivative:

  4. Optimal Transport Relation: Brenier's theorem impliessatisfies

  5. Continuum Limit: Set,. As,Meanwhile, the evolution of the optimal transport potential gives velocity field, satisfyingCombined with the variational condition(from), we get Energy Dissipation: Along the heat equation, entropy decreases monotonically:whereis the Fisher information, always non-negative. This is precisely the energy dissipation property of gradient flows in Wasserstein geometry.

Other Gradient Flow Examples

Fokker-Planck Equation: Consider the free energy functionalIts Wasserstein gradient flow isThis is the Fokker-Planck equation (or Smoluchowski equation) with external potential. The equilibrium distribution is the Gibbs measure.

Porous Medium Equation: Consider the internal energy functionalIts gradient flow is the porous medium equation:This equation describes gas diffusion in porous media and exhibits finite propagation speed.

Keller-Segel Equation: The equation describing chemotaxiscan be understood as the gradient flow of the energy functionalembodying the competition between entropy increase (diffusion) and attractive potential energy (chemotaxis).

Mean-Field Theory of Neural Network Training

From Finite Width to Infinite Width

Consider a two-layer neural network:whereare parameters,is the hidden layer width, andis the activation function.

Loss Function: Given data, the empirical risk is Gradient Descent: Parameter updates areExpanding gives Particle System Interpretation: View each neuronas a particle. Neural network training is an interacting system ofparticles:Particles couple through the loss function (sinceincludes contributions from all particles).

Mean-Field Limit: As, the empirical measure of the particle systemunder appropriate assumptions converges to a continuous distribution, whose evolution is described by a mean-field equation (Vlasov equation).

Derivation of Mean-Field Equations

Assumptions:

  1. Initial parameters are independent and identically distributed:.
  2. Activation functionsatisfies Lipschitz conditions.
  3. Loss functionis smooth with bounded gradient.

Representation: Network output can be written asFor finite,.

Loss Functional: Gradient Flow: The evolution of particleis Computing Variational Derivative:whereTherefore Mean-Field Limit: As,satisfies the continuity equationwhere the velocity field is Explicit Form (simplified case with fixed):This is a nonlinear Fokker-Planck type equation.

Global Convergence Analysis

Theorem (Mei et al. 2018, Chizat & Bach 2018): Under the following conditions, the mean-field equation converges globally to zero loss:

  1. Over-parameterization:(or in the continuum limit, the support ofis sufficiently large).
  2. Positive Definiteness: The Neural Tangent Kernel (NTK)is positive definite on data points.
  3. Initialization:satisfies certain regularity conditions (e.g., Gaussian distribution).

Proof Sketch:

Step 1: Linearization. Under small learning rate or in the NTK regime, network evolution can be approximated asThis is a linear PDE with respect to.

Step 2: Energy Dissipation. DefineThenwhereis the minimum eigenvalue of the kernel matrix.

Step 3: Exponential Convergence. Solving givesThus loss converges exponentially to zero.

NTK vs. Mean-Field Comparison:

  • NTK Limit (Jacot et al. 2018): Width, fixed learning rate, parameters barely move (lazy training). Network linearizes near initialization.

  • Mean-Field Limit: Width, learning rate scales as, parameters move significantly. Captures global nonlinear dynamics.

Illustration: In parameter space, NTK corresponds to linear approximation in a small neighborhood, while mean-field describes large-scale particle flow.

Gradient Flow Representation on Wasserstein Space

Key Observation: The mean-field equation can be written as a gradient flow in Wasserstein formThis is a gradient flow in Euclidean space over. But in certain cases, it can be reformulated as a Wasserstein gradient flow.

Theorem (Chizat & Bach 2018): If the loss can be written aswhereis symmetric, then the mean-field equation is the Wasserstein gradient flow of the functional(degenerate when).

Application: This formulation reveals global convexity of training — though the loss is non-convex with respect to parameters, in measure space the functional may be convex (displacement convexity).

Example (Quadratic Loss): For output layer training (fixed features), the loss iswhereare fixed features. This can be written asIn RKHS,is convex.

Continuous-Time Interpretation of Deep Networks

ResNet and ODE: Residual networksas the number of layerswith step size, converge to an ODE:This is the foundation of Neural ODE (Chen et al. 2018).

Conditional Optimal Transport (Onken et al. 2021): Training ResNets can be understood as learning conditional optimal transport maps: given input, find the optimal trajectory mappingto target distribution.

Theorem (Deep ResNets and Conditional Optimal Transport): Training deep ResNets is equivalent to solving a conditional optimal transport problem:whereis the distribution obtained by starting fromand mapping through network.

Significance: This perspective understands representation learning in deep learning as: learning an optimal map that progressively "flattens" the complex distribution in input space, transporting it layer by layer to an easily classifiable target space.

Experimental Validation: Bridging Theory and Practice

To validate the preceding theory, we design three sets of experiments: (1) visualizing gradient flow trajectories; (2) verifying the mean-field limit; (3) studying the effect of initialization on convergence.

Experiment 1: Gradient Flow Trajectory Visualization

We visualize continuous gradient flow trajectories versus discrete updates on different functions.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

def gradient_descent(f, grad_f, x0, lr=0.01, num_steps=1000):
"""Gradient descent optimization"""
trajectory = [x0]
x = x0.copy()
for _ in range(num_steps):
x = x - lr * grad_f(x)
trajectory.append(x.copy())
return np.array(trajectory)

def gradient_flow_ode(grad_f, x0, t_span, num_points=1000):
"""Solve gradient flow using ODE solver"""
from scipy.integrate import solve_ivp

def dynamics(t, x):
return -grad_f(x)

sol = solve_ivp(dynamics, t_span, x0,
t_eval=np.linspace(t_span[0], t_span[1], num_points),
method='RK45')
return sol.y.T

# Example 1: Quadratic function
def quadratic_2d(x):
"""f(x,y) = x^2 + 4y^2"""
return x[0]**2 + 4*x[1]**2

def grad_quadratic_2d(x):
return np.array([2*x[0], 8*x[1]])

# Example 2: Rosenbrock function
def rosenbrock(x):
"""f(x,y) = (1-x)^2 + 100(y-x^2)^2"""
return (1 - x[0])**2 + 100*(x[1] - x[0]**2)**2

def grad_rosenbrock(x):
dx = -2*(1 - x[0]) - 400*x[0]*(x[1] - x[0]**2)
dy = 200*(x[1] - x[0]**2)
return np.array([dx, dy])

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for idx, (f, grad_f, name) in enumerate([
(quadratic_2d, grad_quadratic_2d, 'Quadratic Function'),
(rosenbrock, grad_rosenbrock, 'Rosenbrock Function')
]):
ax = axes[idx]

# Plot contours
x = np.linspace(-2, 2, 400)
y = np.linspace(-1, 3, 400)
X, Y = np.meshgrid(x, y)
Z = np.array([[f(np.array([xi, yi])) for xi in x] for yi in y])

contour = ax.contour(X, Y, Z, levels=20, cmap='viridis', alpha=0.6)
ax.clabel(contour, inline=True, fontsize=8)

# Initial point
x0 = np.array([1.5, 2.5]) if idx == 1 else np.array([1.5, 1.0])

# Discrete gradient descent
traj_gd = gradient_descent(f, grad_f, x0, lr=0.001 if idx == 1 else 0.1, num_steps=500)
ax.plot(traj_gd[:, 0], traj_gd[:, 1], 'r-', linewidth=2, label='Discrete GD', alpha=0.7)
ax.plot(traj_gd[0, 0], traj_gd[0, 1], 'ro', markersize=10, label='Start')

# Continuous gradient flow (ODE)
if idx == 0:
traj_ode = gradient_flow_ode(grad_f, x0, [0, 2], num_points=1000)
ax.plot(traj_ode[:, 0], traj_ode[:, 1], 'b--', linewidth=2, label='Continuous Flow', alpha=0.7)

ax.set_xlabel('$x$', fontsize=12)
ax.set_ylabel('$y$', fontsize=12)
ax.set_title(f'Gradient Flow on {name}', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('gradient_flow_trajectories.png', dpi=300, bbox_inches='tight')
plt.show()

Experiment Description:

  • Quadratic Function: Gradient flow is a linear system, with trajectories being exponentially decaying ellipses. Discrete gradient descent (Euler method) highly agrees with continuous ODE solution (at small learning rates).

  • Rosenbrock Function: The banana-shaped valley makes optimization difficult. Gradient descent zigzags in the valley, differing from the smooth continuous flow trajectory (at large learning rates).

Observation: Smaller learning rates make discrete trajectories closer to continuous gradient flow; but computational cost is higher. This motivates accelerated methods (momentum, Adam), which correspond to different continuous-time dynamics (Lagrangian mechanics, twisted Riemannian metrics).

Experiment 2: Mean-Field Limit Verification

We train two-layer neural networks of varying widths, observing particle density evolution and comparing with theoretical predictions.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
import numpy as np

# Generate data: fit simple function f(x) = x^2
np.random.seed(42)
torch.manual_seed(42)

n_data = 50
X_train = np.linspace(-1, 1, n_data).reshape(-1, 1)
y_train = X_train**2 + 0.05 * np.random.randn(n_data, 1)

X_train = torch.FloatTensor(X_train)
y_train = torch.FloatTensor(y_train)

# Define two-layer neural network
class TwoLayerNet(nn.Module):
def __init__(self, width):
super().__init__()
self.width = width
self.fc1 = nn.Linear(1, width)
self.fc2 = nn.Linear(width, 1)

# Gaussian initialization
nn.init.normal_(self.fc1.weight, mean=0, std=1.0)
nn.init.normal_(self.fc1.bias, mean=0, std=1.0)
nn.init.normal_(self.fc2.weight, mean=0, std=1.0/np.sqrt(width))
nn.init.zeros_(self.fc2.bias)

def forward(self, x):
return self.fc2(torch.relu(self.fc1(x)))

def train_and_track(width, num_epochs=1000, lr=0.01):
"""Train network and track hidden layer weight distribution"""
model = TwoLayerNet(width)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
criterion = nn.MSELoss()

# Record weight history
weight_history = []
loss_history = []

# Snapshot epochs
snapshot_epochs = [0, 100, 300, 1000]

for epoch in range(num_epochs + 1):
optimizer.zero_grad()
output = model(X_train)
loss = criterion(output, y_train)
loss.backward()
optimizer.step()

if epoch in snapshot_epochs:
# Save hidden layer weights (first dimension)
weights = model.fc1.weight.data.cpu().numpy().flatten()
weight_history.append((epoch, weights))

if epoch % 100 == 0:
loss_history.append((epoch, loss.item()))

return weight_history, loss_history

# Train networks of different widths
widths = [10, 100, 1000]
results = {}

for width in widths:
print(f"Training network with width m={width}...")
weight_hist, loss_hist = train_and_track(width, num_epochs=1000, lr=0.01/np.sqrt(width))
results[width] = (weight_hist, loss_hist)

# Visualize weight density evolution
fig, axes = plt.subplots(len(widths), 4, figsize=(16, 4*len(widths)))

for i, width in enumerate(widths):
weight_hist, _ = results[width]

for j, (epoch, weights) in enumerate(weight_hist):
ax = axes[i, j]

# Plot histogram
ax.hist(weights, bins=30, density=True, alpha=0.6, color='steelblue', edgecolor='black')

# Plot kernel density estimate
if len(weights) > 1:
kde = gaussian_kde(weights)
x_range = np.linspace(weights.min() - 1, weights.max() + 1, 200)
ax.plot(x_range, kde(x_range), 'r-', linewidth=2, label='KDE')

ax.set_xlabel('Weight$w$', fontsize=10)
ax.set_ylabel('Density$\\rho(w)$', fontsize=10)
ax.set_title(f'm={width}, epoch={epoch}', fontsize=11)
ax.grid(True, alpha=0.3)
if j == 0:
ax.legend()

plt.tight_layout()
plt.savefig('meanfield_density_evolution.png', dpi=300, bbox_inches='tight')
plt.show()

# Visualize loss convergence
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

for width in widths:
_, loss_hist = results[width]
epochs, losses = zip(*loss_hist)
ax.plot(epochs, losses, marker='o', linewidth=2, label=f'm={width}')

ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss (MSE)', fontsize=12)
ax.set_title('Loss Convergence for Different Network Widths', fontsize=14)
ax.set_yscale('log')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('meanfield_loss_convergence.png', dpi=300, bbox_inches='tight')
plt.show()

Experiment Description:

  1. Density Evolution: As training progresses, the weight distribution gradually evolves from the initial Gaussian. At small width (), discreteness is evident (histogram shows significant fluctuations); as width increases (), the distribution becomes smoother, approaching a continuous density.

  2. Mean-Field Limit: Theory predicts densitysatisfies a PDE. As, empirical densityshould converge to. In experiments, the distribution atis already quite smooth.

  3. Convergence Speed: Loss curves show that larger width leads to faster convergence (over-parameterization effect). Note that learning rate scales as, ensuring parameter movement scale remains consistent (mean-field scaling).

Theory Comparison: If initialization, theoretical analysis (Mei et al. 2018) predicts that under linear activation or small learning rate,remains approximately Gaussian, with mean shifting but variance nearly constant. Similar phenomena are observed in experiments (especially at).

Experiment 3: Wasserstein Distance Computation

We use the Python Optimal Transport (POT) library to compute Wasserstein distance between empirical distributions and target distributions, verifying whether training decreases alongdistance.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import ot
import numpy as np
import matplotlib.pyplot as plt

# Use weight history from previous training
# Assume target distribution is a known Gaussian (or extracted from sufficiently trained distribution)
target_mean = 0.0
target_std = 0.8

# Generate target distribution samples (Gaussian)
n_target_samples = 1000
target_samples = np.random.normal(target_mean, target_std, n_target_samples)

def compute_wasserstein_distance(source_samples, target_samples):
"""Compute Wasserstein-2 distance between two 1D distributions"""
# In 1D, W_2^2 = E[(X - Y)^2], where X, Y are matched samples
# Optimal transport plan is one-to-one correspondence after sorting
source_sorted = np.sort(source_samples)
target_sorted = np.sort(target_samples)

# If sample sizes differ, need to interpolate to same count
if len(source_sorted) != len(target_sorted):
from scipy.interpolate import interp1d
# Align using percentiles
n = min(len(source_sorted), len(target_sorted))
percentiles = np.linspace(0, 100, n)
source_aligned = np.percentile(source_sorted, percentiles)
target_aligned = np.percentile(target_sorted, percentiles)
else:
source_aligned = source_sorted
target_aligned = target_sorted

# W_2 distance
w2_dist = np.sqrt(np.mean((source_aligned - target_aligned)**2))
return w2_dist

# Compute Wasserstein distance during training for different width networks
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

for width in widths:
weight_hist, _ = results[width]
w2_distances = []
epochs_recorded = []

for epoch, weights in weight_hist:
w2 = compute_wasserstein_distance(weights, target_samples)
w2_distances.append(w2)
epochs_recorded.append(epoch)

ax.plot(epochs_recorded, w2_distances, marker='o', linewidth=2,
markersize=8, label=f'm={width}')

ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Wasserstein-2 Distance$W_2(\\rho_t, \\rho_*)$', fontsize=12)
ax.set_title('Wasserstein Distance Between Weight Distribution and Target During Training', fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('wasserstein_distance_evolution.png', dpi=300, bbox_inches='tight')
plt.show()

Experiment Description:

  1. Distance Decrease: If training is a Wasserstein gradient flow of a functional,should decrease monotonically. Experiments observe an overall decreasing trend, though fluctuations may occur (discrete updates, finite sample effects).

  2. Width Effect: Larger width makes empirical measure closer to continuous distribution, makingdistance computation more stable.

  3. Theory Verification: This experiment directly validates the hypothesis that "training is a gradient flow on Wasserstein space." With appropriate choice of functional (e.g., Fisher-Rao metric in the paper Kernel Approximation of Fisher-Rao Gradient Flows), more precise correspondence can be obtained.

Experiment 4: Two-Layer Neural Network Loss Surface Visualization

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Simplified problem: loss surface with two parameters
# Consider network f(x; w1, w2) = w2 * relu(w1 * x), fitting single data point

X_single = torch.tensor([[1.0]])
y_single = torch.tensor([[0.5]])

def compute_loss(w1, w2):
"""Compute loss L(w1, w2) = 0.5 * (w2 * relu(w1) - 0.5)^2"""
output = w2 * torch.relu(w1 * X_single)
loss = 0.5 * (output - y_single)**2
return loss.item()

# Create grid
w1_range = np.linspace(-2, 2, 200)
w2_range = np.linspace(-2, 2, 200)
W1, W2 = np.meshgrid(w1_range, w2_range)

# Compute loss surface
L = np.zeros_like(W1)
for i in range(W1.shape[0]):
for j in range(W1.shape[1]):
w1 = torch.tensor(W1[i, j], requires_grad=False)
w2 = torch.tensor(W2[i, j], requires_grad=False)
L[i, j] = compute_loss(w1, w2)

# 3D surface plot
fig = plt.figure(figsize=(14, 6))

ax1 = fig.add_subplot(121, projection='3d')
surf = ax1.plot_surface(W1, W2, L, cmap='viridis', alpha=0.8, edgecolor='none')
ax1.set_xlabel('$w_1$', fontsize=12)
ax1.set_ylabel('$w_2$', fontsize=12)
ax1.set_zlabel('Loss$L(w_1, w_2)$', fontsize=12)
ax1.set_title('Two-Layer Neural Network Loss Surface (3D)', fontsize=14)
fig.colorbar(surf, ax=ax1, shrink=0.5)

# 2D contour plot + gradient flow trajectories
ax2 = fig.add_subplot(122)
contour = ax2.contour(W1, W2, L, levels=30, cmap='viridis')
ax2.clabel(contour, inline=True, fontsize=8)

# Simulate gradient descent trajectories
def gradient_descent_trajectory(w1_init, w2_init, lr=0.1, num_steps=100):
trajectory = [(w1_init, w2_init)]
w1 = torch.tensor(w1_init, requires_grad=True)
w2 = torch.tensor(w2_init, requires_grad=True)

for _ in range(num_steps):
output = w2 * torch.relu(w1 * X_single)
loss = 0.5 * (output - y_single)**2
loss.backward()

with torch.no_grad():
w1 -= lr * w1.grad
w2 -= lr * w2.grad
w1.grad.zero_()
w2.grad.zero_()

trajectory.append((w1.item(), w2.item()))

return np.array(trajectory)

# Different initializations
init_points = [(-1.5, 1.5), (1.5, 1.5), (-1.5, -1.5)]
colors = ['red', 'blue', 'green']

for init, color in zip(init_points, colors):
traj = gradient_descent_trajectory(init[0], init[1], lr=0.05, num_steps=150)
ax2.plot(traj[:, 0], traj[:, 1], color=color, linewidth=2, alpha=0.7, label=f'Init ({init[0]}, {init[1]})')
ax2.plot(init[0], init[1], 'o', color=color, markersize=10)

ax2.set_xlabel('$w_1$', fontsize=12)
ax2.set_ylabel('$w_2$', fontsize=12)
ax2.set_title('Loss Contours and Optimization Trajectories', fontsize=14)
ax2.legend(fontsize=9)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('neural_network_loss_surface.png', dpi=300, bbox_inches='tight')
plt.show()

Experiment Description:

  1. Non-convexity: Loss surface has multiple saddle points and flat regions. ReLU activation causes zero gradient when(dead neurons).

  2. Trajectory Depends on Initialization: Trajectories from different initial points converge to different local minima. But in the mean-field limit, global convergence is guaranteed (due to averaging effect of particle ensemble).

  3. Symmetry: Loss surface is symmetric about theaxis (positive-negative symmetry of ReLU).

Fisher-Rao Gradient Flows and Conditional Gradient Flows

Fisher-Rao Metric and Natural Gradient

Besides the Wasserstein metric, the space of probability distributions has another important metric — the Fisher-Rao metric, which is the Riemannian metric on statistical manifolds.

Definition (Fisher Information Matrix): For a parametric distribution, the Fisher information matrix is Fisher-Rao Metric: The infinitesimal distance in parameter space is defined as Natural Gradient (Amari 1998): The gradient flow under Fisher-Rao geometry iswhereis the inverse of the Fisher matrix. This converges faster than Euclidean gradient descent because it accounts for the intrinsic geometry of parameter space.

Theorem (Fisher-Rao Gradient Flow): For the KL divergence functionalits Fisher-Rao gradient flow isWhenis a flat distribution, this is equivalent to the heat equation.

Comparison with Wasserstein Gradient Flow:

  • Wasserstein: Measures "transport cost," suitable for describing particle movement.
  • Fisher-Rao: Measures "information geometric distance," suitable for describing distribution shape changes.

Paper Kernel Approximation of Fisher-Rao Gradient Flows studies how to approximate Fisher-Rao gradient flows using kernel methods and applies them to sampling algorithms (e.g., Langevin dynamics).

Conditional Gradient Flows and Frank-Wolfe Algorithm

Problem: Minimize functionalover constraint set: Conditional Gradient Method (Frank-Wolfe): Instead of moving directly along gradient direction, find steepest descent direction within constraint set:Then update Application to Neural Networks: If constraint set is(bounded total variation), conditional gradient flow corresponds to adding new neurons at each step (rather than updating all parameters). This explains greedy training strategies.

PDE Interpretation of Adaptive Optimization Algorithms

Adam Optimizer (Kingma & Ba 2015) uses first and second moment estimates:where.

Continuous-Time Limit: Formally, as step size, Adam corresponds towheresatisfyThis is a coupled dynamical system.acts as an adaptive learning rate (similar to metric reparametrization).

Geometric Interpretation: Adam is equivalent to gradient flow under a coordinate-dependent Riemannian metric:This is similar to the natural gradient idea but uses diagonal approximation (rather than full Fisher matrix).

Theoretical Deepening: Recent Research Advances

Convergence of Mean-Field SGD

Standard mean-field theory assumes continuous time and full-batch gradient. But actual training uses stochastic gradient descent (SGD), involving noise and discreteness.

Paper Mean-Field Analysis of Neural SGD-Ascent studies mean-field equations with noise:whereis SGD noise intensity, corresponding to stochastic differential equation (SDE): Main Results:

  1. Noise Accelerates Convergence: Appropriate noise helps escape saddle points, accelerating convergence to global minimum.
  2. Fluctuation-Dissipation Relation: Relationship between noise intensity, temperature, and batch size:, whereis batch size.
  3. Implicit Regularization: SGD noise corresponds to adding entropy regularization termto loss, favoring flat minima (generalization).

Mean-Field Limit of Multi-Layer Networks

The preceding theory mainly targets two-layer networks. For deep networks, mean-field analysis is more complex, requiring consideration of inter-layer coupling.

Layered Mean-Field Equations: For an-layer network, parameter distributionfor each layer satisfies a coupled PDE system: Challenges:

  1. Asymmetric Coupling: Changes in shallow layers affect deep layers, but feedback occurs through backpropagation.
  2. Gradient Vanishing/Explosion: Gradient flow in deep networks may be temporally unstable.
  3. Residual Connections: Skip connections in ResNets alter flow structure, corresponding to symplectic geometry or volume-preserving flows.

Current Progress: The paper Deep ResNets and Conditional Optimal Transport understands ResNets as discrete-time optimal transport steps, providing a new analytical framework.

Double Descent in Over-Parameterization

Experimental Observation (Belkin et al. 2019): Test error versus model complexity exhibits a "double descent" curve — first decreasing, then increasing (overfitting), then decreasing again (over-parameterization regime).

Mean-Field Explanation: In the over-parameterization regime (), the solution space of mean-field equations is very large, and optimization tends to find interpolating solutions (zero training error). But different initializationslead to different solutions.

Theorem (Implicit Bias): In the mean-field limit, gradient flow converges to the maximum entropy solution:whereis KL divergence relative to. This preference for maximum entropy (most dispersed) solutions has good generalization properties.

Lyapunov Function Construction for Neural Networks

For guaranteeing convergence, the key is finding a Lyapunov functionsatisfying: Candidate Lyapunov Functions:

  1. Training Loss:. But only decreases monotonically under convex or PL conditions.
  2. Free Energy:. Combines loss and entropy.
  3. Inter-Particle Distance:.

Open Problem: For general non-convex losses and arbitrary depth networks, constructing a unified Lyapunov function remains a challenge.

Outlook: Future Directions from the PDE Perspective

Theoretical Directions

  1. Stronger Convergence Guarantees: Precise convergence rates for non-convex losses, finite width, and discrete time.
  2. Generalization Theory: Connecting mean-field limits with PAC learning and Rademacher complexity.
  3. Adversarial Robustness: Characterizing adversarial perturbations under Wasserstein metric, designing robust training algorithms.
  4. PDE Theory for Transformers: How do attention mechanisms, understood as integral operators, evolve?

Algorithmic Directions

  1. PDE Numerical Methods for Optimization: Using high-order ODE/PDE solvers (e.g., Runge-Kutta) to design new optimizers.
  2. Control Theory: Viewing hyperparameter tuning (learning rate, momentum) as optimal control problems.
  3. Sampling Algorithms: Using Langevin dynamics and Wasserstein gradient flows to design more efficient MCMC samplers (for Bayesian deep learning).

Application Directions

  1. Generative Models: Diffusion models are essentially reverse Fokker-Planck equations; PDE theory provides theoretical foundation.
  2. Reinforcement Learning: Policy gradients can be understood as gradient flows on policy space; mean-field methods analyze multi-agent systems.
  3. Scientific Computing: Deep Ritz Method and Physics-Informed Neural Networks (PINNs) transform PDE solving into optimization problems, using PDE theory in reverse to improve training.

Interdisciplinary Crossover

  • Statistical Mechanics: Neural network training analogous to spin glass systems, phase transition phenomena.
  • Optimal Control: Pontryagin's maximum principle for end-to-end optimization in deep learning.
  • Differential Geometry: Deep applications of information geometry and symplectic geometry in optimization.

Summary

This article, starting from variational principles, systematically establishes a partial differential equation perspective on neural network optimization. We demonstrated:

  1. Calculus of variations bridges discrete optimization and continuous dynamics, with Euler-Lagrange equations unifying physics, geometry, and optimization.

  2. Wasserstein geometry provides a natural metric for the space of probability distributions; gradient flow theory unifies classical PDEs like heat and Fokker-Planck equations as gradient flows of energy functionals.

  3. Mean-field limit understands training of finite-width neural networks as collective behavior of particle systems, converging under appropriate scaling to Vlasov-type PDEs, providing global convergence guarantees.

  4. Experimental validation demonstrates correspondence between theoretical predictions and actual training: gradient flow trajectories, particle density evolution, Wasserstein distance decrease — all phenomena are clearly visible in numerical experiments.

  5. Frontier advances include Fisher-Rao gradient flows, stochastic PDE theory for SGD, conditional optimal transport interpretations of deep networks, pointing toward future research directions.

This perspective not only deepens understanding of the essence of neural network optimization but also provides powerful tools for designing new algorithms, analyzing generalization, and constructing theoretical guarantees. As mathematics and machine learning continue to intersect, PDE theory will surely play an increasingly important role in deep learning.

References

  1. L. Chizat and F. Bach, "On the Global Convergence of Gradient Descent for Over-parameterized Models using Optimal Transport," NeurIPS, 2018. arXiv:1805.09545

  2. S. Mei, A. Montanari, and P.-M. Nguyen, "A Mean Field View of the Landscape of Two-Layer Neural Networks," PNAS, 2018. arXiv:1804.06561

  3. G. M. Rotskoff and E. Vanden-Eijnden, "Neural Networks as Interacting Particle Systems: Asymptotic Convexity of the Loss Landscape and Universal Scaling of the Approximation Error," arXiv:1805.00915, 2018.

  4. A. Jacot, F. Gabriel, and C. Hongler, "Neural Tangent Kernel: Convergence and Generalization in Neural Networks," NeurIPS, 2018. arXiv:1806.07572

  5. W. E and B. Yu, "The Deep Ritz Method: A Deep Learning-Based Numerical Algorithm for Solving Variational Problems," CPAM, 2018. arXiv:1710.00211

  6. R. T. Q. Chen, Y. Rubanova, J. Bettencourt, and D. Duvenaud, "Neural Ordinary Differential Equations," NeurIPS, 2018. arXiv:1806.07366

  7. L. Ambrosio, N. Gigli, and G. Savar é, Gradient Flows in Metric Spaces and in the Space of Probability Measures, Birkh ä user, 2008.

  8. C. Villani, Optimal Transport: Old and New, Springer, 2009.

  9. R. Jordan, D. Kinderlehrer, and F. Otto, "The Variational Formulation of the Fokker-Planck Equation," SIAM J. Math. Anal., 1998.

  10. F. Otto, "The Geometry of Dissipative Evolution Equations: the Porous Medium Equation," Comm. PDE, 2001.

  11. Mean-Field Analysis of Neural SGD-Ascent, Y. Lu and J. Lu, 2024.

  12. Kernel Approximation of Fisher-Rao Gradient Flows, A. Kazeykina and M. Fornasier, 2024.

  13. Deep ResNets and Conditional Optimal Transport, D. Onken et al., 2024.

  14. M. Belkin, D. Hsu, S. Ma, and S. Mandal, "Reconciling Modern Machine Learning Practice and the Classical Bias-Variance Trade-off," PNAS, 2019.

  15. S. Amari, "Natural Gradient Works Efficiently in Learning," Neural Computation, 1998.

  16. D. P. Kingma and J. Ba, "Adam: A Method for Stochastic Optimization," ICLR, 2015. arXiv:1412.6980

  17. Y. Li and Y. Liang, "Learning Overparameterized Neural Networks via Stochastic Gradient Descent on Structured Data," NeurIPS, 2018.

  18. L. Chizat, E. Oyallon, and F. Bach, "On Lazy Training in Differentiable Programming," NeurIPS, 2019. arXiv:1812.07956

  19. G. Peyr é and M. Cuturi, "Computational Optimal Transport," Foundations and Trends in Machine Learning, 2019. arXiv:1803.00567

  20. J. Sirignano and K. Spiliopoulos, "Mean Field Analysis of Neural Networks: A Central Limit Theorem," Stoch. Proc. Appl., 2020. arXiv:1808.09372


Code Repository: Complete experimental code and visualization scripts have been uploaded to the GitHub repository (please replace with actual link).

Acknowledgments: Thanks to anonymous reviewers for valuable feedback and in-depth discussions with my advisor on calculus of variations and optimization theory.

  • Post title:PDE and Machine Learning (3): Variational Principles and Optimization
  • Post author:Chen Kai
  • Create time:2022-01-25 10:15:00
  • Post link:https://www.chenk.top/pde-ml-3-variational-principles/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments