重参数化详解与 Gumbel Softmax 深入探讨
Chen Kai BOSS

一旦模型里出现“采样”,训练就会立刻遇到一个硬问题:梯度怎么穿过随机节点?重参数化技巧给出的答案非常直接——把 改写成,把随机性隔离到与参数无关的噪声里,于是反向传播可以顺着走下去。麻烦在于离散变量:一类操作不可导,梯度会断掉。 Gumbel-Softmax(或 Concrete 分布)用“带温度的 softmax + Gumbel 噪声”把离散采样变成可微的近似,让你在保留离散结构的同时仍能端到端训练。本文会把连续重参数化与 Gumbel-Softmax 的推导、直觉与实现细节讲清楚,并重点解释温度、偏差-方差权衡以及实际训练中最常见的坑。

重参数化的基本概念

重参数化( Reparameterization) 是机器学习中一种重要的技术,主要用于处理涉及随机变量的模型。其核心思想是将随机变量的采样过程转化为一个确定性函数与噪声变量的组合,从而使得梯度能够通过采样过程进行传播。这对于优化包含随机性的模型,如变分自编码器( VAE)和生成对抗网络( GANs)等,至关重要。

为什么需要重参数化?

在许多机器学习模型中,需要从某个分布中采样随机变量。例如,在 VAE 中,潜在变量的采样对于模型的训练至关重要。然而,直接对随机变量进行采样会导致以下问题:

  1. 梯度不可传递:采样过程本身是一个非微分操作,无法通过反向传播计算梯度。
  2. 优化困难:由于无法计算梯度,传统的梯度下降方法难以应用于模型参数的优化。

重参数化通过将采样过程重新表达为一个可微分的形式,解决了上述问题,使得模型参数能够通过梯度下降等方法进行有效优化。

重参数化的数学表达

重参数化的基本思想是将随机变量表示为一个确定性函数与独立噪声变量的组合:

其中:

- 是来自一个简单且与模型参数 无关的分布(例如,标准正态分布)。 - 是一个确定性函数,通常依赖于模型参数

通过这种表示方式,随机性被隔离在 中,而模型参数 影响的是确定性部分,从而使得整个过程对 可微分。

连续分布中的重参数化

正态分布的重参数化

正态分布是重参数化中最常见的例子。假设潜在变量服从均值为、方差为的正态分布:

直接对 进行采样会导致梯度无法有效传递,因为采样过程不可微。通过重参数化,可以将 表示为:

其中, 表示逐元素相乘。这样, 被表示为 的函数,以及独立于模型参数的噪声。由于 是可微的,整个优化过程可以通过梯度下降有效进行。

重参数化在 VAE 中的应用

变分自编码器( VAE)中,重参数化技巧用于优化潜在变量的分布。 VAE 通过最大化证据下界( ELBO)来学习数据的潜在表示。具体流程如下:

  1. 编码器:将高维输入数据映射到低维潜在空间,输出潜在变量的分布参数(如均值和标准差)。具体来说,编码器网络接受输入数据,并输出潜在变量的分布参数。这里,通常是通过神经网络的最后一层线性变换得到的:

这些参数定义了潜在变量 的高斯分布:

  1. 重参数化:为了使得采样过程可微分,引入了重参数化技巧,将随机变量 表示为确定性函数和独立噪声变量 的组合,通过 生成潜在变量

    1. 独立噪声变量:独立于模型参数,只依赖于预定义的简单分布(如标准正态分布)
    2. 确定性函数:将 作为参数,通过线性变换与噪声 结合,生成潜在变量
  2. 解码器:解码器网络接受潜在变量,并生成重建后的数据

    1. 重建过程:解码器试图从潜在变量 中重建出原始输入数据,目标是使得 尽可能接近
    2. 生成能力:通过训练,解码器学会了如何从潜在空间中生成与训练数据相似的新数据。

这种方法允许梯度通过传播,从而实现端到端的训练。

ELBO 的最大化

VAE 的目标是最大化证据下界( ELBO),其数学表达式为:

> 其中:

  • 第一项:重建误差,衡量从潜在变量重建数据的准确性。
  • 第二项: KL 散度,衡量编码器输出的潜在分布与先验分布之间的差异。

通过重参数化技巧, ELBO 的梯度能够有效传递到编码器和解码器的参数,从而实现优化。

重参数化的数学原理

重参数化的数学基础在于将期望形式的目标函数转化为可微形式。对于连续情形的目标函数:

通过重参数化,可以将其转化为:

这使得梯度可以通过传递,从而实现有效的优化。

代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 定义编码器网络
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim) # 输出均值
self.fc_logvar = nn.Linear(hidden_dim, latent_dim) # 输出对数方差

