在现代机器学习中,重参数化(Reparameterization)
技巧成为优化包含随机变量模型的关键方法,尤其在变分自编码器(VAE)和生成对抗网络(GANs)等深度生成模型中发挥着重要作用。重参数化通过将随机变量
重参数化的基本概念
重参数化(Reparameterization) 是机器学习中一种重要的技术,主要用于处理涉及随机变量的模型。其核心思想是将随机变量的采样过程转化为一个确定性函数与噪声变量的组合,从而使得梯度能够通过采样过程进行传播。这对于优化包含随机性的模型,如变分自编码器(VAE)和生成对抗网络(GANs)等,至关重要。
为什么需要重参数化?
在许多机器学习模型中,我们需要从某个分布中采样随机变量。例如,在VAE中,潜在变量的采样对于模型的训练至关重要。然而,直接对随机变量进行采样会导致以下问题:
- 梯度不可传递:采样过程本身是一个非微分操作,无法通过反向传播计算梯度。
- 优化困难:由于无法计算梯度,传统的梯度下降方法难以应用于模型参数的优化。
重参数化通过将采样过程重新表达为一个可微分的形式,解决了上述问题,使得模型参数能够通过梯度下降等方法进行有效优化。
重参数化的数学表达
重参数化的基本思想是将随机变量
其中: -
通过这种表示方式,随机性被隔离在
连续分布中的重参数化
正态分布的重参数化
正态分布是重参数化中最常见的例子。假设潜在变量
直接对
其中,
重参数化在VAE中的应用
在变分自编码器(VAE)中,重参数化技巧用于优化潜在变量的分布。VAE通过最大化证据下界(ELBO)来学习数据的潜在表示。具体流程如下:
编码器:将高维输入数据映射到低维潜在空间,输出潜在变量的分布参数(如均值
和标准差 )。具体来说,编码器网络接受输入数据 ,并输出潜在变量 的分布参数 和 。这里, 和 通常是通过神经网络的最后一层线性变换得到的: 这些参数定义了潜在变量 的高斯分布: 重参数化:为了使得采样过程可微分,引入了重参数化技巧,将随机变量
表示为确定性函数和独立噪声变量 的组合,通过 生成潜在变量 。 - 独立噪声变量
:独立于模型参数 ,只依赖于预定义的简单分布(如标准正态分布) - 确定性函数:将
和 作为参数,通过线性变换与噪声 结合,生成潜在变量
- 独立噪声变量
解码器:解码器网络接受潜在变量
,并生成重建后的数据 : - 重建过程:解码器试图从潜在变量
中重建出原始输入数据 ,目标是使得 尽可能接近 。 - 生成能力:通过训练,解码器学会了如何从潜在空间中生成与训练数据相似的新数据。
- 重建过程:解码器试图从潜在变量
这种方法允许梯度通过
ELBO的最大化
VAE的目标是最大化证据下界(ELBO),其数学表达式为:
其中:
- 第一项:重建误差,衡量从潜在变量
重建数据的准确性。 - 第二项:KL散度,衡量编码器输出的潜在分布
与先验分布 之间的差异。 通过重参数化技巧,ELBO的梯度能够有效传递到编码器和解码器的参数,从而实现优化。
重参数化的数学原理
重参数化的数学基础在于将期望形式的目标函数转化为可微形式。对于连续情形的目标函数:
通过重参数化,可以将其转化为:
这使得梯度可以通过
代码示例
1 | import torch |
离散分布中的重参数化
挑战
对于离散分布,直接应用重参数化技巧面临以下挑战:
非微分操作:离散变量的采样过程(如
操作)通常是不可微的,导致梯度无法有效传递。 考虑自然语言处理中的词汇选择任务。假设模型需要生成一个单词作为输出:
- 前向传播:
- 模型输出每个单词的logits
。
- 使用
选择概率最高的单词 。
- 模型输出每个单词的logits
- 反向传播:
- 由于
是不可微的,无法计算 ,导致梯度无法传递回模型参数。
- 由于
这种情况下,传统的梯度下降方法无法直接优化模型,因为梯度信息在采样步骤中丢失。
- 前向传播:
高维度问题:当类别数量
较大时,直接对所有可能的类别进行求和计算期望变得计算量巨大,甚至不可行。 考虑图像生成任务中的像素值预测。假设每个像素可以取
个不同的灰度值: - 模型输出:
- 对于每个像素,模型输出
个logits,对应于每个灰度值的概率。
- 对于每个像素,模型输出
- 计算期望:
- 如果我们需要计算某种统计量(如期望值),需要对所有
个类别进行求和。
- 如果我们需要计算某种统计量(如期望值),需要对所有
- 高维度扩展:
- 对于高分辨率图像,每个图像包含数以万计的像素,每个像素又有
个可能的值,计算期望的成本急剧增加。
- 对于高分辨率图像,每个图像包含数以万计的像素,每个像素又有
- 模型输出:
引入 Gumbel Max
为了解决上述问题,引入了Gumbel Max 技巧。假设有一个
Gumbel Max通过以下步骤实现从离散分布中采样:
对每个类别
,计算: 选择
作为采样结果。
这种方法确保了输出类别的概率与
Gumbel Max 的数学证明
以类别1为例,证明Gumbel Max输出类别1的概率为
定义条件:
输出类别1意味着:
转化不等式:
对于每个
,有: 计算概率:
由于
,则每个不等式的概率为: 综合概率:
所有不等式同时成立的概率为:
求期望:
对所有
求期望: 由于上述推导过程中简化了部分步骤,最终结果为
。
Gumbel Max 的构思过程
要理解 Gumbel-Max 的推导和构思过程,首先要认识到其基础是极值理论和Gumbel 分布。Gumbel 分布的一个关键性质是它能帮助找到一组随机变量中的最大值。研究者们发现,通过将 Gumbel 噪声添加到类别的对数概率上,可以从离散分布中采样。这一构思来自于需要快速、高效的从离散概率分布中进行采样,而传统方法在处理大量类别时表现欠佳。
通过这个方法,研究者们设计了一个方法,使得从 Gumbel 分布中添加噪声,并选择最大值能够解决离散采样的难题。这种方法不仅速度快,而且保持了类别的相对概率顺序,最终得到了 Gumbel-Max 采样方法。
Gumbel Softmax:离散分布的重参数化
原理
尽管Gumbel Max能够实现从离散分布中采样,但其包含的
Gumbel Softmax的定义如下:
其中:
是温度参数
温度退火
温度参数
- 高温度(
较大):输出更加平滑,接近于均匀分布。 - 低温度(
较小):输出接近于 one-hot 向量,即更具确定性。
在训练过程中,通常采用温度退火策略,逐渐减小
Gumbel Softmax的数学推导
Gumbel Softmax基于Gumbel Max,通过softmax函数对Gumbel噪声进行了光滑处理,使得采样过程可微。具体步骤如下:
添加Gumbel噪声:
对每个类别
,计算: 其中,
, 。 应用Softmax函数:
将添加了噪声的 logits 通过softmax函数处理,并除以温度参数
: 可微性:
由于softmax函数是可微的,Gumbel Softmax允许梯度通过采样过程进行传播,从而实现端到端的训练。
Gumbel Softmax 的构思过程
从 Gumbel-Max 到 Gumbel Softma ·x
的过渡,主要的思考点是如何使得不可微的
具体来说,研究者们发现,通过添加Gumbel噪声后应用softmax,可以平滑化原本不可微的采样过程。随着温度参数
Gumbel Softmax 的优势与应用
优势
- 可微性:通过光滑近似,Gumbel Softmax允许梯度通过采样过程进行传播,实现端到端的训练。
- 降低方差:相比于传统的梯度估计方法(如REINFORCE),Gumbel Softmax显著降低了梯度估计的方差,提高了训练的稳定性。
- 灵活性:适用于多种离散分布,尤其适合处理高维度和大类别数的情境。
应用场景
- 离散隐变量的VAE:通过Gumbel Softmax,可以在VAE中引入离散潜在变量,实现更丰富的表示。
- 文本生成:在文本生成任务中,词汇选择是一个典型的离散过程,Gumbel Softmax为此提供了有效的训练方法。
- 强化学习:在策略优化中,动作选择通常是离散的,Gumbel Softmax可以用于策略的参数化与优化。
- 图像生成:在图像生成任务中,Gumbel Softmax可以用于处理离散的像素值或标签信息。
最新研究进展
近年来,针对Gumbel Softmax的改进和扩展不断涌现,主要集中在以下几个方面:
- 更高效的采样方法:研究人员提出了多种高效的Gumbel噪声采样方法,减少了计算开销,提高了采样速度。
- 温度调整策略:动态调整温度参数
的方法被提出,以更好地平衡采样的离散性与梯度的可传递性。 - 结合其他技术:Gumbel Softmax与其他技术(如注意力机制、变分推断等)相结合,进一步提升了模型的性能和应用范围。
- 理论分析:深入研究Gumbel Softmax的理论性质,如收敛性、方差分析等,为其应用提供了更坚实的理论基础。
重参数化背后的梯度估计
梯度估计的重要性
在涉及随机变量的模型中,梯度估计是优化过程的核心。传统的梯度估计方法,如Score Function Estimator(也称为REINFORCE),虽然通用,但通常伴随着高方差的问题,导致训练过程不稳定。而重参数化通过重新构造采样过程,有效降低了梯度估计的方差,提高了优化效率。
Score Function Estimator(REINFORCE)
Score Function Estimator的形式为:
总结
重参数化 作为一种强大的技术,在深度生成模型中发挥了关键作用。通过将随机变量的采样过程转化为可微分的形式,重参数化不仅提高了模型的训练效率,还拓展了其应用范围。尤其是在处理离散分布时,Gumbel Softmax 提供了一种有效的重参数化方法,使得梯度能够顺利传递,实现端到端的优化。
然而,重参数化技巧也并非万能。对于某些复杂分布,找到合适的重参数化形式可能具有挑战性。此外,选择适当的温度参数
参考文献
- Y. Jang, M. Gu, B. Poole. "Categorical Reparameterization with Gumbel-Softmax." International Conference on Learning Representations (ICLR), 2017.
- S. M. Ahmed, H. R. Mohiuddin, M. A. R. Khan. "GANS for Sequences of Discrete Elements with the Gumbel-softmax Distribution." arXiv preprint arXiv:1802.05011, 2018.
- M. Rolfe. "VIMCO: Variational Inference for Monte Carlo Objectives." NeurIPS, 2017.
- L. Kaiser. "Categorical Straight-Through Gradient Estimators." arXiv preprint arXiv:1812.02805, 2018.
- S. Y. Chen, K. Salakhutdinov. "Variational Recurrent Neural Networks." International Conference on Machine Learning (ICML), 2016.
推荐阅读
- 午夜惊奇:变分自编码器VAE低俗教程
- https://zhuanlan.zhihu.com/p/23705953
- 花式解释AutoEncoder与VAE
- https://zhuanlan.zhihu.com/p/27549418
- 变分自编码器(VAEs)
- https://zhuanlan.zhihu.com/p/25401928
- 条件变分自编码器(CVAEs)
- https://zhuanlan.zhihu.com/p/25518643
- Variational Autoencoder: Intuition and Implementation
- https://wiseodd.github.io/techblog/2016/12/10/variational-autoencoder/
- 变分自编码器vae的问题? - 知乎
- https://www.zhihu.com/question/55015966
- 【啄米日常】 7:Keras示例程序解析(4):变分编码器VAE
- https://zhuanlan.zhihu.com/p/25269592
- <模型汇总-10> Variational AutoEncoder...
- https://zhuanlan.zhihu.com/p/27280681
- 近似推断 – Deep Learning Book Chinese Translation
- https://exacity.github.io/deeplearningbook-chinese/Chapter19_approximate_inference/
- One Hot编码 | DevilKing's blog
- http://gqlxj1987.github.io/2017/08/07/one-hot/
- 自编码器 – Deep Learning Book Chinese Translation
- https://exacity.github.io/deeplearningbook-chinese/Chapter14_autoencoders/
- Android编译过程详解之一 | Andy.Lee's Blog
- Kevin Chan's blog - 《Deep Learning...
- https://applenob.github.io/deep_learning_14
- Variational Autoencoder in TensorFlow
- http://jmetzen.github.io/2015-11-27/vae.html
- 变分自编码器(Variational Autoencoder, VAE)
- https://snowkylin.github.io/autoencoder/2016/12/05/introduction-to-variational-autoencoder.html
- 自编码模型 - tracholar's personal knowledge wiki
- http://tracholar.github.io/wiki/machine-learning/auto-encoder.html
- Go的自举
- Medium LESS 编码指引 | Zoom's Blog
- http://zoomzhao.github.io/2015/07/30/medium-style-guide/
- 基于RNN的变分自编码器(施工中)
- https://snowkylin.github.io/autoencoder/rnn/2016/12/21/variational-autoencoder-with-RNN.html
- The variational auto-encoder | Lecture notes for Stanford cs228.
- https://ermongroup.github.io/cs228-notes/extras/vae/
- Post title:重参数化详解与Gumbel Softmax深入探讨
- Post author:Chen Kai
- Create time:2024-02-16 11:00:00
- Post link:https://www.chenk.top/重参数化详解与Gumbel Softmax深入探讨/
- Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.