スコアベースモデルと暗黙的スコアマッチング(Implicit Score Matching)

スコアを用いる生成モデルであるスコアベースモデル(SBM)ではスコアの学習にあたってスコアマッチング(Score Matching)を行います。当記事ではシンプルなスコアマッチングの手法である明示的スコアマッチングと暗黙的スコアマッチングについて取りまとめを行いました。
「拡散モデル ーデータ生成技術の数理(岩波書店)」の$1$章の「生成モデル」を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

概要

スコアベースモデル(SBM)

確率分布のスコアが得られるとき、ランジュバン・モンテカルロ法(Langevin Monte Carlo)法を用いることで確率分布からのサンプリングを行うことができる。

このように学習した確率分布のスコアを用いて実現される生成モデルをスコアベースモデル(SBM; Score Based Model)という。よって、SBMを用いるにあたってはスコアの値を得る必要があり、このスコアを得るプロセスをスコアマッチング(Score Matching)という。

明示的スコアマッチング

明示的スコアマッチング(ESM; Explicit Score Matching)はスコアマッチングの手法の一つである。ニューラルネットワークのようなパラメータ$\theta$に基づくスコア関数を$s_{\theta}(\mathbf{x}):\mathbb{R}^{d} \longrightarrow \mathbb{R}^{d}$で近似を行うにあたって、明示的スコアマッチングでは下記のように目的関数の$J_{ESM_{p}}(\theta)$を目標分布$p(\mathbf{x})$の期待値の形式で定義する。
$$
\large
\begin{align}
J_{ESM_{p}}(\theta) = \frac{1}{2} \mathbb{E}_{p(\mathbf{x})} \left[ || \nabla_{\mathbf{x}} \log{p(\mathbf{x})} – s_{\theta}(\mathbf{x}) ||^{2} \right] \quad (1)
\end{align}
$$

上記のように定義される$J_{ESM_{p}}(\theta)$は二乗和誤差関数の最小化と同様に解釈できる一方で、多くの生成モデルではスコアの$\nabla_{\mathbf{x}} \log{p(x)}$が未知であり、式をそのまま用いることができない。

このような場合の解決策の$1$つが暗黙的スコアマッチング(ISM; Implicit Score Matching)であり、次項で取り扱う。

暗黙的スコアマッチング

暗黙的スコアマッチング(ISM; Implicit Score Matching)ではスコア関数$s_{\theta}(\mathbf{x})$の学習にあたっての目的関数に下記を用いる。
$$
\large
\begin{align}
J_{ISM_{p}}(\theta) = \mathbb{E}_{p(\mathbf{x})} \left[ \frac{1}{2} ||s_{\theta}(\mathbf{x})||^{2} + \mathrm{tr}(\nabla_{\mathbf{x}} s_{\theta}(\mathbf{x})) \right] \quad (2)
\end{align}
$$

暗黙的スコアマッチングの式理解

数式の解釈

$$
\large
\begin{align}
J_{ISM_{p}}(\theta) = \mathbb{E}_{p(\mathbf{x})} \left[ \frac{1}{2} ||s_{\theta}(\mathbf{x})||^{2} + \mathrm{tr}(\nabla_{\mathbf{x}} s_{\theta}(\mathbf{x})) \right] \quad (2)
\end{align}
$$

$(2)$式は目標分布の$p(\mathbf{x})$を用いて期待値が定義されるが、実際の生成問題では$p(\mathbf{x})$は未知である代わりに訓練データ$D={ \mathbf{x}^{(1)}, \cdots , \mathbf{x}^{(N)} }$を元に下記のように近似する。
$$
\large
\begin{align}
J_{ISM_{p}}(\theta) \simeq \frac{1}{N} \sum_{i=1}^{N} \left[ \frac{1}{2} ||s_{\theta}(\mathbf{x}^{(i)})||^{2} + \mathrm{tr}(\nabla_{\mathbf{x}} s_{\theta}(\mathbf{x}^{(i)})) \right]
\end{align}
$$

