ブログ

DeepLearningを用いた順序回帰(ordinal regression)

順序尺度を目的変数に持つ順序回帰(ordinal regression)は「単なる分類」や「連続値の回帰」と同様な取り扱いをするかどうかを注意して検討する必要があります。当記事ではDeepLearningを用いた順序回帰の取り扱いについて取りまとめを行いました。
当記事の作成にあたっては、「深層学習 第$2$版」の第$2$章「ネットワークの基本構造」の内容などを参考にしました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

順序回帰の問題設定と学習の方針

順序回帰の問題設定

口コミサイトやAmazonの$5$段階評価のように、順序が決まったカテゴリを持つ目的変数$y_1, \cdots , y_n$を下記のように定義する。
$$
\large
\begin{align}
y_1, \cdots , y_n & \in \{ C_1, C_2, C_3, C_4, C_5 \} \\
C_1 < C_2 &< C_3 < C_4 < C_5
\end{align}
$$

上記のカテゴリ$C_1$〜$C_5$はそれぞれ$5$段階評価の$1$〜$5$に対応するとここでは解釈する。ここで通常のカテゴリ分類ではこのように$C_1$〜$C_5$が得られた際に$C_3$が正解であれば$C_3$のみが正解、その他は不正解である一方で、上記のような場合、$C_2, C_4$は「正解に近い」と見なせる。

このように目的変数が順序が決まったクラスで表されるとき、「正解/不正解」以外に「正解に近い」というのも取り扱う必要が生じる。

DeepLearningを用いた順序回帰

前項の「順序回帰の問題設定」で確認したように、順序回帰では「多クラス分類」と「通常の回帰」の中間のような取り扱いが必要になる。よって、DeepLearningを順序回帰に用いるにあたっては出力層や損失関数の設計を柔軟に取り扱う必要が生じる。詳しくは次節で取り扱った。

DeepLearningの構成法

DeepLearningを順序回帰に用いるにあたっては、大まかに$2$パターンの構成法が存在する。以下、それぞれについて取りまとめた。

2値分類に基づく構成

DeepLearningを順序回帰に用いるにあたっての方法の$1$つ目が、「クラス数$K$に対し、$K-1$個の$2$値分類に帰着させる方法」である。文章で表すと処理の実体以上に難しいので、まず$3$クラス分類を元に具体的に確認を行う。
$$
\large
\begin{align}
[1, 0]
\end{align}
$$

たとえば$3-1=2$個の$2$値分類が上記のように得られた場合、「正解クラスが$1$より大きく、$2$以下であるので、カテゴリ$2$が対応する」という規則を定義する。このとき$[0, 0]$が得られた場合は「正解クラスが$1$以下」なのでカテゴリ$1$、$[1, 1]$が得られた場合は「正解クラスが$2$より大きい」のでカテゴリ$3$が対応する。同様な規則を用いることで$K$クラス分類問題を$K-1$個の$2$値分類を元に表すことが可能である。

上記のような規則を用いることで「$K$クラス分類を$K-1$個の$2$値分類に帰着させる」ことは基本的には可能である一方で、$[1, 1, 0, 0, 0]$のように$1$から$0$に変わるのは最大$1$回でなければならないという制約があることには注意が必要である。DeepLearningの学習時はアノテーション作成時に$[1, 1, 0, 0, 0]$が自動的に対応する一方で、学習済みのDeepLearningを元に予測を行う際は$[1, 0, 1, 0, 0]$のような結果が出力される場合がある。

$[1, 0, 1, 0, 0]$のような場合を取り扱うにあたっては、予測時は「$1$の数$+1$をカテゴリとする」ことで対処可能である。たとえば$[1, 0, 1, 0, 0]$の場合は$2+1=3$のように計算できる。

一方DeepLearningの学習時はたとえばカテゴリ$3$に対応する$[1, 1, 0, 0, 0]$が出力されるように、各出力層をロジスティック回帰に対応させてクロスエントロピーの最小化によってオーソドックスな学習を行えばよい。

ソフトラベルの使用

DeepLearningを順序回帰に用いるにあたっての方法の$2$つ目が、「$K$クラス分類問題に対しソフトラベル用いる方法」である。一般的な$K$クラス分類問題では目的変数の正解フラグに$1-$of$-K$符号(ハードラベル)を用いるが、確率化と同様な処理を行なったソフトラベルを用いる。
$$
\large
\begin{align}
d_{k} = \frac{\exp{(-|\bar{k}-k|)}}{\sum_{i=1}^{K} \exp{(-|\bar{k}-i|)}}
\end{align}
$$

$\bar{k}$が正解であるとき、ソフトラベル$\mathbf{d}$の$k$成分の$d_k$はたとえば上記のように定義される。$1-$of$-K$表現では$[0, 0, 1, 0, 0]$のように表される場合、$\mathbf{d}$は下記のように計算できる。

import numpy as np

k_t = 3.
idx = np.arange(1,6,1)
d = np.zeros(5)

for k in range(5):
    d[k] = np.exp(-np.abs(k_t-idx[k])) / np.sum(np.exp(-np.abs(k_t-idx)))

print("d: {}".format(d))
print("sum of d: {}".format(np.sum(d)))

・実行結果

d: [ 0.06745081  0.1833503   0.49839779  0.1833503   0.06745081]
sum of d: 1.0

上記のように計算された$\mathbf{d}$を用いてDeepLearningの学習を行うことで順序回帰の学習を行うことができる。推論時にはスコアが最大の$k$がクラスに対応する。

【CNN】DeepLearningで用いられる正規化 〜バッチ正規化、グループ正規化〜

バッチ正規化(batch normalization)のような正規化処理はMLP(Multi Layer Perceptron)に限らず広く用いられます。当記事ではCNN(Convolutional Neural Network)の学習にあたって用いられるバッチ正規化やグループ正規化などについて取りまとめました。
当記事の作成にあたっては、Group Normalization論文や「深層学習 第$2$版」の$5.5$節「畳み込み層の出力の正規化」の内容などを参考にしました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

MLPにおける正規化

下記で詳しく取り扱った。

CNNにおける正規化

Group Normalizationの論文とCNNにおける正規化の体系化

CNN(Convolutional Neural Network)における正規化を把握するにあたってはグループ正規化(Group Normalization)論文の図を元に確認すると良い。

Group Normalization論文 Figure$\, 2$

上図に基づいてCNNにおける「バッチ正規化(Batch Normalization)」、「レイヤー正規化(Layer Normalization)」、「インスタンス正規化(Instance Normalization)」、「グループ正規化(Group Normalization)」をそれぞれ理解することが可能である。

