拡散モデルのlossの導出①:イェンセンの不等式に基づく変分下限とKLダイバージェンスを用いた表記

拡散とDenoisingに基づく拡散モデル(Diffision Model)は多くの生成モデル(generative model)に導入される概念です。当記事ではイェンセンの不等式(Jensen’s Inequality)やKLダイバージェンスの定義を用いた拡散モデルの負の対数尤度の変分下限の導出について取り扱いました。
Diffusion Model論文DDPM論文や「拡散モデル ーデータ生成技術の数理(岩波書店)」の$2$章の「拡散モデル」などを参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

イェンセンの不等式

$$
\large
\begin{align}
\lambda_i & \geq 0 \\
\sum_{i=1}^{M} \lambda_{i} &= 1
\end{align}
$$

上記のように$\lambda_1, \cdots , \lambda_M$を定義するとき、下に凸の関数$f(x)$の任意の点$(x_i, f(x_i))$について下記の不等式が成立する。
$$
\large
\begin{align}
f \left( \sum_{i=1}^{M} \lambda_{i} x_{i} \right) \leq \sum_{i=1}^{M} \lambda_{i} f \left( x_{i} \right) \quad (1.1)
\end{align}
$$

上記をイェンセンの不等式(Jensen’s Inequality)という。イェンセンの不等式については下記でも取り扱った。

当記事で取り扱う導出で出てくる関数$f(x)=-\log{x}$が下に凸の関数であるので、当項では下に凸の関数についてのイェンセンの不等式を取り扱ったが、上に凸の関数についてのイェンセンの不等式は不等号が逆になることも合わせて抑えておくと良い。

期待値の定義式へのイェンセンの不等式の適用

前項$(1.1)$式の$\lambda_{i}$について$\displaystyle \lambda_i \geq 0, \, \sum_{i=1}^{M} \lambda_i = 1$が成立することから、$\lambda_{i}$に確率関数$p(x_i)$を対応させることができる。このとき下に凸の関数$f$について下記のような式が導出できる。
$$
\large
\begin{align}
f \left( \sum_{i=1}^{M} p(x_i) x_{i} \right) & \leq \sum_{i=1}^{M} p(x_i) f \left( x_{i} \right) \\
f \left( \mathbb{E} \left[ x_{i} \right] \right) & \leq \mathbb{E} \left[ f \left( x_{i} \right) \right]
\end{align}
$$

上記は離散型確率分布の式から導出したが、連続変数についても同様に下記が成立する。
$$
\large
\begin{align}
f \left( \int x p(x) dx \right) & \leq \int f(x) p(x) dx
\end{align}
$$

KLダイバージェンスの定義と解釈

連続型確率分布$p(x)$と$q(x)$のKLダイバージェンス$\mathrm{KL}(p||q)$は下記のように定義される。
$$
\large
\begin{align}
\mathrm{KL}(p||q) &= -\int \left[ p(x) \log{\frac{q(x)}{p(x)}} \right] dx \\
&= -\int \left[ p(x) \log{q(x)} \, – \, p(x) \log{{p(x)}} \right] dx \\
&= \int \left[ p(x) \log{\frac{p(x)}{q(x)}} \right] dx \quad (1.2)
\end{align}
$$

上記の式は$p$に対応する確率分布の期待値の記号$\mathbb{E}_{p}$を用いて下記のように表すこともできる。
$$
\large
\begin{align}
\mathrm{KL}(p||q) &= \mathbb{E}_{p} \left[ – \log{\frac{q(x)}{p(x)}} \right] \quad (1.3)
\end{align}
$$

$(1.3)$式は離散型確率分布でも成立する。KLダイバージェンスの解釈にあたっては、確率分布$p(x)$と$q(x)$の類似度を表すと解釈すればよい。下記ではソフトマックス関数を題材にKLダイバージェンスの値の変化について具体的に取り扱った。

拡散モデルのlossの導出

DDPM論文$(3)$式の導出

