变分推断(Variational
Inference)将贝叶斯推断转化为优化问题——当后验分布难以精确计算时,变分推断通过优化一个易处理的分布族来近似真实后验,将积分问题转化为优化问题。从变分
EM
到变分自编码器,从主题模型到深度生成模型,变分推断已成为现代机器学习的核心技术。本章将系统推导变分推断的数学原理、平均场近似、坐标上升算法与黑盒变分推断。
贝叶斯推断与后验难题
贝叶斯推断框架
观测数据 : * *隐变量**: * *参数**:
目标 :计算后验分布
$$
P(, ) = = $$
困难 :边缘似然(证据)
$$
P() = P(, , ) d d $
通常无法解析计算,也难以数值积分(高维)。
精确推断 vs 近似推断
精确推断 : - 共轭先验:某些模型后验有闭式解 -
图模型:变量消除、信念传播(树结构)
近似推断 (大多数情况需要): 1.
采样方法 : MCMC(马尔可夫链蒙特卡洛) - 优点:渐近精确
- 缺点:收敛慢,难以诊断 2. 变分方法 :将推断转化为优化
- 优点:快速,确定性 - 缺点:有偏近似
变分推断的基本原理
ELBO 推导
想法 :用简单分布 近似复杂后验
优化目标 :最小化 KL 散度
$
问题 :包含未知的
转换 :
$
其中证据下界 (ELBO):
$
关键关系 :
$
变分推断目标 :
$$
q^{*} = _q (q) = _q (q | p) $$
平均场近似
假设 :变分分布完全分解
$$
q(, ) = _{i=1}^N q_i(i) {j} q_j(_j) $$
或更简洁地,假设隐变量和参数划分为 组:
$$
q(, ) = _{j=1}^M q_j(_j) $$
优化 :对每个因子 ,固定其他因子,最大化
ELBO
坐标上升变分推断
ELBO 展开 :
$
其中 是熵。
对 优化 :固定 $
**最优$ q_j^$ *:
$q_j^{*}(j) = {q_{-j}} [P(, , )] + $
q_j^{*}(j) ( {q_{-j}} [P(, , )] ) $$
算法 :循环更新每个因子直至收敛
变分 EM 算法
EM 与变分推断的联系
标准 EM : - E 步: (精确后验) - M
步: (点估计)
变分 EM : - E 步: (变分近似) - M 步:
变分贝叶斯 EM : - VE 步: 变分更新 - VM 步: 变分更新(贝叶斯推断参数而非点估计)
变分贝叶斯 GMM
模型 : - 先验: ,_k (_0, _0^{-1} _k), - 似然: ,_i z_i=k (_k, _k^{-1})$
变分分布 :
$$
q(, , , ) = q() q() _{k=1}^K q(_k, _k) $$
更新公式 (共轭性质):
** * *:
$$
r_{ik} = q(z_i=k) ( [_k] + [|_k|] - - [(_i - _k)^T _k (_i - _k)] )
$$
**$ q(): ,其中_k = _0 + N_k, N _k = i r {ik} q(_k,
_k)$* *:正态-Wishart 分布,参数由充分统计量更新(详见 Bishop PRML 10.2
节)
黑盒变分推断(BBVI)
梯度估计问题
ELBO :
$
梯度 :
$
困难 :梯度与期望不可交换( 依赖 )
REINFORCE 梯度估计器
对数导数技巧 :
$
ELBO 梯度 :
$
蒙特卡洛估计 :
$
其中
问题 :高方差
重参数化技巧
想法 :将随机性从 分离出来
重参数化 : ,其中 是固定分布
示例 (高斯): $ = + , (0, )$
ELBO 梯度 :
$
蒙特卡洛估计 :
$
优势 :低方差,可自动微分
实现示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 import numpy as npfrom scipy.stats import multivariate_normal, dirichletfrom scipy.special import digammaclass VariationalGMM : """变分贝叶斯高斯混合模型(简化版)""" def __init__ (self, n_components=3 , max_iter=100 , tol=1e-3 ): self.K = n_components self.max_iter = max_iter self.tol = tol def fit (self, X ): N, d = X.shape K = self.K alpha0 = 1.0 m0 = np.mean(X, axis=0 ) beta0 = 1.0 nu0 = d W0 = np.eye(d) self.alpha = np.ones(K) * alpha0 + N / K self.beta = np.ones(K) * beta0 + N / K self.nu = np.ones(K) * nu0 + N / K self.m = np.array([np.mean(X, axis=0 ) + 0.1 * np.random.randn(d) for _ in range (K)]) self.W = np.array([W0 for _ in range (K)]) r = np.random.dirichlet([1 ] * K, N) for iteration in range (self.max_iter): r_old = r.copy() r = self._update_r(X, N, d, K) N_k = np.sum (r, axis=0 ) x_bar_k = (r.T @ X) / N_k[:, np.newaxis] self.alpha = alpha0 + N_k self.beta = beta0 + N_k self.m = (beta0 * m0 + N_k[:, np.newaxis] * x_bar_k) / self.beta[:, np.newaxis] self.nu = nu0 + N_k for k in range (K): S_k = np.zeros((d, d)) for i in range (N): diff = X[i] - x_bar_k[k] S_k += r[i, k] * np.outer(diff, diff) diff_m = x_bar_k[k] - m0 self.W[k] = np.linalg.inv( np.linalg.inv(W0) + N_k[k] * S_k / N_k[k] + (beta0 * N_k[k]) / (beta0 + N_k[k]) * np.outer(diff_m, diff_m) ) if np.max (np.abs (r - r_old)) < self.tol: break return self def _update_r (self, X, N, d, K ): """更新责任度""" r = np.zeros((N, K)) for k in range (K): E_log_pi = digamma(self.alpha[k]) - digamma(np.sum (self.alpha)) E_log_det = np.sum ([digamma((self.nu[k] + 1 - i) / 2 ) for i in range (1 , d + 1 )]) E_log_det += d * np.log(2 ) + np.log(np.linalg.det(self.W[k])) for i in range (N): diff = X[i] - self.m[k] E_dist = self.nu[k] * diff @ self.W[k] @ diff + d / self.beta[k] r[i, k] = E_log_pi + 0.5 * E_log_det - 0.5 * E_dist r = np.exp(r - np.max (r, axis=1 , keepdims=True )) r /= np.sum (r, axis=1 , keepdims=True ) return r def predict (self, X ): N, d = X.shape r = self._update_r(X, N, d, self.K) return np.argmax(r, axis=1 ) class BBVI_Gaussian : """黑盒变分推断(高斯近似)""" def __init__ (self, dim, lr=0.01 ): self.mu = np.zeros(dim) self.log_sigma = np.zeros(dim) self.lr = lr def sample (self, n_samples=1 ): """重参数化采样""" epsilon = np.random.randn(n_samples, len (self.mu)) return self.mu + np.exp(self.log_sigma) * epsilon def elbo (self, log_p_func, n_samples=10 ): """估计 ELBO""" z_samples = self.sample(n_samples) log_p = np.array([log_p_func(z) for z in z_samples]) log_q = -0.5 * np.sum ((z_samples - self.mu) ** 2 / np.exp(2 * self.log_sigma), axis=1 ) log_q -= 0.5 * len (self.mu) * np.log(2 * np.pi) + np.sum (self.log_sigma) return np.mean(log_p - log_q) def step (self, log_p_func, n_samples=10 ): """单步优化(数值梯度)""" elbo_current = self.elbo(log_p_func, n_samples) eps = 1e-4 grad_mu = np.zeros_like(self.mu) grad_log_sigma = np.zeros_like(self.log_sigma) for i in range (len (self.mu)): self.mu[i] += eps grad_mu[i] = (self.elbo(log_p_func, n_samples) - elbo_current) / eps self.mu[i] -= eps self.log_sigma[i] += eps grad_log_sigma[i] = (self.elbo(log_p_func, n_samples) - elbo_current) / eps self.log_sigma[i] -= eps self.mu += self.lr * grad_mu self.log_sigma += self.lr * grad_log_sigma if __name__ == '__main__' : from sklearn.datasets import make_blobs X, _ = make_blobs(n_samples=300 , centers=3 , n_features=2 , random_state=42 ) vgmm = VariationalGMM(n_components=3 , max_iter=50 ) vgmm.fit(X) labels = vgmm.predict(X) print (f"聚类完成,权重估计: {vgmm.alpha / np.sum (vgmm.alpha)} " )
Q&A 精选
Q1: 变分推断 vs MCMC?
A: - 变分 : 快速、确定性、有偏(KL 散度非零) -
MCMC : 慢、随机、渐近无偏
变分适合大规模数据和在线学习,MCMC 适合精确推断。
Q2: 为什么用 KL(q||p)而非 KL(p||q)?
A: KL(q||p)是"反向 KL",使 在 小的地方也小(零逼近)。
KL(p||q)是"正向 KL",使 覆盖 的所有模式(矩匹配)。反向 KL
计算上只需 可采样,不需要归一化 。
Q3: 平均场假设何时失效?
A: 当变量强相关时。解决方法: - 结构化变分(保留部分依赖) - 更 rich
的变分族(normalizing flows)
Q4: 变分贝叶斯 vs 点估计(MAP/MLE)?
A:
变分贝叶斯保留不确定性,防止过拟合。代价:计算复杂度高。小数据/正则化需求高→变分贝叶斯;大数据/速度需求→点估计。
Q5: 重参数化技巧的适用范围?
A: 需要连续可微的分布。适用:高斯、 Logistic 、 Laplace
。不适用:离散分布(需 REINFORCE 或 Gumbel-Softmax)。
参考文献
Jordan, M. I., et al. (1999). An introduction to
variational methods for graphical models. Machine Learning ,
37(2), 183-233.
Blei, D. M., Kucukelbir, A., & McAuliffe, J. D.
(2017). Variational inference: A review for statisticians.
JASA , 112(518), 859-877.
Kingma, D. P., & Welling, M. (2014).
Auto-encoding variational Bayes. ICLR .
Ranganath, R., Gerrish, S., & Blei, D. (2014).
Black box variational inference. AISTATS .
变分推断将贝叶斯推断的积分难题转化为优化问题,以确定性算法换取计算效率。从经典的平均场近似到现代的黑盒变分推断,从
VAE
到深度生成模型,变分方法已成为机器学习的基础工具。理解变分推断,是通往概率编程、贝叶斯深度学习的必经之路。