"> "> 深度学习-Score-Based model | Yufei Luo's Blog

深度学习-Score-Based model

得分函数

定义

假设\(\{\boldsymbol x_i\in \mathbb{R}^D\}_{i=1}^N\)是从未知的数据分布\(p_{\text{data}}(\boldsymbol x)\)中采集到的一组独立同分布的数据集,定义得分函数(score function)\(\boldsymbol s (\boldsymbol x)=\nabla_{\boldsymbol x}\log p(\boldsymbol x)\),即概率密度的对数函数在数据点\(\boldsymbol x\)处的梯度。也就是说,得分函数对应于一个向量场,它指出了概率密度增长最大的方向。

得分函数通常表示为\(\boldsymbol s_{\boldsymbol \theta} (\boldsymbol x)=\nabla_{\boldsymbol x}\log p(\boldsymbol x)\)的形式,也就是用一个神经网络来对得分函数进行近似,其中\(\boldsymbol \theta\)表示神经网络的参数。

得分匹配

训练得分函数的过程被称为得分匹配(score matching),它的训练目标为最小化如下表达式: \[ \frac{1}{2}E_{p_{\text{data}}(\boldsymbol x)}\left[ ||\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x)-\nabla_{\boldsymbol x}\log p_{\text{data}}(\boldsymbol x)||_2^2 \right] \] 但是在这一公式中\(p_{\text{data}}(\boldsymbol x)\)未知,因此可以通过求解等价问题,即最小化下面的表达式来求解: \[ E_{p_{\text{data}}(\boldsymbol x)}\left[\text{tr}(\nabla_{\boldsymbol x} \boldsymbol s_{\boldsymbol \theta}(\boldsymbol x))+ \frac{1}{2}||\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x)||_2^2 \right] \]

注:上述两个问题等价的推导 \[ \begin{aligned} & E_{p_{\text{data}}(\boldsymbol x)}\left[ ||\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x)-\nabla_{\boldsymbol x}\log p_{\text{data}}(\boldsymbol x)||_2^2 \right] \\ =& \int p_{\text{data}}(\boldsymbol x)\cdot \left[ ||\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x)||_2^2+ ||\nabla_{\boldsymbol x}\log p_{\text{data}}(\boldsymbol x)||_2^2 -2\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x)^T \nabla_{\boldsymbol x}\log p_{\text{data}}(\boldsymbol x) \right]d\boldsymbol x \\ \end{aligned} \] 其中,\(||\nabla_{\boldsymbol x}\log p_{\text{data}}(\boldsymbol x)||_2^2\)与参数\(\boldsymbol \theta\)无关,可以省略掉。

\(\int p_{\text{data}}(\boldsymbol x)\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x)^T \nabla_{\boldsymbol x}\log p_{\text{data}}(\boldsymbol x)d\boldsymbol x\)这一项可以化简为: \[ \begin{aligned} &\int p_{\text{data}}(\boldsymbol x)\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x)^T \nabla_{\boldsymbol x}\log p_{\text{data}}(\boldsymbol x)d\boldsymbol x \\ =& \int p_{\text{data}}(\boldsymbol x) \sum_{i=1}^{n}s_{\boldsymbol \theta}(\boldsymbol x)_i \frac{\nabla_{x_i} p_{\text{data}}(\boldsymbol x)}{p_{\text{data}}(\boldsymbol x)}d\boldsymbol x \\ =& \sum_{i=1}^{n}\int s_{\boldsymbol \theta}(\boldsymbol x)_i \nabla_{x_i} p_{\text{data}}(\boldsymbol x)d\boldsymbol x \\ =& \sum_{i=1}^{n}\left(s_{\boldsymbol \theta}(\boldsymbol x)_i\cdot p_{\text{data}}(\boldsymbol x)\bigg|_{-\infty}^{\infty}-\int \nabla_{\boldsymbol x_i} (s_{\boldsymbol \theta}(\boldsymbol x)_i)\cdot p_{\text{data}}(\boldsymbol x) d\boldsymbol x\right) \end{aligned} \] 第一项\(s_{\boldsymbol \theta}(\boldsymbol x)_i\cdot p_{\text{data}}(\boldsymbol x)\big|_{-\infty}^{\infty}\)在积分的上下界都为0(这里假设当\(\boldsymbol x\)取负无穷或者正无穷时,对应的概率\(p_{\text{data}}(\boldsymbol x)\)为0);第二项中的\(\sum_{i=1}^{n} \nabla_{\boldsymbol x_i} s_{\boldsymbol \theta}(\boldsymbol x)_i\)是矩阵\(\nabla_{\boldsymbol x}\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x)\)中所有对角线元素的和,等于\(\text{tr}(\nabla_{\boldsymbol x} \boldsymbol s_{\boldsymbol \theta}(\boldsymbol x))\)