図の$C$は畳み込みにおけるチャネル数、$N$は同時に処理するバッチに含まれるサンプル数にそれぞれ対応する。また、$H,W$は画像の高さ$H$と幅$W$を$2$次元から$1$次元に変換したものであると理解すると良い1

以下、「バッチ正規化」、「レイヤー正規化」、「インスタンス正規化」、「グループ正規化」のそれぞれの詳細について確認を行う。

バッチ正規化

バッチ正規化(Batch Normalization)はチャネル毎に「バッチに含まれる全てのサンプルの全ての位置の値の平均・分散を計算」し、正規化処理を行う手法である。チャネル毎に平均$\mu_{c}$を計算することを下記のような式で表すこともできる。
$$
\large
\begin{align}
\mu_{c} = \frac{1}{NWH} \sum_{i,j,n} u_{ijc}^{(n)} \quad (1)
\end{align}
$$

$(1)$式における$u_{ijc}^{(n)}$は$n$番目のサンプルの$c$番目のチャネルにおける位置$(i,j)$の値に対応する。また、$\displaystyle \sum_{i,j,n}$は下記のように置き換えて理解すれば良い。
$$
\large
\begin{align}
\sum_{i,j,n} \longrightarrow \sum_{n=1}^{N} \sum_{i=1}^{W} \sum_{j=1}^{H}
\end{align}
$$

レイヤー正規化

レイヤー正規化(Layer Normalization)はバッチに含まれるサンプル毎に「全てのチャネルの全ての位置の値の平均・分散を計算」し、正規化処理を行う手法である。チャネル毎に平均$\mu_{n}$を計算することを下記のような式で表すこともできる。
$$
\large
\begin{align}
\mu_{c} = \frac{1}{CWH} \sum_{i,j,c} u_{ijc}^{(n)} \quad (2)
\end{align}
$$

$(2)$式の$\displaystyle \sum_{i,j,c}$は下記のように置き換えて理解すれば良い。
$$
\large
\begin{align}
\sum_{i,j,n} \longrightarrow \sum_{c=1}^{C} \sum_{i=1}^{W} \sum_{j=1}^{H}
\end{align}
$$

グループ正規化

グループ正規化(Group Normalization)はレイヤー正規化をチャネル方向にいくつかグループを作成し、正規化を行う手法である。チャネル群の$k$番目のグループのチャネルのインデックスの集合を$\mathcal{S}_{k}$とおくと、サンプル$n$、グループ$k$の平均$\mu_{k}^{(n)}$は下記のように計算できる。
$$
\large
\begin{align}
\mu_{c} = \frac{1}{|\mathcal{S}_{k}|WH} \sum_{c \in \mathcal{S}_{k}} \sum_{i,j} u_{ijc}^{(n)} \quad (3)
\end{align}
$$

上記の$|\mathcal{S}_{k}|$は$\mathcal{S}_{k}$に含まれるチャネルのインデックスの数に対応する。$(3)$式の$\displaystyle \sum_{i,j}$は下記のように置き換えて理解すれば良い。
$$
\large
\begin{align}
\sum_{i,j} \longrightarrow \sum_{i=1}^{W} \sum_{j=1}^{H}
\end{align}
$$

参考

・Group Normalization論文
・バッチ正規化(Batch Normalization)論文

  1. CNNの入力は$(N, C, W, H)$のような$4$次元で表されるが、行列の積に基づく演算の場合は$2$次元、図で表す場合は$3$次元表記が基本となるので、ここで取り扱ったような行列の変換はよく行われることは抑えておくと良い。 ↩︎

DeepLearningの正則化 〜L2正則化、ドロップアウト(dropout)〜

DeepLearningなどの機械学習では学習時に用いたサンプルへの過学習(overfitting)が課題になります。当記事ではDeepLearningで過学習を防ぐにあたって導入される正則化(regularization)の手法であるドロップアウトなどについて取りまとめました。
当記事の作成にあたっては、Dropout論文や「深層学習 第$2$版」の第$3$章「確率的勾配降下法」の内容などを参考にしました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

過学習

学習に用いたサンプル以外のサンプルに対し正しく予測が行えることを汎化(generalization)という。関連して訓練データに対する誤差を訓練誤差(training error)、サンプルの母集団に対する誤差の期待値を汎化誤差(generalization error)という。

一方でサンプルの母集団は具体的に得られないので汎化誤差を直接計算することができない。この対応にあたっては学習に用いる訓練データ以外のテスト用データを用意し訓練誤差と同様に誤差を計算することが多い。テストデータを用いて計算される誤差をテスト誤差(test error)という。

汎化誤差を近似し得るテスト誤差の値が訓練誤差の値と大きく乖離することを過剰適合(overfitting)や過学習(overlearning)という。DeepLearningの学習にあたっては過学習が起こらないようにテスト誤差の値をチェックすることが重要になる。

正則化

DeepLearningのようにパラメータが多い際に過学習は生じやすい。この対応にあたっては「学習時にパラメータに一定の制約を課す」ことが多い。このような「パラメータに制約を課すこと」を「正則化(regularization)」という。

正則化の手法は様々であるが、重回帰やロジスティック回帰のような一般化線形モデル(GLM; Generalized Linear Model)の学習の際にはL$2$正則化、CNNの学習ではドロップアウト(dropout)が用いられることが多い。

L2 正則化

パラメータ$\mathbf{w}$によって予測される結果の$n$番目のサンプルの誤差関数(error function)を$E_{n}(\mathbf{w})$、ミニバッチ$\mathcal{D}_{t}$全体の誤差関数を$E_{t}(\mathbf{w})$とおくとき、L$2$正則化に基づいて$E_{t}(\mathbf{w})$は下記のように表される。
$$
\large
\begin{align}
E_{t}(\mathbf{w}) = \sum_{n \in \mathcal{D}_{t}} E_{n}(\mathbf{w}) + \frac{\lambda}{2} ||\mathbf{w}||^{2}
\end{align}
$$

上記の式は第$2$項を追加することにより、パラメータが大きくならないように勾配が生じると解釈すれば良い。$\lambda$は正則化をどのくらい行うかを表すハイパーパラメータである。$\lambda$の値の直感的な解釈にあたって$f(x,y)=(x-1)^{2}+(y-1)^{2}$の正則化を元に以下、確認を行う。
$$
\large
\begin{align}
f(x, y) &= (x-1)^{2}+(y-1)^{2} \\
g(x, y) &= f(x, y) + \frac{\lambda}{2} (x^{2} + y^{2})
\end{align}
$$

