ブログ

MoE(Mixture of Experts)とSwitch Transformers

Transformerに分岐処理を行うMoE(Mixture of Experts)を導入することで計算コストを大きく増やさずにパラメータ数を増やすことが可能になります。当記事ではこのような方針に基づいてTransformerの学習を行った研究であるSwitch Transformerについて取りまとめを行いました。
MoE(Mixture of Experts)論文や、Switch Transformers論文などの内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformer

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

MoEとSwich Transformersの仕組み

Switch Transformersの概要

Switch Transformersでは大まかに下図のような処理が行われます。

Switch Transformers論文 Figure$\, 2$

図の左が通常のTransformerにおける処理を表しており、オレンジのブロックがself-attention、水色のブロックがMLP(FFN)処理にそれぞれ対応します。右側がSwitch Transformerの処理を表しており、複数のFFNをExpertと見なし、Routerでtokenを各Expertに割り当てる処理が導入されています。

このような処理を行うことで同一レイヤーに複数のMLP処理が存在することから計算コストは上げずにパラメータ数を増やすことが可能になります。以下、詳しい処理についてSwitch Transformers論文の数式を元に確認を行います。

Mixture of Expert Routing

Switch Transformerの論文ではtokenのベクトル表現を$\mathbf{x} \in d_{model}$とおくとき、全$N$個のExpertsの中で$i$番目のExpertのgate-valueの$p_{i}(\mathbf{x})$が下記のように定義されます。
$$
\large
\begin{align}
p_{i}(\mathbf{x}) &= \frac{e^{h_{i}(\mathbf{x})}}{\sum_{j=1}^{N} e^{h_{j}(\mathbf{x})}} = \mathrm{Softmax}(h_{i}(\mathbf{x})) \\
h(\mathbf{x}) &= W_{r} \mathbf{x} \\
W_{r} & \in \mathbb{R}^{N \times d_{model}}
\end{align}
$$

上記の$h(\mathbf{x})$は要素数が$N$のベクトルであり、$h_{i}(\mathbf{x})$や$h_{j}(\mathbf{x})$は$h(\mathbf{x})$の$i$番目の要素と$j$番目の要素にそれぞれ対応します。

一般的なMixture of ExpertにおけるRoutingでは、このように計算を行ったgate-valueの$p_{1}(\mathbf{x}), \cdots , p_{N}(\mathbf{x})$の中から上位$k$個の値のインデックスを選び各Expertの出力の線形和によって全体の出力を得ます。ここで上位$k$個に対応するインデックスの集合を$\mathcal{T}$とおくと、全体の出力$\mathbf{y} \in \mathbb{R}^{d_{model}}$は下記のような式で定義されます。
$$
\large
\begin{align}
\mathbf{y} &= \sum_{i \in \mathcal{T}} p_{i}(\mathbf{x}) E_{i}(\mathbf{x}) \\
E_{i}(\mathbf{x}) &= FFN_{i}(\mathbf{x})
\end{align}
$$

上記の$E_{i}(\mathbf{x})$は各Expertの出力に対応するので、$\mathbf{x}$に$i$番目のExpertのMLP処理を施したと理解すると良いです。

Switch Routing

MoE(Mixture of Experts)論文では前項の「Mixture of Expert Routing」のように複数のExpertを用いて処理を行うことが必須であるとされた一方で、Switch Transformerでは$1$つのトークンに対し$1$つのExpertのみが用いられます。

Switch Transformer論文では$k=1$のExpertを用いるRoutingの方針に基づいて構成されるレイヤーを「Switch Layer」、このようなRoutingを「Switch Routing」のようにそれぞれ表されます。

分散処理とauxiliary loss

Expert CapacityとCapacity Factor

Switch Transformerでは前節で確認を行ったようにRouterがExperts(複数のFFNに対応)に処理を分岐させます。したがって、各Expertの処理は分散処理が可能です。

一方で単に処理を分岐させるだけの場合、$1$つのExpertに処理が偏り結果的にRouterの処理の意義がなくなる懸念があります。このような場合の対処にあたって、Switch Transformerでは各Expertに下記の数式に基づいて「Expert Capacity」が設定されます。
$$
\large
\begin{align}
\mathrm{expert \,\, capacity} = \left( \frac{\mathrm{tokens \,\, per \,\, batch}}{\mathrm{number \,\, of \,\, experts}} \right) \times \mathrm{capacity \,\, factor}
\end{align}
$$

上記の$\mathrm{tokens \,\, per \,\, batch}$はSwitch Transformerに入力するバッチのトークンの数、$\mathrm{number \,\, of \,\, experts}$はExpertsの数にそれぞれ対応します。

要するに基本的には各Expertに均等に処理を分岐させる前提でExpertの容量の上限が定義され、$\mathrm{capacity \,\, factor}$は生じうる分岐の偏りに対するバッファと解釈すると良いです。

各Expertの容量の上限であるExpert Capacityの数をトークンが超えた場合は超えた分のトークンの計算がその層ではスキップされます。ここまでに確認した内容について論文では下記のような図で図式化されます。

Switch Transformer論文 Figure$\, 3$の一部

上図のようにSwitch Transformerでは各Expertにトークンを分岐させFFN処理を行います。

図の左のCapacity Factorが$1.0$の場合にExpert Capacityの上限がオーバーし処理されないトークンが出たことが、赤の点線より確認できます。

Load Balancing Loss

それぞれのExpertになるべく均等にtokenが配分されるように、Switch TransformerではLoad Balancing Lossというlossが導入されます。lossの式を下記で表しました。
$$
\large
\begin{align}
\mathrm{loss} &= \alpha N \sum_{i=1}^{N} f_{i} P_{i} \\
f_{i} &= \frac{1}{T} \sum_{x \in \mathcal{B}} \mathbb{1} \{ \mathrm{argmax} \, p(x) = i \} \\
P_{i} &= \frac{1}{T} \sum_{x \in \mathcal{B}} p_{i}(x)
\end{align}
$$

上記は『Switch Routingに用いられる$1$hot表現の平均$(f_{1}, \cdots , f_{N})$と、一般的なMixture of Expert Routingに用いられる確率(gated-value)の平均$(P_{1}, \cdots , P_{N})$の分布が大きく乖離しないようにlossを導入した』と解釈すると良いです。

$N$のより具体的な解釈にあたっては、$f_{1}, \cdots , f_{N}$と$P_{1}, \cdots , P_{N}$の分布がどちらも一様分布の場合、$\displaystyle N \sum_{i=1}^{N} f_{i} P_{i}$が下記のように計算できることを確認しておくと良いと思います。
$$
\large
\begin{align}
N \sum_{i=1}^{N} f_{i} P_{i} &= N \sum_{i=1}^{N} \frac{1}{N} \cdot \frac{1}{N} \\
&= \cancel{N^{2}} \cdot \frac{1}{\cancel{N^{2}}} \\
&= 1
\end{align}
$$

全体処理とパラメータ

参考

・Transformer論文:Attention is All you need[2017]
・Switch Transformer論文
・Mixture of Experts論文

ブロック対角行列の行列式の計算と固有多項式(characteristic polynomial)

固有多項式(characteristic polynomial)は固有値を計算する際の固有方程式に用いられる多項式です。当記事ではブロック対角行列(block-diagonal matrix)の行列式の計算と、固有多項式の計算について取り扱いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$8$章「固有値問題と行列の対角化」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

ブロック対角行列の固有多項式

ブロック行列の行列式

$$
\large
\begin{align}
X = \left( \begin{array}{cc} A & B \\ O & D \end{array} \right)
\end{align}
$$

上記のように定義した$X$の行列式$\det{(X)}$について下記が成立する。
$$
\large
\begin{align}
\det{(X)} = \left| \begin{array}{cc} A & B \\ O & D \end{array} \right| = |A||D| = \det{(A)}\det{(D)} \quad (1)
\end{align}
$$

固有多項式の定義

