Machine Learning Mathematical Derivations (14): Variational Inference and Variational EM
Chen Kai BOSS

Variational Inference transforms Bayesian inference into an optimization problem — when the posterior distribution is difficult to compute exactly, variational inference optimizes over a tractable family of distributions to approximate the true posterior, converting an integration problem into an optimization problem. From variational EM to variational autoencoders, from topic models to deep generative models, variational inference has become a core technique in modern machine learning. This chapter systematically derives the mathematical principles of variational inference, mean-field approximation, coordinate ascent algorithms, and black-box variational inference.

Figure 4
Figure 5

Bayesian Inference and the Posterior Challenge

Bayesian Inference Framework

Observed data: Extra close brace or missing open brace\mathbf{X} = \{\mathbf{x}_1, \dots, \mathbf{x}_N}

Latent variables:Extra close brace or missing open brace\mathbf{Z} = \{\mathbf{z}_1, \dots, \mathbf{z}_N}

Parameters:

Objective: Compute the posterior distribution

Difficulty: The marginal likelihood (evidence)is typically intractable analytically and difficult to compute numerically (high-dimensional).

Exact Inference vs Approximate Inference

Exact inference: - Conjugate priors: Some models have closed-form posteriors - Graphical models: Variable elimination, belief propagation (tree structures)

Approximate inference (needed in most cases): 1. Sampling methods: MCMC (Markov Chain Monte Carlo) - Advantage: Asymptotically exact - Disadvantage: Slow convergence, difficult to diagnose 2. Variational methods: Convert inference to optimization - Advantage: Fast, deterministic - Disadvantage: Biased approximation

Basic Principles of Variational Inference

ELBO Derivation

Idea: Use a simple distributionto approximate the complex posterior

Optimization objective: Minimize KL divergence

Problem: Contains the unknown Transformation:

where the Evidence Lower Bound (ELBO):

Key relationship:

Variational inference objective:

Mean-Field Approximation

Assumption: The variational distribution fully factorizes

Or more concisely, assuming latent variables and parameters are partitioned intogroups:

Optimization: For each factor, fix other factors and maximize ELBO

Coordinate Ascent Variational Inference

ELBO expansion:

whereis entropy.

Optimize for: Fix

Optimal:

Algorithm: Cyclically update each factor until convergence

Variational EM Algorithm

Connection between EM and Variational Inference

Standard EM: - E-step:(exact posterior) - M-step:(point estimate)

Variational EM: - E-step:(variational approximation) - M-step: Variational Bayes EM: - VE-step: Variational update of - VM-step: Variational update of(Bayesian inference of parameters, not point estimate)

Variational Bayes GMM

Model: - Prior:,, - Likelihood:, Variational distribution:

Update formulas (conjugacy properties):

:

:, where, : Normal-Wishart distribution, parameters updated by sufficient statistics (see Bishop PRML Section 10.2)

Black-Box Variational Inference (BBVI)

Gradient Estimation Problem

ELBO:

Gradient:

Difficulty: Gradient and expectation cannot be exchanged (depends on)

REINFORCE Gradient Estimator

Log-derivative trick:

ELBO gradient:

Monte Carlo estimate:

where Problem: High variance

Reparameterization Trick

Idea: Separate randomness from Reparameterization:, whereis a fixed distribution

Example (Gaussian): , ELBO gradient:

Monte Carlo estimate:

Advantages: Low variance, amenable to automatic differentiation

Implementation Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import numpy as np
from scipy.stats import multivariate_normal, dirichlet
from scipy.special import digamma

class VariationalGMM:
"""Variational Bayes Gaussian Mixture Model (Simplified)"""
def __init__(self, n_components=3, max_iter=100, tol=1e-3):
self.K = n_components
self.max_iter = max_iter
self.tol = tol

def fit(self, X):
N, d = X.shape
K = self.K

# Initialize hyperparameters
alpha0 = 1.0
m0 = np.mean(X, axis=0)
beta0 = 1.0
nu0 = d
W0 = np.eye(d)