上記の例では$f(x, y)$の偏微分が下記のように計算できる。
$$
\large
\begin{align}
\frac{\partial f(x, y)}{\partial x} &= 2(x-1) \\
\frac{\partial f(x, y)}{\partial y} &= 2(y-1)
\end{align}
$$

上記より$(x,y)=(1,1)$で$f(x,y)$は最小値を持つ1。同様に$g(x,y)$の偏微分は下記のように計算できる。
$$
\large
\begin{align}
\frac{\partial g(x, y)}{\partial x} &= 2(x-1) + \lambda x = (2+\lambda)x \, – \, 2 \\
\frac{\partial g(x, y)}{\partial y} &= 2(y-1) + \lambda y = (2+\lambda)x \, – \, 2
\end{align}
$$

上記より$\displaystyle (x,y) = \left( \frac{2}{2 + \lambda}, \frac{2}{2 + \lambda} \right)$で$g(x,y)$は最小値を持つ。

ここで$\lambda = 0, 2$や$\lambda \to \infty$のとき、$\displaystyle (x,y) = \left( \frac{2}{2 + \lambda}, \frac{2}{2 + \lambda} \right)$はそれぞれ下記のように計算できる。
・$\lambda = 0$
$$
\large
\begin{align}
(x,y) &= \left( \frac{2}{2 + \lambda}, \frac{2}{2 + \lambda} \right) \\
&= \left( \frac{2}{2 + 0}, \frac{2}{2 + 0} \right) = (1,1)
\end{align}
$$

・$\lambda = 2$
$$
\large
\begin{align}
(x,y) &= \left( \frac{2}{2 + \lambda}, \frac{2}{2 + \lambda} \right) \\
&= \left( \frac{2}{2 + 2}, \frac{2}{2 + 2} \right) = \left( \frac{1}{2}, \frac{1}{2} \right)
\end{align}
$$

・$\lambda \to \infty$
$$
\large
\begin{align}
(x,y) &= \left( \frac{2}{2 + \lambda}, \frac{2}{2 + \lambda} \right) \\
& \to (0,0)
\end{align}
$$

よって$\lambda = 0$のとき、$g(x,y)$の最小点は$f(x,y)$の最小点に一致し、$\lambda \to \infty$に近づくにつれて徐々に$(0,0)$に近づいていくことが確認できる。このようにL$2$正則化を行うことで、パラメータの大きさに制約を設定することができる。

L$2$正則化(L$2$ regularization)は主に重回帰やGLMの学習にあたって用いられる手法であるが、DeepLearningはGLMの延長と解釈することもできるので、DeepLearningにL$2$正則化を用いること自体はそれほど不自然ではないので合わせて抑えておく必要がある。

DeepLearningの正則化

ドロップアウト

ドロップアウト(dropout)はニューラルネットワークの学習時に中間層の変数をランダムに選別して削除する方法であり、正則化の$1$つに分類される。

Dropout論文 Figure$\, 1$

MLP(Multi Layer Perceptron)の学習におけるDropoutを図示すると上図の右のようになる。ドロップアウトの処理は、ランダムフォレストのようなアンサンブル学習と同様に解釈すると良い。

また、学習時の変数の採用確率が$p$のレイヤーでは、推論時に「出力を$p$倍する」か「パラメータを$p$倍する」操作が必要であることも注意しておく必要がある。たとえば$p=0.5$のレイヤーでは推論時の層の出力を$0.5$倍する必要がある。

陰的正則化

確率的勾配降下法(SGD; Stochastic Gradient Descent method)では「全てのサンプルを同時に学習させる」ではなく、「サンプルをランダムに選び学習させる」ということが行われる。

この際に目的関数が変化することで正則化のような効果が得られると解釈が可能である。このように「サンプルをランダムに選び学習させる」ことによって得られる「正則化効果」を「陰的正則化」という。学習の早期終了も陰的正則化の$1$つであると見なすことができる。

  1. 元の式が平方完成であることに着目すると$(x,y)=(1,1)$が最下点であることが自明であるが、ここでは$g(x,y)$の最小値問題も同じ手法で解くにあたって偏微分の傾きが$0$になる点を取得した。 ↩︎

【MLP】DeepLearningで用いられる正規化 〜バッチ正規化、レイヤー正規化〜

DeepLearningの学習にあたっては多層の計算処理を行うので、パラメータの値によっては計算結果が外れ値と同様に歪な分布になる場合があります。当記事ではこのような現象の解決にあたって導入されることが多いバッチ正規化やレイヤー正規化などの正規化(normalization)について取りまとめを作成しました。
当記事の作成にあたっては、「深層学習 第$2$版」の$3.6$節「層出力の正規化」の内容などを参考にしました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

正規化

サンプル集合$\mathcal{D} = \{ (\mathbf{x}_{1}, \mathbf{y}_{1}), \, \cdots \, (\mathbf{x}_{N}, \mathbf{y}_{N}) \}$が得られたとき、サンプル$\mathbf{x}_{n}$の$i$成分を$x_{n,i}$のように定義する。このとき$x_{n,i}$の正規化は下記のような式で表される。
$$
\large
\begin{align}
x_{n,i} & \longleftarrow \frac{x_{n,i} \, – \, \bar{x}_{i}}{\sqrt{\sigma_{i}^{2} + \varepsilon}} \\
\bar{x}_{i} &= \frac{1}{N} \sum_{n=1}^{N} x_{n, i} \\
\sigma_{i} &= \sqrt{\frac{1}{N} \sum_{n=1}^{N}(x_{n,i}-\bar{x}_{i})^{2}}
\end{align}
$$

上記の$\varepsilon$は$\sigma_{i}=0$のときも計算が行われるように$\varepsilon=10^{-5}$のような小さな数を設定する。任意の$n$について$x_{n,i}=\bar{x}_{i}$が成立する際に$\sigma_{i}=0$が成立することも合わせて抑えておくと良い。

DeepLearningにおける正規化

バッチ正規化

サンプル集合$\mathcal{D} = \{ (\mathbf{x}_{1}, \mathbf{y}_{1}), \, \cdots \, (\mathbf{x}_{N}, \mathbf{y}_{N}) \}$の$\mathbf{x}_{n}$の中間層$\mathbf{u}_{n} \in \mathbb{R}^{D}$を下記のように定義する。
$$
\large
\begin{align}
\hat{u}_{n} = \left( \begin{array}{ccc} u_{n1} & \cdots & u_{nD} \end{array} \right) \quad (1)
\end{align}
$$