$n$次正方行列$A$の固有多項式$F_{A}(t)$は$F_{A}(t)=\det{(tI_{n} \, – \, A)}$のように定義される。固有多項式の定義は下記でも取り扱った。

ブロック対角行列の固有多項式

次節の基本例題$156$で取り扱った。

計算例

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$156$

$$
\large
\begin{align}
A = \left( \begin{array}{cccc} A_1 & O & \cdots & O \\ O & A_2 & \cdots & O \\ \vdots & \vdots & \ddots & \vdots \\ O & O & \cdots & A_r \end{array} \right)
\end{align}
$$

上記のブロック対角行列$A$に対し、$(1)$式を繰り返し適用することで下記のように固有方程式$F_{A}(t)$を得ることができる。
$$
\large
\begin{align}
F_{A}(t) &= \left| \begin{array}{cccc} t I_1 \, – \, A_1 & O & \cdots & O \\ O & t I_2 \, – \, A_2 & \cdots & O \\ \vdots & \vdots & \ddots & \vdots \\ O & O & \cdots & t I_r \, – \, A_r \end{array} \right| \\
&= \det{(t I_1 \, – \, A_1)} \left| \begin{array}{cccc} t I_2 \, – \, A_2 & \cdots & O \\ \vdots & \ddots & \vdots \\ O & \cdots & t I_r \, – \, A_r \end{array} \right| \\
&= \cdots \\
&= \det{(t I_1 \, – \, A_1)} \cdots \det{(t I_r \, – \, A_r)} \\
&= F_{A_1}(t) \cdots F_{A_r}(t)
\end{align}
$$

上記より、ブロック対角行列$A$の固有多項式について下記が成立する。
$$
\large
\begin{align}
F_{A}(t) = F_{A_1}(t) \cdots F_{A_r}(t)
\end{align}
$$

最小多項式(minimal polynomial)の定義と計算例

行列$A$を代入すると零行列$O$になる多項式の中で「次数が最小」かつ「最高次の係数が$1$」である多項式を最小多項式(minimal polynomial)といいます。当記事では最小多項式の定義とチャート式線形代数の演習を題材に計算例について取りまとめを行いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$8$章「固有値問題と行列の対角化」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

最小多項式の定義と求め方

最小多項式の定義

$n$次正方行列$A$に対して集合$I_{A}$を下記のように定義する。
$$
\large
\begin{align}
I_{A} = \{ f(t)| f(A) = O \}
\end{align}
$$

上記は$I_{A}$が「$A$を代入すると零行列になるような多項式の全体がなす集合」と解釈できる。このように定義を行なった集合$I_{A}$における「次数が最小」かつ「多項式の最高次数の係数が$1$」である多項式を最小多項式(minimal polynomial)という。

最小多項式の求め方

固有多項式$F_{A}(t)$を因数分解の形式で表した後に、$2$乗より大きい要素は$1$乗から順に最小多項式の定義が成立するかを確認すれば良い。具体的な導出の流れは次節の演習で取り扱った。

最小多項式の使用例

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$173$

・$[1]$
$$
\large
\begin{align}
A &= \left( \begin{array}{cc} 2 & 1 \\ 2 & 3 \end{array} \right)
\end{align}
$$

上記の$A$の固有多項式を$F_{A}(t)$とおくと、$F_{A}(t)$は下記のように表せる。
$$
\large
\begin{align}
F_{A}(t) &= \det{(tI_{2} \, – \, A)} = \left| \begin{array}{cc} t-2 & -1 \\ -2 & t-3 \end{array} \right| \\
&= (t-2)(t-3) – 2 \\
&= t^{2} – 5t + 6 – 2 \\
&= t^{2} – 5t + 4 \\
&= (t-1)(t-4)
\end{align}
$$

上記より、行列$A$の最小多項式は$(t-1)(t-4)$である。

・$[2]$
$$
\large
\begin{align}
A = \left( \begin{array}{ccc} 3 & 1 & 1 \\ 2 & 4 & 2 \\ 1 & 1 & 3 \end{array} \right)
\end{align}
$$

上記の$A$の固有多項式を$F_{A}(t)$とおくと、$F_{A}(t)$は下記のように表せる。
$$
\large
\begin{align}
F_{A}(t) &= \det{(tI_{2} \, – \, A)} = \left| \begin{array}{ccc} t-3 & -1 & -1 \\ -2 & t-4 & -2 \\ -1 & -1 & t-3 \end{array} \right| \\
&= -\left| \begin{array}{ccc} -1 & -1 & t-3 \\ -2 & t-4 & -2 \\ t-3 & -1 & -1 \end{array} \right| \\
&= \left| \begin{array}{ccc} 1 & 1 & 3-t \\ -2 & t-4 & -2 \\ t-3 & -1 & -1 \end{array} \right| \\
&= \left| \begin{array}{ccc} 1 & 0 & 0 \\ -2 & t-2 & -2(t-2) \\ t-3 & -(t-2) & (t-2)(t-4) \end{array} \right| \\
&= (-1)^{1+1} \left| \begin{array}{cc} t-2 & -2(t-2) \\ -(t-2) & (t-2)(t-4) \end{array} \right| \\
&= (t-2)^{2} \left| \begin{array}{cc} 1 & -2 \\ -1 & t-4 \end{array} \right| \\
&= (t-2)^{2} (t-4-2) = (t-2)^{2} (t-6)
\end{align}
$$

ここで$p(t)=(t-2)(t-6)$とおくと、$p(A)$は下記のように計算できる。
$$
\large
\begin{align}
p(A) &= (A \, – \, 2I_{3})(A \, – \, 6I_{3}) \\
&= \left( \begin{array}{ccc} 1 & 1 & 1 \\ 2 & 2 & 2 \\ 1 & 1 & 1 \end{array} \right) \left( \begin{array}{ccc} -3 & 1 & 1 \\ 2 & -2 & 2 \\ 1 & 1 & -3 \end{array} \right) \\
&= \left( \begin{array}{ccc} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 0 & 0 \end{array} \right) = O
\end{align}
$$

よって最小多項式の定義より、行列$A$の最小多項式は$(t-2)(t-6)$である。

固有多項式(characteristic polynomial)の定義と三角行列の固有多項式

固有多項式(characteristic polynomial)は固有値を計算する際の固有方程式に用いられる多項式です。当記事では固有多項式の定義・活用と、三角行列(triangular matrix)における固有多項式の計算について取り扱いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$8$章「固有値問題と行列の対角化」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

固有多項式

固有多項式の定義

$n$次正方行列$A$の$t$を変数とする固有方程式$F_{A}(t)$は行列式$\det$と$n$次の単位行列$I_{n}$を用いて下記のように定義される。
$$
\large
\begin{align}
F_{A}(t) = \det{(tI_{n} \, – \, A)}
\end{align}
$$

固有方程式は上記を用いて$F_{A}(t)=\det{(tI_{n} \, – \, A)}=0$のように表す。

三角行列の固有多項式

固有多項式の使用例

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$154$

$[1]$
$$
\large
\begin{align}
A = \left( \begin{array}{cc} 2 & 1 \\ 2 & 3 \end{array} \right)
\end{align}
$$

上記の$A$の固有多項式を$F_{A}(t)$とおくと、$F_{A}(t)$は下記のように表せる。
$$
\large
\begin{align}
F_{A}(t) &= \det{(tI_{2} \, – \, A)} = \left| \begin{array}{cc} t-2 & -1 \\ -2 & t-3 \end{array} \right| \\
&= (t-2)(t-3) – 2 \\
&= t^{2} – 5t + 6 – 2 \\
&= t^{2} – 5t + 4 \\
&= (t-1)(t-4)
\end{align}
$$

上記より行列$A$の固有値は$1$と$4$である。

・固有値$1$に対応する固有空間の基底
$A-I_{2}$は下記のように行基本変形を行うことができる。
$$
\large
\begin{align}
A \, – \, I_{2} &= \left( \begin{array}{cc} 1 & 1 \\ 2 & 2 \end{array} \right) \\
& \longrightarrow \left( \begin{array}{cc} 1 & 1 \\ 0 & 0 \end{array} \right)
\end{align}
$$