因此,可以得到两个优化问题等价的结论。

但是上述优化问题中对于\(\text{tr}(\nabla_{\boldsymbol x} \boldsymbol s_{\boldsymbol \theta}(\boldsymbol x))\)这一项的开销很大,尤其是对于高维输入以及深层网络来说。

因此对于\(\boldsymbol s_{\boldsymbol \theta} (\boldsymbol x)\)的训练,通常使用下面两种方法:

  1. 去噪得分匹配(denoising score matching)

    这一方法完全绕过了计算\(\text{tr}(\nabla_{\boldsymbol x} \boldsymbol s_{\boldsymbol \theta}(\boldsymbol x))\)。它首先使用一个预先定义好的噪声分布\(q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)\)生成干扰后的数据\(\tilde{\boldsymbol x}\),然后再计算\(\tilde{\boldsymbol x}\)的得分。\(\tilde{\boldsymbol x}\)的概率分布被定义为\(q_\sigma(\tilde{\boldsymbol x})=\int q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)p_{\text{data}}(\boldsymbol x)d\boldsymbol x\),因此优化目标如下: \[ \frac{1}{2}E_{q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)p_{\text{data}}(\boldsymbol x)}\left[ ||\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})-\nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)||_2^2 \right] \]

    这一优化目标的推导过程(主要参考了DenoisingScoreMatching_NeuralComp2011的附录部分):

    原始的优化目标: \[ \frac{1}{2}E_{q_\sigma(\tilde{\boldsymbol x})}\left[ ||\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})-\nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x})||_2^2 \right] \] 将上式展开有: \[ \begin{aligned} &E_{q_\sigma(\tilde{\boldsymbol x})}\left[ ||\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})-\nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x})||_2^2 \right] \\ =& E_{q_\sigma(\tilde{\boldsymbol x})}\left[ ||\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})||^2+||\nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x})||^2 -2\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})^T\nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x}) \right] \end{aligned} \] 上式中的第二项与参数\(\boldsymbol \theta\)无关,第一项\(E_{q_\sigma(\tilde{\boldsymbol x})} [\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})]\)可以做如下变换: \[ \begin{aligned} & E_{q_\sigma(\tilde{\boldsymbol x})} [\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})] \\ =& \int q_\sigma(\tilde{\boldsymbol x})\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})d\tilde{\boldsymbol x} \\ =&\int \left( \int q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)p_{\text{data}}(\boldsymbol x)d\boldsymbol x\right) \boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})d\tilde{\boldsymbol x} \\ =& \iint q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)p_{\text{data}}(\boldsymbol x)\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})d\boldsymbol x d\tilde{\boldsymbol x} \\ =& E_{q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)p_{\text{data}}(\boldsymbol x)}[s_{\boldsymbol \theta}(\tilde{\boldsymbol x})] \end{aligned} \] 第三项\(\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})^T\nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x})\)可以对其做如下变换: \[ \begin{aligned} &E_{q_\sigma(\tilde{\boldsymbol x})}\left[ \boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})^T\nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x})\right ] \\ =& \int q_\sigma(\tilde{\boldsymbol x})\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})^T\nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x}) d\tilde{\boldsymbol x} \\ =& \int \boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})^T\nabla_{\tilde{\boldsymbol x}} q_\sigma(\tilde{\boldsymbol x}) d\tilde{\boldsymbol x} \\ =& \int \boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})^T\nabla_{\tilde{\boldsymbol x}} \left(\int q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)\cdot p_{\text{data}}(\boldsymbol x)d\boldsymbol x \right) d\tilde{\boldsymbol x} \\ =& \int \boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})^T \left(\int \left(\nabla_{\tilde{\boldsymbol x}}q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)\right)\cdot p_{\text{data}}(\boldsymbol x)d\boldsymbol x \right) d\tilde{\boldsymbol x} \\ =& \int \boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})^T \left(\int \left(\nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)\right)\cdot q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)\cdot p_{\text{data}}(\boldsymbol x)d\boldsymbol x \right) d\tilde{\boldsymbol x} \\ =&\iint \left( q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)\cdot p_{\text{data}}(\boldsymbol x)\right) \cdot \boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})^T \left(\nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)\right) d\boldsymbol x d\tilde{\boldsymbol x} \\ =& E_{q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)p_{\text{data}}(\boldsymbol x)}\left[ \boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})^T \nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)\right] \end{aligned} \] 也就是说,优化目标\(\frac{1}{2}E_{q_\sigma(\tilde{\boldsymbol x})}\left[ ||\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})-\nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x})||_2^2 \right]\)\(\frac{1}{2}E_{q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)p_{\text{data}}(\boldsymbol x)}\left[ ||\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x})-\nabla_{\tilde{\boldsymbol x}}\log q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)||_2^2 \right]\)只相差了一个常数,因此两个优化目标等价。

  2. 切片得分匹配(sliced score matching)

    这一方法使用随机投影去估算\(\text{tr}(\nabla_{\boldsymbol x} \boldsymbol s_{\boldsymbol \theta}(\boldsymbol x))\)的值,也就是优化目标变为: \[ E_{p_\boldsymbol{v}}E_{p_{\text{data}}(\boldsymbol x)}\left[\boldsymbol v^T\nabla_{\boldsymbol x} \boldsymbol s_{\boldsymbol \theta}(\boldsymbol x)\boldsymbol v + \frac{1}{2}\|\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x)\|_2^2 \right] \] 其中\(p_{\boldsymbol v}\)代表随机向量分布。

    详细可以参考https://arxiv.org/pdf/1905.07088.pdf