このときバッチ正規化(batch normalization)の演算は下記のような式で定義される。
$$
\large
\begin{align}
\hat{u}_{nj} &= \gamma_{j} \frac{u_{nj} \, – \, \mu_{j}}{\sqrt{\sigma_{i}^{2} + \varepsilon}} + \beta_{j} \\
\mu_{j} &= \frac{1}{N} \sum_{n=1}^{N} u_{nj} \\
\sigma_{j} &= \sqrt{\frac{1}{N} \sum_{n=1}^{N}(u_{nj}-\mu_{j})^{2}}
\end{align}
$$

上記のようにバッチ正規化では$n$番目のサンプルの中間層$\hat{u}_{nj}$を全ての中間層の位置$j$の値に基づいて正規化を行うことで得る。バッチ正規化はある程度の数のバッチサイズがある前提の計算であるので、バッチが少ない場合やサンプル$1$つの推論を行う際はそのままの処理を用いることができない。特に推論を行う際はサンプル$1$つの計算を行う場合が多いので、学習時に計算を行った$\mu_{j}$や$\sigma_{j}$の値の移動平均を用いるなどで代用することが多い。また、$\beta_{j}, \gamma_{j}$は初期値を$\beta_{j}=0, \gamma_{j}=1$に設定した上でMLPのアフィン変換のパラメータと同様に学習を行います。

レイヤー正規化

前項「バッチ正規化」の$(1)$のようにサンプル集合を定義するとき、レイヤー正規化(layer normalization)の演算は下記のような式で定義される。
$$
\large
\begin{align}
\hat{u}_{nj} &= \gamma_{j} \frac{u_{nj} \, – \, \mu}{\sqrt{\sigma^{2} + \varepsilon}} + \beta_{j} \\
\mu &= \frac{1}{ND} \sum_{n=1}^{N} \sum_{j=1}^{D} u_{nj} \\
\sigma &= \sqrt{\frac{1}{ND} \sum_{n=1}^{N} \sum_{j=1}^{D} (u_{nj}-\mu)^{2}}
\end{align}
$$

行列式と置換②:置換(permutation)の合成と使用例の確認

線形代数の枠組みで$n$次正方行列の行列式(determinant)を取り扱うにあたっては置換(permutation)という概念を抑えておく必要があります。当記事では置換(permutation)の合成の概要と具体的な使用例の確認について取りまとめを行いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$4$章「行列式」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

置換の合成

置換の定義

置換$\sigma$は下記のように表される。
$$
\large
\sigma :
\begin{cases}
1 \longmapsto 3 \\
2 \longmapsto 5 \\
3 \longmapsto 2 \\
4 \longmapsto 4 \\
5 \longmapsto 1
\end{cases}
$$

$\sigma$は下記のように表す場合もある。
$$
\large
\begin{align}
\sigma = \left[ \begin{array}{ccccc} 1 & 2 & 3 & 4 & 5 \\ 3 & 5 & 2 & 4 & 1 \end{array} \right]
\end{align}
$$

上記の詳細は下記で取り扱った。
https://www.hello-statisticians.com/math_basic/matrix_permutation2.html

置換の合成

置換$\sigma$を作用させた後に置換$\tau$を作用させる置換を$\sigma$と$\tau$の合成といい、$\tau \sigma$のように表す。基本的に置換は演算子$\nabla$のように右から作用させるので、$\tau \sigma$の場合は$\sigma, \, \tau$の順に置換の処理が行われる。

置換の使用例

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$053$

・$[1]$
$$
\large
\begin{align}
\sigma = \left[ \begin{array}{cccc} 1 & 2 & 3 & 4 \\ 3 & 2 & 4 & 1 \end{array} \right], \, \tau = \left[ \begin{array}{cccc} 1 & 2 & 3 & 4 \\ 2 & 4 & 3 & 1 \end{array} \right]
\end{align}
$$

「$1 \longmapsto 3 \longmapsto 3$」、「$2 \longmapsto 2 \longmapsto 4$」、「$3 \longmapsto 4 \longmapsto 1$」、「$4 \longmapsto 1 \longmapsto 2$」のように置換処理を行うので、合成置換$\tau \sigma$は下記のように表される。
$$
\large
\begin{align}
\tau \sigma = \left[ \begin{array}{cccc} 1 & 2 & 3 & 4 \\ 3 & 4 & 1 & 2 \end{array} \right]
\end{align}
$$

・$[2]$
$$
\large
\begin{align}
\sigma = \tau = \left[ \begin{array}{ccccc} 1 & 2 & 3 & 4 & 5 \\ 2 & 3 & 5 & 1 & 4 \end{array} \right]
\end{align}
$$

「$1 \longmapsto 2 \longmapsto 3$」、「$2 \longmapsto 3 \longmapsto 5$」、「$3 \longmapsto 5 \longmapsto 4$」、「$4 \longmapsto 1 \longmapsto 2$」、「$5\longmapsto 4 \longmapsto 1$」のように置換処理を行うので、合成置換$\tau \sigma$は下記のように表される。
$$
\large
\begin{align}
\sigma = \tau = \left[ \begin{array}{ccccc} 1 & 2 & 3 & 4 & 5 \\ 3 & 5 & 4 & 2 & 1 \end{array} \right]
\end{align}
$$

ResNet(Deep Residual Network)の構成の詳細とベンチマークまとめ

ResNetはCNNに基づくDeepLearningにResidual Blockを導入することで層の深いCNNの学習を可能にしたアーキテクチャです。当記事では現在画像認識タスクなどでデフォルトに用いられることが多いResNetの構成の詳細とベンチマークについて取りまとめを行いました。
当記事の作成にあたっては、ResNet論文や「深層学習 第$2$版」の第$5$章「畳み込みニューラルネットワーク」の内容などを参考にしました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

畳み込み演算の概要

畳み込み演算の数式

フィルタサイズ$W_{f} \times H_{f}$を用いた畳み込みによって出力の$(i,j)$成分$u_{ij}$が計算されるとき、入力に対応する$x$とフィルタに対応する$h$を元に$u_{ij}$は下記のように計算されます。
$$
\large
\begin{align}
u_{ij} = \sum_{p=0}^{W_f-1} \sum_{q=0}^{H_f-1} x_{i+p,j+q} h_{pq}
\end{align}
$$

上記は入力のチャネルとフィルタの枚数が$1$の場合の畳み込みに対応しますが、入力のチャネル数が$C$、フィルタの枚数$C_{out}$枚のとき畳み込みによって出力の$u_{ijk}$成分は下記のように計算されます。
$$
\large
\begin{align}
u_{ijk} = \sum_{c=1}^{C} \sum_{p=0}^{W_f-1} \sum_{q=0}^{H_f-1} x_{i+p,j+q,c} h_{pqck} + b_{k}
\end{align}
$$

上記の$c$は入力のチャネルのインデックス、$k$はフィルタのインデックスにそれぞれ対応します。フィルタのインデックスと出力のチャネルのインデックスが一致することも合わせて確認しておくと良いです。

Resの構成とパフォーマンス

Residual Block

ResNet論文 Figure$\, 2$

ResNetでは上図のResidual Blockが導入されたことが特徴的です。入力$\mathbf{x}$に対応するパラメータ処理を$\mathcal{F}$と表す場合、AlexNetやVGGNetのようなResNet以前のDeepLearningでは出力を$\mathcal{F}(\mathbf{x})$で計算するのに対し、ResNetでは下記のように計算を行います。
$$
\large
\begin{align}
\mathcal{F}(\mathbf{x}) + \mathbf{x}
\end{align}
$$

このような計算を行うことにより、層の深いDeepLearningにおける勾配消失(vanishing gradients)/勾配爆発(exploding gradients)問題を緩和することが可能になります。たとえば中間層における入力を$\mathbf{h}_{l}$、出力を$\mathbf{h}_{l+1}$とおくとき、Residual Blockの式に基づいて$\mathbf{h}_{l+1} = \mathcal{F}(\mathbf{h}_{l}) + \mathbf{h}_{l}$のように処理が表されます。

このとき、誤差逆伝播の式における$\displaystyle \frac{\partial \mathbf{h}_{l+1}}{\partial \mathbf{h}_{l}}$は下記のように得られます。
$$
\large
\begin{align}
\frac{\partial \mathbf{h}_{l+1}}{\partial \mathbf{h}_{l}} = \frac{\partial \mathcal{F}(\mathbf{h}_{l})}{\partial \mathbf{h}_{l}} + \mathbf{1} \quad (1)
\end{align}
$$

一方でResidual Blockを用いないオーソドックスなDeepLearningにおける$\displaystyle \frac{\partial \mathbf{h}_{l+1}}{\partial \mathbf{h}_{l}}$は下記のように計算されます。
$$
\large
\begin{align}
\frac{\partial \mathbf{h}_{l+1}}{\partial \mathbf{h}_{l}} = \frac{\partial \mathcal{F}(\mathbf{h}_{l})}{\partial \mathbf{h}_{l}} \quad (2)
\end{align}
$$

$(2)$式ではある層のパラメータが全て$0$に近くなった場合、$\displaystyle \frac{\partial \mathbf{h}_{l+1}}{\partial \mathbf{h}_{l}} \simeq 0$となり、誤差逆伝播における以降の勾配が全て零ベクトル/零行列となります。

ある層のパラメータが全て$0$に近くなる場合も$(1)$式が用いられていればそれまでの勾配が等倍されるので勾配が保存されます。

VGG-19とResNetの対応

ResNetの構造は下図のようなVGG-$19$との対応を元に理解すると良いです。

ResNet論文 Figure$\, 3$

プーリングとダウンサンプリング

前項「VGG-$19$とResNetの対応」の図でVGG-$19$とResNetの対応の確認を行いましたが、一番左のVGG-$19$ではプーリングを行なっているのに対し、真ん中と右側ではプーリングではなくストライドを$2$にすることでダウンサンプリングを行なっていることに注意が必要です。

入門書などのCNNの解説では「畳み込み」と「プーリング」がセットで解説されることが多い一方で、ストライドが$2$以上の畳み込みの二つの処理が代用されることが多くなっているようです1

bottleneck構造とResNetの層の数

オーソドックスなResNetでは$3 \times 3$の畳み込みが$2$回繰り返されるところを、$1 \times 1$、$3 \times 3$、$1 \times 1$の$3$回に置き換えた構造をbottleneck構造といいます。bottleneck構造は下図のように表されます。

ResNet論文 Figure$\, 5$

上記のbottleneck構造を元に、ResNetの構成は下記のようなパターンを持ちます。

ResNet論文 Table$\, 1$

ResNetのパフォーマンス

ResNet論文 Table$\, 4$
ResNet論文 Table$\, 5$

参考

・ResNet論文
・VGGNet論文

  1. 深層学習 第$2$版 $5.4$節 P.$87 \,$ l.$4$〜$5$ ↩︎

同変性(equivariance)と不変性(invariance)の定義と具体例に基づく解釈

点群(point clouds)の取り扱いやCNN(Convolutional Neural Network)を用いた画像処理の理解にあたって、同変性(equivariance)と不変性(invariance)を抑えておくと良いです。当記事では同変性と不変性の定義や具体的な処理例に基づく解釈について取りまとめを行いました。
当記事の作成にあたっては、「深層学習 第$2$版」の$5.7.1$「同変性と不変性」の内容などを参考にしました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

同変性と不変性の定義

同変性(equivariance)の定義

DeepLearningの学習結果に基づく推論は基本的に「入力$\mathbf{x}$に関数$f$を作用させて特徴量$\mathbf{x}’=f(\mathbf{x})$を得る」という流れで処理が行われる。このとき変換の$g$に対し、「関数$f$が同変である」場合、下記の式が成立するように変換$g’$が存在する。
$$
\large
\begin{align}
f(g(\mathbf{x})) = g'(f(\mathbf{x}))
\end{align}
$$

上記の変換$g’$は$g$に一致する場合もある。たとえば「畳み込み演算$f$」における「並進移動」を$g$とおくと、$f$がプーリングのようなダウンサンプリングを伴わない場合、$g’$は$g$に一致する。

不変性(invariance)の定義

$\mathbf{x}’=f(\mathbf{x})$に基づく特徴量の取得にあたって、変換$g$について下記が成立する場合、「$g$について$f$が不変である」という。
$$
\large
\begin{align}
f(\mathbf{x}) = f(g(\mathbf{x}))
\end{align}
$$

たとえば点群のような集合データにおける並び替えを$g$、PointNetやSet TransformerのようなDeepLearningに対応する関数を$f$とする場合、並び替えによって結果は変わらないので$f$は$g$について不変(invariant)である。