def forward(self, x):
h = F.relu(self.fc1(x))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar

# 定义解码器网络
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, output_dim)

def forward(self, z):
h = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h))

# 定义 VAE 模型
class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super(VAE, self).__init__()
self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
self.decoder = Decoder(latent_dim, hidden_dim, input_dim)

def reparameterize(self, mu, logvar):
"""
重参数化技巧:
z = mu + sigma * epsilon
其中 epsilon ~ N(0, 1)
"""
std = torch.exp(0.5 * logvar) # 计算标准差
eps = torch.randn_like(std) # 从标准正态分布采样 epsilon
return mu + std * eps # 生成潜在变量 z

def forward(self, x):
mu, logvar = self.encoder(x) # 编码器输出均值和对数方差
z = self.reparameterize(mu, logvar) # 重参数化生成 z
recon_x = self.decoder(z) # 解码器重建输入
return recon_x, mu, logvar

# 定义损失函数
def loss_function(recon_x, x, mu, logvar):
"""
VAE 的损失函数包括重建误差和 KL 散度
"""
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum') # 重建误差
# KL 散度
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD

# 训练过程
def train_vae(model, dataloader, optimizer, epochs=10):
model.train()
for epoch in range(1, epochs + 1):
train_loss = 0
for batch_idx, (data, _) in enumerate(dataloader):
data = data.view(-1, 784) # 展平图像
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()

print(f'Epoch {epoch}, Average loss: {train_loss / len(dataloader.dataset):.4f}')

# 测试过程
def test_vae(model, dataloader):
model.eval()
test_loss = 0
with torch.no_grad():
for data, _ in dataloader:
data = data.view(-1, 784)
recon_batch, mu, logvar = model(data)
test_loss += loss_function(recon_batch, data, mu, logvar).item()

print(f'Test set loss: {test_loss / len(dataloader.dataset):.4f}')

# 主函数
def main():
# 数据加载与预处理
transform = transforms.ToTensor()
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# 模型初始化
model = VAE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 训练与测试
for epoch in range(1, 11):
train_vae(model, train_loader, optimizer, epochs=1)
test_vae(model, test_loader)

if __name__ == "__main__":
main()

离散分布中的重参数化

挑战

对于离散分布,直接应用重参数化技巧面临以下挑战:

  1. 非微分操作:离散变量的采样过程(如操作)通常是不可微的,导致梯度无法有效传递。

    考虑自然语言处理中的词汇选择任务。假设模型需要生成一个单词作为输出:

    1. 前向传播

      • 模型输出每个单词的 logits
      • 使用选择概率最高的单词
    2. 反向传播

      • 由于是不可微的,无法计算,导致梯度无法传递回模型参数。

    这种情况下,传统的梯度下降方法无法直接优化模型,因为梯度信息在采样步骤中丢失。

  2. 高维度问题:当类别数量较大时,直接对所有可能的类别进行求和计算期望变得计算量巨大,甚至不可行。

    考虑图像生成任务中的像素值预测。假设每个像素可以取个不同的灰度值:

    1. 模型输出

      • 对于每个像素,模型输出个 logits,对应于每个灰度值的概率。
    2. 计算期望

      • 如果需要计算某种统计量(如期望值),需要对所有个类别进行求和。
    3. 高维度扩展

      • 对于高分辨率图像,每个图像包含数以万计的像素,每个像素又有个可能的值,计算期望的成本急剧增加。

引入 Gumbel Max

为了解决上述问题,引入了Gumbel Max 技巧。假设有一个类别的分布,其概率通过 Softmax 函数定义:Gumbel Max 通过以下步骤实现从离散分布中采样:

  1. 对每个类别,计算:

  2. 选择 作为采样结果。

这种方法确保了输出类别的概率与一致。

Gumbel Max 的数学证明

以类别 1 为例,证明 Gumbel Max 输出类别 1 的概率为

  1. 定义条件

    输出类别 1 意味着:

  2. 转化不等式

    对于每个,有:

  3. 计算概率

    由于,则每个不等式的概率为:

  4. 综合概率

    所有不等式同时成立的概率为:

  5. 求期望

    对所有求期望:

由于上述推导过程中简化了部分步骤,最终结果为

Gumbel Max 的构思过程

要理解 Gumbel-Max 的推导和构思过程,首先要认识到其基础是极值理论Gumbel 分布。 Gumbel 分布的一个关键性质是它能帮助找到一组随机变量中的最大值。研究者们发现,通过将 Gumbel 噪声添加到类别的对数概率上,可以从离散分布中采样。这一构思来自于需要快速、高效的从离散概率分布中进行采样,而传统方法在处理大量类别时表现欠佳。