よって、$(A-I_{2})\mathbf{x}=\mathbf{0}$の解は$\displaystyle \mathbf{x} = c \left( \begin{array}{c} 1 \\ -1 \end{array} \right)$であり、このベクトルが固有値$1$に対応する固有空間の基底である。

・固有値$4$に対応する固有空間の基底
$A-4I_{2}$は下記のように行基本変形を行うことができる。
$$
\large
\begin{align}
A \, – \, 4I_{2} &= \left( \begin{array}{cc} -2 & 1 \\ 2 & -1 \end{array} \right) \\
& \longrightarrow \left( \begin{array}{cc} -2 & 1 \\ 0 & 0 \end{array} \right)
\end{align}
$$

よって、$(A \, – \, 4I_{2})\mathbf{x}=\mathbf{0}$の解は$\displaystyle \mathbf{x} = c \left( \begin{array}{c} 1 \\ 2 \end{array} \right)$であり、このベクトルが固有値$1$に対応する固有空間の基底である。

・$[2]$
$$
\large
\begin{align}
A = \left( \begin{array}{ccc} 3 & 1 & 1 \\ 2 & 4 & 2 \\ 1 & 1 & 3 \end{array} \right)
\end{align}
$$

上記の$A$の固有多項式を$F_{A}(t)$とおくと、$F_{A}(t)$は下記のように表せる。
$$
\large
\begin{align}
F_{A}(t) &= \det{(tI_{2} \, – \, A)} = \left| \begin{array}{ccc} t-3 & -1 & -1 \\ -2 & t-4 & -2 \\ -1 & -1 & t-3 \end{array} \right| \\
&= -\left| \begin{array}{ccc} -1 & -1 & t-3 \\ -2 & t-4 & -2 \\ t-3 & -1 & -1 \end{array} \right| \\
&= \left| \begin{array}{ccc} 1 & 1 & 3-t \\ -2 & t-4 & -2 \\ t-3 & -1 & -1 \end{array} \right| \\
&= \left| \begin{array}{ccc} 1 & 0 & 0 \\ -2 & t-2 & -2(t-2) \\ t-3 & -(t-2) & (t-2)(t-4) \end{array} \right| \\
&= (-1)^{1+1} \left| \begin{array}{cc} t-2 & -2(t-2) \\ -(t-2) & (t-2)(t-4) \end{array} \right| \\
&= (t-2)^{2} \left| \begin{array}{cc} 1 & -2 \\ -1 & t-4 \end{array} \right| \\
&= (t-2)^{2} (t-4-2) = (t-2)^{2} (t-6)
\end{align}
$$

上記より行列$A$の固有値は$2$と$6$であり、それぞれの重複度は$2$と$1$である。以下、それぞれの固有値に対応する固有ベクトルの計算を行う。

・固有値$2$に対応する固有空間の基底
$A \,- \, 2I_{3}$は下記のように行基本変形を行うことができる。
$$
\large
\begin{align}
A \, – \, 2I_{3} &= \left( \begin{array}{ccc} 1 & 1 & 1 \\ 2 & 2 & 2 \\ 1 & 1 & 1 \end{array} \right) \\
& \longrightarrow \left( \begin{array}{ccc} 1 & 1 & 1 \\ 0 & 0 & 0 \\ 0 & 0 & 0 \end{array} \right)
\end{align}
$$

よって、$(A \, – \, 2I_{3})\mathbf{x}=\mathbf{0}$の解は$\displaystyle \mathbf{x} = c \left( \begin{array}{c} 1 \\ -1 \\ 0 \end{array} \right) + d \left( \begin{array}{c} 1 \\ 0 \\ -1 \end{array} \right)$であるので、$\displaystyle c \left( \begin{array}{c} 1 \\ -1 \\ 0 \end{array} \right)$と$d \left( \begin{array}{c} 1 \\ 0 \\ -1 \end{array} \right)$が固有値$2$に対応する固有空間の基底である。

・固有値$6$に対応する固有空間の基底
$A \, – \, 6I_{3}$は下記のように行基本変形を行うことができる。
$$
\large
\begin{align}
A \, – \, 6I_{2} &= \left( \begin{array}{ccc} -3 & 1 & 1 \\ 2 & -2 & 2 \\ 1 & 1 & -3 \end{array} \right) \\
& \longrightarrow \left( \begin{array}{ccc} 1 & 1 & -3 \\ 2 & -2 & 2 \\ -3 & 1 & 1 \end{array} \right) \\
& \longrightarrow \left( \begin{array}{ccc} 1 & 1 & -3 \\ 0 & -4 & 8 \\ 0 & 4 & -8 \end{array} \right) \\
& \longrightarrow \left( \begin{array}{ccc} 1 & 1 & -3 \\ 0 & 1 & -2 \\ 0 & 0 & 0 \end{array} \right) \\
& \longrightarrow \left( \begin{array}{ccc} 1 & 0 & -1 \\ 0 & 1 & -2 \\ 0 & 0 & 0 \end{array} \right)
\end{align}
$$

よって、$(A \, – \, 6I_{2})\mathbf{x}=\mathbf{0}$の解は$\displaystyle \mathbf{x} = c \left( \begin{array}{c} 1 \\ 2 \\ 1 \end{array} \right)$であり、このベクトルが固有値$6$に対応する固有空間の基底である。

基本例題$155$

行列の固有多項式とケーリー・ハミルトンの定理(Cayley–Hamilton theorem)

ケーリー・ハミルトンの定理(Cayley–Hamilton theorem)は行列の次数下げなどにあたって用いられる式です。当記事では行列の固有多項式に基づくケーリー・ハミルトンの定理の一般的な式を確認した後に、$2$次正方行列のケーリー・ハミルトンの定理の式との対応について確認します。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$8$章「固有値問題と行列の対角化」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

前提の確認

行列の固有多項式

$n$次正方行列$A$の変数$t$の固有多項式$F_{A}(t)$は行列式$\det$と$n$次の単位行列$I_{n}$を元に下記のように定義される。
$$
\large
\begin{align}
F_{A}(t) = \det{(tI_{n} – A)}
\end{align}
$$

$2$次正方行列におけるケーリー・ハミルトンの定理

$$
\large
\begin{align}
A = \left( \begin{array}{cc} a & b \\ c & d \end{array} \right)
\end{align}
$$

上記のように定義される$2$次正方行列$A$について下記が成立する。
$$
\large
\begin{align}
A^{2} – (a+d)A + (ad – bc) I_{2} &= O \\
A^{2} &= (a+d)A – (ad – bc) I_{2} \quad (1)
\end{align}
$$

上記の$O$は零行列を表す。

ケーリー・ハミルトンの定理

固有多項式とケーリー・ハミルトンの定理

$n$次正方行列$A$の固有多項式が$F_{A}(t)$のように表されるとき、下記が成立する。
$$
\large
\begin{align}
F_{A}(A) = O
\end{align}
$$

上記をケイリー・ハミルトンの定理という。

$2$次正方行列の式の導出

$$
\large
\begin{align}
A = \left( \begin{array}{cc} a & b \\ c & d \end{array} \right)
\end{align}
$$

上記のように定義される$2$次正方行列$A$の固有多項式$F_{A}(t)$は下記のように表すことができる。
$$
\large
\begin{align}
F_{A}(t) &= \det{(tI_{n} – A)} \\
&= \left| \begin{array}{cc} t-a & b \\ c & t-d \end{array} \right| \\
&= (t-a)(t-d) – bc \\
&= t^{2} -(a+d)t + ad-bc
\end{align}
$$

上記より、$F_{A}(A)=O$は下記のように変形できる。
$$
\large
\begin{align}
F_{A}(A) &= O \\
A^{2} -(a+d)A + (ad-bc)I_{2} &= O \\
A^{2} &= (a+d)A – (ad-bc)I_{2} \quad (2)
\end{align}
$$