同変性と不変性の解釈

同変性(equivariance)は「畳み込み演算$f$」における「平行移動$g$」のように入力も出力も一定の順序で配置される場合に出てくると理解しておくとよい。一方、不変性(invariance)は「点群などの集合データ」に対して「Set Transformerのような演算$f$」を行う際の「並び変え$g$」や、「GAP(Global Average Pooling)を伴うCNNに対応する関数$f$」と「平行移動$g$」について成立するので、最終的に和を計算する場合などに成立しやすいと解釈しておくと良い。

「点群」などの処理に用いられるSet Transformerは下記で取り扱った。

ISAB(Induced Set Attention Block)とSet Transformer

点群(point clouds)のような集合の入力(input set)の処理にあたってTransformerを用いた研究にSet Transformerがあります。当記事ではISAB(Induced Set Attention Block)などを中心にSet Transformer論文の取りまとめを行いました。
「Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks」や「深層学習 改定第$2$版」の第$7$章「集合・グラフのためのネットワークと注意機構」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

Transformerの基本式

Transformerの基本的な処理であるDot Product Attentionを$\mathrm{Attention}(Q, K, V)$とおくと、$\mathrm{Attention}(Q, K, V)$は下記のような式で表されます。
$$
\large
\begin{align}
\mathrm{Attention}(Q, K, V) &= \mathrm{Softmax} \left( \frac{Q K^{\mathrm{T}}}{\sqrt{d}} \right) V \quad (1) \\
Q & \in \mathbb{R}^{m \times d}, K \in \mathbb{R}^{n \times d}, \, V \in \mathbb{R}^{n \times d} \\
Q K^{\mathrm{T}} & \in \mathbb{R}^{m \times n}, \, \mathrm{Attention}(Q, K, V) \in \mathbb{R}^{m \times d}
\end{align}
$$

Transformerでは基本的に$K=V$であり、さらにEncoderなどのSelf Attentionでは$Q=K=V$である場合が多いです。上記では$Q \neq K, \, K = V$の場合の立式を行いました。計算結果の$\mathrm{Attention}(Q, K, V)$が$\mathrm{Attention}(Q, K, V) \in \mathbb{R}^{m \times d}$のように$K, V$ではなく$Q$と同じサイズの行列が得られることに注意が必要です。また、$(1)$式における$\sqrt{d}$は$\mathrm{Softmax}$の計算結果が極端にならないように導入されます。

次に、Multi Head Attention演算を$\mathrm{Multihead}(Q, K, V)$のようにおくと、$\mathrm{Multihead}(Q, K, V)$は$(1)$式を元に下記のように定義されます。
$$
\large
\begin{align}
\mathrm{Multihead}(Q, K, V) &= \mathrm{concat}(O_1, \cdots , O_h)W^{O} \\
O_i &= \mathrm{Attention}(QW_{i}^{Q}, KW_{i}^{K}, VW_{i}^{V}) \\
O_i & \in \mathbb{R}^{m \times d_v}, \, \mathrm{concat}(O_1, \cdots , O_h) \in \mathbb{R}^{m \times h d_v} \\
W_{i}^{Q} & \in \mathbb{R}^{d \times d_k}, \, W_{i}^{K} \in \mathbb{R}^{d \times d_k}, W_{i}^{K} \in \mathbb{R}^{d \times d_v} \\
W^{O} & \in \mathbb{R}^{h d_v \times d} \\
d &= h d_{k} = h d_{v} \\
\mathrm{Multihead}(Q, K, V) & \in \mathbb{R}^{m \times d}
\end{align}
$$

Set Transformer

Multihead Attention Block

Set Transformerの論文ではMulti Head Attentionが下記のような$\mathrm{MAB}(X, Y)$で定義されます。
$$
\large
\begin{align}
\mathrm{MAB}(X, Y) &= \mathrm{LayerNorm}(H + \mathrm{FFN}(H)) \\
H &= \mathrm{LayerNorm}(X + \mathrm{Multihead}(X, Y, Y)) \\
X \in \mathbb{R}^{m \times d}, \, Y & \in \mathbb{R}^{n \times d}, H \in \mathbb{R}^{m \times d}, \, \mathrm{FFN}(H) \in \mathbb{R}^{m \times d}, \, \mathrm{MAB}(X, Y) \in \mathbb{R}^{m \times d}
\end{align}
$$

$\mathrm{MAB}(X, Y)$の$\mathrm{MAB}$はMultihead Attention Blockの略です。Set Transformerの論文では$\mathrm{MAB}(X, Y)$が下図のようにも表されます。

Set Transformer論文 Figure$\, 1, (b)$

Multihead Attention Blockの理解にあたっては、出力の$\mathrm{MAB}(X, Y)$の行列のサイズが$X$の行列のサイズに一致することに注意しておくと良いです。$X$は通常のTransformerの$Q, K, V$の$Q$に対応します。

Set Attention Block

Set Transformerの論文ではSet Attention Blockの$\mathrm{SAB}(X)$を下記のように定義します。
$$
\large
\begin{align}
\mathrm{SAB} = \mathrm{MAB}(X, X)
\end{align}
$$

上記の式はSet Transformerの論文では下記のように図式化されます。

Set Transformer論文 Figure$\, 1, (c)$

$\mathrm{SAB}(X)$はTransformerのEncoderにおける$Q=K=V$のSelf Attentionと同様の処理であると理解すると良いです。

Induced Set Attention Block

Set Transformerの論文ではInduced Set Attention Blockを表す$\mathrm{ISAB}_{m}(X)$が下記のように定義されます。
$$
\large
\begin{align}
\mathrm{ISAB}_{m}(X) &= \mathrm{MAB}(X, H) \\
H &= \mathrm{MAB}(I, X) \\
X & \in \mathbb{R}^{n \times d}, \, I \in \mathbb{R}^{m \times d}, \, H \in \mathbb{R}^{m \times d}, \, \mathrm{ISAB}_{m}(X) \in \mathbb{R}^{n \times d}
\end{align}
$$

上記の式はSet Transformerの論文では下記のように図式化されます。

Set Transformer論文 Figure$\, 1, (d)$

