拡散モデルのlossの導出②:正規分布のKLダイバージェンスの計算に基づくlossの導出

拡散とDenoisingに基づく拡散モデル(Diffision Model)は多くの生成モデル(generative model)に導入される概念です。当記事では正規分布のKLダイバージェンス(KL-Divergence)の計算を元にDDPM論文におけるlossの導出について取り扱いました。
Diffusion Model論文DDPM論文や「拡散モデル ーデータ生成技術の数理(岩波書店)」の$2$章の「拡散モデル」などを参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

DDPM論文$(5)$式

$$
\begin{align}
L &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ D_{KL}(q(\mathbf{x}_{T}|\mathbf{x}_{0})||p_{\theta}(\mathbf{x}_{T})) + \sum_{t>1} D_{KL}(q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})||p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t}))) – \log{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})} \right] \quad (1.1) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ L_{T} + \sum_{t>1} L_{t-1} + L_{0} \right] \quad (1.1)’
\end{align}
$$

上記の$(1.1), \, (1.1)’$式がDDPM論文の$\mathrm{Eq}. \: (5)$に対応する。導出については下記で詳しく取り扱った。

$2$つの正規分布のKLダイバージェンスの計算

$2$つの確率分布$p(x)$と$q(x)$に関するKLダイバージェンス$D_{KL}(p||q)$は下記のように表される。
$$
\large
\begin{align}
D_{KL}(p||q) = – \log{ \frac{q(x)}{p(x)} }
\end{align}
$$

$2$つの正規分布$\mathcal{N}(\mu_{a}, \Sigma_{a})$と$\mathcal{N}(\mu_{b}, \Sigma_{b})$のKLダイバージェンス$D_{KL}(\mathcal{N}(\mu_{a}, \Sigma_{a})||\mathcal{N}(\mu_{b}, \Sigma_{b}))$は$(1.2)$式より下記を用いて計算できる。
$$
\begin{align}
D_{KL}(\mathcal{N}(\mu_{a}, \Sigma_{a})||\mathcal{N}(\mu_{b}, \Sigma_{b})) &= \frac{1}{2} \left[ \log{ \frac{|\Sigma_{b}|}{|\Sigma_{a}|} } – d + \mathrm{tr} \left( \Sigma_{b}^{-1} \Sigma_{a} \right) + (\mu_{b}-\mu_{a})^{\mathrm{T}} \Sigma_{b}^{-1} (\mu_{b}-\mu_{a}) \right]
\end{align}
$$

拡散モデルのlossの導出

$q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})$の導出

$(1.1)$式の$q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})$は下記のように表される。
$$
\large
\begin{align}
q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0}) &= \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0}), \tilde{\beta}_{t} \mathbf{I}) \quad (2.1) \\
\tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0}) &= \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_{t}}{\bar{\beta}_{t}} \mathbf{x}_{0} + \frac{\bar{\beta}_{t-1}}{\bar{\beta}_{t}} \mathbf{x}_{t} \quad (2.2) \\
\bar{\beta}_{t} &= \frac{\bar{\beta}_{t-1}}{\bar{\beta}_{t}} \beta_{t} \quad (2.3)
\end{align}
$$

$\displaystyle \small L_{t-1} = \mathbb{E}_{q} \left[ \frac{1}{2 \sigma_{t}^{2}} || \tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0}) – \mu_{\theta}(\mathbf{x}_{t}, t) ||^{2} \right] + C$の導出

$\tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0})$の詳細

以下、$\tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0})$の詳細について式変形を元に確認を行う。まず、$(2.2)$式より$\tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0})$は下記のように表される。
$$
\large
\begin{align}
\tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0}) = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_{t}}{\bar{\beta}_{t}} \mathbf{x}_{0} + \frac{\bar{\beta}_{t-1}}{\bar{\beta}_{t}} \mathbf{x}_{t} \quad (2.2)
\end{align}
$$

また、任意時刻の拡散条件付き確率の式$q(\mathbf{x}_{t}|\mathbf{x}_{0})$は下記のように表すことができる。
$$
\large
\begin{align}
q(\mathbf{x}_{t}|\mathbf{x}_{0}) &= \mathcal{N}(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}, (1 \, – \, \bar{\alpha}_{t}) \mathbf{I}) \quad (2.4) \\
\alpha_{t} &= 1-\beta_{t} \\
\bar{\alpha}_{t} &= \prod_{s=1}^{t} \alpha_{s}
\end{align}
$$

上記の導出は下記で詳しく取り扱った。

$(2.4)$式より、$\mathbf{x}_{t}$は$\mathbf{x}_{0}$とノイズ$\epsilon$を用いて下記のように表せる。
$$
\large
\begin{align}
\mathbf{x}_{t}(\mathbf{x}_{0}, \epsilon) = \sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0} + \sqrt{1 \, – \, \bar{\alpha}_{t}} \epsilon, \quad \mathcal{N}(\mathbf{0}, \mathbf{I}) \quad (2.5)
\end{align}
$$

$(2.5)$式を$\mathbf{x}_{0}$について変形を行うと下記が得られる。
$$
\large
\begin{align}
\mathbf{x}_{0} = \frac{1}{\sqrt{\bar{\alpha}_{t}}} \left( \mathbf{x}_{t}(\mathbf{x}_{0}, \epsilon) \, – \, \sqrt{1 \, – \, \bar{\alpha}_{t}} \epsilon \right) \quad (2.6)
\end{align}
$$

ここで$(2.2)$式に$(2.6)$式を代入すると下記のように変形できる。
$$
\large
\begin{align}
\tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0}) &= \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_{t}}{\bar{\beta}_{t}} \mathbf{x}_{0} + \frac{\bar{\beta}_{t-1}}{\bar{\beta}_{t}} \mathbf{x}_{t} \quad (2.2) \\
&=
\end{align}
$$

$\mu_{\theta}(\mathbf{x}_{t}, t)$の詳細

$\displaystyle \small L_{t-1} = \mathbb{E}_{\mathbf{x}_{0}, \epsilon} \left[ \frac{\beta_{t}^{2}}{2 \sigma_{t}^{2} \alpha_{t} \bar{\beta}_{t}} \left| \middle| \epsilon – \epsilon_{\theta} \left( \sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0} + \sqrt{\bar{\beta}_{t}} \epsilon, t \right) \middle| \right|^{2} \right] + C$の導出

DDPMを用いたサンプリング