通过这个方法,研究者们设计了一个方法,使得从 Gumbel 分布中添加噪声,并选择最大值能够解决离散采样的难题。这种方法不仅速度快,而且保持了类别的相对概率顺序,最终得到了 Gumbel-Max 采样方法。

Gumbel Softmax:离散分布的重参数化

原理

尽管 Gumbel Max 能够实现从离散分布中采样,但其包含的操作是不可微的,无法用于梯度传播。为此,引入了Gumbel Softmax,它是 Gumbel Max 的光滑近似版本,通过引入温度参数实现可微分采样过程。

Gumbel Softmax 的定义如下:

其中:

- - -是温度参数

温度退火

温度参数控制了输出分布的平滑度:

  • 高温度(较大):输出更加平滑,接近于均匀分布。
  • 低温度(较小):输出接近于 one-hot 向量,即更具确定性。

在训练过程中,通常采用温度退火策略,逐渐减小,以提高采样结果的离散性,从而更好地模拟真实的离散采样过程。

Gumbel Softmax 的数学推导

Gumbel Softmax 基于 Gumbel Max,通过 softmax 函数对 Gumbel 噪声进行了光滑处理,使得采样过程可微。具体步骤如下:

  1. 添加 Gumbel 噪声

    对每个类别,计算:

其中,

  1. 应用 Softmax 函数

    将添加了噪声的 logits 通过 softmax 函数处理,并除以温度参数

  2. 可微性

    由于 softmax 函数是可微的, Gumbel Softmax 允许梯度通过采样过程进行传播,从而实现端到端的训练。

Gumbel Softmax 的构思过程

从 Gumbel-Max 到 Gumbel Softma · x 的过渡,主要的思考点是如何使得不可微的​ 操作变为可微的操作。研究者们通过将替换为softmax函数,设计出了一种平滑的近似操作,使得采样过程变得可微。同时,引入了温度参数,控制采样的连续性与离散性。

具体来说,研究者们发现,通过添加 Gumbel 噪声后应用 softmax,可以平滑化原本不可微的采样过程。随着温度参数逐渐减小, softmax 的输出趋近于 one-hot 形式,从而逐渐逼近真实的离散采样结果。这一设计使得 Gumbel Softmax 既可以在训练早期保持采样的连续性,确保梯度稳定传递,又能够在训练后期通过温度退火逐渐增强采样的离散性,从而更好地模拟实际的离散分布。

Gumbel Softmax 的优势与应用

优势

  1. 可微性:通过光滑近似, Gumbel Softmax 允许梯度通过采样过程进行传播,实现端到端的训练。
  2. 降低方差:相比于传统的梯度估计方法(如 REINFORCE), Gumbel Softmax 显著降低了梯度估计的方差,提高了训练的稳定性。
  3. 灵活性:适用于多种离散分布,尤其适合处理高维度和大类别数的情境。

应用场景

  1. 离散隐变量的 VAE:通过 Gumbel Softmax,可以在 VAE 中引入离散潜在变量,实现更丰富的表示。
  2. 文本生成:在文本生成任务中,词汇选择是一个典型的离散过程, Gumbel Softmax 为此提供了有效的训练方法。
  3. 强化学习:在策略优化中,动作选择通常是离散的, Gumbel Softmax 可以用于策略的参数化与优化。
  4. 图像生成:在图像生成任务中, Gumbel Softmax 可以用于处理离散的像素值或标签信息。

最新研究进展

近年来,针对 Gumbel Softmax 的改进和扩展不断涌现,主要集中在以下几个方面:

  1. 更高效的采样方法:研究人员提出了多种高效的 Gumbel 噪声采样方法,减少了计算开销,提高了采样速度。
  2. 温度调整策略:动态调整温度参数的方法被提出,以更好地平衡采样的离散性与梯度的可传递性。
  3. 结合其他技术: Gumbel Softmax 与其他技术(如注意力机制、变分推断等)相结合,进一步提升了模型的性能和应用范围。
  4. 理论分析:深入研究 Gumbel Softmax 的理论性质,如收敛性、方差分析等,为其应用提供了更坚实的理论基础。

重参数化背后的梯度估计

梯度估计的重要性

在涉及随机变量的模型中,梯度估计是优化过程的核心。传统的梯度估计方法,如Score Function Estimator(也称为 REINFORCE),虽然通用,但通常伴随着高方差的问题,导致训练过程不稳定。而重参数化通过重新构造采样过程,有效降低了梯度估计的方差,提高了优化效率。

Score Function Estimator( REINFORCE)