$\mathbf{x}_{0}$は学習に用いるサンプルに対応するので、lossに用いるnegative log-likelihood関数の期待値$l$は下記のように表せる。
$$
\large
\begin{align}
l = \mathbb{E} \left[ -\log{p_{\theta}(\mathbf{x}_{0})} \right] \quad (2.1)
\end{align}
$$

ここで「拡散モデル(Diffusion Model)の概要と式定義まとめ」の$(1)$式より、$p_{\theta}(\mathbf{x}_{0})$は下記のように表せる。
$$
\large
\begin{align}
p_{\theta}(\mathbf{x}_{0}) = \int p_{\theta}(\mathbf{x}_{0:T}) \, d\mathbf{x}_{1:T}
\end{align}
$$

上記は同時確率分布に関する基本演算に基づいて下記のように変形できる。
$$
\large
\begin{align}
p_{\theta}(\mathbf{x}_{0}) &= \int p_{\theta}(\mathbf{x}_{0:T}) \, d\mathbf{x}_{1:T} \\
&= \int p_{\theta}(\mathbf{x}_{0:T}) \cdot \frac{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})}{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \, d\mathbf{x}_{1:T} \\
&= \int q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) \cdot \frac{p_{\theta}(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \, d\mathbf{x}_{1:T} \\
&= \int q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) \cdot p_{\theta}(\mathbf{x}_{T}) \frac{p_{\theta}(\mathbf{x}_{0:(T-1)}|\mathbf{x}_{T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \, d\mathbf{x}_{1:T} \\
&= \int q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) \cdot p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} \, d\mathbf{x}_{1:T} \quad (2.2)
\end{align}
$$

$(2.1)$式に$(2.2)$式を代入することで下記が得られる。
$$
\large
\begin{align}
l &= \mathbb{E} \left[ -\log{p_{\theta}(\mathbf{x}_{0})} \right] \quad (2.1) \\
&= \int -\log{p_{\theta}(\mathbf{x}_{0})} d \mathbf{x}_{0} \\
&= \int -\log{ \left[ \int q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) \cdot p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} d\mathbf{x}_{1:T} \right] } d \mathbf{x}_{0} \quad (2.3)
\end{align}
$$

ここで$(2.3)$式の確率関数$q(\mathbf{x}_{1:T}|\mathbf{x}_{0})$に着目しイェンセンの不等式を適用することで下記が得られる。
$$
\large
\begin{align}
l &= \int -\log{ \left[ \int q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) \cdot p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} d\mathbf{x}_{1:T} \right] } d \mathbf{x}_{0} \quad (2.3) \\
& \leq \int q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) \cdot \left( -\log{ \left[ p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} \right] } \right) d \mathbf{x}_{0:T} \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{ \left( p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} \right) } \right] \quad (2.4) \\
\end{align}
$$

条件付き確率の期待値の式に基づいて$(2.4)$式は下記のように変形できる。
$$
\large
\begin{align}
l &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{ \left( p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} \right) } \right] \quad (2.4) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{ \frac{p_{\theta}(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} } \right] \quad (2.5)
\end{align}
$$

また、$(2.4)$式を$\log$に着目することで下記のように和の形式で表すこともできる。
$$
\large
\begin{align}
l &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{ \left( p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} \right) } \right] \quad (2.4) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T})} \, – \, \sum_{t \geq 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} } \right] \quad (2.6)
\end{align}
$$

$(2.5)$式と$(2.6)$式より、DDPMの論文の$\mathrm{Eq}. \, (3)$が正しいことが確認できる。

DDPM論文$(5)$式の導出

$(2.5)$式、$(2.6)$式は$\mathbf{x}_{0}$の負の対数尤度の変分下限であり、この変分下限を$L$とおくと$L$は下記のように変形できる。
$$
\large
\begin{align}
L &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{ \frac{p_{\theta}(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} } \right] \quad (2.5) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T}}) \, – \, \sum_{t \geq 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} } \right] \quad (2.6) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T}}) \, – \, \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} } \, – \, \log{ \frac{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})}{q(\mathbf{x}_{1}|\mathbf{x}_{0})} } \right] \quad (2.7)
\end{align}
$$