$(2)$式は$(1)$式に一致する。

拡散モデルのlossの導出②:正規分布のKLダイバージェンスの計算に基づくlossの導出

拡散とDenoisingに基づく拡散モデル(Diffision Model)は多くの生成モデル(generative model)に導入される概念です。当記事では正規分布のKLダイバージェンス(KL-Divergence)の計算を元にDDPM論文におけるlossの導出について取り扱いました。
Diffusion Model論文DDPM論文や「拡散モデル ーデータ生成技術の数理(岩波書店)」の$2$章の「拡散モデル」などを参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

DDPM論文$(5)$式

$$
\begin{align}
L &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ D_{KL}(q(\mathbf{x}_{T}|\mathbf{x}_{0})||p_{\theta}(\mathbf{x}_{T})) + \sum_{t>1} D_{KL}(q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})||p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t}))) – \log{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})} \right] \quad (1.1) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ L_{T} + \sum_{t>1} L_{t-1} + L_{0} \right] \quad (1.1)’
\end{align}
$$

上記の$(1.1), \, (1.1)’$式がDDPM論文の$\mathrm{Eq}. \: (5)$に対応する。導出については下記で詳しく取り扱った。

$2$つの正規分布のKLダイバージェンスの計算

$2$つの確率分布$p(x)$と$q(x)$に関するKLダイバージェンス$D_{KL}(p||q)$は下記のように表される。
$$
\large
\begin{align}
D_{KL}(p||q) = – \log{ \frac{q(x)}{p(x)} }
\end{align}
$$

$2$つの正規分布$\mathcal{N}(\mu_{a}, \Sigma_{a})$と$\mathcal{N}(\mu_{b}, \Sigma_{b})$のKLダイバージェンス$D_{KL}(\mathcal{N}(\mu_{a}, \Sigma_{a})||\mathcal{N}(\mu_{b}, \Sigma_{b}))$は$(1.2)$式より下記を用いて計算できる。
$$
\begin{align}
D_{KL}(\mathcal{N}(\mu_{a}, \Sigma_{a})||\mathcal{N}(\mu_{b}, \Sigma_{b})) &= \frac{1}{2} \left[ \log{ \frac{|\Sigma_{b}|}{|\Sigma_{a}|} } – d + \mathrm{tr} \left( \Sigma_{b}^{-1} \Sigma_{a} \right) + (\mu_{b}-\mu_{a})^{\mathrm{T}} \Sigma_{b}^{-1} (\mu_{b}-\mu_{a}) \right]
\end{align}
$$

拡散モデルのlossの導出

$q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})$の導出

$(1.1)$式の$q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})$は下記のように表される。
$$
\large
\begin{align}
q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0}) &= \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0}), \tilde{\beta}_{t} \mathbf{I}) \quad (2.1) \\
\tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0}) &= \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_{t}}{\bar{\beta}_{t}} \mathbf{x}_{0} + \frac{\bar{\beta}_{t-1}}{\bar{\beta}_{t}} \mathbf{x}_{t} \quad (2.2) \\
\bar{\beta}_{t} &= \frac{\bar{\beta}_{t-1}}{\bar{\beta}_{t}} \beta_{t} \quad (2.3)
\end{align}
$$

$\displaystyle \small L_{t-1} = \mathbb{E}_{q} \left[ \frac{1}{2 \sigma_{t}^{2}} || \tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0}) – \mu_{\theta}(\mathbf{x}_{t}, t) ||^{2} \right] + C$の導出

$\tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0})$の詳細

以下、$\tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0})$の詳細について式変形を元に確認を行う。まず、$(2.2)$式より$\tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0})$は下記のように表される。
$$
\large
\begin{align}
\tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0}) = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_{t}}{\bar{\beta}_{t}} \mathbf{x}_{0} + \frac{\bar{\beta}_{t-1}}{\bar{\beta}_{t}} \mathbf{x}_{t} \quad (2.2)
\end{align}
$$

また、任意時刻の拡散条件付き確率の式$q(\mathbf{x}_{t}|\mathbf{x}_{0})$は下記のように表すことができる。
$$
\large
\begin{align}
q(\mathbf{x}_{t}|\mathbf{x}_{0}) &= \mathcal{N}(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}, (1 \, – \, \bar{\alpha}_{t}) \mathbf{I}) \quad (2.4) \\
\alpha_{t} &= 1-\beta_{t} \\
\bar{\alpha}_{t} &= \prod_{s=1}^{t} \alpha_{s}
\end{align}
$$

上記の導出は下記で詳しく取り扱った。

$(2.4)$式より、$\mathbf{x}_{t}$は$\mathbf{x}_{0}$とノイズ$\epsilon$を用いて下記のように表せる。
$$
\large
\begin{align}
\mathbf{x}_{t}(\mathbf{x}_{0}, \epsilon) = \sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0} + \sqrt{1 \, – \, \bar{\alpha}_{t}} \epsilon, \quad \mathcal{N}(\mathbf{0}, \mathbf{I}) \quad (2.5)
\end{align}
$$

$(2.5)$式を$\mathbf{x}_{0}$について変形を行うと下記が得られる。
$$
\large
\begin{align}
\mathbf{x}_{0} = \frac{1}{\sqrt{\bar{\alpha}_{t}}} \left( \mathbf{x}_{t}(\mathbf{x}_{0}, \epsilon) \, – \, \sqrt{1 \, – \, \bar{\alpha}_{t}} \epsilon \right) \quad (2.6)
\end{align}
$$

ここで$(2.2)$式に$(2.6)$式を代入すると下記のように変形できる。
$$
\large
\begin{align}
\tilde{\mu}_{t}(\mathbf{x}_{t}, \mathbf{x}_{0}) &= \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_{t}}{\bar{\beta}_{t}} \mathbf{x}_{0} + \frac{\bar{\beta}_{t-1}}{\bar{\beta}_{t}} \mathbf{x}_{t} \quad (2.2) \\
&=
\end{align}
$$

$\mu_{\theta}(\mathbf{x}_{t}, t)$の詳細

$\displaystyle \small L_{t-1} = \mathbb{E}_{\mathbf{x}_{0}, \epsilon} \left[ \frac{\beta_{t}^{2}}{2 \sigma_{t}^{2} \alpha_{t} \bar{\beta}_{t}} \left| \middle| \epsilon – \epsilon_{\theta} \left( \sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0} + \sqrt{\bar{\beta}_{t}} \epsilon, t \right) \middle| \right|^{2} \right] + C$の導出

DDPMを用いたサンプリング

拡散モデルのlossの導出①:イェンセンの不等式に基づく変分下限とKLダイバージェンスを用いた表記

拡散とDenoisingに基づく拡散モデル(Diffision Model)は多くの生成モデル(generative model)に導入される概念です。当記事ではイェンセンの不等式(Jensen’s Inequality)やKLダイバージェンスの定義を用いた拡散モデルの負の対数尤度の変分下限の導出について取り扱いました。
Diffusion Model論文DDPM論文や「拡散モデル ーデータ生成技術の数理(岩波書店)」の$2$章の「拡散モデル」などを参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

イェンセンの不等式

$$
\large
\begin{align}
\lambda_i & \geq 0 \\
\sum_{i=1}^{M} \lambda_{i} &= 1
\end{align}
$$

上記のように$\lambda_1, \cdots , \lambda_M$を定義するとき、下に凸の関数$f(x)$の任意の点$(x_i, f(x_i))$について下記の不等式が成立する。
$$
\large
\begin{align}
f \left( \sum_{i=1}^{M} \lambda_{i} x_{i} \right) \leq \sum_{i=1}^{M} \lambda_{i} f \left( x_{i} \right) \quad (1.1)
\end{align}
$$

上記をイェンセンの不等式(Jensen’s Inequality)という。イェンセンの不等式については下記でも取り扱った。