上記の第$1$項の$\displaystyle ||s_{\theta}(\mathbf{x}^{(i)})||^{2}$は「訓練データ$\mathbf{x}^{(i)}$の位置におけるスコアの絶対値」に対応し、この値が$\mathbf{0}$になれば対数尤度が$\mathbf{x}^{(i)}$で停留点(極小値・鞍点・極大値)を持つ。

第$2$項は入力ベクトルの各成分の$2$次微分に対応し、この値が最小(負の値)であれば$1$次微分が単調減少であり、対数尤度の停留点が極大値を持つ。

$J_{ESM_{p}}(\theta) = J_{ISM_{p}}(\theta)+C_1$の導出

$(1)$式で定義した明示的スコアマッチングの目的関数$J_{ESM_{p}}(\theta)$と$(2)$式で定義した暗黙的スコアマッチングの目的関数$J_{ISM_{p}}(\theta)$間にはパラメータ$\theta$に関係ない項の$C_1$を用いて下記のような式が成立する。
$$
\large
\begin{align}
J_{ESM_{p}}(\theta) = J_{ISM_{p}}(\theta)+C_1 \quad (3)
\end{align}
$$

以下、$(3)$式の導出を行う。導出にあたっては下記の$4$つの仮定をおく。
仮定$1. \,$ $p(\mathbf{x})$が微分可能
仮定$2. \,$ $\mathbb{E}_{p(\mathbf{x})}[||\nabla_{\mathbf{x}} \log{p(\mathbf{x})}||^{2}]$が有限
仮定$3. \,$ 任意の$\theta$について$\mathbb{E}_{p(\mathbf{x})}[||s_{\theta}(\mathbf{x})]$が有限
仮定$4. \,$ $\displaystyle \lim_{||\mathbf{x}|| \to \infty} [p(\mathbf{x})s_{\theta}(\mathbf{x})]=0$

まず、$(1)$式は下記のように変形できる。
$$
\large
\begin{align}
J_{ESM_{p}}(\theta) &= \frac{1}{2} \mathbb{E}_{p(\mathbf{x})} \left[ || \nabla_{\mathbf{x}} \log{p(\mathbf{x})} – s_{\theta}(\mathbf{x}) ||^{2} \right] \quad (1) \\
&= \mathbb{E}_{p(\mathbf{x})} \left[ \frac{1}{2}||\nabla_{\mathbf{x}} \log{p(\mathbf{x})}||^{2} + \frac{1}{2}||s_{\theta}(\mathbf{x})||^{2} – \nabla_{\mathbf{x}} \log{p(\mathbf{x})}^{\mathrm{T}} s_{\theta}(\mathbf{x}) \right] \\
&= \mathbb{E}_{p(\mathbf{x})} \left[ \frac{1}{2}||s_{\theta}(\mathbf{x})||^{2} – \nabla_{\mathbf{x}} \log{p(\mathbf{x})}^{\mathrm{T}} s_{\theta}(\mathbf{x}) \right] + C_1 \quad (4)
\end{align}
$$

上記の変形にあたっては仮定$2.$を用いた。ここで$(4)$式の第$1$項は$(2)$式の第$1$項と一致するので、以下では下記の$(5)$式が成立することを示す。
$$
\large
\begin{align}
\mathbb{E}_{p(\mathbf{x})} \left[ \mathrm{tr}(\nabla_{\mathbf{x}} s_{\theta}(\mathbf{x})) \right] = -\mathbb{E}_{p(\mathbf{x})} \left[ \nabla_{\mathbf{x}} \log{p(\mathbf{x})}^{\mathrm{T}} s_{\theta}(\mathbf{x}) \right] \quad (5)
\end{align}
$$