ここで$t>1$のとき$q(\mathbf{x}_{t}|\mathbf{x}_{t-1})$はマルコフ連鎖の定義とベイズの定理に基づいて下記のように変形を行える。
$$
\large
\begin{align}
q(\mathbf{x}_{t}|\mathbf{x}_{t-1}) &= \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{t})q(\mathbf{x}_{t})}{q(\mathbf{x}_{t-1})} \\
&= \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})q(\mathbf{x}_{t}|\mathbf{x}_{0})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})} \quad (2.8)
\end{align}
$$

$(2.8)$式を$(2.7)$式に代入すると下記が得られる。
$$
\begin{align}
L &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T}}) \, – \, \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} } \, – \, \log{ \frac{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})}{q(\mathbf{x}_{1}|\mathbf{x}_{0})} } \right] \quad (2.7) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T}}) \, – \, \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} \cdot \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})}{q(\mathbf{x}_{t}|\mathbf{x}_{0})} } \, – \, \log{ \frac{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})}{q(\mathbf{x}_{1}|\mathbf{x}_{0})} } \right] \quad (2.9)
\end{align}
$$

ここで$(2.9)$式の$\displaystyle \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} \cdot \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})}{q(\mathbf{x}_{t}|\mathbf{x}_{0})} }$について下記が成立する。
$$
\large
\begin{align}
& \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} \cdot \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})}{q(\mathbf{x}_{t}|\mathbf{x}_{0})} } \\
&= \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} } + \sum_{t > 1} \log{ \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})}{q(\mathbf{x}_{t}|\mathbf{x}_{0})} } \\
&= \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} } + \log{ \prod_{t > 1} \frac{q(\mathbf{x}_{t-1}|\mathbf{x}{0})}{q(\mathbf{x}_{t}|\mathbf{x}_{0})} } \\
&= \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} } + \log{ \left[ \frac{q(\mathbf{x}_{1}|\mathbf{x}_{0}) \cdot \cancel{q(\mathbf{x}_{2}|\mathbf{x}_{0})} \cdots \cancel{q(\mathbf{x}_{T-1}|\mathbf{x}_{0})}}{\cancel{q(\mathbf{x}_{2}|\mathbf{x}_{0})} \cdots \cancel{q(\mathbf{x}_{T-1}|\mathbf{x}_{0})} \cdot q(\mathbf{x}_{T}|\mathbf{x}_{0})} \right] } \\
&= \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} } + \log{ \frac{q(\mathbf{x}_{1}|\mathbf{x}_{0})}{q(\mathbf{x}_{T}|\mathbf{x}_{0})} } \quad (2.10)
\end{align}
$$

$(2.10)$式を$(2.9)$式に代入することで下記が得られる。
$$
\begin{align}
L &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T}}) \, – \, \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} \cdot \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})}{q(\mathbf{x}_{t}|\mathbf{x}_{0})} } \, – \, \log{ \frac{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})}{q(\mathbf{x}_{1}|\mathbf{x}_{0})} } \right] \quad (2.9) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T}}) \, – \, \log{ \frac{\cancel{q(\mathbf{x}_{1}|\mathbf{x}_{0})}}{q(\mathbf{x}_{T}|\mathbf{x}_{0})} } \, – \, \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} } \, – \, \log{ \frac{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})}{\cancel{q(\mathbf{x}_{1}|\mathbf{x}_{0})}} } \right] \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{ \frac{p_{\theta}(\mathbf{x}_{T})}{q(\mathbf{x}_{T}|\mathbf{x}_{0})} } – \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} } – \log{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})} \right] \quad (2.11) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ D_{KL}(q(\mathbf{x}_{T}|\mathbf{x}_{0})||p_{\theta}(\mathbf{x}_{T})) + \sum_{t>1} D_{KL}(q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})||p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t}))) – \log{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})} \right] \quad (2.12) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ L_{T} + \sum_{t > 1} L_{t-1} + L_{0} \right] \quad (2.12)’
\end{align}
$$