当記事で取り扱う導出で出てくる関数$f(x)=-\log{x}$が下に凸の関数であるので、当項では下に凸の関数についてのイェンセンの不等式を取り扱ったが、上に凸の関数についてのイェンセンの不等式は不等号が逆になることも合わせて抑えておくと良い。

期待値の定義式へのイェンセンの不等式の適用

前項$(1.1)$式の$\lambda_{i}$について$\displaystyle \lambda_i \geq 0, \, \sum_{i=1}^{M} \lambda_i = 1$が成立することから、$\lambda_{i}$に確率関数$p(x_i)$を対応させることができる。このとき下に凸の関数$f$について下記のような式が導出できる。
$$
\large
\begin{align}
f \left( \sum_{i=1}^{M} p(x_i) x_{i} \right) & \leq \sum_{i=1}^{M} p(x_i) f \left( x_{i} \right) \\
f \left( \mathbb{E} \left[ x_{i} \right] \right) & \leq \mathbb{E} \left[ f \left( x_{i} \right) \right]
\end{align}
$$

上記は離散型確率分布の式から導出したが、連続変数についても同様に下記が成立する。
$$
\large
\begin{align}
f \left( \int x p(x) dx \right) & \leq \int f(x) p(x) dx
\end{align}
$$

KLダイバージェンスの定義と解釈

連続型確率分布$p(x)$と$q(x)$のKLダイバージェンス$\mathrm{KL}(p||q)$は下記のように定義される。
$$
\large
\begin{align}
\mathrm{KL}(p||q) &= -\int \left[ p(x) \log{\frac{q(x)}{p(x)}} \right] dx \\
&= -\int \left[ p(x) \log{q(x)} \, – \, p(x) \log{{p(x)}} \right] dx \\
&= \int \left[ p(x) \log{\frac{p(x)}{q(x)}} \right] dx \quad (1.2)
\end{align}
$$

上記の式は$p$に対応する確率分布の期待値の記号$\mathbb{E}_{p}$を用いて下記のように表すこともできる。
$$
\large
\begin{align}
\mathrm{KL}(p||q) &= \mathbb{E}_{p} \left[ – \log{\frac{q(x)}{p(x)}} \right] \quad (1.3)
\end{align}
$$

$(1.3)$式は離散型確率分布でも成立する。KLダイバージェンスの解釈にあたっては、確率分布$p(x)$と$q(x)$の類似度を表すと解釈すればよい。下記ではソフトマックス関数を題材にKLダイバージェンスの値の変化について具体的に取り扱った。

拡散モデルのlossの導出

DDPM論文$(3)$式の導出

$\mathbf{x}_{0}$は学習に用いるサンプルに対応するので、lossに用いるnegative log-likelihood関数の期待値$l$は下記のように表せる。
$$
\large
\begin{align}
l = \mathbb{E} \left[ -\log{p_{\theta}(\mathbf{x}_{0})} \right] \quad (2.1)
\end{align}
$$

ここで「拡散モデル(Diffusion Model)の概要と式定義まとめ」の$(1)$式より、$p_{\theta}(\mathbf{x}_{0})$は下記のように表せる。
$$
\large
\begin{align}
p_{\theta}(\mathbf{x}_{0}) = \int p_{\theta}(\mathbf{x}_{0:T}) \, d\mathbf{x}_{1:T}
\end{align}
$$

上記は同時確率分布に関する基本演算に基づいて下記のように変形できる。
$$
\large
\begin{align}
p_{\theta}(\mathbf{x}_{0}) &= \int p_{\theta}(\mathbf{x}_{0:T}) \, d\mathbf{x}_{1:T} \\
&= \int p_{\theta}(\mathbf{x}_{0:T}) \cdot \frac{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})}{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \, d\mathbf{x}_{1:T} \\
&= \int q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) \cdot \frac{p_{\theta}(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \, d\mathbf{x}_{1:T} \\
&= \int q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) \cdot p_{\theta}(\mathbf{x}_{T}) \frac{p_{\theta}(\mathbf{x}_{0:(T-1)}|\mathbf{x}_{T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \, d\mathbf{x}_{1:T} \\
&= \int q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) \cdot p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} \, d\mathbf{x}_{1:T} \quad (2.2)
\end{align}
$$

$(2.1)$式に$(2.2)$式を代入することで下記が得られる。
$$
\large
\begin{align}
l &= \mathbb{E} \left[ -\log{p_{\theta}(\mathbf{x}_{0})} \right] \quad (2.1) \\
&= \int -\log{p_{\theta}(\mathbf{x}_{0})} d \mathbf{x}_{0} \\
&= \int -\log{ \left[ \int q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) \cdot p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} d\mathbf{x}_{1:T} \right] } d \mathbf{x}_{0} \quad (2.3)
\end{align}
$$

ここで$(2.3)$式の確率関数$q(\mathbf{x}_{1:T}|\mathbf{x}_{0})$に着目しイェンセンの不等式を適用することで下記が得られる。
$$
\large
\begin{align}
l &= \int -\log{ \left[ \int q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) \cdot p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} d\mathbf{x}_{1:T} \right] } d \mathbf{x}_{0} \quad (2.3) \\
& \leq \int q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) \cdot \left( -\log{ \left[ p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} \right] } \right) d \mathbf{x}_{0:T} \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{ \left( p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} \right) } \right] \quad (2.4) \\
\end{align}
$$

条件付き確率の期待値の式に基づいて$(2.4)$式は下記のように変形できる。
$$
\large
\begin{align}
l &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{ \left( p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} \right) } \right] \quad (2.4) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{ \frac{p_{\theta}(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} } \right] \quad (2.5)
\end{align}
$$

また、$(2.4)$式を$\log$に着目することで下記のように和の形式で表すこともできる。
$$
\large
\begin{align}
l &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{ \left( p_{\theta}(\mathbf{x}_{T}) \prod_{t=1}^{T} \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} \right) } \right] \quad (2.4) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T})} \, – \, \sum_{t \geq 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} } \right] \quad (2.6)
\end{align}
$$

$(2.5)$式と$(2.6)$式より、DDPMの論文の$\mathrm{Eq}. \, (3)$が正しいことが確認できる。

DDPM論文$(5)$式の導出

$(2.5)$式、$(2.6)$式は$\mathbf{x}_{0}$の負の対数尤度の変分下限であり、この変分下限を$L$とおくと$L$は下記のように変形できる。
$$
\large
\begin{align}
L &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{ \frac{p_{\theta}(\mathbf{x}_{0:T})}{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} } \right] \quad (2.5) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T}}) \, – \, \sum_{t \geq 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} } \right] \quad (2.6) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T}}) \, – \, \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} } \, – \, \log{ \frac{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})}{q(\mathbf{x}_{1}|\mathbf{x}_{0})} } \right] \quad (2.7)
\end{align}
$$

ここで$t>1$のとき$q(\mathbf{x}_{t}|\mathbf{x}_{t-1})$はマルコフ連鎖の定義とベイズの定理に基づいて下記のように変形を行える。
$$
\large
\begin{align}
q(\mathbf{x}_{t}|\mathbf{x}_{t-1}) &= \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{t})q(\mathbf{x}_{t})}{q(\mathbf{x}_{t-1})} \\
&= \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})q(\mathbf{x}_{t}|\mathbf{x}_{0})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})} \quad (2.8)
\end{align}
$$

$(2.8)$式を$(2.7)$式に代入すると下記が得られる。
$$
\begin{align}
L &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T}}) \, – \, \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t}|\mathbf{x}_{t-1})} } \, – \, \log{ \frac{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})}{q(\mathbf{x}_{1}|\mathbf{x}_{0})} } \right] \quad (2.7) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T}}) \, – \, \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} \cdot \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})}{q(\mathbf{x}_{t}|\mathbf{x}_{0})} } \, – \, \log{ \frac{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})}{q(\mathbf{x}_{1}|\mathbf{x}_{0})} } \right] \quad (2.9)
\end{align}
$$