# Initialize variational parameters
self.alpha = np.ones(K) * alpha0 + N / K
self.beta = np.ones(K) * beta0 + N / K
self.nu = np.ones(K) * nu0 + N / K
self.m = np.array([np.mean(X, axis=0) + 0.1 * np.random.randn(d) for _ in range(K)])
self.W = np.array([W0 for _ in range(K)])

# Initialize responsibilities
r = np.random.dirichlet([1] * K, N)

for iteration in range(self.max_iter):
r_old = r.copy()

# Update responsibilities
r = self._update_r(X, N, d, K)

# Update parameters
N_k = np.sum(r, axis=0)
x_bar_k = (r.T @ X) / N_k[:, np.newaxis]

self.alpha = alpha0 + N_k
self.beta = beta0 + N_k
self.m = (beta0 * m0 + N_k[:, np.newaxis] * x_bar_k) / self.beta[:, np.newaxis]
self.nu = nu0 + N_k

for k in range(K):
S_k = np.zeros((d, d))
for i in range(N):
diff = X[i] - x_bar_k[k]
S_k += r[i, k] * np.outer(diff, diff)

diff_m = x_bar_k[k] - m0
self.W[k] = np.linalg.inv(
np.linalg.inv(W0) + N_k[k] * S_k / N_k[k] +
(beta0 * N_k[k]) / (beta0 + N_k[k]) * np.outer(diff_m, diff_m)
)

# Check convergence
if np.max(np.abs(r - r_old)) < self.tol:
break

return self

def _update_r(self, X, N, d, K):
"""Update responsibilities"""
r = np.zeros((N, K))

for k in range(K):
# E[log pi_k]
E_log_pi = digamma(self.alpha[k]) - digamma(np.sum(self.alpha))

# E[log |Lambda_k|]
E_log_det = np.sum([digamma((self.nu[k] + 1 - i) / 2) for i in range(1, d + 1)])
E_log_det += d * np.log(2) + np.log(np.linalg.det(self.W[k]))

# Expected Mahalanobis distance
for i in range(N):
diff = X[i] - self.m[k]
E_dist = self.nu[k] * diff @ self.W[k] @ diff + d / self.beta[k]
r[i, k] = E_log_pi + 0.5 * E_log_det - 0.5 * E_dist

# Normalize
r = np.exp(r - np.max(r, axis=1, keepdims=True))
r /= np.sum(r, axis=1, keepdims=True)

return r

def predict(self, X):
N, d = X.shape
r = self._update_r(X, N, d, self.K)
return np.argmax(r, axis=1)

# Black-box variational inference example (reparameterization)
class BBVI_Gaussian:
"""Black-Box Variational Inference (Gaussian Approximation)"""
def __init__(self, dim, lr=0.01):
self.mu = np.zeros(dim)
self.log_sigma = np.zeros(dim)
self.lr = lr

def sample(self, n_samples=1):
"""Reparameterized sampling"""
epsilon = np.random.randn(n_samples, len(self.mu))
return self.mu + np.exp(self.log_sigma) * epsilon

def elbo(self, log_p_func, n_samples=10):
"""Estimate ELBO"""
z_samples = self.sample(n_samples)
log_p = np.array([log_p_func(z) for z in z_samples])
log_q = -0.5 * np.sum((z_samples - self.mu) ** 2 / np.exp(2 * self.log_sigma), axis=1)
log_q -= 0.5 * len(self.mu) * np.log(2 * np.pi) + np.sum(self.log_sigma)
return np.mean(log_p - log_q)

def step(self, log_p_func, n_samples=10):
"""Single optimization step (numerical gradient)"""
elbo_current = self.elbo(log_p_func, n_samples)

# Numerical gradient (simplified)
eps = 1e-4
grad_mu = np.zeros_like(self.mu)
grad_log_sigma = np.zeros_like(self.log_sigma)

for i in range(len(self.mu)):
self.mu[i] += eps
grad_mu[i] = (self.elbo(log_p_func, n_samples) - elbo_current) / eps
self.mu[i] -= eps

self.log_sigma[i] += eps
grad_log_sigma[i] = (self.elbo(log_p_func, n_samples) - elbo_current) / eps
self.log_sigma[i] -= eps

