机器学习数学推导(十四)变分推断与变分 EM
Chen Kai BOSS

变分推断(Variational Inference)将贝叶斯推断转化为优化问题——当后验分布难以精确计算时,变分推断通过优化一个易处理的分布族来近似真实后验,将积分问题转化为优化问题。从变分 EM 到变分自编码器,从主题模型到深度生成模型,变分推断已成为现代机器学习的核心技术。本章将系统推导变分推断的数学原理、平均场近似、坐标上升算法与黑盒变分推断。

贝叶斯推断与后验难题

贝叶斯推断框架

观测数据* *隐变量**: * *参数**:

目标:计算后验分布

$$

P(, ) = = $$

困难:边缘似然(证据)

$$

P() = P(, , ) d d$

通常无法解析计算,也难以数值积分(高维)。

精确推断 vs 近似推断

精确推断: - 共轭先验:某些模型后验有闭式解 - 图模型:变量消除、信念传播(树结构)

近似推断(大多数情况需要): 1. 采样方法: MCMC(马尔可夫链蒙特卡洛) - 优点:渐近精确 - 缺点:收敛慢,难以诊断 2. 变分方法:将推断转化为优化 - 优点:快速,确定性 - 缺点:有偏近似

变分推断的基本原理

ELBO 推导

想法:用简单分布 近似复杂后验

优化目标:最小化 KL 散度

$

问题:包含未知的

转换

$

其中证据下界(ELBO):

$

关键关系

$

变分推断目标

$$

q^{*} = _q (q) = _q (q | p) $$

平均场近似

假设:变分分布完全分解

$$

q(, ) = _{i=1}^N q_i(i) {j} q_j(_j) $$

或更简洁地,假设隐变量和参数划分为 组:

$$

q(, ) = _{j=1}^M q_j(_j) $$

优化:对每个因子,固定其他因子,最大化 ELBO

坐标上升变分推断

ELBO 展开

$

其中 是熵。

优化:固定 $

**最优$ q_j^$ *:

$q_j^{*}(j) = {q_{-j}} [P(, , )] + $

q_j^{*}(j) ( {q_{-j}} [P(, , )] ) $$

算法:循环更新每个因子直至收敛

变分 EM 算法

EM 与变分推断的联系

标准 EM: - E 步:(精确后验) - M 步: (点估计)

变分 EM: - E 步:(变分近似) - M 步:

变分贝叶斯 EM: - VE 步: 变分更新 - VM 步: 变分更新(贝叶斯推断参数而非点估计)

变分贝叶斯 GMM

模型: - 先验: ,_k (_0, _0^{-1} _k) - 似然: ,_i z_i=k (_k, _k^{-1})$

变分分布

$$

q(, , , ) = q() q() _{k=1}^K q(_k, _k) $$

更新公式(共轭性质):

*** *:

$$

r_{ik} = q(z_i=k) ( [_k] + [|_k|] - - [(_i - _k)^T _k (_i - _k)] ) $$

**$ q(),其中_k = _0 + N_kN _k = i r{ik} q(_k, _k)$* *:正态-Wishart 分布,参数由充分统计量更新(详见 Bishop PRML 10.2 节)

黑盒变分推断(BBVI)

梯度估计问题

ELBO

$

梯度

$