ここで$(2.9)$式の$\displaystyle \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} \cdot \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})}{q(\mathbf{x}_{t}|\mathbf{x}_{0})} }$について下記が成立する。
$$
\large
\begin{align}
& \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} \cdot \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})}{q(\mathbf{x}_{t}|\mathbf{x}_{0})} } \\
&= \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} } + \sum_{t > 1} \log{ \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})}{q(\mathbf{x}_{t}|\mathbf{x}_{0})} } \\
&= \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} } + \log{ \prod_{t > 1} \frac{q(\mathbf{x}_{t-1}|\mathbf{x}{0})}{q(\mathbf{x}_{t}|\mathbf{x}_{0})} } \\
&= \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} } + \log{ \left[ \frac{q(\mathbf{x}_{1}|\mathbf{x}_{0}) \cdot \cancel{q(\mathbf{x}_{2}|\mathbf{x}_{0})} \cdots \cancel{q(\mathbf{x}_{T-1}|\mathbf{x}_{0})}}{\cancel{q(\mathbf{x}_{2}|\mathbf{x}_{0})} \cdots \cancel{q(\mathbf{x}_{T-1}|\mathbf{x}_{0})} \cdot q(\mathbf{x}_{T}|\mathbf{x}_{0})} \right] } \\
&= \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} } + \log{ \frac{q(\mathbf{x}_{1}|\mathbf{x}_{0})}{q(\mathbf{x}_{T}|\mathbf{x}_{0})} } \quad (2.10)
\end{align}
$$

$(2.10)$式を$(2.9)$式に代入することで下記が得られる。
$$
\begin{align}
L &= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T}}) \, – \, \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} \cdot \frac{q(\mathbf{x}_{t-1}|\mathbf{x}_{0})}{q(\mathbf{x}_{t}|\mathbf{x}_{0})} } \, – \, \log{ \frac{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})}{q(\mathbf{x}_{1}|\mathbf{x}_{0})} } \right] \quad (2.9) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{p_{\theta}(\mathbf{x}_{T}}) \, – \, \log{ \frac{\cancel{q(\mathbf{x}_{1}|\mathbf{x}_{0})}}{q(\mathbf{x}_{T}|\mathbf{x}_{0})} } \, – \, \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} } \, – \, \log{ \frac{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})}{\cancel{q(\mathbf{x}_{1}|\mathbf{x}_{0})}} } \right] \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ -\log{ \frac{p_{\theta}(\mathbf{x}_{T})}{q(\mathbf{x}_{T}|\mathbf{x}_{0})} } – \sum_{t > 1} \log{ \frac{p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t})}{q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})} } – \log{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})} \right] \quad (2.11) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ D_{KL}(q(\mathbf{x}_{T}|\mathbf{x}_{0})||p_{\theta}(\mathbf{x}_{T})) + \sum_{t>1} D_{KL}(q(\mathbf{x}_{t-1}|\mathbf{x}_{t},\mathbf{x}_{0})||p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t}))) – \log{p_{\theta}(\mathbf{x}_{0}|\mathbf{x}_{1})} \right] \quad (2.12) \\
&= \mathbb{E}_{q(\mathbf{x}_{1:T}|\mathbf{x}_{0})} \left[ L_{T} + \sum_{t > 1} L_{t-1} + L_{0} \right] \quad (2.12)’
\end{align}
$$

機械学習・DeepLearning分野の論文読解に役に立つ論文著者索引

論文の本文中では「oo et al., yyyy」のように先行研究を参照することが多いです。それぞれ「References」に具体的な論文を確認することができる一方で、都度確認するのは大変です。そこで当記事では論文の著者名から該当する研究を確認できるような索引の作成を行いました。

作成にあたってはよく引用される各分野の有名論文を中心にabc順でまとめました。厳密さではなく、「概ねこの論文が該当するだろう」を重視しているので、正確にはそれぞれの論文の「References」を確認してください。特に複数該当する場合は多くの場合「yyyya, yyyyb, $\cdots$」のように表記されます。

DeepLearning

CNN

著者名該当研究
He et al., 2015ResNet
Krizhevsky et al., 2012AlexNet
Simonyan et al., 2014VGGNet

RNN・Transformer・LLM

著者名該当研究
Brown et al., 2020GPT-3
Devlin et al., 2018BERT
Du et al., 2021GLaM
Hoffmann et al., 2022Chinchilla
Kitaev et al., 2020Reformer
Liu et al., 2019RoBERTa
Mikolov et al., 2013Word2vec
Radford et al., 2018GPT・GPT-2
Rae et al., 2021Gopher
Raffel et al., 2020T5
Smith et al., 2022Megatron–Turing NLG
Sutskever et al., 2014seq2seq
Thoppilan et al., 2022LaMDA
Vaswani et al., 2017Transformer
Yang et al.,XLNet

生成モデル

著者名該当研究
Goodfellow et al., 2014GAN: Generative Adversarial networks
Ho et al., 2020DDPM
Sohl-Dickstein et al., 2015Diffusion Model
Radford et al., 2021CLIP
Ramesh et al., 2021DALL-E

GNN

著者名該当研究
Battaglia et al., 2018Graph Network・Inductive Bias
Gilmer et al., 2017MPNN: Message Passing Neural Network
Wang et al., 2018NLNN: Non-Local Neural Network

点群・集合

著者名該当研究
Lee et al., 2019Set Transformer
Qi et al., 2017PointNet
Zaheer et al., 2017Deep sets

強化学習

強化学習×Transformer

著者名該当研究
Chen et al., 2021Decision Transformer

拡散モデル(Diffusion Model)の概要と式定義まとめ

拡散とDenoisingに基づく拡散モデル(Diffision Model)は多くの生成モデル(generative model)に導入される概念です。当記事では拡散モデルの概要と式定義、イェンセンの不等式などを用いるloss関数の導出などについて取りまとめを行いました。
DDPM論文や「拡散モデル ーデータ生成技術の数理(岩波書店)」の$2$章の「拡散モデル」などを参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

拡散モデルの概要

DDPM論文 Figure$\, 2$

拡散モデルの概要は上図を元に理解すると良い。$\mathbf{x}_{0}$が画像、$\mathbf{x}_{T}$が$\mathbf{x}_{0}$と次元が同じ潜在変数を表す。また、図の$q$が拡散過程(diffusion process)、$p$が逆向き過程(reverse process)にそれぞれ対応する。

拡散モデルでは元データに対し、拡散過程を適用した結果を逆向き過程で復元できるように逆向き過程$p$の学習を行うことで、新たに潜在変数の$\mathbf{x}_{T}$が与えられた際に$\mathbf{x}_{0}$を計算し、生成を行うことができる。

拡散モデルの式定義

拡散過程と逆拡散過程の式定義

拡散モデル(Diffusion model)は潜在変数モデルであり、下記のような式で定義される。
$$
\large
\begin{align}
p_{\theta}(\mathbf{x}_{0}) = \int p_{\theta}(\mathbf{x}_{0:T}) d\mathbf{x}_{1:T} \quad (1)
\end{align}
$$

$(1)$式は生成される画像の分布は同時分布$p_{\theta}(\mathbf{x}_{0:T})$を$\mathbf{x}_{1}, \cdots , \mathbf{x}_{T}$について積分し、周辺分布を得たと解釈すると良い。ここで式に出てくる同時分布$p_{\theta}(\mathbf{x}_{0:T})$は下記のように定義する。
$$
\large
\begin{align}
p_{\theta}(\mathbf{x}_{0:T}) &= p(\mathbf{x}_{T}) \prod_{t=1}^{T} p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t}) \quad (2) \\
p_{\theta}(\mathbf{x}_{t-1}|\mathbf{x}_{t}) &= \mathcal{N}(\boldsymbol{\mu}_{\theta}(\mathbf{x}_{t},t), \boldsymbol{\Sigma}_{\theta}(\mathbf{x}_{t}, t)) \quad (3) \\
p(\mathbf{x}_{T}) &= \mathcal{N}(\mathbf{0}, \mathbf{I})
\end{align}
$$

