Adam(Adaptive moment)の導出と直感的理解 〜momentum、AdaGrad〜

DeepLearningの学習でよく用いられるAdam(Adaptive moment)ですが、式が少々複雑なので直感的な理解が難しいです。当記事ではAdamを構成する主な考え方のmomentumとAdaGradなどを元にAdamを表し、その直感的な理解に関して取りまとめを行いました。

深層学習」の$3.1$節の「確率的勾配降下法(SGD)」と$3.5$節の「SGDの改良」を参考に作成を行いました。

Adamを構成する考え方

確率的勾配降下法(SGD)

$(x_1,y_1), …, (x_n,y_n)$が観測された際に$f(x_i,w)$によって$y_i$を予測することを考える。DeepLearningを考えるにあたって$f(x_i,w)$は多層パーセプトロンやCNNなどが用いられるが、$f(x_i,w) = x_i^{\mathrm{T}}w$のようにシンプルな回帰を考えてもこれは同様である。この学習にあたっては、クロスエントロピー誤差関数などを用いてパラメータ$w$の損失関数$E(w)$を定義し、$E(w)$の最小化を行う。

このときの目的は$E(w)$を最小にする$w$を計算することであり、この計算にあたって勾配降下法のような最適化のアルゴリズムが用いられる。線形回帰では正規方程式の解を計算することができたが、GLMやDeepLearningでは解を直接解くことが難しいので勾配法などが用いられる。

ここで損失関数$E(w)$はサンプル$(x_i,y_i)$ごとに計算を行うことができるので、$i$番目のサンプルの損失関数を$E_i(w)$、サンプルセット全体の損失関数を$E(w)$のように表すことを考える。

このとき$E(w)$は$E_i(w)$を用いて下記のように表すことができる。
$$
\large
\begin{align}
E(w) = \sum_{i=1}^{n} E_i(w)
\end{align}
$$

上記に対し、$E(w)$を用いて勾配法を実行することをDeepLearningの分野ではバッチ学習(batch learning)という。バッチ学習では下記のような漸化式を用いて$w$を繰り返し演算によって求める。
$$
\large
\begin{align}
w_{t+1} &= w_{t} – \alpha \nabla E(w)
\end{align}
$$

バッチ学習のような手法は局所解から抜け出せない場合もありうるので、確率的に選んだサンプルに対して学習を行うことで大域最適解を探すにあたって確率的勾配降下法(SGD)の考え方が重要になる。

確率的勾配降下法(SGD; Stochastic Gradient Descent)では全体の損失関数の$E(w)$ではなく、サンプル$(x_i,y_i)$に対する損失関数の$E_i(w)$を用いることで下記のように$w$の更新を行う。
$$
\large
\begin{align}
w_{t+1} = w_{t} – \alpha \nabla E_i(w)
\end{align}
$$

一方で確率的勾配降下法ではサンプル$1$つずつに関して計算することで、計算の並列化が難しくパラメータの収束がなかなか進まない。この解決にあたって、サンプル$1$つずつではなく数十〜数百ほどのサンプルを同時に取り扱うことで効率化が行われる。

このような複数サンプルはミニバッチ(minibatch)といわれ、現行の多くのDeepLearningの学習においてはミニバッチに基づいてパラメータの推定が行われる。

$w_{t}$の更新に用いるミニバッチを$\mathcal{D}_t$、$\mathcal{D}_t$に対応する勾配を$E_t(w)$のようにおくとき、$E_t(w)$は下記のように表される。
$$
\large
\begin{align}
E_j(w) = \sum_{i \in \mathcal{D}_t} E_i(w)
\end{align}
$$

また、ミニバッチの損失関数$E_t(w)$を元に勾配法を考えると下記のような数式で表すことができる。
$$
\large
\begin{align}
w_{t+1} &= w_{t} – \alpha \nabla E_t(w) \quad (1)
\end{align}
$$

なお、ミニバッチを用いた学習のことも確率的勾配降下法ということは抑えておくとよい。

・注意事項
参照元の「深層学習」の表記では学習率は$\epsilon$で表されるが、AdaGrad以降で$\varepsilon$が出てきて判別しにくいので当記事では$\epsilon$から$\alpha$に置き換えを行った。