困难:梯度与期望不可交换( 依赖

REINFORCE 梯度估计器

对数导数技巧

$

ELBO 梯度

$

蒙特卡洛估计

$

其中

问题:高方差

重参数化技巧

想法:将随机性从 分离出来

重参数化,其中 是固定分布

示例(高斯): $ = + , (0, )$

ELBO 梯度

$

蒙特卡洛估计

$

优势:低方差,可自动微分

实现示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import numpy as np
from scipy.stats import multivariate_normal, dirichlet
from scipy.special import digamma

class VariationalGMM:
"""变分贝叶斯高斯混合模型(简化版)"""
def __init__(self, n_components=3, max_iter=100, tol=1e-3):
self.K = n_components
self.max_iter = max_iter
self.tol = tol

def fit(self, X):
N, d = X.shape
K = self.K

# 初始化超参数
alpha0 = 1.0
m0 = np.mean(X, axis=0)
beta0 = 1.0
nu0 = d
W0 = np.eye(d)

# 初始化变分参数
self.alpha = np.ones(K) * alpha0 + N / K
self.beta = np.ones(K) * beta0 + N / K
self.nu = np.ones(K) * nu0 + N / K
self.m = np.array([np.mean(X, axis=0) + 0.1 * np.random.randn(d) for _ in range(K)])
self.W = np.array([W0 for _ in range(K)])

# 初始化责任度
r = np.random.dirichlet([1] * K, N)

for iteration in range(self.max_iter):
r_old = r.copy()

# 更新责任度
r = self._update_r(X, N, d, K)

# 更新参数
N_k = np.sum(r, axis=0)
x_bar_k = (r.T @ X) / N_k[:, np.newaxis]

self.alpha = alpha0 + N_k
self.beta = beta0 + N_k
self.m = (beta0 * m0 + N_k[:, np.newaxis] * x_bar_k) / self.beta[:, np.newaxis]
self.nu = nu0 + N_k

for k in range(K):
S_k = np.zeros((d, d))
for i in range(N):
diff = X[i] - x_bar_k[k]
S_k += r[i, k] * np.outer(diff, diff)

diff_m = x_bar_k[k] - m0
self.W[k] = np.linalg.inv(
np.linalg.inv(W0) + N_k[k] * S_k / N_k[k] +
(beta0 * N_k[k]) / (beta0 + N_k[k]) * np.outer(diff_m, diff_m)
)

# 检查收敛
if np.max(np.abs(r - r_old)) < self.tol:
break

return self

def _update_r(self, X, N, d, K):
"""更新责任度"""
r = np.zeros((N, K))

for k in range(K):
# E[log pi_k]
E_log_pi = digamma(self.alpha[k]) - digamma(np.sum(self.alpha))

# E[log |Lambda_k|]
E_log_det = np.sum([digamma((self.nu[k] + 1 - i) / 2) for i in range(1, d + 1)])
E_log_det += d * np.log(2) + np.log(np.linalg.det(self.W[k]))

# 马氏距离期望
for i in range(N):
diff = X[i] - self.m[k]
E_dist = self.nu[k] * diff @ self.W[k] @ diff + d / self.beta[k]
r[i, k] = E_log_pi + 0.5 * E_log_det - 0.5 * E_dist

# 归一化
r = np.exp(r - np.max(r, axis=1, keepdims=True))
r /= np.sum(r, axis=1, keepdims=True)

return r

def predict(self, X):
N, d = X.shape
r = self._update_r(X, N, d, self.K)
return np.argmax(r, axis=1)

# 黑盒变分推断示例(重参数化)
class BBVI_Gaussian:
"""黑盒变分推断(高斯近似)"""
def __init__(self, dim, lr=0.01):
self.mu = np.zeros(dim)
self.log_sigma = np.zeros(dim)
self.lr = lr

def sample(self, n_samples=1):
"""重参数化采样"""
epsilon = np.random.randn(n_samples, len(self.mu))
return self.mu + np.exp(self.log_sigma) * epsilon

def elbo(self, log_p_func, n_samples=10):
"""估计 ELBO"""
z_samples = self.sample(n_samples)
log_p = np.array([log_p_func(z) for z in z_samples])
log_q = -0.5 * np.sum((z_samples - self.mu) ** 2 / np.exp(2 * self.log_sigma), axis=1)
log_q -= 0.5 * len(self.mu) * np.log(2 * np.pi) + np.sum(self.log_sigma)
return np.mean(log_p - log_q)

def step(self, log_p_func, n_samples=10):
"""单步优化(数值梯度)"""
elbo_current = self.elbo(log_p_func, n_samples)

# 数值梯度(简化)
eps = 1e-4
grad_mu = np.zeros_like(self.mu)
grad_log_sigma = np.zeros_like(self.log_sigma)

for i in range(len(self.mu)):
self.mu[i] += eps
grad_mu[i] = (self.elbo(log_p_func, n_samples) - elbo_current) / eps
self.mu[i] -= eps

self.log_sigma[i] += eps
grad_log_sigma[i] = (self.elbo(log_p_func, n_samples) - elbo_current) / eps
self.log_sigma[i] -= eps

# 梯度上升
self.mu += self.lr * grad_mu
self.log_sigma += self.lr * grad_log_sigma

if __name__ == '__main__':
# 变分 GMM 示例
from sklearn.datasets import make_blobs
X, _ = make_blobs(n_samples=300, centers=3, n_features=2, random_state=42)

vgmm = VariationalGMM(n_components=3, max_iter=50)
vgmm.fit(X)
labels = vgmm.predict(X)

print(f"聚类完成,权重估计: {vgmm.alpha / np.sum(vgmm.alpha)}")

Q&A 精选

Q1: 变分推断 vs MCMC?

A: - 变分: 快速、确定性、有偏(KL 散度非零) - MCMC: 慢、随机、渐近无偏

变分适合大规模数据和在线学习,MCMC 适合精确推断。


Q2: 为什么用 KL(q||p)而非 KL(p||q)?

A: KL(q||p)是"反向 KL",使 小的地方也小(零逼近)。 KL(p||q)是"正向 KL",使 覆盖 的所有模式(矩匹配)。反向 KL 计算上只需 可采样,不需要归一化


Q3: 平均场假设何时失效?

A: 当变量强相关时。解决方法: - 结构化变分(保留部分依赖) - 更 rich 的变分族(normalizing flows)


Q4: 变分贝叶斯 vs 点估计(MAP/MLE)?

A: 变分贝叶斯保留不确定性,防止过拟合。代价:计算复杂度高。小数据/正则化需求高→变分贝叶斯;大数据/速度需求→点估计。


Q5: 重参数化技巧的适用范围?

A: 需要连续可微的分布。适用:高斯、 Logistic 、 Laplace 。不适用:离散分布(需 REINFORCE 或 Gumbel-Softmax)。


参考文献

  1. Jordan, M. I., et al. (1999). An introduction to variational methods for graphical models. Machine Learning, 37(2), 183-233.
  2. Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational inference: A review for statisticians. JASA, 112(518), 859-877.
  3. Kingma, D. P., & Welling, M. (2014). Auto-encoding variational Bayes. ICLR.
  4. Ranganath, R., Gerrish, S., & Blei, D. (2014). Black box variational inference. AISTATS.

变分推断将贝叶斯推断的积分难题转化为优化问题,以确定性算法换取计算效率。从经典的平均场近似到现代的黑盒变分推断,从 VAE 到深度生成模型,变分方法已成为机器学习的基础工具。理解变分推断,是通往概率编程、贝叶斯深度学习的必经之路。

  • 本文标题:机器学习数学推导(十四)变分推断与变分 EM
  • 本文作者:Chen Kai
  • 创建时间:2021-11-11 14:30:00
  • 本文链接:https://www.chenk.top/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E6%95%B0%E5%AD%A6%E6%8E%A8%E5%AF%BC%EF%BC%88%E5%8D%81%E5%9B%9B%EF%BC%89%E5%8F%98%E5%88%86%E6%8E%A8%E6%96%AD%E4%B8%8E%E5%8F%98%E5%88%86EM/
  • 版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
 评论