$(2)$式の同時確率分布の分解(factorization)ではマルコフ性(Markov Property)を仮定している。また、$(3)$式は前節における逆向き過程を表す。多次元正規分布の平均ベクトルと共分散行列はDeep Learningを用いて学習を行ったパラメータ$\theta$の関数から出力される値である。

拡散モデルの特徴は$p$で表される逆向き過程(reverse process)に対応する拡散過程(diffusion process)の$q$を定義することである。拡散過程$q$はvariance scheduleの$\beta_1, \cdots , \beta_{T}$を用いて下記のように定義される。
$$
\large
\begin{align}
q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) &= \prod_{t=1}^{T} q(\mathbf{x}_{t}|\mathbf{x}_{t-1}) \quad (4) \\
q(\mathbf{x}_{t}|\mathbf{x}_{t-1}) &= \mathcal{N} \left( \sqrt{1-\beta_{t}} \mathbf{x}_{t}, \beta_{t} \mathbf{I} \right) \quad (5)
\end{align}
$$

任意時刻の拡散条件付き確率の証明

$$
\large
\begin{align}
q(\mathbf{x}_{1:T}|\mathbf{x}_{0}) &= \prod_{t=1}^{T} q(\mathbf{x}_{t}|\mathbf{x}_{t-1}) \quad (4) \\
q(\mathbf{x}_{t}|\mathbf{x}_{t-1}) &= \mathcal{N} \left( \sqrt{1-\beta_{t}} \mathbf{x}_{t-1}, \beta_{t} \mathbf{I} \right) \quad (5)
\end{align}
$$

拡散過程$q$が上記のように表される時、任意時刻$t$における拡散条件付き確率$q(\mathbf{x}_{t}|\mathbf{x}_{0})$は下記のように表される。
$$
\large
\begin{align}
q(\mathbf{x}_{t}|\mathbf{x}_{0}) &= \mathcal{N}(\sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0}, (1 \, – \, \bar{\alpha}_{t}) \mathbf{I}) \quad (6) \\
\alpha_{t} &= 1-\beta_{t} \quad (7) \\
\bar{\alpha}_{t} &= \prod_{s=1}^{t} \alpha_{s} \quad (8)
\end{align}
$$

上記の$(6)$式はDDPM論文の$\mathrm{Eq}. \: (4)$に対応する。以下、数学的帰納法を用いて上記の証明を行う。

・$[1] \,$ $t=1$の時
$(7)$式と$(8)$式より下記が成立する。
$$
\large
\begin{align}
\sqrt{\bar{\alpha}_{1}} &= \sqrt{\alpha_{1}} = \sqrt{1-\beta_{1}} \\
1 \, – \, \bar{\alpha_{1}} &= 1 \, – \, \alpha_{1} = \beta_{1}
\end{align}
$$

上記より$t=1$の時$(6)$式は$(5)$式の定義式に一致するので$(6)$式は成立する。

・$[2] \,$ $t=k$で$(6)$式が成立する場合の$t=k+1$の時
$t=k$で$(6)$式が成立するとき、下記が成立する。
$$
\large
\begin{align}
q(\mathbf{x}_{k}|\mathbf{x}_{0}) = \mathcal{N}(\sqrt{\bar{\alpha}_{k}} \mathbf{x}_{0}, (1 \, – \, \bar{\alpha}_{k}) \mathbf{I}) \quad (6)’
\end{align}
$$

このとき$(6)’$式に基づいて$\mathbf{x}_{k}$は下記のように表せる。
$$
\large
\begin{align}
\mathbf{x}_{k} = \sqrt{\bar{\alpha}_{k}} \mathbf{x}_{0} + \sqrt{1 \, – \, \bar{\alpha}_{k}} \epsilon, \quad \epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \quad (9)
\end{align}
$$

ここで$(5)$式の$q$の定義式より、$q(\mathbf{x}_{k+1}|\mathbf{x}_{k})$は下記のように表される。
$$
\large
\begin{align}
q(\mathbf{x}_{k+1}|\mathbf{x}_{k}) = \mathcal{N} \left( \sqrt{1-\beta_{t+1}} \mathbf{x}_{t}, \beta_{t+1} \mathbf{I} \right) \quad (5)’
\end{align}
$$

ここで$(5)’$式と$(9)$式に基づいて$\mathbf{x}_{k+1}$は下記のように表せる。
$$
\large
\begin{align} \mathbf{x}_{k+1} &= \sqrt{1-\beta_{t+1}} \mathbf{x}_{t} + \sqrt{\beta_{t+1}} \epsilon, \quad \epsilon \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\
&= \sqrt{1-\beta_{t+1}} (\sqrt{\bar{\alpha}_{k}} \mathbf{x}_{0} + \sqrt{1 \, – \, \bar{\alpha}_{k}} \epsilon) + \sqrt{\beta_{t+1}} \epsilon \\
&= \sqrt{\alpha_{t+1}} (\sqrt{\bar{\alpha}_{k}} \mathbf{x}_{0} + \sqrt{1 \, – \, \bar{\alpha}_{k}} \epsilon) + \sqrt{\beta_{t+1}} \epsilon \\
&= \sqrt{\bar{\alpha}_{k+1}} \mathbf{x}_{0} + \sqrt{\alpha_{k+1}} \sqrt{1 \, – \, \bar{\alpha}_{k}} \epsilon + \sqrt{\beta_{t+1}} \epsilon \quad (10)
\end{align}
$$

$(10)$式より、$\mathbf{x}_{k+1}$の分散は$\mathbf{X}, \mathbf{Y} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$とする際の$\sqrt{\alpha_{k+1}} \sqrt{1 \, – \, \bar{\alpha}_{k}} \mathbf{X} + \sqrt{\beta_{t+1}} \mathbf{Y}$の分散に一致する。ここで正規分布の再生性より、$\sqrt{\alpha_{k+1}} \sqrt{1 \, – \, \bar{\alpha}_{k}} \mathbf{X} + \sqrt{\beta_{t+1}} \mathbf{Y}$の分散$\mathbf{\Sigma}$は下記のように計算できる。
$$
\large
\begin{align}
\mathbf{\Sigma} &= (\sqrt{\alpha_{k+1}} \sqrt{1 \, – \, \bar{\alpha}_{k}})^{2} \mathbf{I} + (\sqrt{\beta_{t+1}})^{2} \mathbf{I} \\
&= \alpha_{k+1} (1 \, – \, \bar{\alpha}_{k}) \mathbf{I} + \beta_{t+1} \mathbf{I} \\
&= \left( \cancel{\alpha_{k+1}} – \bar{\alpha}_{k+1} + 1 – \cancel{\alpha_{k+1}} \right) \mathbf{I} = (1-\bar{\alpha}_{k+1}) \mathbf{I}
\end{align}
$$

$(10)$式と$(11)$式より$q(\mathbf{x}_{k+1}|\mathbf{x}_{0})$は下記のように表せる。
$$
\large
\begin{align}
q(\mathbf{x}_{k+1}|\mathbf{x}_{0}) = \mathcal{N} \left( \sqrt{\bar{\alpha}_{k+1}} \mathbf{x}_{0}, (1-\bar{\alpha}_{k+1}) \mathbf{I} \right) \quad (12)
\end{align}
$$

$(12)$式より、$t=k$で$(6)$式が成立する場合、$t=k+1$でも$(6)$式が成立することが確認できる。

$[1], \, [2]$より、数学的帰納法に基づいて$(6)$式が成立することが示される。

モーメント母関数を用いた正規分布の再生性の導出は下記で詳しく取り扱った。

最大内積探索(MIPS; Maximum Inner Product Search)まとめ