ISABの式は、「通常のTransformerにおけるAttentionの計算量が$\mathcal{O}(n^{2})$であり、点が多くなると処理が難しい」ので、「計算量が$\mathcal{O}(mn)$になるように$m$個のinducing pointsを導入しベクトル表現を$I \in \mathbb{R}^{m \times d}$のように定義した」と解釈すると良いです。また、ここで$\mathrm{MAB}$の演算を二回繰り返すことで、$\mathbb{R}^{n \times d} \longrightarrow \mathbb{R}^{m \times d} \longrightarrow \mathbb{R}^{n \times d}$のように元の$X$と同じ行列のサイズに戻すことができることも合わせて抑えておくと良いです。

ここで導入した$I \in \mathbb{R}^{m \times d}$はSet Transformerのパラメータ(trainable parameters)であり、Multi Head Attentionの写像計算のパラメータやFFNのMLP処理のパラメータと同様にTransformerの学習の際に値の推定が行われます。

Pooling

点群のように点の集合のPooling処理を取り扱う際に「CNNのような局所的なPoolingを行うことができない」点について注意が必要です。よって、点の集合のPoolingにあたっては全ての点の「平均」や「最大値」を計算することが一般的です。

一方で、Set TransformerではMultihead Attentionを用いたPooling(PMA; Pooling by Multihead Attention)が行われます。

Set Transformer論文ではPoolingによって$n$個の点を$k$個に集約する場合、下記のようなMultihead Attentionの処理が実行されます。
$$
\large
\begin{align}
\mathrm{PMA}_{k}(Z) &= \mathrm{MAB}(S, FFN(Z)) \\
S & \in \mathbb{R}^{k \times d}, \, Z \in \mathbb{R}^{n \times d}
\end{align}
$$

上記の$Z$は$X$の処理結果に対応し、$S$は前項の$I$のようにSet Transformerにおける学習パラメータです。

多くの場合は$k=1$が用いられる一方で、クラスタリング(clustering)のようなタスクの場合は$k > 1$が用いられることも合わせて抑えておくと良いです。$k=1$の場合は下記で取り扱ったグラフ分類と概ね同様なイメージで理解すると良いと思います。

Overall Architecture

SABを用いる場合

Set Attention Blockを用いる場合のEncoderとDecoderの計算例は複数のSABなどを用いて下記のように表されます。
$$
\large
\begin{align}
\mathrm{Encoder}(X) &= \mathrm{SAB}(\mathrm{SAB}(X)) = Z \\
\mathrm{Decoder}(Z) &= FFN(\mathrm{SAB}(\mathrm{PMA}_{k}(Z))) \\
X & \in \mathbb{R}^{n \times d}, Z \in \mathbb{R}^{n \times d}, \, \mathrm{Decoder}(Z) \in \mathbb{R}^{k \times d}
\end{align}
$$

ISABを用いる場合

EncoderにInduced Set Attention Blockを用いる場合もSet Attention Blockを用いる場合と同様に複数の$\mathrm{ISAB}_{m}$を用いて下記のように計算例が定義されます。
$$
\large
\begin{align}
\mathrm{Encoder}(X) &= \mathrm{ISAB}_{m}(\mathrm{ISAB}_{m}(X)) = Z \\
X & \in \mathbb{R}^{n \times d}, Z \in \mathbb{R}^{n \times d}
\end{align}
$$

Positional Encoding

Set Transformerでは基本的に前節で取り扱ったTransformerのアーキテクチャを用いる一方で、Positional Encodingは用いないことに注意が必要です。点群のような集合は入力間に順序がない(permutation invariant)ので、位置をEncodingするPositional Encodingの必要がありません。

むしろ「元々のTransformerには位置の情報がない一方で機械翻訳などのNLPタスクでは順序を取り扱う必要がありPositional Encodingが導入された」ので、Transformerのアルゴリズム自体は点群のような集合の取り扱いにより即していると解釈することもできると思います。

参考

・Transformer論文:Attention is All you need[2017]
・Set Transformer論文


行列式と置換①:置換(permutation)の定義と使用例の確認

線形代数の枠組みで$n$次正方行列の行列式(determinant)を取り扱うにあたっては置換(permutation)という概念を抑えておく必要があります。当記事では置換(permutation)の定義と具体的な使用例の確認について取りまとめを行いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$4$章「行列式」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

置換の定義

「置換(permutation)」は文字や数字の入れ替えの操作に対応する。たとえば$12345$という数字の列に対し「$1$を$3$」、「$2$を$5$」、「$3$を$2$」、「$4$を$4$」、「$5$を$1$」にそれぞれ入れ替えると$35241$という数字の列が得られる。

上記の「$1$を$3$」、「$2$を$5$」、「$3$を$2$」、「$4$を$4$」、「$5$を$1$」に入れ替える置換操作は$\sigma$という記号を用いて下記のように表される。
$$
\large
\sigma :
\begin{cases}
1 \longmapsto 3 \\
2 \longmapsto 5 \\
3 \longmapsto 2 \\
4 \longmapsto 4 \\
5 \longmapsto 1
\end{cases}
$$

$\sigma$は下記のように表す場合もある。
$$
\large
\begin{align}
\sigma = \left[ \begin{array}{ccccc} 1 & 2 & 3 & 4 & 5 \\ 3 & 5 & 2 & 4 & 1 \end{array} \right]
\end{align}
$$

また、$\sigma$によって$i \longmapsto j$になることを写像と同様な式を用いて$j = \sigma(i)$のように表すこともできる。上記の例はそれぞれ$\sigma(1)=3, \, \sigma(2)=5, \, \sigma(3)=2, \, \sigma(4)=4, \, \sigma(5)=1$のように表せる。

置換の使用例

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$052$

・$[1]$
$$
\large
\begin{align}
\sigma = \left[ \begin{array}{cccc} 1 & 2 & 3 & 4 \\ 3 & 2 & 4 & 1 \end{array} \right], \quad (1324)
\end{align}
$$

「$1 \longmapsto 3$」、「$2 \longmapsto 2$」、「$3 \longmapsto 4$」、「$4 \longmapsto 1$」のように置換処理を行うので、求める順列は$3421$である。

・$[2]$
$$
\large
\begin{align}
\sigma = \left[ \begin{array}{ccccc} 1 & 2 & 3 & 4 & 5 \\ 2 & 3 & 5 & 1 & 4 \end{array} \right], \quad (35124)
\end{align}
$$

「$1 \longmapsto 2$」、「$2 \longmapsto 3$」、「$3 \longmapsto 5$」、「$4 \longmapsto 1$」、「$5\longmapsto 4$」のように置換処理を行うので、求める順列は$54231$である。