ここで$(5)$式の右辺は下記のように変形できる。
$$
\large
\begin{align}
-\mathbb{E}_{p(\mathbf{x})} \left[ \nabla_{\mathbf{x}} \log{p(\mathbf{x})}^{\mathrm{T}} s_{\theta}(\mathbf{x}) \right] &= -\int_{\mathbf{x} \in \mathbb{R}^{d}} p(\mathbf{x}) \left[ \nabla_{\mathbf{x}} \log{p(\mathbf{x})}^{\mathrm{T}} s_{\theta}(\mathbf{x}) \right] d \mathbf{x} \\
&= -\sum_{i=1}^{d} \int_{\mathbf{x} \in \mathbb{R}^{d}} p(\mathbf{x}) \left[ (\nabla_{\mathbf{x}} \log{p(\mathbf{x})})_{i} s_{\theta}(\mathbf{x})_{i} \right] d \mathbf{x} \quad (6)
\end{align}
$$

上記の$(\nabla_{\mathbf{x}} \log{p(\mathbf{x})})_{i}$、$s_{\theta}(\mathbf{x})_{i}$はそれぞれ$(\nabla_{\mathbf{x}} \log{p(\mathbf{x})})$と$s_{\theta}(\mathbf{x})$の$i$番目の成分に一致する。$(6)$式はさらに下記のように変形できる。
$$
\large
\begin{align}
& -\mathbb{E}_{p(\mathbf{x})} \left[ \nabla_{\mathbf{x}} \log{p(\mathbf{x})}^{\mathrm{T}} s_{\theta}(\mathbf{x}) \right] \\
&= -\sum_{i=1}^{d} \int_{\mathbf{x} \in \mathbb{R}^{d}} p(\mathbf{x}) \left[ (\nabla_{\mathbf{x}} \log{p(\mathbf{x})})_{i} s_{\theta}(\mathbf{x})_{i} \right] d \mathbf{x} \quad (6) \\
&= -\sum_{i=1}^{d} \int_{\mathbf{x} \in \mathbb{R}^{d}} p(\mathbf{x}) \left[ \frac{\partial \log{p(\mathbf{x})}}{\partial x_i} s_{\theta}(\mathbf{x})_{i} \right] d \mathbf{x} \\
&= -\sum_{i=1}^{d} \int_{\mathbf{x} \in \mathbb{R}^{d}} \frac{\cancel{p(\mathbf{x})}}{\cancel{p(\mathbf{x})}} \left[ \frac{\partial p(\mathbf{x})}{\partial x_i} s_{\theta}(\mathbf{x})_{i} \right] d \mathbf{x} \\
&= -\sum_{i=1}^{d} \int_{\mathbf{x} \in \mathbb{R}^{d}} \frac{\partial p(\mathbf{x})}{\partial x_i} s_{\theta}(\mathbf{x})_{i} d \mathbf{x} \quad (7)
\end{align}
$$

同様に$(5)$式の左辺は下記のように表せる。
$$
\large
\begin{align}
\mathbb{E}_{p(\mathbf{x})} \left[ \mathrm{tr}(\nabla_{\mathbf{x}} s_{\theta}(\mathbf{x})) \right] &= \int_{\mathbf{x} \in \mathbb{R}^{d}} p(\mathbf{x}) \mathrm{tr}(\nabla_{\mathbf{x}} s_{\theta}(\mathbf{x})) d \mathbf{x} \\
&= \sum_{i=1}^{d} \int_{\mathbf{x} \in \mathbb{R}^{d}} p(\mathbf{x}) \frac{\partial s_{\theta}(\mathbf{x})_{i}}{\partial x_i} d \mathbf{x} \quad (8)
\end{align}
$$

$(7)$式と$(8)$式より、下記の$(9)$式を示せば$(3)$式が成立することが示される。
$$
\large
\begin{align}
\int_{\mathbf{x} \in \mathbb{R}^{d}} \frac{\partial p(\mathbf{x})}{\partial x_i} s_{\theta}(\mathbf{x})_{i} d \mathbf{x} = \int_{\mathbf{x} \in \mathbb{R}^{d}} p(\mathbf{x}) \frac{\partial s_{\theta}(\mathbf{x})_{i}}{\partial x_i} d \mathbf{x} \quad (9)
\end{align}
$$

$(9)$式は部分積分の公式などを活用することによって示すことができる。

$(9)$式の詳細は当記事では省略しますが、「拡散モデル ーデータ生成技術の数理(岩波書店)」が詳しいので詳しくは下記を参照ください。