朗之万动力学采样

在得分函数训练完成之后,便可以使用朗之万动力学采样的方法得到一个采样自\(p(\boldsymbol x)\)的一个样本。

给定一个固定的步长\(\epsilon>0\),初始值\(\tilde{\boldsymbol x}_0\sim \pi(\boldsymbol x)\)(其中\(\pi\)代表先验概率分布),朗之万动力学方法递归地执行下面的表达式: \[ \tilde{\boldsymbol x}_t=\tilde{\boldsymbol x}_{t-1}+\frac{\epsilon}{2}\nabla_{\boldsymbol x}\log p(\tilde{\boldsymbol x}_{t-1})+\sqrt{\epsilon}\boldsymbol z_t \] 其中\(\boldsymbol z_t\sim \mathcal{N}(0,I)\)

\(\epsilon\rightarrow 0\)以及\(T\rightarrow \infty\)时,\(\tilde{\boldsymbol x}_T\)的概率分布就等于\(p(\boldsymbol x)\)。当然实际中这样的条件很难达到,但是当\(\epsilon\)比较小以及\(T\)比较大的情况下,其中的误差通常可以被忽略。

备注:上述公式中,如果去掉最后的高斯项,那么其实就等同于梯度上升的过程,相当于是让样本往概率密度较大的方向移动。而高斯项则用于在优化过程中随机加入一些噪声。

另一个角度是将\(\nabla_{\boldsymbol x}\log p(\tilde{\boldsymbol x})\)理解为概率密度的势场,最后的\(\boldsymbol z_t\)项理解为热涨落的过程,这样就相当于是模拟粒子在势能和热涨落中运动。

条件噪声得分