momentum

確率的勾配法(SGD)を用いるにあたってはパラメータの更新毎に異なるサンプルセットを用いることで勾配がばらつくということが起こる。これにより学習が不安定になりやすいが、この対応にあたって、前のステップの修正量を用いるモメンタム(momentum)という手法がよく用いられる。

momentumではミニバッチ$\mathcal{D}_{t-1}$に対する修正量を$v_{t} = w_{t} – w_{t-1}$のように定義し、下記のような式に基づいて$w$に関する計算を行う。
$$
\large
\begin{align}
w_{t+1} = w_{t} – \alpha \nabla E_{t}(w) + \mu v_t
\end{align}
$$

上記の計算式の$w_{t+1} = w_{t} – \alpha \nabla E_{t}(w)$はミニバッチを用いるSGDと同じ式だが、momentumでは$\mu v_t = \mu(w_{t} – w_{t-1})$がさらに加算される。

ここで$\mu$は加算の割合を制御するハイパーパラメータであり、通常は$0.5$から$0.9$の範囲から選ばれる。

また、$(2)$式は$v_{t} = w_{t} – w_{t-1}$であることに基づいて下記のように表すこともできる。
$$
\large
\begin{align}
w_{t+1} &= w_{t} – \alpha \nabla E_{t}(w) + \mu v_t \\
w_{t+1} – w_{t} &= \mu v_t – \alpha \nabla E_{t}(w) \\
v_{t+1} &= \mu v_t – \alpha \nabla E_{t}(w) \quad (3)
\end{align}
$$

$\Delta w$を用いた式表記

シンプルな表記を行うにあたって$w$の$k$番目のパラメータに関して$\displaystyle \Delta w_{t,k} = w_{t+1,k}-w_{t,k}, g_{t,k} = \frac{\partial}{\partial w_k} E_t(w)$の表記を導入することを考える。ここで$\Delta w_{t,j}$は「momentum」で導入した$v_t$と同様の式であることに注意する。

確率的勾配降下法を表す$(1)$式は$\Delta w_t$を用いて下記のように書き直すことができる。
$$
\large
\begin{align}
w_{t+1,k} &= w_{t,k} – \alpha \frac{\partial}{\partial w_k} E_t(w) \\
w_{t+1,k} – w_{t,k} &= – \alpha \frac{\partial}{\partial w_k} E_t(w) \\
\Delta w_{t,k} &= – \alpha g_{t,k}
\end{align}
$$

適応的調整(AdaGrad)

AdaGradでは前項の表記に基づいて下記のような$\Delta w_{t,k}$を元に$w$の繰り返し演算を行う手法である。
$$
\large
\begin{align}
\Delta w_{t,k} = – \frac{\alpha}{\sqrt{\sum_{t’=1}^{t} g_{t’,k}^2+\varepsilon}} g_{t,k}
\end{align}
$$

上記の式は、「$k$方向の累積の修正量が多い場合は$k$方向の修正量を減らす」というように解釈することができる。また$k$方向の累積の修正量の$2$乗の和は$0$になることもあり得ることから$0$で割ることを避けるにあたって$\varepsilon$が導入されている。

このような$w$の修正量の調整は「適応的(Adaptive)」と表現され、AdaGradは”Adaptive Gradient”の略だと考えることができる。

移動平均(RMS Prop, Adadelta)

前項の「AdaGrad」では学習開始時からの$g_{t,k}^2$の総和を考えたが、これでは学習が進むにつれて修正量が$0$になるので、これを避けるにあたって総和ではなく移動平均を取る手法が考案されている。

移動平均を取る主な手法には「RMSProp」と「Adadelta」があるので、以下それぞれの移動平均の取り方に関して取りまとめ、その後に移動平均の基本式の理解に関して確認を行う。

