【Word2vecなどの出力層高速化】巨大なソフトマックス関数の課題と重点サンプリングによる解決

分布仮説(distributional hypothesis)に基づくWord$2$vecなどの学習にあたっては、出力層が語彙の数に対応する分類問題に対応するので、そのまま取り扱うと巨大なソフトマックス関数の取り扱いが必要になります。当記事はこの課題の重点サンプリングを用いた解決策について取りまとめました。
「深層学習による自然言語処理」$4.3$節の「出力層の高速化」などを参考に当記事の作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

ソフトマックス関数の巨大化

交差エントロピー損失関数とソフトマックス関数

Word$2$vecや翻訳・文書要約・対話などの文章の生成タスクの文章の生成タスクの学習を行う際は、入力文$\mathbf{x} \in \mathcal{X}$と予測対象の単語$y \in \mathcal{Y}$を元に、下記のように交差エントロピー損失関数を定義する。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) = -\log{ \frac{\exp{(f_{\boldsymbol{\theta}}(\mathbf{x},y))}}{\displaystyle \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y}))}} } \quad (1.1)
\end{align}
$$

文章の全体の生成における学習については$(1.1)$式を各単語について和を計算すれば良いが、式が複雑になるので以下では$1$つの単語の予測のみを取り扱う。ここで$(1.1)$式は条件付き確率分布$P(y|\mathbf{x})$やソフトマックス関数$\mathrm{softmax}(x)$を用いて下記のように表すこともできる。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) &= -\log{P(y|\mathbf{x})} \quad (1.2) \\
P(y|\mathbf{x}) &= \mathrm{softmax}(f_{\boldsymbol{\theta}}(\mathbf{x},y)) = \frac{\exp{(f_{\boldsymbol{\theta}}(\mathbf{x},y))}}{\displaystyle \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y}))}} \quad (1.3)
\end{align}
$$

$(1.2)$式は負の対数尤度に一致するので、損失関数の最大化は対数尤度の最大化に対応する。また、$(1.3)$式はソフトマックス関数の定義式を含む。

分配関数の定義

前項「交差エントロピー損失関数とソフトマックス関数」$(1.1)$式の$y$の取り扱いにあたって、下記のように関数$s(y)$を定義する。
$$
\large
\begin{align}
s(y) = f_{\boldsymbol{\theta}}(\mathbf{x},y) \quad (1.4)
\end{align}
$$

また、下記のように分配関数(partition function)の$Z(\mathcal{Y})$を定義する。
$$
\large
\begin{align}
Z(\mathcal{Y}) = \sum_{\tilde{y} \in \mathcal{Y}} \exp{(s(\tilde{y}))} = \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y})))} \quad (1.5)
\end{align}
$$

このとき$(1.1)$式は$(1.4), \, (1.5)$式を用いて下記のように表すことができる。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) &= -\log{ \frac{\exp{(f_{\boldsymbol{\theta}}(\mathbf{x},y))}}{\displaystyle \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y}))}} } \quad (1.1) \\
&= -\log{ \frac{\exp{(s(y))}}{Z(\mathcal{Y})} } \\
&= -s(y) + \log{Z(\mathcal{Y})} \quad (1.6)
\end{align}
$$

パラメータに関する勾配

$(1.6)$式の両辺をパラメータベクトル$\boldsymbol{\theta} \in \mathbb{R}^{p}$で方向微分すると下記が得られる。
$$
\large
\begin{align}
\nabla l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) &= – \nabla s(y) + \nabla \log{Z(\mathcal{Y})} \quad (1.6)’ \\
\nabla &= \frac{\partial}{\partial \boldsymbol{\theta}} = \left(\begin{array}{c} \displaystyle \frac{\partial}{\partial \theta_{1}} \\ \vdots \\ \displaystyle \frac{\partial}{\partial \theta_{p}} \end{array} \right)
\end{align}
$$