上面提出的模型在实际应用中会面临如下的两个问题:

  1. 流形假设

    这一假设指的是真实世界中的数据会更倾向于分布在高维空间内的一个低维流形上。在这一假设下,得分模型则会遇到两个困难:一是由于得分函数是在高维空间中计算的梯度,因此如果\(\boldsymbol x\)被局限在低维流形上面,则相当于计算得到的梯度没有意义;二是对于上述定义的得分匹配这一训练目标,只有当数据分布的支撑(相当于概率不为0处)是整个空间时,训练得到的得分函数才具有一致性。

    例如作者在自己的实验中,在CIFAR-10数据集上训练得分模型,在不给训练数据加噪声的情况下模型的训练过程会一直震荡,而在加入了很小的噪声之后,模型的训练则会最终收敛。

    image-20221029151648516

  2. 低概率密度区域

    在概率密度较低的区域会遇到两个问题,一是得分函数的估计会不准确,二是朗之万动力学的采样也很难复原出真实的概率分布。

    对于第一个问题,由于训练得分函数使用的是一组采集自\(p_{\text{data}}\)的独立同分布样本,因此相应地概率密度较低的区域所采集到的数据也会很少。没有足够的样本便很难训练出一个准确的得分函数。

    image-20221029153139333

    而对于第二个问题,作者用了一个简单的例子来说明。假设数据的概率分布\(p\)服从一个混合模型\(p=\pi p_1+(1-\pi)p_2\),其中\(p_1\)\(p_2\)的支撑集不相交。但是得分函数\(\nabla \log p=\nabla \log p_1+\nabla \log p_2\),是一个与\(\pi\)无关的表达式。此时,除非使用很小的步长以及很大的迭代步数才能产生正确的采样。

    image-20221029153247477

因此,作者提出了两个改进方法:

第一个改进是使用噪声条件得分网络(Noise Conditional Score Networks)。设\(\{\sigma_i\}_{i=1}^L\)是一个等比数列,满足\(\sigma_i/\sigma_{i-1}>1\),并且\(\sigma_1\)的值足够大(指的是能够解决上述提到的两个问题),\(\sigma_L\)的值足够小(对数据造成的干扰尽可能小)。定义\(q_\sigma(\tilde{\boldsymbol x})=\int \mathcal{N}(\tilde{\boldsymbol x}| \boldsymbol x,\sigma^2I)p_{\text{data}}(\boldsymbol x)d\boldsymbol x\)为干扰后的数据分布,其中\(q_\sigma(\tilde{\boldsymbol x}| \boldsymbol x)=\mathcal{N}(\tilde{\boldsymbol x}| \boldsymbol x,\sigma^2I)\)相当于为原始数据添加高斯噪声。我们的目标是训练如下的噪声条件得分网络: \[ \boldsymbol s_{\boldsymbol \theta}(\boldsymbol x,\sigma)\approx \nabla_{\boldsymbol x} \log q_\sigma(\boldsymbol x), ~\forall \sigma\in \{\sigma_i\}_{i=1}^L \] 使用上文去噪得分匹配的训练目标,对于一个给定的\(\sigma\),优化目标可以写为: \[ \ell(\boldsymbol \theta,\sigma)=\frac{1}{2}E_{\boldsymbol p_{\text{data}}(\boldsymbol x)}E_{\tilde{\boldsymbol x}\sim\mathcal{N}(\boldsymbol x,\sigma^2 I)}\left[ \left\|\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x},\sigma)+\frac{\tilde{\boldsymbol x} - \boldsymbol x}{\sigma^2}\right\|_2^2 \right] \] 因此,考虑所有的\(\sigma\),可以得到最终的优化目标: \[ \mathcal{L}(\boldsymbol \theta)=\frac{1}{L}\sum_{i=1}^L \lambda(\sigma_i)\ell(\boldsymbol \theta;\sigma_i) \] 其中\(\lambda(\sigma_i)>0\)是一个系数方程。

假设模型具有足够强大的近似能力,则当且仅当\(\boldsymbol s_{\boldsymbol \theta^*}(\boldsymbol x,\sigma_i)= \nabla_{\boldsymbol x} \log q_{\sigma_i}(\boldsymbol x)\)对于\(\forall \sigma\in \{\sigma_i\}_{i=1}^L,\forall i\in\{1,\dots,L\}\)都成立时,\(\boldsymbol s_{\boldsymbol \theta^*}(\boldsymbol x,\sigma)\)使得上式取最小值。

\(\lambda(\cdot)\)的取值方法有很多种,在理想情况下希望\(\forall \sigma\in \{\sigma_i\}_{i=1}^L,\forall i\in\{1,\dots,L\}\)\(\lambda(\sigma_i)\ell(\boldsymbol \theta;\sigma_i)\)的值大致都为一个量级。从经验来看,当得分网络训练完成之后,大致可以得到\(\| \boldsymbol s_{\boldsymbol \theta}(\boldsymbol x,\sigma)\propto 1/\sigma \|\),因此通常取\(\lambda(\sigma)=\sigma^2\)