# Gradient ascent
self.mu += self.lr * grad_mu
self.log_sigma += self.lr * grad_log_sigma

if __name__ == '__main__':
# Variational GMM example
from sklearn.datasets import make_blobs
X, _ = make_blobs(n_samples=300, centers=3, n_features=2, random_state=42)

vgmm = VariationalGMM(n_components=3, max_iter=50)
vgmm.fit(X)
labels = vgmm.predict(X)

print(f"Clustering complete, weight estimates: {vgmm.alpha / np.sum(vgmm.alpha)}")

Q&A

Q1: Variational Inference vs MCMC?

A: - Variational: Fast, deterministic, biased (non-zero KL divergence) - MCMC: Slow, stochastic, asymptotically unbiased

Variational is suitable for large-scale data and online learning; MCMC is suitable for exact inference.


Q2: Why use KL(q||p) instead of KL(p||q)?

A: KL(q||p) is the "reverse KL", makingsmall whereis small (zero-forcing). KL(p||q) is the "forward KL", makingcover all modes of(moment matching). The reverse KL only requires sampling from, not normalizing.


Q3: When does the mean-field assumption fail?

A: When variables are strongly correlated. Solutions: - Structured variational (preserve some dependencies) - Richer variational families (normalizing flows)


Q4: Variational Bayes vs point estimates (MAP/MLE)?

A: Variational Bayes preserves uncertainty and prevents overfitting. Cost: Higher computational complexity. Use variational Bayes for small data/high regularization needs; use point estimates for large data/speed requirements.


Q5: When is the reparameterization trick applicable?

A: Requires continuous differentiable distributions. Applicable: Gaussian, Logistic, Laplace. Not applicable: Discrete distributions (need REINFORCE or Gumbel-Softmax).


✏️ Exercises and Solutions

Exercise 1: ELBO Derivation

Problem: Prove. Solution: , KL≥0, hence ELBO.

Exercise 2: Mean Field

Problem: , derive . Solution:

Exercise 3: Variational EM

Problem: What do E-step and M-step optimize? Solution: E-step: fix, optimize . M-step: fix , optimize.

Exercise 4: VAE Reparameterization

Problem: Why instead of direct sampling? Solution: Direct sampling non-differentiable. Reparameterization moves randomness to.

Exercise 5: VI vs MCMC

Problem: When to use VI vs MCMC? Solution: VI: fast but biased, good for large data. MCMC: asymptotically unbiased but slow.

✏️ Exercises and Solutions

Exercise 1: ELBO Derivation

Problem: Prove. Solution: , KL≥0, hence ELBO.

Exercise 2: Mean Field

Problem: , derive . Solution:

Exercise 3: Variational EM

Problem: What do E-step and M-step optimize? Solution: E-step: fix, optimize . M-step: fix , optimize.

Exercise 4: VAE Reparameterization

Problem: Why instead of direct sampling? Solution: Direct sampling non-differentiable. Reparameterization moves randomness to.

Exercise 5: VI vs MCMC

Problem: When to use VI vs MCMC? Solution: VI: fast but biased, good for large data. MCMC: asymptotically unbiased but slow.

Referencess

  1. Jordan, M. I., et al. (1999). An introduction to variational methods for graphical models. Machine Learning, 37(2), 183-233.
  2. Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational inference: A review for statisticians. JASA, 112(518), 859-877.
  3. Kingma, D. P., & Welling, M. (2014). Auto-encoding variational Bayes. ICLR.
  4. Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. AISTATS.

Variational inference transforms the integration challenge of Bayesian inference into an optimization problem, trading off deterministic algorithms for computational efficiency. From classical mean-field approximation to modern black-box variational inference, from VAE to deep generative models, variational methods have become foundational tools in machine learning. Understanding variational inference is a necessary path toward probabilistic programming and Bayesian deep learning.

  • Post title:Machine Learning Mathematical Derivations (14): Variational Inference and Variational EM
  • Post author:Chen Kai
  • Create time:2021-11-11 14:30:00
  • Post link:https://www.chenk.top/Machine-Learning-Mathematical-Derivations-14-Variational-Inference-and-Variational-EM/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments