GELU(Gaussian Error Linear Unit)の数式とグラフの描画

近年様々なタスクに用いられるTransformer処理では活性化関数にGELU(Gaussian Error Linear Unit)が用いられることが多いです。当記事ではGELUの数式の確認と、Pythonを用いたグラフの描画を行いました。
当記事の作成にあたっては、GELU論文や「深層学習 第$2$版」の第$2$章「ネットワークの基本構造」の内容などを参考にしました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

GELUの数式

標準正規分布の累積分布関数

GELU(Gaussian Error Linear Unit)の数式には標準正規分布$\mathcal{N}(0,1)$の累積分布関数が用いられます。標準正規分布の確率密度関数を$\phi(x)$、累積分布関数を$\Phi(x)$とおくとき、$\phi(x), \Phi(x)$はそれぞれ下記のように表されます。
$$
\large
\begin{align}
\phi(x) &= \frac{1}{\sqrt{2 \pi}} \exp{ \left( \frac{x^{2}}{2} \right) } \\
\Phi(x) &= \int_{-\infty}^{x} \phi(t) dt
\end{align}
$$

GELUの数式

前項で確認を行った標準正規分布の累積分布関数$\Phi(x)$を元にGELUの数式$\mathrm{GELU}(x)$は下記のように定義されます。
$$
\large
\begin{align}
\mathrm{GELU}(x) = x \Phi(x)
\end{align}
$$

GELUの微分

$\mathrm{GELU}(x) = x \Phi(x)$は下記のように計算することができます。
$$
\large
\begin{align}
\frac{d}{dx} \mathrm{GELU}(x) &= \frac{d}{dx} (x \Phi(x)) \\
&= \Phi(x) + x \cdot \frac{d}{dx} \Phi(x) \\
&= \Phi(x) + x \phi(x) \quad (1)
\end{align}
$$

GELUのグラフの描画

ReLUとGELUのグラフは下記を実行することで行うことができます。

import numpy as np
from scipy import stats

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme()

x = np.arange(-2.5, 2.51, 0.01)

y_relu = np.maximum(0, x)
y_gelu = x * stats.norm.cdf(x)

plt.plot(x, y_relu, label="ReLU")
plt.plot(x, y_gelu, label="GELU")

plt.legend()
plt.show()

・実行結果

上記のGELUの理解にあたっては、下記を実行するとわかりやすいと思います。

x = np.arange(-1., 2., 0.01)

y = x
y_gelu = x * stats.norm.cdf(x)

plt.plot(x, y, label="ReLU")
plt.plot(x, y_gelu, label="GELU")

plt.plot(x, np.zeros(x.shape[0]), "k--")

print("Phi(-1): {:.2f}".format(stats.norm.cdf(-1)))
print("Phi(1): {:.2f}".format(stats.norm.cdf(1)))
print("Phi(2): {:.2f}".format(stats.norm.cdf(2)))

plt.legend()
plt.show()

・実行結果

$\Phi(-1)=0.16, \, \Phi(1)=0.84, \, \Phi(2)=0.98$のような値が得られた一方で、$x=-1$ではオレンジが青のおおよそ$0.16$倍、$x=1$ではオレンジが青のおおよそ$0.84$倍、$x=2$ではオレンジが青のおおよそ$0.98$倍であることがそれぞれ確認できます。

また、$x=0$で微分を行うことのできないReLUに対し、GELUでは微分を行うことが可能です。前節の$(1)$式に基づいてGELUの$x=0$における接線は下記のように描画することができます。

x = np.arange(-1., 2., 0.01)

y_gelu = x * stats.norm.cdf(x)
y_tangent = stats.norm.cdf(0) * x

plt.plot(x, y_gelu, label="GELU")
plt.plot(x, y_tangent, label="tangent_line")

plt.plot([0], [0], "go")

plt.legend()
plt.show()

・実行結果

参考

・GELU論文