而在\(\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x,\sigma)\)训练完成之后,便可以使用带有退火的朗之万动力学方法采样: \[ \tilde{\boldsymbol x}_t=\tilde{\boldsymbol x}_{t-1}+\frac{\alpha_i}{2}\boldsymbol s_{\boldsymbol \theta}(\tilde{\boldsymbol x}_{t-1},\sigma_i)+\sqrt{\alpha_i}\boldsymbol z_t \] 其中,\(\alpha_i=\epsilon\cdot \sigma_i^2/\sigma_L^2\)为迭代步长。在采样时,\(t=1,\dots,T\)为内层循环,代表在概率分布\(q_{\sigma_i}\)上采样;\(i=1,\dots,L\)为外层循环,代表在\(q_{\sigma_i}\)上采样点的基础上,减小步长并在\(q_{\sigma_{i+1}}\)上采样。通过这样的循环,最终得到\(q_{\sigma_L}\)上的采样点。由于\(\sigma_L\)足够小,因此可以认为最终是近似在\(p_{\text{data}}\)上采样。

实践与改进

噪声等级

关于噪声等级的改进主要包括三部分:

初始噪声尺度

假设数据集\(\{\boldsymbol{x}_1,\boldsymbol{x}_2,\dots,\boldsymbol{x}_N\}\)是采样自\(p_{\text{data}}(\boldsymbol{x})\)的独立同分布样本,并且\(N\)足够大,此时我们有\(\hat{p}_{\text{data}}(\boldsymbol{x})=\frac{1}{N}\sum_{i=1}^{N}\delta(\boldsymbol{x}=\boldsymbol{x}_i)\approx p_{\text{data}}(\boldsymbol{x})\),其中\(\delta(\cdot)\)代表狄拉克分布。当使用分布\(\mathcal{N}(0,\sigma_1^2I)\)对数据做干扰时,经验分布就变成了\(\hat{p}_{\sigma_1}(\boldsymbol{x})=\frac{1}{N}\sum_{i=1}^{N}p_i(\boldsymbol{x})\),其中\(p_i(\boldsymbol x)=\mathcal{N}(\boldsymbol x|\boldsymbol x_i,\sigma_1^2I)\)

\(r_i(\boldsymbol x)=p_i(\boldsymbol x)/\sum_{k=1}^N p_k(\boldsymbol x)\),则得分函数\(\nabla_{\boldsymbol x}\log \hat{p}_{\sigma_1}(\boldsymbol{x})=\sum_{i=1}^N r_i(\boldsymbol x)\nabla_{\boldsymbol x}\log p_i(\boldsymbol x)\),而且有如下结论成立: \[ E_{p_i(\boldsymbol x)}[r_j(\boldsymbol x)]\le \frac{1}{2}\exp\left( -\frac{\|\boldsymbol x_i-\boldsymbol x_j\|_2^2}{8\sigma_i^2} \right) \]