・RMSPropで用いられる移動平均
RMSPropでは勾配の二乗和の移動平均を$\langle g_k^2 \rangle_t = \gamma \langle g_k^2 \rangle_{t-1} + (1-\gamma) g_{t,k}^2$のように定め、下記のように重み$w$の修正量の計算を行う。
$$
\large
\begin{align}
\Delta w_{t,k} = – \frac{\alpha}{\sqrt{\langle g_k^2 \rangle_t+\varepsilon}} g_{t,k}
\end{align}
$$

・Adadeltaで用いられる移動平均

・移動平均の式の理解
移動平均の式の理解にあたっては平均の漸化式的な計算を行う式に基づいて理解すればよい。移動平均に関して考える前に$a_1, a_2, …, a_n$の平均を$\bar{a}_n$とおき、$\bar{a}_{n}$を$\bar{a}_{n}$と$a_{n}$を用いて表すことを考える。
$$
\large
\begin{align}
\bar{a}_{n} &= \frac{1}{n} \sum_{i=1}^{n} a_i \\
&= \frac{1}{n} \sum_{i=1}^{n-1} a_i + \frac{1}{n} a_n \\
&= \frac{1}{n} \times \frac{n-1}{n-1} \times \sum_{i=1}^{n-1} a_i + \frac{1}{n} a_n \\
&= \frac{n-1}{n} \times \frac{1}{n-1} \sum_{i=1}^{n-1} a_i + \frac{1}{n} a_n \\
&= \frac{n-1}{n} \bar{a}_{n-1} + \frac{1}{n} a_n
\end{align}
$$

上記に対し、移動平均を$\tilde{a}_{n} = \gamma \tilde{a}_{n-1} + (1-\gamma) a_n = (1-\gamma) a_n + \gamma \tilde{a}_{n-1}$のように表すと考える。この式を元に$\tilde{a}_{n}$は下記のように導出できる。
$$
\large
\begin{align}
\tilde{a}_{n} &= (1-\gamma) a_{n} + \gamma \tilde{a}_{n-1} \quad (4) \\
&= (1-\gamma) a_{n} + \gamma ( (1-\gamma) a_{n-1} + \gamma \tilde{a}_{n-2}) \\
&= (1-\gamma) a_{n} + \gamma (1-\gamma) a_{n-1} + \gamma^2 ( (1-\gamma) a_{n-2} + \gamma \tilde{a}_{n-3}) \\
&= … \\
&=(1-\gamma) ( a_{n} + \gamma a_{n-1} + \gamma^2 a_{n-2} … )
\end{align}
$$

ここで$0 < \gamma < 1$の$\gamma$を考えると、上記は等比数列の和であることが確認でき、移動平均の定義を表すことが確認できる。

Adam

Adamの式表記

Adamでは勾配の$1$次と$2$次のモーメントをそれぞれ$m_{t,k}, v_{t,k}$のように定義し、それぞれを下記の移動平均で計算を行う。
$$
\large
\begin{align}
m_{t,k} &= \beta_{1} m_{t-1,k} + (1-\beta_{1}) g_{t,k} \quad (5) \\
v_{t,k} &= \beta_{2} v_{t-1,k} + (1-\beta_{2}) g_{t,k}^2 \quad (6)
\end{align}
$$

ここで上記は偏差を含むので、下記のように$m_{t,k}, v_{t,k}$の補正を行った$\hat{m}_{t,k}, \hat{v}_{t,k}$を定義する。
$$
\large
\begin{align}
\hat{m}_{t,k} &= \frac{m_{t,k}}{(1-\beta_{1}^t)} \\
\hat{v}_{t,k} &= \frac{v_{t,k}}{(1-\beta_{2}^t)}
\end{align}
$$

上記を元に$\Delta w_{t,k}$を下記のように考える。
$$
\large
\begin{align}
\Delta w_{t,k} = – \frac{\alpha}{\sqrt{\hat{v}_{t,k}}+\varepsilon} \hat{m}_{t,k}
\end{align}
$$

Adamでは上記を用いてパラメータ$w$の推定を行う。

Adamの式の導出