Routing TransformerのようなContent-based Sparse Attentionでは最大内積探索(MIPS; Maximum Inner Product Search)と類似した処理が行われます。当記事では最大内積探索の概要やクラスタリングを用いた計算の効率化について取りまとめました。
Clustering is efficient for approximate maximum inner product search.」などを参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

最大内積探索の問題設定

最大内積探索の立式

$n$個の$D$次元ベクトルの集合$\mathcal{X}$を$\mathcal{X} = \{ \mathbf{x}_{1}, \cdots , \mathbf{x}_{n} \}, \, \mathbf{x}_{i} \in \mathbb{R}^{D}$のように定義する。このとき$D$次元のクエリ$q \in \mathbb{R}^{D}$と集合$\mathcal{X}$の要素であるベクトル$\mathbf{x}_{i}$に関する最大内積探索問題(MIPS; Maximum Inner Product Search problem)は下記のように立式することができる。
$$
\large
\begin{align}
\mathrm{arg} \max_{i} q^{\mathrm{T}} \mathbf{x}_{i}
\end{align}
$$

同様に、内積の計算結果が上位$K$に含まれるベクトルを探すK-MIPS問題は下記のように立式される。
$$
\large
\begin{align}
\mathrm{argmax}_{i}^{(K)} q^{\mathrm{T}} \mathbf{x}_{i} \quad (1.1)
\end{align}
$$

最近傍探索と一致する場合

最大内積探索(MIPS)問題は最近傍探索(NNS; Nearest Neighbor Search)と関連する。より厳密には「ベクトル$\mathbf{x}_{i} \in \mathbb{R}^{D}$が同じユークリッドノルム(Euclidean norm)を持つ場合」にMIPSとNNSが一致する。K-NNSの式はクエリ$q$とベクトル$\mathbf{x}_{i}$間のノルムに着目することで上記のように表すことができる。
$$
\large
\begin{align}
\mathrm{argmin}_{i}^{(K)} ||q \, – \, \mathbf{x}_{i}||^{2} \quad (1.2)
\end{align}
$$

ここでクエリ$q$とベクトル$\mathbf{x}_{i}$のユークリッドノルムが定数$C_1, C_2$を用いて$||q||^{2}=C_1, \, ||\mathbf{x}_{i}||^{2}=C_2$のように表せるとき、$(1.2)$式は下記のように同値変形できる。
$$
\large
\begin{align}
\mathrm{argmin}_{i}^{(K)} ||q \, – \, \mathbf{x}_{i}||^{2} &= \mathrm{argmin}_{i}^{(K)} (q \, – \, \mathbf{x}_{i})^{\mathrm{T}} (q \, – \, \mathbf{x}_{i}) \\
&= \mathrm{argmin}_{i}^{(K)} \left( ||q||^{2} \, – \, 2 q^{\mathrm{T}} \mathbf{x}_{i} + ||\mathbf{x}_{i}||^{2} \right) \\
&= \mathrm{argmin}_{i}^{(K)} \left( C_1 \, – \, 2 q^{\mathrm{T}} \mathbf{x}_{i} + C_2 \right) \\
&= \mathrm{argmin}_{i}^{(K)} \left( – \, 2 q^{\mathrm{T}} \mathbf{x}_{i} \right) \\
&= \mathrm{argmax}_{i}^{(K)} q^{\mathrm{T}} \mathbf{x}_{i} \quad (1.1)
\end{align}
$$

上記より、$\mathbf{x}_{i}$のノルムが同じ場合は「同一のクエリ$q$についてのK-MIPS問題がK-NNS問題に帰着する」ことが確認できる。

MCSSと一致する場合

最大内積探索(MIPS)問題は最大余弦類似度探索(MCSS; Maximum Cosine Similarity Search)とも関連する。NNSと同様に厳密には「ベクトル$\mathbf{x}_{i} \in \mathbb{R}^{D}$が同じユークリッドノルム(Euclidean norm)を持つ場合」にMIPSとMCSSSが一致する。K-MCSSの式はクエリ$q$とベクトル$\mathbf{x}_{i}$間のノルムに着目することで上記のように表すことができる。
$$
\large
\begin{align}
\mathrm{argmin}_{i}^{(K)} \frac{q^{\mathrm{T}} \mathbf{x}_{i}}{||q|| ||\mathbf{x}_{i}||} \quad (1.3)
\end{align}
$$

ここでクエリ$q$とベクトル$\mathbf{x}_{i}$のユークリッドノルムが定数$C_1, C_2$を用いて$||q||^{2}=C_1, \, ||\mathbf{x}_{i}||^{2}=C_2$のように表せるとき、$(1.3)$式は下記のように同値変形できる。
$$
\large
\begin{align}
\mathrm{argmin}_{i}^{(K)} \frac{q^{\mathrm{T}} \mathbf{x}_{i}}{||q|| ||\mathbf{x}_{i}||} &= \mathrm{argmin}_{i}^{(K)} \frac{q^{\mathrm{T}} \mathbf{x}_{i}}{C_1 C_2} \\
&= \mathrm{argmax}_{i}^{(K)} q^{\mathrm{T}} \mathbf{x}_{i} \quad (1.1)
\end{align}
$$

上記より、$\mathbf{x}_{i}$のノルムが同じ場合は「同一のクエリ$q$についてのK-MIPS問題がK-MCSS問題に帰着する」ことが確認できる。

クラスタリングを用いた最大内積探索の効率化

k-means

実際の問題でK-MIPS・K-NNS・K-MCSS問題をそれぞれクエリ$q$とベクトル集合$\mathcal{X} = \{ \mathbf{x}_{1}, \cdots , \mathbf{x}_{n} \}$の総当たりで解こうとすると計算量が大きい場合が多い。

総当たりで計算を行う場合は計算量が大きい一方で正確な結果が得られるが、「検索」のように概ね高い類似性が得られればよく「厳密な正確さ」が必要ない場合も多くある。このような際に、k-means法を用いたクラスタリングに基づく解の近似も有力な手法となる。

具体的には下記のような手順に基づいて近似解を得ることができる。
$[1] \,$ ベクトル集合$\mathcal{X} = \{ \mathbf{x}_{1}, \cdots , \mathbf{x}_{n} \}$をk-means法に基づいて$k$個のクラスタに分類
$[2] \,$ クエリ$q$とクラスタの平均ベクトルの内積 or 類似度を計算、値の大きなクラスタを選択
$[3] \,$ 類似度の高いクラスタからクエリ$q$との類似度が高いベクトル$\mathbf{x}_{i}$を選択

上記の$[1]$は参考論文では下記のようなアルゴリズムの形式で表される。

Algorithm$. 1 \,$ (Clustering is efficient for approximate maximum inner product search.)

上記のようにクラスタの平均ベクトルとの内積 or 類似度を計算することによって、計算量を減らすことが可能である。たとえば$10,000$ベクトルを$100$個のクラスタに$100$ずつ分類する場合、類似度の計算回数は$100 \times 2 = 200$であり、総当たりの$10,000$から大きく減らすことができる。

一方で、クエリ$q$がクラスタの境界付近に位置する場合など、必ずしも厳密な結果が得られるわけではないことに注意が必要である。この課題に対してはクラスタを複数選ぶなどの対応策が用いられる場合も多い。

階層k-means

前項で確認を行ったk-meansクラスタリングに基づく近似では「クラスタの数が多い:精度が高いが計算量が多い」、「クラスタの数が少ない:計算量が少ないが精度も低い」というトレードオフがある。この問題については階層k-meansが一つの解決策になりうる。

Figure$. 1 \,$ (Clustering is efficient for approximate maximum inner product search.)

上図のようにベクトルをいくつかの階層型のクラスタに分類しておくことで順々に分類を行うことが可能になる。たとえば$10,000$を$1,000$を$10$個、$1,000$を$10$個、$100$を$10$個のように階層的に分類すると、探索の計算量は$10 \times 4$まで削減できる。

この階層型クラスタは、多くのクラスタを先に作り、徐々に階層型クラスタリングを行うことで構築が可能である。