【Word2vecなどの出力層高速化】雑音対照推定(NCE)・負例サンプリング

分布仮説に基づくWord$2$vecなどの学習にあたっては、出力層が語彙の数に対応する分類問題に対応するので、そのまま取り扱うと巨大なソフトマックス関数の取り扱いが必要になります。当記事はNCEや負例サンプリング(Negative Sampling)を用いた解決策について取り扱いました。
「深層学習による自然言語処理」$4.3$節の「出力層の高速化」などを参考に当記事の作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

語彙が大きい場合のソフトマックス関数の取り扱い

Word$2$vecや翻訳・文書要約・対話などの文章の生成タスクなどの学習を行う際は、入力文$\mathbf{x} \in \mathcal{X}$と予測対象の単語$y \in \mathcal{Y}$を元に、下記のように交差エントロピー損失関数を定義する。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) = -\log{ \frac{\exp{(f_{\boldsymbol{\theta}}(\mathbf{x},y))}}{\displaystyle \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y}))}} } \quad (1.1)
\end{align}
$$

上記に対し、下記のように$s(y)$と$Z(\mathcal{Y})$を定義する。
$$
\large
\begin{align}
s(y) &= f_{\boldsymbol{\theta}}(\mathbf{x},y) \quad (1.2) \\
Z(\mathcal{Y}) &= \sum_{\tilde{y} \in \mathcal{Y}} \exp{(s(\tilde{y}))} = \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y})))} \quad (1.3)
\end{align}
$$

$(1.2), \, (1.3)$式を元に$(1.1)$式は下記のように表せる。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) &= -\log{ \frac{\exp{(f_{\boldsymbol{\theta}}(\mathbf{x},y))}}{\displaystyle \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y}))}} } \quad (1.1) \\
&= -s(y) + \log{Z(\mathcal{Y})} \quad (1.4)
\end{align}
$$

ここで$(1.4)$式の両辺をパラメータベクトル$\boldsymbol{\theta} \in \mathbb{R}^{p}$で方向微分すると下記が得られる。
$$
\large
\begin{align}
\nabla l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) = – \nabla s(y) + \nabla \log{Z(\mathcal{Y})} \quad (1.5)
\end{align}
$$

ここまでの詳しい式変形は下記で取り扱った。

$(1.5)$式の取り扱いの際に語彙数$|\mathcal{Y}|$が大きいと出力層の計算量が大きくなるので、何らかの解決策があると良いが、上記の記事では重点サンプリングを用いた手法について取り扱った。当記事ではNCEと負例サンプリングに基づく手法を次節で取り扱う。

NEC・負例サンプリング

学習にあたっての基本的な方針

上記の重点サンプリングでは分配関数$Z$の近似を行ったが、当節では$Z$を未知のパラメータとみなして学習させる方法について取り扱う。この学習にあたっては、通常の最尤推定ではない目的関数を設定し、パラメータの推定を行う。

具体的には「訓練データ」と「無作為に生成したノイズ」を区別するような目的関数を用いることで$Z$の推定を行う。

以下、上記に基づく手法である「雑音対照推定(NCE; Noise Contrasive Estimation)」、「負例サンプリング(Negative Sampling)」についてそれぞれ取り扱った。

NCE

NCE(Noise Contrasive Estimation)では「訓練データ」と「ノイズの分布$q$から生成された標本」の識別を行うような分類器の学習を行う。分類を取り扱うにあたって、確率変数$D$を用いて「訓練データ」を$D=1$、「ノイズからの標本」を$D=0$で定義する。

ここで$D=0$が$D=1$の$k$倍出現しやすいと仮定するとき、$D$と単語に対応する$Y$の同時確率は$D=1$と$D=0$に対してそれぞれ下記のように表せる。
$$
\large
\begin{align}
P(D=1, Y=y) &= \frac{1}{k+1} p(y) \quad (2.1) \\
P(D=0, Y=y) &= \frac{k}{k+1} q(y) \quad (2.2)
\end{align}
$$

このとき、単語に関する周辺確率$P(Y=y)$は下記のように表せる。
$$
\large
\begin{align}
P(Y=y) &= P(D=1, Y=y) + P(D=0, Y=y) \\
&= \frac{1}{k+1} p(y) + \frac{k}{k+1} q(y) \quad (2.3)
\end{align}
$$

上記に基づいて単語$Y=y$がノイズからの標本である確率$P(D=0|Y=y)$は条件付き確率の定義式を元に下記のように得られる。
$$
\large
\begin{align}
P(D=0|Y=y) &= \frac{P(D=0, Y=y)}{P(Y=y)} \\
&= \frac{\displaystyle \frac{k}{\cancel{k+1}} q(y)}{\displaystyle \frac{1}{\cancel{k+1}} p(y) + \frac{k}{\cancel{k+1}} q(y)} \\
&= \frac{k q(y)}{p(y) + k q(y)} \quad (2.4)
\end{align}
$$