上述不等式的证明: \[ \begin{aligned} &E_{p_i(\boldsymbol x)}[r_j(\boldsymbol x)] \\ =& \int \frac{p_i(\boldsymbol x)p_j(\boldsymbol x)}{\sum_{k=1}^N p_k(\boldsymbol x)}d\boldsymbol x \le \int \frac{p_i(\boldsymbol x)p_j(\boldsymbol x)}{p_i(\boldsymbol x)+p_j(\boldsymbol x)} d\boldsymbol x\\ =&\frac{1}{2}\int \frac{2}{\frac{1}{p_i(\boldsymbol x)}+\frac{1}{p_j(\boldsymbol x)}} d\boldsymbol x \le \frac{1}{2} \int \sqrt{p_i(\boldsymbol x)p_j(\boldsymbol x)} d\boldsymbol x \\ =&\frac{1}{2\cdot(2\pi \sigma_1^2)^{D/2}}\int \exp \left( -\frac{1}{4\sigma_1^2}\left( \|\boldsymbol x-\boldsymbol x_i\|_2^2+\|\boldsymbol x-\boldsymbol x_j\|_2^2 \right) \right) d\boldsymbol x \\ =&\frac{1}{2\cdot(2\pi \sigma_1^2)^{D/2}}\int \exp \left( -\frac{1}{4\sigma_1^2}\left( \|\boldsymbol x-\boldsymbol x_i\|_2^2+\|\boldsymbol x-\boldsymbol x_i+\boldsymbol x_i-\boldsymbol x_j\|_2^2 \right) \right) d\boldsymbol x \\ =&\frac{1}{2\cdot(2\pi \sigma_1^2)^{D/2}}\int \exp \left( -\frac{1}{2\sigma_1^2}\left( \|\boldsymbol x-\boldsymbol x_i+\frac{\boldsymbol x_i-\boldsymbol x_j}{2}\|_2^2+\frac{\|\boldsymbol x_i-\boldsymbol x_j\|_2^2}{4} \right) \right) d\boldsymbol x \\ =& \frac{1}{2}\exp\left( -\frac{\|\boldsymbol x_i-\boldsymbol x_j\|_2^2}{8\sigma_1^2} \right) \frac{1}{(2\pi \sigma_1^2)^{D/2}}\int \exp \left( -\frac{1}{2\sigma_1^2}\left( \|\boldsymbol x-\boldsymbol x_i+\frac{\boldsymbol x_i-\boldsymbol x_j}{2}\|_2^2 \right) \right)d\boldsymbol x \\ =&\frac{1}{2}\exp\left( -\frac{\|\boldsymbol x_i-\boldsymbol x_j\|_2^2}{8\sigma_1^2} \right) \end{aligned} \]

根据这一结论,在使用朗之万动力学采样从\(p_i(\boldsymbol x)\)过渡到\(p_j(\boldsymbol x)\)时,如果要使得这一过程更容易发生,则需要尽量使\(E_{p_i(\boldsymbol x)}[r_j(\boldsymbol x)]\)尽可能地大,也就是让\(\sigma_1\)的值与\(\|\boldsymbol x_i-\boldsymbol x_j\|_2^2\)基本处于一个量级。否则因为这一期望值过小,会导致朗之万动力学采样时,生成样本的多样性会减小。

考虑训练集中所有的数据都要满足这一条件,因此\(\sigma_1\)的选取原则是,与所有训练数据之间欧氏距离的最大值处于同一量级。

作者在CIFAR-10数据集上进行实验,这一数据集上样本欧氏距离的中位数是18,取\(\sigma_1=1\)\(\sigma_1=50\)分别进行实验,结果显示\(\sigma_1=50\)的时候生成的图像更多样化:

image-20221030110740344

其它噪声尺度

在设置好\(\sigma_L\)\(\sigma_1\)之后,还需要选取\(L\)的值以及其余的\(\{\sigma_i\}_{i=1}^L\)。在设置这些\(\sigma_i\)的值时,出发点是在从\(p_{\sigma_{i-1}}(\boldsymbol x)\)中采样之后,这一样本接下来设置为\(p_{\sigma_i}(\boldsymbol x)\)朗之万动力学采样的初始化值时能够生成合理的梯度值。

考虑一个简单的情形,数据集中只包含有一个数据点,即\(p_{\sigma_i}(\boldsymbol x)=\mathcal{N}(\boldsymbol x|\boldsymbol 0,\sigma_i^2 I)\)。为了更好地理解高维空间内的概率分布,将这一概率分布分解为超球坐标下的表示\(p(\phi)p_{\sigma_i}(r)\),其中\(\phi\)\(r\)分别代表\(\boldsymbol x\)的角度坐标和径向坐标。由于\(p_{\sigma_i}(\boldsymbol x)\)是一个各向同性的高斯分布,因此\(p(\phi)\)对于所有的噪声尺度\(\sigma_i\)都是一样的。

\(\boldsymbol x\in R^D\sim \mathcal{N}(\boldsymbol 0,\sigma^2I)\)\(r=\|\boldsymbol x\|_2\),有下面的结论: \[ p(r)=\frac{1}{2^{D/2-1}\Gamma(D/2)}\frac{r^{D-1}}{\sigma^D}\exp\left( -\frac{r^2}{2\sigma^2} \right) \] 而且当\(D\rightarrow \infty\)时,有\(r-\sqrt{D}\sigma\stackrel{d}\longrightarrow \mathcal{N}(0,\sigma^2/2)\)