ここで$(1.6)’$式の第$2$項の$\nabla \log{Z(\mathcal{Y})}$は合成関数の微分の公式を元に下記のように変形できる。
$$
\large
\begin{align}
\nabla \log{Z(\mathcal{Y})} &= \sum_{\tilde{y} \in \mathcal{Y}} \frac{\exp{(s(\tilde{y}))}}{Z(\mathcal{Y})} s'(\tilde{y}) \\
&= \sum_{\tilde{y} \in \mathcal{Y}} p(\tilde{y}) s'(\tilde(y)) \\
&= \mathbb{E}_{Y \sim p}[s'(Y)] \quad (1.7)
\end{align}
$$

$(1.7)$式を元に$(1.6)’$式は下記のように表せる。
$$
\large
\begin{align}
\nabla l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) &= – \nabla s(y) + \nabla \log{Z(\mathcal{Y})} \quad (1.6)’ \\
&= – \nabla s(y) + \mathbb{E}_{Y \sim p}[s'(Y)] \quad (1.8)
\end{align}
$$

$(1.8)$式の計算にあたって$\mathcal{Y}$が巨大な場合は、$\mathbb{E}_{Y \sim p}[s'(Y)]$は全ての語彙に対して計算する必要があることで計算量が大きくなる。この際に$\mathcal{Y}$に比例しない計算量で近似を行う手法にはいくつかあるが、次節では重点サンプリングを用いた手法を確認する。

重点サンプリングを用いた出力層の高速化

モンテカルロ法

$(1.8)$式における$\mathbb{E}_{Y \sim p}[s'(Y)]$のモンテカルロ法を用いた近似式は下記のように表せる。
$$
\large
\begin{align}
\mathbb{E}_{Y \sim p}[s'(Y)] & \simeq \frac{1}{n} \sum_{i=1}^{n} s'(\tilde{y}_{i}) \quad (2.1) \\
\tilde{y}_{1}, \, \cdots , & \, \tilde{y}_{n} \sim p \quad (2.2)
\end{align}
$$

ここで$(2.2)$式の$p$は$(1.3)$式に対応するので、ここでの$\tilde{y}_{i}$のサンプリングを行うにあたっては分配関数$Z$の計算が必要である。よって、次項では$p$を用いずにサンプリングを行う重点サンプリングを確認する。

重点サンプリング

対象の分布からの無作為抽出が難しい場合に重点サンプリング(importance sampling)はよく用いられる。ここでは$(1.3)$式で表される$p$の代わりに提案分布$q$を用いて重点サンプリングを行うことを考える。このとき$\mathbb{E}_{Y \sim p}[s'(Y)]$は下記のように表すことができる。
$$
\large
\begin{align}
\mathbb{E}_{Y \sim p}[s'(Y)] &= \sum_{\tilde{y} \in \mathcal{Y}} s'(\tilde{y})p(\tilde{y}) \\
&= \sum_{\tilde{y} \in \mathcal{Y}} s'(\tilde{y}) \frac{p(\tilde{y})}{q(\tilde{y})} q(\tilde{y}) \\
&= \mathbb{E}_{Y’ \sim q} \left[ s'(Y) \frac{p(Y’)}{q(Y’)} \right] \quad (2.3)
\end{align}
$$

重点サンプリングの数式表記と計算例については下記でも取り扱った。

ここで$(2.3)$式に基づくモンテカルロ近似の式は下記のように表せる。
$$
\large
\begin{align}
\mathbb{E}_{Y \sim p}[s'(Y)] = \mathbb{E}_{Y’ \sim q} \left[ s'(Y) \frac{p(Y’)}{q(Y’)} \right] \simeq \frac{1}{n} \sum_{i=1}^{n} s'(\tilde{y}_{i}) \frac{p(\tilde{y}_{i})}{q(\tilde{y}_{i})} \quad (2.4)
\end{align}
$$

$(2.4)$式を用いることで提案分布$q$に基づくサンプリングを行える一方で、$(2.4)$式の$p(\tilde{y}_{i})$の計算には分配関数$Z$の計算が必要である。そこで次項では分配関数$Z$の近似値の取得について取り扱う。

分配関数の近似

単語の集合を$\mathcal{Y}$で表したので、語彙数は$|\mathcal{Y}|$で表せる。ここで下記のような一様分布$u$を定義する。
$$
\large
\begin{align}
u(x) = \frac{1}{|\mathcal{Y}|} \quad (2.5)
\end{align}
$$

$(2.5)$式で表した一様分布$u$に従う確率変数$X \sim u$について$\exp{(s(X))}$の期待値は下記のように表すことができる。
$$
\large
\begin{align}
\mathbb{E}_{X \sim u}[\exp{(s(X))}] = \frac{1}{|\mathcal{Y}|} \sum_{\tilde{x} \in \mathcal{Y}} \exp{(s(\tilde{x}))} \quad (2.6)
\end{align}
$$

ここで$(2.6)$式に$(1.5)$式を用いることで、下記のような変形を行うことができる。
$$
\large
\begin{align}
\mathbb{E}_{X \sim u}[\exp{(s(X))}] &= \frac{1}{|\mathcal{Y}|} \sum_{\tilde{x} \in \mathcal{Y}} \exp{(s(\tilde{x}))} \quad (2.6) \\
\mathbb{E}_{X \sim u}[\exp{(s(X))}] &= \frac{1}{|\mathcal{Y}|} Z(\mathcal{Y}) \\
Z(\mathcal{Y}) &= |\mathcal{Y}| \mathbb{E}_{X \sim u}[\exp{(s(X))}] \quad (2.7)
\end{align}
$$

上記の$(2.7)$式に対し、下記の導出を元に提案分布$q$を用いて重点サンプリングを行う。
$$
\large
\begin{align}
Z(\mathcal{Y}) &= |\mathcal{Y}| \mathbb{E}_{X \sim u}[\exp{(s(X))}] \quad (2.7) \\
&= |\mathcal{Y}| \mathbb{E}_{Y’ \sim q} \left[\exp{(s(Y’))} \frac{u(Y’)}{q(Y’)} \right] \\
&= |\mathcal{Y}| \mathbb{E}_{Y’ \sim q} \left[ \exp{(s(Y’))} \frac{1}{|\mathcal{Y}| q(Y’)} \right] \\
&= \frac{\cancel{|\mathcal{Y}|}}{\cancel{|\mathcal{Y}|}} \mathbb{E}_{Y’ \sim q} \left[ \frac{\exp{(s(Y’))}}{q(Y’)} \right] \\
&= \mathbb{E}_{Y’ \sim q} \left[ \frac{\exp{(s(Y’))}}{q(Y’)} \right] \\
& \simeq \frac{1}{n} \sum_{i=1}^{n} \frac{\exp{(s(\tilde{y}_{i}))}}{q(\tilde{y}_{i})} \\
&= \hat{Z} \quad (2.8)
\end{align}
$$

勾配の式の導出

$(1.3)$式、$(1.5)$式、$(2.8)$式に基づいて下記が成立する。
$$
\large
\begin{align}
p(y) &= \frac{\exp{(f_{\boldsymbol{\theta}}(\mathbf{x},y))}}{\displaystyle \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y}))}} \quad (1.3)’ \\
&= \frac{\exp{(s(y))}}{Z(\mathcal{Y})} \quad (1.5)’ \\
& \simeq \frac{\exp{(s(y))}}{\hat{Z}} \quad (2.9)
\end{align}
$$

$(2.9)$式を元に下記が成立する。
$$
\large
\begin{align}
s'(y) \frac{p(y)}{q(y)} & \simeq s'(y) \frac{\exp{(s(y))}/\hat{Z}}{q(y)} \\
&= s'(y) \frac{\exp{(s(y))}/q(y)}{\hat{Z}} \\
&= s'(y) \frac{\displaystyle \frac{\exp{(s(y))}}{q(y)}}{\displaystyle \frac{1}{n} \sum_{i=1}^{n} \frac{\exp{(s(\tilde{y}_{i}))}}{q(\tilde{y}_{i})}} \quad (2.10)
\end{align}
$$

$(2.3), \, (2.4)$式の重点サンプリングの式に$(2.10)$式を適用すると、サンプル$\tilde{q}_{1}, \cdots , \tilde{q}_{n} \sim q$を元に下記が得られる。
$$
\large
\begin{align}
\mathbb{E}_{Y \sim p}[s'(Y)] &= \mathbb{E}_{Y’ \sim q} \left[ s'(Y) \frac{p(Y’)}{q(Y’)} \right] \quad (2.3) \\
& \simeq \frac{1}{n} \sum_{i=1}^{n} s'(\tilde{y}_{i}) \frac{p(\tilde{y}_{i})}{q(\tilde{y}_{i})} \times \left( \frac{1}{n} \sum_{i=1}^{n} \frac{\exp{(s(\tilde{y}_{i}))}}{q(\tilde{y}_{i})} \right)^{-1} \\
&= \frac{\displaystyle \sum_{i=1}^{n} s'(\tilde{y}_{i}) \frac{p(\tilde{y}_{i})}{q(\tilde{y}_{i})}}{\displaystyle \sum_{i=1}^{n} \frac{\exp{(s(\tilde{y}_{i}))}}{q(\tilde{y}_{i})}} \quad (2.11)
\end{align}
$$

$(2.11)$式を$(1.8)$式に代入することで下記が得られる。
$$
\large
\begin{align}
\nabla l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) &= – \nabla s(y) + \mathbb{E}_{Y \sim p}[s'(Y)] \quad (1.8) \\
& \simeq – \nabla s(y) + \frac{\displaystyle \sum_{i=1}^{n} s'(\tilde{y}_{i}) \frac{p(\tilde{y}_{i})}{q(\tilde{y}_{i})}}{\displaystyle \sum_{i=1}^{n} \frac{\exp{(s(\tilde{y}_{i}))}}{q(\tilde{y}_{i})}} \quad (2.12)
\end{align}
$$

$(2.12)$式を用いることで、語彙数$|\mathcal{Y}|$の大きさに関わらず勾配の近似値の計算を行うことができる。

式の解釈

当記事で確認を行った導出は複雑かつ難解であるので、「語彙数$|\mathcal{Y}|$が大きくなる際の勾配計算の高速化」が主目的であることは常に念頭に置く必要がある。

語彙数$|\mathcal{Y}|$が大きくなると分配関数$Z(\mathcal{Y})$の計算量が増加するが、分配関数$Z(\mathcal{Y})$の計算なしでは$(2.1), \, (2.2)$式に基づいてモンテカルロ法をそのまま適用することができない。

上記に対し、「重点サンプリング」と「分配関数の近似」を用いて勾配の近似を行うというのが当節における導出の流れである。特に分配関数の近似にあたっては、ニューラルネットワークの出力層に対応する$s(\tilde{y}_{i})$を提案分布$q$からサンプリングされた$\tilde{y}_{i}$を元にいくつか計算することで、分配関数の大体の大きさを推測することができると解釈しておくとよい。ソフトマックス関数における分配関数$Z(\mathcal{Y})$の役割は値の正規化であり、$(2.8)$式のような近似で大まかな値を得ることができる。

また、提案分布$q$に$u$と同じく一様分布を用いる場合は$(2.8)$が下記のように表される。
$$
\large
\begin{align}
\hat{Z} \simeq \frac{1}{n} \sum_{i=1}^{n} |\mathcal{Y}| \exp{(s(\tilde{y}_{i}))} \quad (2.8)’
\end{align}
$$

上記を$1$つのサンプルを語彙数倍しサンプル数で割ったと解釈すると、サンプリング結果に基づく分配関数の近似式であることが理解しやすい。