Score Function Estimator 的形式为:

总结

重参数化 作为一种强大的技术,在深度生成模型中发挥了关键作用。通过将随机变量的采样过程转化为可微分的形式,重参数化不仅提高了模型的训练效率,还拓展了其应用范围。尤其是在处理离散分布时,Gumbel Softmax 提供了一种有效的重参数化方法,使得梯度能够顺利传递,实现端到端的优化。

然而,重参数化技巧也并非万能。对于某些复杂分布,找到合适的重参数化形式可能具有挑战性。此外,选择适当的温度参数以及有效的退火策略,仍需根据具体任务进行调整与优化。随着研究的不断深入,重参数化与 Gumbel Softmax 的方法将进一步完善,为更多复杂模型的优化提供支持。

参考文献

  1. Y. Jang, M. Gu, B. Poole. "Categorical Reparameterization with Gumbel-Softmax." International Conference on Learning Representations (ICLR), 2017.
  2. S. M. Ahmed, H. R. Mohiuddin, M. A. R. Khan. "GANS for Sequences of Discrete Elements with the Gumbel-softmax Distribution." arXiv preprint arXiv:1802.05011, 2018.
  3. M. Rolfe. "VIMCO: Variational Inference for Monte Carlo Objectives." NeurIPS, 2017.
  4. L. Kaiser. "Categorical Straight-Through Gradient Estimators." arXiv preprint arXiv:1812.02805, 2018.
  5. S. Y. Chen, K. Salakhutdinov. "Variational Recurrent Neural Networks." International Conference on Machine Learning (ICML), 2016.

推荐阅读

  1. 午夜惊奇:变分自编码器 VAE 低俗教程

    • https://zhuanlan.zhihu.com/p/23705953
  2. 花式解释 AutoEncoder 与 VAE

    • https://zhuanlan.zhihu.com/p/27549418
  3. 变分自编码器( VAEs)

    • https://zhuanlan.zhihu.com/p/25401928
  4. 条件变分自编码器( CVAEs)

    • https://zhuanlan.zhihu.com/p/25518643
  5. Variational Autoencoder: Intuition and Implementation

    • https://wiseodd.github.io/techblog/2016/12/10/variational-autoencoder/
  6. 变分自编码器 vae 的问题? - 知乎

    • https://www.zhihu.com/question/55015966
  7. 【啄米日常】 7: Keras 示例程序解析( 4):变分编码器 VAE

    • https://zhuanlan.zhihu.com/p/25269592
  8. <模型汇总-10> Variational AutoEncoder...

    • https://zhuanlan.zhihu.com/p/27280681
  9. 近似推断 – Deep Learning Book Chinese Translation

    • https://exacity.github.io/deeplearningbook-chinese/Chapter19_approximate_inference/
  10. One Hot 编码 | DevilKing's blog

    • http://gqlxj1987.github.io/2017/08/07/one-hot/
  11. 自编码器 – Deep Learning Book Chinese Translation

    • https://exacity.github.io/deeplearningbook-chinese/Chapter14_autoencoders/
  12. Android 编译过程详解之一 | Andy.Lee's Blog

  13. Kevin Chan's blog - 《 Deep Learning...

    • https://applenob.github.io/deep_learning_14
  14. Variational Autoencoder in TensorFlow

    • http://jmetzen.github.io/2015-11-27/vae.html
  15. 变分自编码器( Variational Autoencoder, VAE)

    • https://snowkylin.github.io/autoencoder/2016/12/05/introduction-to-variational-autoencoder.html
  16. 自编码模型 - tracholar's personal knowledge wiki

    • http://tracholar.github.io/wiki/machine-learning/auto-encoder.html
  17. Go 的自举

  18. Medium LESS 编码指引 | Zoom's Blog

    • http://zoomzhao.github.io/2015/07/30/medium-style-guide/
  19. 基于 RNN 的变分自编码器(施工中)

    • https://snowkylin.github.io/autoencoder/rnn/2016/12/21/variational-autoencoder-with-RNN.html
  20. The variational auto-encoder | Lecture notes for Stanford cs228.

    • https://ermongroup.github.io/cs228-notes/extras/vae/
  • 本文标题:重参数化详解与 Gumbel Softmax 深入探讨
  • 本文作者:Chen Kai
  • 创建时间:2021-03-25 14:00:00
  • 本文链接:https://www.chenk.top/%E9%87%8D%E5%8F%82%E6%95%B0%E5%8C%96%E8%AF%A6%E8%A7%A3%E4%B8%8EGumbel-Softmax%E6%B7%B1%E5%85%A5%E6%8E%A2%E8%AE%A8/
  • 版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
 评论