第一个结论的证明:

根据假设\(\boldsymbol x\in R^D\sim \mathcal{N}(\boldsymbol 0,\sigma^2I)\),可得\(s\triangleq \|\boldsymbol x\|_2^2/\sigma^2\sim \chi^2_D\),也就是 \[ p_s(s)=\frac{1}{2^{D/2}\Gamma(D/2)}s^{D/2-1}e^{-s/2} \] 由于\(r=\|\boldsymbol x\|_2=\sigma\sqrt{s}\),可以使用换元法进一步得到第一个结论。

第二个结论的证明:

如果\(x\sim \mathcal{N}(0,\sigma^2)\),则\(x^2/\sigma^2\sim \chi_1^2\),从而有\(E[x^2]=\sigma^2,Var[x^2]=2\sigma^4\)。因此,如果\(x_1,\dots,x_D\stackrel{i.i.d.}\sim \mathcal{N}(0,\sigma^2)\),根据大数定律和中心极限定理,可得: \[ \sqrt{D}\left(\frac{\sum_{i=1}^D x_i^2}{D}-\sigma^2 \right)\stackrel{d} \longrightarrow \mathcal{N}(0,2\sigma^4) \] 等价地有 \[ \sqrt{D}\left(\frac{r^2}{D}-\sigma^2 \right)\stackrel{d} \longrightarrow \mathcal{N}(0,2\sigma^4) \] 根据概率论中的按分布收敛的性质,可以进一步得到: \[ \sqrt{D}\left(\frac{r}{\sqrt{D}}-\sigma \right)\stackrel{d} \longrightarrow \mathcal{N}(0,\sigma^2/2) \]

如果是图像数据的话,它的维度可以达到几十万到百万维,足够使得\(p(r)\approx \mathcal{N}(r|\sqrt{D}\sigma,\sigma^2/2)\)成立。

因此,如果想要\(p_{\sigma_{i-1}}(\boldsymbol x)\)\(p_{\sigma_i}(\boldsymbol x)\)能够很好地过渡,则它们的径向部分\(p(r)\)最好有较大的重叠部分。根据正态分布中的\(3\sigma\)原则,一个自然的选择是取\(\Phi(\sqrt{2D}(\gamma_i-1)+3\gamma_i)-\Phi(\sqrt{2D}(\gamma_i-1)-3\gamma_i)=C\)。其中\(C\)是一个大于0的较大的常数,且对于所有的\(1< i \le L\)都使用同样的\(C\);而\(\gamma_i\triangleq \sigma_{i-1}/\sigma_i\)\(\Phi(\cdot)\)代表标准高斯分布的分布函数。这也就是意味着\(\gamma_2=\gamma_3=\cdots=\gamma_l\),即\(\{\sigma_i\}_{i=1}^L\)是一个等比数列。此外,作者推荐使用\(C=0.5\)

合并噪声信息

在构造条件噪声得分网络\(\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x,\sigma)\)时,作者最开始的做法是在标准化层为每一个噪声等级都构造一组独立的参数。但是这样的话需要消耗大量的内存,而且当网络中没有标准化层时这一方法也变得不可行。

按照假设\(p_{\sigma_i}(\boldsymbol x)=\mathcal{N}(\boldsymbol x|\boldsymbol 0,\sigma_i^2 I)\),可以得到\(E[\|\nabla_{\boldsymbol x}\log p_\sigma(\boldsymbol x)\|_2]\approx \sqrt{D}/\sigma\),而且在实际数据上对条件噪声得分网络训练完成之后,也可以发现\(\| \boldsymbol s_{\boldsymbol \theta}(\boldsymbol x,\sigma)\propto 1/\sigma \|\)的结果。这也就是说,可以考虑使用噪声信息来对一个非条件得分网络\(\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x)\)的输出进行重新放缩,即定义 \[ \boldsymbol s_{\boldsymbol \theta}(\boldsymbol x,\sigma)=\boldsymbol s_{\boldsymbol \theta}(\boldsymbol x)/\sigma \] 这样定义的条件噪声得分网络更加容易训练,而且可以处理更多不同的噪声等级(即使是连续的情形)。