$(5)$式の導出にあたっては、$(3)$式を$(4)$式に基づいて再構成を行ったと考えることができる。$(3)$式を成分$k$に関して考える場合、$m_{t,k}$を用いて下記のように表すことができる。
$$
\large
\begin{align}
v_{t+1,k} &= \mu v_{t,k} – \alpha \frac{\partial}{\partial w_k} E_{t}(w) \quad (3)’ \\
m_{t,k} &= \mu m_{t-1,k} – \alpha g_{t,k}
\end{align}
$$

このとき上記のパラメータ$\mu, \alpha$を$(4)$式のような移動平均の漸化式で表すことを考える。
$$
\large
\begin{align}
m_{t,k} &= \mu m_{t-1,k} – \alpha g_{t,k} \\
m_{t,k} &= \beta_{1} m_{t-1,k} + (1-\beta_{1}) g_{t,k} \quad (5)
\end{align}
$$

上記のように$(5)$式が導出できる。

・注意事項
当項の導出は「深層学習」の表記から筆者が推測したものに過ぎないので、要出典であることに注意が必要である。

偏差の補正の詳細

「Adamの式表記」では$(5)$式と$(6)$式の補正を行ったが、当項では補正の詳細に関して式変形の確認を行う。Adamの論文では$(6)$式を元に解説が行われているので、以下同様に$(6)$式の導出を確認する。

$v_t$に関して$v_0=0$とすると、$(6)$式に基づいて$v_t$は下記のように変形を行うことができる。
$$
\large
\begin{align}
v_{t,k} &= (1-\beta_{2}) g_{t,k}^2 + \beta_{2} v_{t-1,k} \quad (6) \\
&= (1-\beta_{2}) g_{t,k}^2 + \beta_{2} ((1-\beta_{2}) g_{t-1,k}^2 + \beta_{2} v_{t-2,k}) \\
&= (1-\beta_{2}) g_{t,k}^2 + \beta_{2} (1-\beta_{2}) g_{t-1,k}^2 + \beta_{2}^2 ((1-\beta_{2}) g_{t-2,k}^2 + \beta_{2} v_{t-3,k}) \\
&= \cdots \\
&= (1-\beta_{2}) \sum_{i=1}^{t} \beta_{2}^{t-i} g_{i}^2 + \beta_{2}^n v_{0,k} \\
&= (1-\beta_{2}) \sum_{i=1}^{t} \beta_{2}^{t-i} g_{i}^2 \quad (7)
\end{align}
$$

ここで上記の$(7)$式の両辺の期待値を考える。
$$
\large
\begin{align}
\mathbb{E}[v_{t,k}] &= \mathbb{E} \left[ (1-\beta_{2}) \sum_{i=1}^{t} \beta_{2}^{t-i} g_{i,k}^2 \right] \quad (7) \\
&= (1-\beta_{2}) \sum_{i=1}^{t} \beta_{2}^{t-i} \mathbb{E} \left[ g_{i,k}^2 \right] \\
&= (1-\beta_{2}) \sum_{i=1}^{t} \beta_{2}^{t-i} \mathbb{E} \left[ g_{t,k}^2 \right] \\
&= \mathbb{E} \left[ g_{t,k}^2 \right] (1-\beta_{2}) \sum_{j=1}^{t} \beta_{2}^{j-1} \\
&= \mathbb{E} \left[ g_{t,k}^2 \right] (1-\beta_{2}) \frac{1-\beta_{2}^t}{1-\beta_{2}} \\
&= \mathbb{E} \left[ g_{t,k}^2 \right] (1-\beta_{2}^t) \quad (8)
\end{align}
$$

ここで$v_{t,k}$を用いて$g_{t,k}^2$の推定を行うと考えるとき、上記の$(8)$式に基づいて不偏性を考える場合、$v_{t,k}/(1-\beta_{2}^t)$を元に推定を行うと良いことがわかる。この変形が$\displaystyle \hat{v}_{t,k} = \frac{v_{t,k}}{(1-\beta_{2}^t)}$に対応する。

またここで行った変形を$m_{t,k}, \hat{m}_{t,k}$に同様に適用することができることも合わせて抑えておくと良い。

・参考
Adam論文 Section.$3$ Initialization Bias Correction

Adamの解釈・直感的理解

Adamの改良

参考

・深層学習

・Adam論文