同様に単語$Y=y$が訓練データである確率$P(D=1|Y=y)$は下記のように得られる。
$$
\large
\begin{align}
P(D=1|Y=y) &= \frac{P(D=1, Y=y)}{P(Y=y)} \\
&= \frac{\displaystyle \frac{1}{\cancel{k+1}} p(y)}{\displaystyle \frac{1}{\cancel{k+1}} p(y) + \frac{k}{\cancel{k+1}} q(y)} \\
&= \frac{p(y)}{p(y) + k q(y)} \quad (2.5)
\end{align}
$$

ここで$1$つの訓練データ$y$に対して、ノイズ分布$q$からの$k$個の標本$\tilde{\mathcal{D}}=(\tilde{y}_{1}, \cdots , \tilde{y}_{k})$の無作為抽出を行う。この$k+1$個の事例に対して下記の関数$l_{\boldsymbol{\theta}}^{\mathrm{NCE}}(y)$を定義する。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{NCE}}(y) &= -\log{P(D=1|Y=y)} – \sum_{\tilde{y} \in \mathcal{D}} \log{P(D=0|Y=\tilde{y})} \\
&= -\log{\frac{p(y)}{p(y) + k q(y)}} – \sum_{\tilde{y} \in \mathcal{D}} \log{\frac{k q(\tilde{y})}{p(\tilde{y}) + k q(\tilde{y})}} \quad (2.6)
\end{align}
$$

上記の$(2.6)$式は$k+1$個の事例に対する$D$の負の対数尤度であり、$\boldsymbol{\theta}$は関数$\displaystyle p(y)=\frac{\exp{(s(y))}}{Z}=\exp{(s(y)+c)}$で用いられるパラメータである。

$\displaystyle p(y)=\frac{\exp{(s(y))}}{Z}=\exp{(s(y)+c)}$の$c$は下記のように解釈することができる。
$$
\large
\begin{align}
\exp{(s(y)+c)} &= \exp{(s(y))} \exp{(c)} \\
&= \frac{\exp{(s(y))}}{\exp{(-c)}} = \frac{\exp{(s(y))}}{Z} \\
Z &= \exp{(-c)} \quad (2.7)
\end{align}
$$

$(2.6)$式を$\boldsymbol{\theta}$について最適化する際に、このように導入した$c$を同時に学習させることで分配関数$Z$の計算を省略することができる。また、$Z=\exp{(-c)}=1$とおいても結果が変わらないという実験結果もあり、この場合は下記の$(2.8)$式のような目的関数を用いる。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{NCE}}(y) &= -\log{\frac{p(y)}{p(y) + k q(y)}} – \sum_{\tilde{y} \in \mathcal{D}} \log{\frac{k q(\tilde{y})}{p(\tilde{y}) + k q(\tilde{y})}} \quad (2.6) \\
&= -\log{\frac{\exp{(s(y))}}{\exp{(s(y))} + k q(y)}} – \sum_{\tilde{y} \in \mathcal{D}} \log{\frac{k q(\tilde{y})}{\exp{(s(\tilde{y}))} + k q(\tilde{y})}} \quad (2.8)
\end{align}
$$

雑音対照推定(NCE)は上記の$(2.6)$式や$(2.8)$式を用いてパラメータの推定を行う手法である。

負例サンプリング

負例サンプリング(Negative Sampling)はNCEの損失関数の$(2.6)$式の確率にシグモイド関数を用いて簡略化を行った手法である。負例サンプリングの目的関数$l_{\boldsymbol{\theta}}^{\mathrm{NS}}(y)$は下記のように定義される。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{NS}}(y) &= -\log{[\mathrm{sigmoid}(s(y))]} – \sum_{\tilde{y} \in \mathcal{D}} \log{[1-\mathrm{sigmoid}(s(y))]} \\
&= \log{\frac{\exp{(s(y))}}{\exp{(s(y))}+1}} – \sum_{\tilde{y} \in \mathcal{D}} \log{\frac{1}{\exp{(s(\tilde{y}))}+1}} \quad (2.9) \\
\mathrm{sigmoid}(x) &= \frac{\exp{(x)}}{\exp{(x)}+1}
\end{align}
$$

$(2.9)$式の$\tilde{y}$が一様分布によってサンプリングされる場合は、$(2.8)$式で$k=|\mathcal{Y}|$かつ$\displaystyle q(\tilde{y}) = \frac{1}{|\mathcal{Y}|}$である場合の式に一致することは抑えておくと良い。

一方で、負例サンプリングの論文では、単語の出現頻度に比例した確率であるユニグラム確率の$\displaystyle \frac{3}{4}$を使用すると単純なユニグラム確率や一様分布を使用する場合以上に性能が上がるとされる。この辺の設定は実験の前提にもよるので、参考に確認しておくで十分であると思われる。

参考文献

・NCE論文
・負例サンプリング論文