朗之万动力学

步长设置

在使用带有退火的朗之万动力学采样时,需要根据噪声等级的设定,来设置步长\(\epsilon\)和采样步数\(T\)这两个参数。下面首先从理论角度寻找一些依据。

仍然考虑数据集中只包含有一个数据点,即\(p_{\sigma_i}(\boldsymbol x)=\mathcal{N}(\boldsymbol x|\boldsymbol 0,\sigma_i^2 I)\)这一简单假设。带有退火的朗之万动力学使用\(p_{\sigma_{i-1}}(\boldsymbol x)\)中的采样结果来作为\(p_{\sigma_i}(\boldsymbol x)\)的采样过程的初始化,从而将两个相邻的噪声等级\(\sigma_{i-1}>\sigma_i\)联系起来。

\(p_{\sigma_i}(\boldsymbol x)\)的采样过程中,有\(\boldsymbol x_{t+1}\leftarrow \boldsymbol x_t+\alpha \nabla_{\boldsymbol x}\log p_{\sigma_i}(\boldsymbol x_t)+\sqrt{2\alpha}\boldsymbol z_t\),其中\(\boldsymbol x_0\sim p_{\sigma_{i-1}}(\boldsymbol x)\)\(\boldsymbol z_t\sim \mathcal{N}(\boldsymbol 0,I)\)\(\alpha=\epsilon\cdot \sigma_i^2/\sigma_L^2\)。由此可得: \[ \boldsymbol x_T\sim \mathcal{N}(\boldsymbol 0,s_T^2I) \] 其中 \[ \frac{s_T^2}{\sigma_i^2}=\left( 1-\frac{\epsilon}{\sigma_L^2} \right)^{2T}\left( \gamma^2-\frac{2\epsilon}{\sigma_L^2-\sigma_L^2\left(1-\frac{\epsilon}{\sigma_L^2}\right)^2} \right)+\frac{2\epsilon}{\sigma_L^2-\sigma_L^2\left(1-\frac{\epsilon}{\sigma_L^2}\right)^2} \] 从中可以看出,由于\(\{\sigma_i\}\)为等比数列,对于所有的\(1<i\le T\)都有同样的\(\gamma\),从而\(s_T^2/\sigma_i^2\)的值也保持不变,而且它的取值与\(\boldsymbol x\)的维度\(D\)无关。

因此,为了使退火朗之万动力学采样中的每个噪声等级之间能够更好地过渡,最好是选取合适的\(\epsilon\)使得\(s_T^2/\sigma_i^2\approx 1\)。但是这样很可能会使得\(T\)变得很大。考虑到计算开销,实际操作中常常选择一个最大计算开销之内的\(T\),然后选取\(\epsilon\)\(s_T^2/\sigma_i^2\)的值尽可能地接近1。

额外去噪步

在朗之万动力学采样完成之后,再添加一个额外的去噪步骤,返回\(\boldsymbol x_T+\sigma_T^2\boldsymbol{s_\theta}(\boldsymbol x_T,\sigma_T)\),相当于是最后把噪声\(\mathcal{N}(0,\sigma_T^2I)\)去掉。

移动平均

在训练得分网络时,常常会遇到训练曲线剧烈变化的情况,而且生成样本的质量也不稳定。这种情况可以使用移动平均技术来缓解,即每一步参数更新使用下面的表达式: \[ \boldsymbol{\theta}'\leftarrow m \boldsymbol{\theta}'+(1-m)\boldsymbol{\theta}_i \] 其中\(\boldsymbol{\theta}_i\)指的是第\(i\)轮训练完成之后的模型参数,\(\boldsymbol{\theta}'\)指的是当前模型参数的一个副本。模型训练完成之后,使用\(\boldsymbol{\theta}'\)作为最终的模型参数。

参考

  1. Generative Modeling by Estimating Gradients of the Data Distribution (arxiv.org)
  2. Improved Techniques for Training Score-Based Generative Models (arxiv.org)
  3. 抽样理论中有哪些令人印象深刻(有趣)的结论? - 知乎 (zhihu.com)
  4. 朗之万方程,机器学习与液体中的粒子运动,一个意想不到的联系 - 知乎 (zhihu.com)
  5. DenoisingScoreMatching_NeuralComp2011