・$[3]$
$$
\large
\begin{align}
\sigma = \left[ \begin{array}{cccc} 1 & 2 & 3 & 4 & 5 & 6 \\ 5 & 3 & 4 & 6 & 2 & 1 \end{array} \right], \quad (632514)
\end{align}
$$

「$1 \longmapsto 5$」、「$2 \longmapsto 3$」、「$3 \longmapsto 4$」、「$4 \longmapsto 6$」、「$5 \longmapsto 2$」、「$6 \longmapsto 1$」のように置換処理を行うので、求める順列は$143256$である。

「強化学習」×「Transformer」①:Decision Transformer

Transformerは系列モデリングの学習にあたって様々な用途に用いられており、近年では「強化学習」分野へのTransformerの応用も研究されています。当記事ではTransformerを強化学習に応用した論文の一つであるDecision Transformerについての大枠を取りまとめました。
Decision Transformer論文や「ゼロから作るDeep Learning④ー強化学習編」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformer

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

強化学習の基本知識

Deep Learningを用いた強化学習を理解するにあたって知っておくとよい内容を下記で取りまとめました。

・仕組みから理解するChatGPT(統計の森作成)

強化学習の基本トピックについて詳しく学ぶ際には日本語では「ゼロから作るDeep Learning④ー強化学習編」、英語では「Reinforcement Learning, second edition: An Introduction」がおすすめです。

Offline RL

Deep Q NetworkやAlphaZeroなど、多くの強化学習のアルゴリズムではエージェントと環境の相互作用によって軌道(trajectory)を生成し、「価値関数の近似」や「方策の最適化」を通して意思決定の最適化を行います。

一方で、予め既に得た軌道(trajectory)を元に教師あり学習のような学習を行うことで意思決定の最適化を行うことができます。このような強化学習の枠組みはオフライン強化学習(Offline RL)といわれます。

Decision TransformerではこのOffline RLの枠組みを用いるので、このような強化学習の研究分野があることは抑えておくと良いです。Offline RLについて詳しくは下記などを参照すると良いと思います。

・Offline RL論文
・A Survey on Transformers in Reinforcement Learning

Decision Transformer

Causal Transformer

Decision Transformer論文 Figure$\, 1$

Decision Transformerの処理概要は上図のように表されます。図より、メインの処理がCausal Transformerによって実現できることが確認できますが、Decision Transformer論文のSection.$3$のMethodより、基本的にはGPTと同様の処理を用いることが確認できます。

GPT(Generative Pre-Training)は自己回帰型(Auto Regressive)の処理に基づきます。処理の詳細はGPTの論文が参照するTransformer Decoderの解説で取り扱いました。

returns-to-go

Decision Transformerでは通常の強化学習で用いるRewardの$r_{t}$に基づいて下記のように定義される報酬和$\hat{R}_{t}$を用います。
$$
\large
\begin{align}
\hat{R}_{t} = \sum_{t’=t}^{T} r_{t’} \quad (1)
\end{align}
$$

上記をDecision Transformer論文では「returns-to-go」のような用語で表されます。一般的な強化学習では軌道(trajectory)の$\tau$を$\tau=(r_0,s_0,a_0,r_1,s_1,a_1, \cdots)$のように表すことが多い一方で、Decision Transformerでは$(1)$式で定義した「returns-to-go」を用いて下記のように軌道$\tau$を定義します。
$$
\large
\begin{align}
\tau=(\hat{R}_0,s_0,a_0,\hat{R}_1,s_1,a_1, \cdots , \hat{R}_{t}, s_{t}, a_{t}, \cdots)
\end{align}
$$

また、上記の$s_{t}$は時点$t$における状態、$a_{t}$は時点$t$における行動を表します。

Decision Transformerの擬似コード

Decision Transformer論文のAlgorithm$\, 1$にDecision Transformerの擬似コードが掲載されています。

# R, s, a, t: returns-to-go, states, actions, or timesteps
# transformer: transformer with causal masking (GPT)
# embed_s , embed_a , embed_R: linear embedding layers
# embed_t: learned episode positional embedding
# pred_a: linear action prediction layer

# main model
def DecisionTransformer(R, s, a, t):
    # compute embeddings for tokens
    pos_embedding = embed_t(t) # per-timestep (note: not per-token) s_embedding = embed_s(s) + pos_embedding
    a_embedding = embed_a(a) + pos_embedding
    R_embedding = embed_R(R) + pos_embedding

    # interleave tokens as (R_1, s_1, a_1, ..., R_K, s_K)
    input_embeds = stack(R_embedding , s_embedding , a_embedding)

    # use transformer to get hidden states
    hidden_states = transformer(input_embeds=input_embeds)

    # select hidden states for action prediction tokens
    a_hidden = unstack(hidden_states).actions

    # predict action
    return pred_a(a_hidden)

# training loop
for (R, s, a, t) in dataloader: # dims: (batch_size, K, dim)
    a_preds = DecisionTransformer(R, s, a, t)
    loss = mean((a_preds - a)**2) # L2 loss for continuous actions optimizer.zero_grad(); loss.backward(); optimizer.step()

# evaluation loop
target_return = 1 # for instance , expert -level return
R, s, a, t, done = [target_return], [env.reset()], [], [1], False
while not done: # autoregressive generation/sampling
    # sample next action
    action = DecisionTransformer(R, s, a, t)[-1] # for cts actions new_s, r, done, _ = env.step(action)
    
    # append new tokens to sequence
    R = R + [R[-1] - r] # decrement returns-to-go with reward s, a, t = s + [new_s], a + [action], t + [len(R)]
    R, s, a, t = R[-K:], ... # only keep context length of K

上記は28行目のa_preds = DecisionTransformer(R, s, a, t)がDecision Transformerを用いたActionの予測に対応するところから理解すると理解しやすいです。Decision Transformerは「状態」、「行動」、「returns-to-go」の入力に基づいて行動の予測を行います。

Decision Transformerの学習

Decision Transformerの学習では学習に用いる「軌道」のサンプルセットが最適である前提で、同様の振る舞いをTransformerが模倣するように学習を行います。

予測結果の$a_t$と「学習に用いる軌道サンプル」を元に、行動が離散値の場合はCross Entropy、連続値の場合は平均二乗誤差(mean-squared error)がlossに用いられます。前項で取り扱った擬似コードでは平均二乗誤差が計算されます。

Evaluations

Decision Transformer論文 Figure$\, 3$

参考

・Transformer論文:Attention is All you need[2017]
・Decision Transformer論文
・Offline RL論文
・A Survey on Transformers in Reinforcement Learning
・Transformer Decoder論文