ブログ

Pyramid ViTとSpatial Reduction Attention

Transformerを用いてセグメンテーション(Segmentation)やObject DetectionのようなDense Predictionタスクを学習させるには解像度を高くする必要がある一方で、ViTでは解像度を高くすると計算量の問題が生じます。当記事ではこの解決にあたって考案されたSRA(Spatial Reduction Attention)に基づくPyramid Vision Transformer(PVT)の論文について取りまとめを行いました。PVTの論文である「Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

ViT

Pyramid Vision Transformer

処理の概要

Pyramid Vision Transformer(PVT)の処理の全体像は下図を元に掴むと良いです。

Pyramid Vision Transformer論文 Figure$\, 3$

基本的にはViTの処理と同じである一方で、$16 \times 16$のピクセルをパッチと見なすViTに対してPyramid Vision Transformerではまず$4 \times 4$のピクセルをパッチと見なし、ステージ(stage)単位で徐々に解像度を粗くします。

各ステージの処理は「Patch embedding」と「Transformer encoder」に分けられます。大まかには前処理とTransformer処理のように解釈しておくと良いと思います。

ハイパーパラメータの定義と基本的な設定

Pyramid Vision Transformerではステージ単位で処理を行います。ステージ$i$におけるハイパーパラメータはそれぞれ下記のように定義されます。

ハイパーパラメータ意味
$F_i$ステージ$i$におけるFeature Mapの縮小率。$F_1=4, F_2=8, F_3=16, F_4=32$である。
$P_i$ステージ$i$におけるパッチのサイズ。基本的には$P_1=4$、$i \geq 2$では$P_i=2$。ステージ$i-1$の出力をベースにパッチのサイズを定義することに注意。
$C_i$ステージ$i$の出力のチャネル数。解像度が低くなるにつれて$C_i$は大きくなる。
$L_i$ステージ$i$におけるTransformer encoderの層の数。特に$i=2, 3$のときにモデルが大きくなるにつれて$L_i$は大きくなる。
$R_i$ステージ$i$のSpatial Reduction Attentionにおけるreduction ratio。$R_i$が大きくなるにつれてAttentionにおける解像度が粗くなる。
$N_i$ステージ$i$におけるAttentionのヘッドの数。
$E_i$FFNレイヤーにおける隠れ層の倍率1。一般的なTransformerでは$4$が多い。

上記のハイパーパラメータと具体的なモデルの対応は下記を元に確認することができます。

Pyramid Vision Transformer論文 Table$\, 1$

Patch embedding

Pyramid Vision Transformer論文 Figure$\, 3$の一部

Patch embeddingでは入力画像や$i-1$番目のステージの出力を元に$i$番目のステージにおける入力の調整を行います。
$$
\large
\begin{align}
H_{i-1} \times W_{i-1} \times C_{i-1}
\end{align}
$$

ステージ$i-1$の出力のサイズは上記のように表されますが、ステージ$i$ではまず始めに近隣パッチを連結することにより、下記のようなサイズのFeature Mapを得ます。
$$
\large
\begin{align}
\frac{H_{i-1}}{P_{i}} \times \frac{W_{i-1}}{P_{i}} \times (P_{i}^{2} C_{i-1}) \quad (1)
\end{align}
$$

上記はステージ$i$ではステージ$i-1$の出力を$P_{i} \times P_{i}$単位でパッチの連結を行い、その分チャネル(TransformerのMLP層のサイズに対応)が増えたと理解できます。

Patch embedding層では$(1)$式における$P_{i}^{2} C_{i-1}$のチャネルに対しMLP処理を行うことで、チャネル数を$C_{i}$に変えます。
$$
\large
\begin{align}
\frac{H_{i-1}}{P_{i}} \times \frac{W_{i-1}}{P_{i}} \times C_{i} \quad (1)’
\end{align}
$$

さらにReshape処理を行うことで下記のようなサイズを持つTransformer encoder層の入力が作成されます。
$$
\large
\begin{align}
\frac{H_{i-1} W_{i-1}}{P_{i}^{2}} \times C_{i}
\end{align}
$$

また、$H_{i-1}, W_{i-1}$と$H_{i}, W_{i}$について下記が成立することも合わせて抑えておくと良いです。
$$
\large
\begin{align}
\frac{H_{i-1}}{P_{i}} &= H_{i} \\
\frac{W_{i-1}}{P_{i}} &= W_{i}
\end{align}
$$

Transformer encoder

Pyramid Vision Transformer論文 Figure$\, 3$の一部

Transformer encoderではPatch embeddingによって作成されたTransformerの入力を元にTransformerの演算が行われます。ステージ$i$では$1$度のPatch embedding処理に対し、$L_i$回のTransformer処理が行われることに注意が必要です。

Spatial Reduction Attention(SRA)はCNNと同様に画像の局所的な相関を活用することで処理の効率化を行う手法です。詳しくは次項で取り扱います。

Spatial Reduction Attention(SRA)

Spatial Reduction Attention(SRA)はAttentionの入力であるKeyとValueの解像度を低くすることで計算の効率化を実現する手法です。

Pyramid Vision Transformer論文 Figure$\, 4$

Spatial Reduction Attention処理は下記のような数式で表されます2
$$
\large
\begin{align}
\mathrm{SRA}(Q, K, V) &= \mathrm{Concat}(\mathrm{head}_{1}, \cdots , \mathrm{head}_{N_{i}})W^{O} \quad (2) \\
\mathrm{head}_{j} &= \mathrm{Attention}(Q W_{j}^{Q}, \mathrm{SR}(K) W_{j}^{K}, \mathrm{SR}(V) W_{j}^{V}) \quad (3) \\
W_{j}^{Q} & \in \mathbb{R}^{C_i \times d_{head}}, \, W_{j}^{K} \in \mathbb{R}^{C_i \times d_{head}}, \, W_{j}^{V} \in \mathbb{R}^{C_i \times d_{head}} \\
W^{O} & \in \mathbb{R}^{C_i \times C_i} \\
d_{head} &= \frac{C_i}{N_i} \quad (4)
\end{align}
$$

$(2)$式は一般的なMultiHead Attentionの計算と同じである一方で、$(3)$式の$\mathrm{head}_{j}$の計算過程に$\mathrm{SR}(K)$が出てくることに注意が必要です。行列$X$についてこの$\mathrm{SR}(X)$の計算は下記のように定義されます。
$$
\large
\begin{align}
\mathrm{SR}(X) &= \mathrm{Norm}(\mathrm{Reshape}(X, R_i)W^{S}) \quad (5) \\
X \in \mathbb{R}^{(H_i W_i) \times C_i}, \,\, & \mathrm{Reshape}(X, R_i) \in \mathbb{R}^{\frac{H_i W_i}{R_i^{2}} \times (R_i^{2} C_i)}, \,\, W^{S} \in \mathbb{R}^{(R_i^{2} C_i) \times C_i} \quad (6)
\end{align}
$$

$\displaystyle \mathrm{Reshape}(X, R_i) \in \mathbb{R}^{\frac{H_i W_i}{R_i^{2}} \times (R_i^{2} C_i)}$で表したように、$(5)$式のReshape処理はPatch embeddingと同様な処理が行われるので対応させて抑えておくと良いと思います。

$(5)$式と$(6)$式より、$\mathrm{SR}(X)$は$\displaystyle \mathrm{SR}(X) \in \mathbb{R}^{\frac{H_i W_i}{R_i^{2}} \times C_i}$で表される行列であり、$X \in \mathbb{R}^{(H_i W_i) \times C_i}$の行数が$\displaystyle \frac{1}{R_i^{2}}$になったことが確認できます。この処理は「Queryに対応して重み付け和を計算するValueの空間方向の解像度を低くすることで、CNNにおけるダウンサンプリングと同様な効果が得られる」と解釈することができます3

参考

・Transformer論文:Attention is All you need[$2017$]
・Pyramid Vision Transformer論文

  1. expansion ratioの厳密な説明が論文内に見当たらなかったので、要出典。 ↩︎
  2. 論文では$d_{head}$と表記がありますが、$(4)$式で定義されるようにステージ$i$毎に$d_{head}$の値が変わる可能性があることは注意しておくと良いです。 ↩︎
  3. TransformerにおけるAttentionでは$K=V$かつ$Q, K, V$の列数が一致する必要がある一方で、$Q$と$K=V$の行数が必ずしも一致しないことは抑えておくと良いと思います。Encoderでは$Q=K=V$であることが多い一方で、Decoderでは$Q \neq K$の計算も行われます。 ↩︎

【SimCLR】対照学習(Contrastive Learning)に基づくベクトル表現の取得①

SimCLR(Simple Framework for Contrastive Learning of Visual Representations)は対照学習(Contrastive Learning)を用いて画像のベクトル表現を抽出する手法です。当記事ではSimCLRの一連の学習手順について取りまとめを行いました。
SimCLRの論文の「A Simple Framework for Contrastive Learning of Visual Representations」を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

指示関数

指示関数(indicator function)の$\mathbb{1}_{[k \neq i]} \in \{ 0, 1 \}$は下記のように定義される。
$$
\large
\mathbb{1}_{[k \neq i]} =
\begin{cases}
1 \quad \mathrm{if} \quad k \neq i \\
0 \quad \mathrm{otherwise}
\end{cases}
$$

対照学習

SimCLR

SimCLRの全容

SimCLRは下記の$4$つの主要要素によって構成される。
・確率的データ拡張(A stochastic data augmentation)
・エンコーダ(A neural network base encoder $f(·)$)
・A small neural network projection head $g(·)$
・対照損失関数(A contrastive loss function)

上記の$4$つの主要要素は下図を元に抑えておくと良い。

SimCLR論文 Figure$\, 2$

上図の$\mathbf{x}$から$\tilde{\mathbf{x}}_{i}$や$\tilde{\mathbf{x}}_{j}$を作成するプロセスが確率的データ拡張(A stochastic data augmentation)、$\tilde{\mathbf{x}}_{i}, \, \tilde{\mathbf{x}}_{j}$から$\mathbf{h}_{i}, \, \mathbf{h}_{j}$を計算するプロセスがエンコーダ、$\mathbf{h}_{i}, \, \mathbf{h}_{j}$から$\mathbf{z}_{i}, \, \mathbf{z}_{j}$を計算するプロセスがprojection head、$\mathbf{z}_{i}, \, \mathbf{z}_{j}$を用いて定義される損失関数が対照損失関数(A contrastive loss function)にそれぞれ対応する。

以下、$4$つの主要要素についてそれぞれ確認を行う。

stochastic data augmentation

データ拡張(data augmentation)の主要な手法は下図に基づいて把握することができる。

SimCLR論文 Figure$\, 4$

SimCLRでは$(\mathrm{c})$の「crop(切り抜き)+リサイズ+反転(flip)」、$(\mathrm{d}), \, (\mathrm{e})$の「color distortion」、$(\mathrm{i})$の「Gaussian blur」の$3$つが用いられ、cropとcolor distortionが有効であったと報告されている。

$(\mathrm{c})$の切り抜きにあたっては「切り抜く場所」や「切り抜くサイズ」にランダム性があることから、SimCLRにおけるデータ拡張は確率的データ拡張(stochastic data augmentation)と表されていることも合わせて抑えておくと良い。

ニューラルネットの構成①:encoder

encoderの$f(\cdot)$は入力の$\tilde{\mathbf{x}}_{i}$や$\tilde{\mathbf{x}}_{j}$からベクトル表現(Visual Representation)を抽出する関数に対応する。SimCLRでは下記の数式で表されるようにResNetが用いられる。
$$
\large
\begin{align}
\mathbf{h}_{i} &= f(\mathbf{x}_{i}) = \mathrm{ResNet}(\mathbf{x}_{i}) \\
\mathbf{h}_{i} & \in \mathbb{R}^{d}
\end{align}
$$

$d$は抽出する画像のベクトル表現の次元に対応する。

ニューラルネットの構成②:projection head

projection headの$g(\cdot)$は抽出したベクトル表現の$\mathbf{h}_{i}$を対照損失(contrastive loss)の計算用に変換する処理に対応する。原理的には$\mathbf{h}_{i}$をそのまま用いてlossの計算を行うことは可能であるが、SimCLRの論文では$g$を用いることがbeneficialとされる。

SimCLRでは$g$に二層のMLP(Multi Layer Perceptron)が用いられている。このMLPは下記のような数式で表される。
$$
\large
\begin{align}
\mathbf{z}_{i} &= g(\mathbf{h}_{i}) = W^{(2)} \mathrm{ReLU}(W^{(1)} \mathbf{h}_{i}) \\
\mathrm{ReLU}(x) &= \max(0, x)
\end{align}
$$

loss function

$(i,j)$を正例の組とおき、$k$組用意するとき、サンプルは$2k$個となる。このとき、$(i,j)$に関するloss functionを下記のように定義する。
$$
\large
\begin{align}
l_{i,j} &= -\log{\left[ \frac{\exp{(\mathrm{sim}(\mathbf{z}_{i},\mathbf{z}_{j}))/\tau}}{\sum_{k=1}^{2N} \mathbb{1}_{k \neq i} \exp{(\mathrm{sim}(\mathbf{z}_{i},\mathbf{z}_{k}))/\tau}} \right]} \\
\mathrm{sim}(\mathbf{u},\mathbf{v}) &= \frac{\mathbf{u}^{\mathrm{T}} \mathbf{v}}{||\mathbf{u}|| ||\mathbf{v}||}
\end{align}
$$

$\tau$は温度パラメータである。上記の式に基づいてSimCLRの学習を行う。

参考

・SimCLR論文

Swin Transformer: 階層型Vision Transformer まとめ

Transformerの画像処理への応用にあたってはViT(Vision Transformer)などが有名である一方で、画像の局所特徴量の抽出の観点からは少々処理が非効率です。当記事では階層型のAttentionを用いることで改善を行なった研究であるSwin Transformerについて取りまとめを行いました。
「Swin Transformer: Hierarchical Vision Transformer using Shifted Windows」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

ViT

Swin Transformer

パッチの作成と連結

Swin Transformerでは$4 \times 4$のピクセル単位でパッチの作成を行います。基本的にはRGB値を持つカラー画像が入力であるので、それぞれのパッチが持つ要素数は$4 \times 4 \times 3 = 48$となります。

このように作成を行なったパッチを処理を行うにつれて隣接するパッチと連結を行い、Feature Mapの作成を行います。

この「パッチの連結によるFeature Mapの作成」はVGGNetやResNetにおけるpoolingを用いたFeature Mapの作成と対応づけて抑えておくと良いです。

Swin Transformerの処理概要

Swin Transformer論文 Figure$\, 3, (a)$

Swin Transformerの処理の全体は上図のように表されます。まず$H \times W \times 3$の入力を$4 \times 4$単位でパッチの作成を行います。

$4 \times 4$単位でそれぞれ$48$の要素を持つパッチの作成を行うことで、$H \times W \times 3$の入力を$\displaystyle \frac{H}{4} \times \frac{W}{4} \times 48$の表現に変換することができます。この処理がPatch Partitionに対応します。

Swin Transformer論文 Figure$\, 3, (b)$

次にLinear Embedding(MLP処理)を通して$48$次元を$C$次元に変換したのちにSwin Transformer Blockを用いて処理を行います。Swin Transformer Blockの処理は上図のように表されますが、Swin Transformer Blockの詳細は次項以降で取り扱います。

ここまでのLinear Embedding・Swin Transformer Block・Patch Mergingの処理の組み合わせによって、Swin TransformerではFeature Mapの作成が行われます。

W-MSA

Swin Transformer論文 Figure$\, 3, (b)$

W-MSAはWindowを用いたMultihead Self Attention処理の略です。Windowはパッチの集合を表しており、$M \times M$の場合は$M^2$個のパッチをまとめて取り扱います。

全てのパッチ同士で計算を行うViTに対して、W-MSAではWindowの内部のパッチだけを元にAttention処理を行います。このような処理を行うことで計算量の削減が可能になります。

SW-MSA

W-MSAを用いることで計算量の削減が可能になる一方で、W-MSAではWindowの境界におけるパッチ間の相関を取り扱うことができないという課題があります。この解決にあたってW-MSAと併用されるのがSW-MSAです。

W-MSAのみを用いる場合、上図の左のように毎回Windowの境界のパッチが同じパッチになります。このような問題の解決にあたってSwin Transformerでは、上図の右のようにWindowをずらした上で(Shifted Window)Attention処理を行うSW-MSAという処理が用いられます1

Swin Transformer論文 Figure$\, 3, (b)$

このようにWindowをずらすことでWindowの境界のパッチが毎回同じ位置にならないようにすることができます。

cyclic shifting

参考

・Transformer論文:Attention is All you need[$2017$]
・Swin Transformer論文

  1. Swin TransformerのSwinがShifted Windowの略であることも合わせて抑えておくと良いです。 ↩︎

行列式と置換⑤:互換と巡回置換(cyclic permutation)

線形代数の枠組みで$n$次正方行列の行列式(determinant)を取り扱うにあたっては置換(permutation)という概念を抑えておく必要があります。当記事では互換(transpositions)と巡回置換(cyclic permutation)について取りまとめを行いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$4$章「行列式」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

互換と巡回置換

互換

順列$1, 2, \cdots , n$について$i < j, \, i, j \in {1, 2, \cdots , n }$である$i$と$j$のみを入れ替え、他をそのままにする置換を互換という。$i$と$j$のみを入れ替える置換を$\sigma$とおくと、$\sigma$は下記のように表せる。
$$
\large
\begin{align}
\sigma(i) &= j \\
\sigma(j) &= i \\
\sigma(k) &= k, \quad k \neq i \cap k \neq j
\end{align}
$$

$i$と$j$を入れ替える互換$\sigma$は$\sigma=(i \quad j)$のようにも表記される。

巡回置換

順列$1, 2, \cdots , n$から$k$個の数字を抜き出し、それぞれ$i_1 < i_2 < \cdots < i_{k}$のように表す。このとき下記のような置換$\sigma$を巡回置換という。
$$
\large
\sigma :
\begin{cases}
i_1 & \longmapsto i_2 \\
i_2 & \longmapsto i_3 \\
\cdots \\
i_{k-1} & \longmapsto i_{k} \\
i_{k} & \longmapsto i_{1}
\end{cases}
$$

巡回置換$\sigma$は$\sigma=(i_1 \quad i_2 \quad \cdots \cdots \quad i_{k})$のようにも表記される。

例題の確認

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$057$

$$
\large
\begin{align}
(i_1 \quad i_2 \quad \cdots \cdots \quad i_{k}) = (i_1 \quad i_k)(i_1 \quad i_{k-1}) \cdots \cdots (i_1 \quad i_{2}) \quad (1)
\end{align}
$$

右辺の変換によって、$i_1 \, i_2 \, i_3 \, \cdots \cdots \, i_{k-1} \, i_k$は下記のように変換される。
$$
\large
\begin{align}
& i_1 \, i_2 \, i_3 \, \cdots \cdots \, i_{k-1} \, i_k \\
\longrightarrow & i_2 \, i_1 \, i_3 \, \cdots \cdots \, i_{k-1} \, i_k \\
\longrightarrow & i_2 \, i_3 \, i_1 \, \cdots \cdots \, i_{k-1} \, i_k \\
& \cdots \\
\longrightarrow & i_2 \, i_3 \, i_4 \, \cdots \cdots \, i_1 \, i_k \\
\longrightarrow & i_2 \, i_3 \, i_4 \, \cdots \cdots \, i_k \, i_1
\end{align}
$$

上記より$(1)$式が成立することが確認できる。

行列式と置換④:置換(permutation)の符号と偶置換・奇置換

線形代数の枠組みで$n$次正方行列の行列式(determinant)を取り扱うにあたっては置換(permutation)という概念を抑えておく必要があります。当記事では置換(permutation)の符号や符号に関連する転倒数・偶置換・奇置換について取りまとめを行いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$4$章「行列式」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

置換の符号

置換の転倒数

$n$個の数字$1, \cdots , n$の置換$\sigma$について、「$i < j, \, i, j \in \{ 1, \cdots , n \}$」かつ「$\sigma(i) > \sigma(j)$」が成立する$(i,j)$の組の個数を置換$\sigma$の転倒数という。

$n$変数の差積と置換の符号

$n$変数$x_{1}, x_{2}, \cdots , x_{n}$の差積$D(x_{1}, x_{2}, \cdots , x_{n})$を下記のように定義する。
$$
\large
\begin{align}
D(x_{1}, x_{2}, \cdots , x_{n}) &= (x_1-x_n) \times \cdots \times (x_1-x_3) \times (x_1-x_2) \\
& \times (x_2-x_n) \times \cdots \times (x_2-x_3) \\
& \times \cdots \\
& \times (x_{n-2}-x_n) \times (x_{n-2}-x_{n-1}) \\
& \times (x_{n-1}-x_n)
\end{align}
$$

ここで$n$個の数字$1, \cdots , n$の置換$\sigma$について、$\sigma$の符号$\mathrm{sgn}{(\sigma)}$を下記のように定義する。
$$
\large
\begin{align}
\mathrm{sgn}{(\sigma)} = \frac{D(x_{\sigma(1)}, x_{\sigma(2)} , \cdots , x_{\sigma(n)})}{D(x_{1}, x_{2}, \cdots , x_{n})}
\end{align}
$$

上記を置換$\sigma$の符号という。$\mathrm{sgn}{(\sigma)}$は置換$\sigma$の転倒数が偶数であるとき$\mathrm{sgn}{(\sigma)}=1$、転倒数が奇数であるとき$\mathrm{sgn}{(\sigma)}=-1$となることも合わせて抑えておくと良い。

偶置換・奇置換

転倒数が偶数・符号が$1$である置換を偶置換、転倒数が奇数・符号が$-1$である置換を奇置換という。

例題の確認

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$056$

・$[1]$
$$
\large
\begin{align}
A = \left[ \begin{array}{cccc} 1 & 2 & 3 & 4 \\ 2 & 4 & 3 & 1 \end{array} \right]
\end{align}
$$

$1) \,$ $i=1, \sigma(i)=2$に対し、$i < j, \sigma(i) > \sigma(j)$となるのは$j=4$のみである。

$2) \,$ $i=2, \sigma(i)=4$に対し、$i < j, \sigma(i) > \sigma(j)$となるのは$j=3, 4$である。

$3) \,$ $i=3, \sigma(i)=3$に対し、$i < j, \sigma(i) > \sigma(j)$となるのは$j=4$のみである。

上記より転倒数は$1+2+1=4$であるので、符号は$\mathrm{sgn}(\sigma)=1$である。

・$[2]$
$$
\large
\begin{align}
\sigma = \left[ \begin{array}{ccccc} 1 & 2 & 3 & 4 & 5 \\ 2 & 3 & 5 & 1 & 4 \end{array} \right]
\end{align}
$$

$1) \,$ $i=1, \sigma(i)=2$に対し、$i < j, \sigma(i) > \sigma(j)$となるのは$j=4$のみである。

$2) \,$ $i=2, \sigma(i)=3$に対し、$i < j, \sigma(i) > \sigma(j)$となるのは$j=4$のみである。

$3) \,$ $i=3, \sigma(i)=5$に対し、$i < j, \sigma(i) > \sigma(j)$となるのは$j=4, 5$である。

$4) \,$ $i=4, \sigma(i)=1$に対し、$i < j, \sigma(i) > \sigma(j)$となる$j$は存在しない。

上記より転倒数は$1+1+2+0=4$であるので、符号は$\mathrm{sgn}(\sigma)=1$である。

行列式と置換③:置換(permutation)の指数法則・単位置換・逆置換

線形代数の枠組みで$n$次正方行列の行列式(determinant)を取り扱うにあたっては置換(permutation)という概念を抑えておく必要があります。当記事では置換(permutation)の指数法則や単位置換、逆置換について取りまとめを行いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$4$章「行列式」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

置換の指数法則・単位置換・逆置換

置換の指数法則

任意の置換$\sigma$と任意の整数$k, l$について下記が成立する。
$$
\large
\begin{align}
\sigma^{k+l} = \sigma^{k} \sigma^{l}
\end{align}
$$

単位置換

文字$i \in \{ 1, 2, \cdots , n \}$を$i$自身に写す置換を$e$とおくと、$e$は下記のように表される。
$$
\large
\begin{align}
e = \left[ \begin{array}{cccc} 1 & 2 & \cdots & n \\ 1 & 2 & \cdots & n \end{array} \right]
\end{align}
$$

この置換$e$を単位置換という。また、任意の置換$\sigma$について$\sigma^{0}=e$が成立すると定義する。

逆置換

任意の置換$\sigma$について$\sigma \tau = \tau \sigma = e$が成立する置換$\tau$が存在する。この置換$\tau$のことを$\sigma$の逆置換という。

具体的には$12 \cdots n$が$\sigma(1) \sigma(2) \cdots \sigma(n)$のように変換されるとき、$\tau(\sigma(1))=1, \, \tau(\sigma(2))=2, \, \cdots , \, \tau(\sigma(n))=n$が成立するように$\tau$を定めれば良い。

上記のような$\sigma$の逆行列$\tau$は$\tau = \sigma^{-1}$のように表すことができ、指数法則より$\sigma \tau = \sigma^{1} \sigma^{-1} = \sigma^{1-1} = \sigma^{0} = e$のように表すこともできる。

例題の確認

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$054$

基本例題$055$

$\tau \sigma (\sigma^{-1} \tau^{-1})$は下記のように式変形することができる。
$$
\large
\begin{align}
\tau \sigma (\sigma^{-1} \tau^{-1}) &= \tau \sigma^{1-1} \tau^{-1} \\
&= \tau \tau^{-1} \\
&= e
\end{align}
$$

同様に$(\sigma^{-1} \tau^{-1}) \tau \sigma$は下記のように式変形できる。
$$
\large
\begin{align}
(\sigma^{-1} \tau^{-1}) \tau \sigma &= \sigma^{-1} \tau^{-1+1} \sigma \\
&= \sigma^{-1} \sigma \\
&= e
\end{align}
$$

上記より$(\tau \sigma) (\sigma^{-1} \tau^{-1}) = (\sigma^{-1} \tau^{-1}) (\tau \sigma)$が成立するので逆置換の定義より、$(\tau \sigma)^{-1} = \sigma^{-1} \tau^{-1}$が成立する。

TransformerにおけるSoftmax関数の計算量とLinear Transformer

Transformerは汎用的に用いることのできる強力なDeepLearningである一方、入力系列のトークンが多くなると計算量も増大します。当記事ではTransformerの各Attention処理でのSoftmax計算の軽減にあたっての研究である、Linear Transformer論文について取りまとめました。
作成にあたってはLinear Transformer論文や、「A Survey of Transformers」の内容を参考にしました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの仕組みの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformer(統計の森作成)

Transformerの式表現

TransformerのDot Product Attentionは下記のような式で定義されます1
$$
\large
\begin{align}
\mathrm{Attention}(Q, K, V) &= \mathrm{Softmax} \left( \frac{Q K^{\mathrm{T}}}{\sqrt{d}} \right) V \quad (1) \\
Q, K, V & \in \mathbb{R}^{n \times d}
\end{align}
$$

Linear Transformer

Linear Transformerの概要

Linear Transformerの概要は下図を確認すると理解しやすいです。

A Survey of Transformers Fig$. \, 7$:通常のTransformerとLinear Transformerの処理フローとそれぞれの計算量

上図の左側が一般的なTransformerの計算フロー、右側がLinear Transformerにおける計算フローに対応します。左側のTransformerでは$Q, K \in \mathbb{R}^{n \times d}$について$Q K^{\mathrm{T}} \in \mathbb{R}^{n \times n}$の行毎(row-wise)にソフトマックス関数を適用する際に、計算量が入力するトークン数$n$の二乗になります。図では$T=n$であり、計算量が$\mathcal{O}(T^{2})$で表現されています。オリジナルのTransformerでは$n \simeq d$を前提としている一方で、Linear TransformerではReformerなどと同様に$n >> d$の場合を仮定することに注意しておくと良いです。

Linear Transformerの論文ではこのようなソフトマックス関数の計算時の計算量を軽減するにあたって、ソフトマックス関数の$\exp$の代わりにfeature mapの$\phi$が導入されます。$\exp$を用いないことで行列の積の順序を変えることが可能になることに注意しておくと良いです。

Transformerの式の改変

$(1)$式の出力の$i$行目を縦ベクトルの$z_{i}$とおくと、$Q \in \mathbb{R}^{n \times d}$の$i$行目の$q_{i}$に対応する$z_{i}$は下記のように表すことが可能です。
$$
\large
\begin{align}
z_{i}^{\mathrm{T}} = \mathrm{Attention}(q_{i}^{\mathrm{T}}, K, V) = \mathrm{Softmax} \left( \frac{q_{i}^{\mathrm{T}} K^{\mathrm{T}}}{\sqrt{d}} \right) V \quad (2)
\end{align}
$$

ここで$K, V \in \mathbb{R}^{n \times d}$の$j$行目を$k_{j}, v_{j} \in \mathbb{R}^{d}$のように表すとき、$(2)$式は下記のように表すこともできます。
$$
\large
\begin{align}
z_{i} &= \left[ \mathrm{softmax}{\left( q_{i}^{\mathrm{T}} K^{\mathrm{T}} \right)} V \right]^{\mathrm{T}} \\
&= \sum_{k=1}^{n} \left[ \frac{\mathrm{sim}(q_{i}, k_{k})}{\sum_{j=1}^{n} \mathrm{sim}(q_{i}, k_{j})} v_{j} \right] \quad (3) \\
\mathrm{sim}(q, k) &= \exp{ \left( \frac{q^{\mathrm{T}} k}{\sqrt{d}} \right) }
\end{align}
$$

上記の$q_{i}, k_{j}, v_{j}$は$Q, K, V$の$i$行目や$j$行目を抜き出して縦ベクトルで表されたものに対応します2。出力の$z_{i}$も同様に縦ベクトルであることにご注意ください。よって、たとえば$q_{i}, k_{j}$の内積は$q_{i}^{\mathrm{T}} k_{j}$のように表されます。

Linear Transformerの数式

$(3)$式の$\mathrm{sim}(q_{i}, k_{k})$を下記のように$\phi(x)$を用いて表す場合を仮定します。
$$
\large
\begin{align}
\mathrm{sim}(q_{i}, k_{k}) = \phi(q_{i})^{\mathrm{T}} \phi(k_{k}) \quad (4)
\end{align}
$$

$(4)$式に基づいて$(3)$式は下記のように改変することができます。
$$
\large
\begin{align}
z_{i} &= \sum_{k=1}^{n} \left[ \frac{\mathrm{sim}(q_{i}, k_{k})}{\sum_{j=1}^{n} \mathrm{sim}(q_{i}, k_{j})} v_{k} \right] \quad (3) \\
&= \sum_{k=1}^{n} \left[ \frac{\phi(q_{i})^{\mathrm{T}} \phi(k_{k})}{\sum_{j=1}^{n} \phi(q_{i})^{\mathrm{T}} \phi(k_{j})} v_{k} \right] \\
&= \left[ \frac{\phi(q_{i})^{\mathrm{T}} \sum_{k=1}^{n} \phi(k_{k}) v_{k}^{\mathrm{T}}}{\phi(q_{i})^{\mathrm{T}} \sum_{j=1}^{n} \phi(k_{j})} \right]^{\mathrm{T}} \quad (5)
\end{align}
$$

$(5)$式は下記のような行列の積の形式で表すこともできます。
$$
\large
\begin{align}
\left( \phi(Q) \phi(K)^{\mathrm{T}} \right) V = \phi(Q) \left( \phi(K)^{\mathrm{T}} V \right) \quad (6)
\end{align}
$$

$(5)$式や$(6)$式の右辺の計算の計算量は$\mathcal{O}(N)$であり、「Linear Transformerの概要」の図に対応します。

また、Linear Transformerでは$\phi(x)$を下記のように定義します。
$$
\large
\begin{align}
\phi(x) = \mathrm{elu}(x) + 1
\end{align}
$$

上記の$\mathrm{elu}$関数については次項の「exponential linear unit」で詳しく確認を行います。

exponential linear unit

exponential linear unitの略であるelu関数$\mathrm{elu}(x)$は下記のような式で定義されます。
$$
\large
\mathrm{elu}(x) =
\begin{cases}
x & \mathrm{if} \, x > 0 \\
\alpha(\exp{(x)}-1) & \mathrm{if} \, x \leq 0
\end{cases}
$$

また、$\mathrm{elu}(x)$関数の$x$についての微分は下記のように表されます。
$$
\large
\frac{d}{dx} \mathrm{elu}(x) =
\begin{cases}
1 & \mathrm{if} \, x > 0 \\
\alpha \exp{(x)} = \mathrm{elu}(x) + \alpha & \mathrm{if} \, x \leq 0
\end{cases}
$$

ELUはReLUのような活性化関数に用いるにあたって考案されており、下記のようなグラフでも表すことができます。

ELU論文 Figure$\, 1$:ELUのパラメータは$\alpha=1$

NumPyを用いた計算による実験

以下$(6)$式に基づいて、$Q, K, V \in \mathbb{R}^{N \times D}$の場合のDot Product Attentionの計算時間を$N$の値を変えて計測を行います。

$N=10{,}000, D=500$

$N=10{,}000, D=500$のとき、$Q(K^{\mathrm{T}}V)$の計算時間は下記のようになります。

import numpy as np
import time

import numpy as np
import time

N = 10000
D = 500

Q = np.ones([N, D])
K = np.ones([N, D])
V = np.ones([N, D])

start = time.time()
V_1 = np.dot(Q, np.dot(K.T, V))
end = time.time()

time_diff = end - start
print("{:.3f}".format(time_diff))

・実行結果

0.209

同様に$(QK^{\mathrm{T}})V$の計算時間は下記のようになります。

start = time.time()
V_1 = np.dot(Q, np.dot(K.T, V))
end = time.time()

time_diff = end - start
print("{:.3f}".format(time_diff))

・実行結果

3.270

$N=20{,}000, D=500$

$N=20{,}000, D=500$のとき、$Q(K^{\mathrm{T}}V)$の計算時間は下記のようになります。

N = 20000
D = 500

Q = np.ones([N, D])
K = np.ones([N, D])
V = np.ones([N, D])

start = time.time()
V_1 = np.dot(Q, np.dot(K.T, V))
end = time.time()

time_diff = end - start
print("{:.3f}".format(time_diff))

・実行結果

0.389

同様に$(QK^{\mathrm{T}})V$の計算時間は下記のようになります。

start = time.time()
V_1 = np.dot(Q, np.dot(K.T, V))
end = time.time()

time_diff = end - start
print("{:.3f}".format(time_diff))

・実行結果

12.354

$N=10{,}000$、$N=20{,}000$の結果より、$N$が大きくなるにつれて$Q(K^{\mathrm{T}}V)$の計算量が$\mathcal{O}(N)$、$(QK^{\mathrm{T}})V$の計算量が$\mathcal{O}(N^2)$であることが概ね確認できます3

参考

・Linear Transformer論文
・A Survey of Transformers
・ELU(Exponential Linear Unit)論文

  1. ここではTransformerに入力するトークン数を$n$、トークンの特徴量ベクトルの次元を$d$で表しました。「Linear Transformerの概要」で用いた図の$T$は$n$と一致することにご注意ください。 ↩︎
  2. 行列の$i$行目や$j$行目を縦ベクトルで表すというのはミスリードになりやすいかもしれませんが、Linear Transformerの論文の表記に合わせました。 ↩︎
  3. 計測自体はかなり雑なので、値はそれほど参考にしないようにご注意ください。ここでは$Q(K^{\mathrm{T}}V)$の順に計算すると速いことの確認を主な目的に計測を行いました。 ↩︎

Depthwise Separable ConvolutionとMobileNets

DeepLearningの軽量化・高速化にあたって、畳み込み処理の分解などが行われることが多いです。当記事ではMobileNetsにおける点単位畳み込み(Pointwise Convolution)やチャネル別畳み込み(Channelwise Convolution)について取りまとめを行いました。当記事の作成にあたっては、MobileNets論文や「深層学習 第$2$版」の第$5$章「畳み込みニューラルネットワーク」の内容などを参考にしました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

畳み込み演算の概要

畳み込み演算の数式

入力のチャネル数が$C$、フィルタの枚数$C_{out}$枚のとき畳み込みによって出力の$u_{ijk}$成分は下記のように計算されます。
$$
\large
\begin{align}
u_{ijk} = \sum_{c=1}^{C} \sum_{p=0}^{W_f-1} \sum_{q=0}^{H_f-1} x_{i+p,j+q,c} h_{pqck} + b_{k}
\end{align}
$$

詳しくは下記で取り扱いました。

MobileNetsの構成

Depthwise Separable Convolution概要

MobileNetsではDepthwise Separable Convolutionという畳み込みに基づいて構成されます。Depthwise Separable Convolutionは通常の$3 \times 3$の畳み込みを「空間方向」と「チャネル方向」に分解した処理であり、Depthwise Separable Convolutionを用いることでパラメータの軽量化や計算の高速化が可能になります。

MobileNets論文 Figure$\, 2$

畳み込みの分解にあたっては上図などを元に理解すると良いです。$(a)$で表されたオーソドックスな$3 \times 3$の畳み込みを、チャネル毎に(channelwise/depthwise)畳み込みを行うのが$(b)$、位置毎に(pointwise)$1 \times 1$の畳み込みを行うのが$(c)$と解釈すれば良いです。$1 \times 1$は「VGGNet」や「ResNetのボトルネック構造」でも採用されていますが、MobileNetsではチャネル毎の畳み込みとセットで用いられている点が特徴的です。

MobileNets論文 Figure$\, 3$

また、バッチ正規化(BN)や活性化関数(ReLU)も含めた処理の流れは上図のように表されます。左がオーソドックスな畳み込み、右がDepthwise Separable Convolutionにそれぞれ対応します。

Depthwise Separable Convolutionの数式

・チャネル別畳み込み(channelwise/depthwise convolution)
位置$(i,j)$のチャネル$c$におけるチャネル別畳み込みは下記のような数式で定義されます。
$$
\large
\begin{align}
u_{ijc} = \sum_{p=0}^{W_f-1} \sum_{q=0}^{H_f-1} z_{i+p,j+q,c}^{(l-1)} h_{pqc} + b_{c}
\end{align}
$$

上記のような式に基づいて、畳み込みの出力$u_{ijc}$が計算されます。$z$は中間層と$h$はフィルタ、$b$はバイアス項がそれぞれ対応します。

・位置別畳み込み(pointwise convolution)
位置$(i,j)$における位置別畳み込みは下記のような数式で定義されます。
$$
\large
\begin{align}
u_{ijc}’ = \sum_{c=1}^{C} u_{ijc} h_{c} + b_{c}
\end{align}
$$

計算コスト

フィルタのサイズが$W_{f} \times H_{f}$、入力のチャネルが$C$、出力のチャネルが$C_{out}$のとき、位置$(i,j)$における畳み込みの計算量について以下では確認を行います1

・オーソドックスな畳み込みの積算の回数
$$
\large
\begin{align}
W_{f} \times H_{f} \times C \times C_{out}
\end{align}
$$

・Depthwise Separable Convolutionの積算の回数
$$
\large
\begin{align}
W_{f} \times H_{f} \times C + C \times C_{out}
\end{align}
$$

・(Depthwise Separable Convolutionの積算の回数)/(オーソドックスな畳み込みの積算の回数)
$$
\large
\begin{align}
\frac{W_{f} \times H_{f} \times C + C \times C_{out}}{W_{f} \times H_{f} \times C \times C_{out}} = \frac{1}{C_{out}} + \frac{1}{W_{f} H_{f}} \quad (1)
\end{align}
$$

$W_{f} = H_{f} = 3, \, C_{out}=128$のとき、$(1)$式は下記のように計算できます。
$$
\large
\begin{align}
\frac{1}{C_{out}} + \frac{1}{W_{f} H_{f}} &= \frac{1}{128} + \frac{1}{9} \\
&= 0.1189 \cdots
\end{align}
$$

MobileNetsの全体構成

MobileNets論文では下記のような$28$層構造でニューラルネットが構成されます。

MobileNets論文 Table$\, 1$

上記より、Conv層が$27$層、FC層が$1$層あることが確認できます2。基本的にはチャネル別畳み込みと位置別畳み込みが交互に行われると理解して良いと思います。

参考

・MobileNets論文

  1. MobileNets論文ではFeature Mapのサイズも定義されるが、「深層学習 第$2$版」の$5.8節の解説にあるように点$(i,j)$のみを元に計算は可能である。当記事ではシンプルさを重視するにあたって、「深層学習 第$2$版」に基づいて取りまとめました。 ↩︎
  2. Avg PoolSoftmaxはカウントしないかつ、$5 \times$の行で$10$層プラスすることで$28$層になることが確認できます。 ↩︎

サブピクセル畳み込みを用いたアップサンプリング(upsampling)処理の表現

畳み込み演算を用いて画像のセグメンテーションや生成を行う際に何らかの計算に基づいてアップサンプリング(upsampling)処理が行われます。当記事ではアップサンプリングの際に用いられるdeconvolutionを畳み込み演算を用いて表す一連の流れについて取りまとめを行いました。
当記事の作成にあたっては、「Is the deconvolution layer the same as a convolutional layer?」や「深層学習 第$2$版」の$5.9$節「アップサンプリングと畳み込み」の内容などを参考にしました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

転置畳み込み

1Dの畳み込み

ストライド$2$の$1$次元畳み込みの計算は下図のように表すことができる。

Shi et al., $2016$ Figure$\, 2$

上図からサイズ$8$の$\mathbf{x}$にストライド$2$の畳み込みを行うことでサイズ$5$の$\mathbf{y}$が出力されたと読み取れる。$\mathbf{x}$のグレー部分はパディングを表す。

転置畳み込み

前項の「$1D$の畳み込み」における畳み込みの入力と出力を入れ替えると下図のような計算が得られる。

Shi et al., $2016$ Figure$\, 3 \, (a)$

このような処理を転置畳み込み(transposed convolution)という。転置畳み込みは「逆関数」と同様なイメージで解釈すると良い。上図では$\mathbf{x}$のサイズが$5$、$\mathbf{y}$のサイズが$8$であり、「$1D$の畳み込み」の図と逆になったことも合わせて抑えておくと良い。

畳み込みを用いたアップサンプリング

サブピクセル畳み込み

前節の「転置畳み込み」の処理に対し、下図のようにサブピクセル(sub-pixel)を導入することができる。

Shi et al., $2016$ Figure$\, 3 \, (b)$

このような処理をサブピクセル畳み込み(sub-pixel convolution)という。$\mathbf{x}$の白のピクセルの間のグレーのピクセルがサブピクセル畳み込みで追加される「サブピクセル」を表す。

サブピクセル畳み込みはアップサンプリング(upsampling)を畳み込み演算によって表した演算であると解釈することもできる。以下、「$2$Dのサブピクセル畳み込み」が「畳み込み演算で表したアップサンプリング処理」と解釈できることについて確認を行う。

2Dのサブピクセル畳み込み

Shi et al., $2016$ Figure$\, 5$

上図のように$2D$の入力に対し、サブピクセル畳み込み(sub-pixel convolution)を行う場合を仮定する。図の$*$は畳み込み演算を表す演算子であり、$*$の「左が入力」、「右が畳み込みに用いるフィルタ」にそれぞれ対応する。グレーで表されたサブピクセルには$0$が入ると仮定する。

一般的な畳み込み演算の原理に基づいてこのサブピクセル畳み込み(sub-pixel convolution)を行うことで下図のような結果が得られる。

Shi et al., $2016$ Figure$\, 6$

紫の出力が得られるフィルタの位置からフィルタを$1$つ右にずらすと青、$1$つ下にずらすと緑、$1$つ右+$1$つ下にずらすと赤が白のピクセル位置に重なることが確認できるので、出力はこの対応に基づいて理解すると良い。サブピクセル畳み込みを用いたこの演算は下図の演算の結果と対応する。

Shi et al., $2016$ Figure$\, 7$

Figure$\, 7$のような処理を用いることでアップサンプリングを行うことができる一方で、Figure$\, 6$を用いれば同様の演算を畳み込みを用いて実現することができる。

このようにアップサンプリングも畳み込み演算で表せることは抑えておくと良い。

ソフトマックス関数への温度スケーリング(temperature scaling)の導入

DeepLearningに関連する計算にあたってソフトマックス関数(softmax function)はよく出てくる一方で、出力値が過剰になる場合もあり得ます。当記事ではこのような際に値の調整に用いられる温度スケーリング(temperature scaling)の概要と使用例について取りまとめました。
当記事の作成にあたっては、「深層学習 第$2$版」の$7.2$節「注意機構」や$8.3$節「不確かさの予測」の内容などを参考にしました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

温度スケーリングの概要

ソフトマックス関数

$\mathbf{u} = (u_1, \cdots , u_K)$が与えられたとき、ソフトマックス関数$\mathrm{Softmax}(u_k)$は下記のように定義できる。
$$
\large
\begin{align}
\mathrm{Softmax}(u_k) = \frac{\exp{(u_k)}}{\displaystyle \sum_{j=1}^{K} \exp{(u_j)}}
\end{align}
$$

上記の定義より、ソフトマックス関数について下記の式が成立する。
$$
\large
\begin{align}
\mathrm{Softmax}(u_k) & \geq 0 \\
\sum_{j=1}^{K} \mathrm{Softmax}(u_j) &= 1
\end{align}
$$

温度スケーリングによるソフトマックス関数の出力の調整

ソフトマックス関数は$\exp$を用いることで入力値の差を際立たせた出力を行う。たとえば$\mathbf{u} = (5, 6, 7)$が入力の場合、下記の出力が得られる。

import numpy as np

u = np.array([5., 6., 7.])

p_1 = u/np.sum(u)
p_2 = np.exp(u)/np.sum(np.exp(u))

print(p_1)
print(p_2)

・実行結果

[ 0.27777778  0.33333333  0.38888889]
[ 0.09003057  0.24472847  0.66524096]

計算結果を確認すると、通常の確率化の結果が[ 0.27777778 0.33333333 0.38888889]であるのに対しソフトマックス関数の結果が[ 0.09003057 0.24472847 0.66524096]であり、「緩やかなmax関数」のように解釈することができる。

このようなソフトマックス関数の性質がより必要な時もあればある程度緩和が望ましい場合もあり、このような場合に温度スケーリング(temperature scaling)が用いられる。パラメータ$T$を元に温度スケーリングを施したソフトマックス関数は下記のように定義される。
$$
\large
\begin{align}
\mathrm{Softmax}(u_k) = \frac{\exp{(u_k/T)}}{\displaystyle \sum_{j=1}^{K} \exp{(u_j/T)}} \quad (1)
\end{align}
$$

$T=1$の場合に通常のソフトマックス関数に一致するので$(1)$式はソフトマックス関数の拡張であると理解することもできる。入力値$\mathbf{u} = (5, 6, 7)$に対し$T=0.3, 1, 10$の場合を計算するとそれぞれ下記が得られる。

import numpy as np

u = np.array([5., 6., 7.])
T = np.array([0.3, 1., 10.])

p = np.zeros([T.shape[0], a.shape[0]])
for i in range(T.shape[0]):
    p[i,:] = np.exp(u/T[i])/np.sum(np.exp(u/T[i]))

print("T: {:.1f}, p: {}".format(T[0], p[0,:]))
print("T: {:.1f}, p: {}".format(T[1], p[1,:]))
print("T: {:.1f}, p: {}".format(T[2], p[2,:]))

・実行結果

T: 0.3, p: [ 0.00122729  0.03440292  0.96436979]
T: 1.0, p: [ 0.09003057  0.24472847  0.66524096]
T: 10.0, p: [ 0.30060961  0.33222499  0.3671654 ]

上記より、$T<1$を設定するとより極端な結果が、$T>1$を設定するとより一様分布に近い結果が得られることが確認できる。このようにソフトマックス関数に温度スケーリングを導入することで出力の値を調整することができる。

温度スケーリングの使用例

Transformer

温度スケーリングの導入

TransformerのDot Product Attentionでは下記のような計算を行う。
$$
\large
\begin{align}
\mathrm{Attention}(Q, K, V) &= \mathrm{Softmax} \left( \frac{Q K^{\mathrm{T}}}{\sqrt{d}} \right) V
\end{align}
$$

$\sqrt{d}$は温度スケーリング$T$と対応させて理解することができる。ここで上記の$d$はトークン毎の次元数に対応するので、「トークン毎の次元が大きい場合は出力を一様な値に調整する」と大まかに解釈することができる。

$\sqrt{d}$を用いる理由の考察

トークン$i$とトークン$j$に対応する中間層のベクトル$\mathbf{v}_{i} \in \mathbb{R}^{d}, \, \mathbf{v}_{j} \in \mathbb{R}^{d}$を下記のように定義する。
$$
\large
\begin{align}
\mathbf{v}_{i} &= \left( \begin{array}{c} v_{i1} \\ \vdots \\ v_{id} \end{array} \right) \\
\mathbf{v}_{j} &= \left( \begin{array}{c} v_{j1} \\ \vdots \\ v_{jd} \end{array} \right)
\end{align}
$$

このとき、$\mathbf{v}_{i}$と$\mathbf{v}_{j}$の内積$\mathbf{v}_{i} \cdot \mathbf{v}_{j}$は下記のように計算できる。
$$
\large
\begin{align}
\mathbf{v}_{i} \cdot \mathbf{v}_{j} &= \left( \begin{array}{c} v_{i1} \\ \vdots \\ v_{id} \end{array} \right) \cdot \left( \begin{array}{c} v_{j1} \\ \vdots \\ v_{jd} \end{array} \right) \\
&= v_{i1} v_{j1} + \cdots + v_{id} v_{jd} \\
&= x_{1} + \cdots + x_{d} = \sum_{k=1}^{d} x_{k}
\end{align}
$$

上記では式の簡略化にあたって、$x_{k} = v_{ik} v_{jk}$を導入した。ここで$x_{k} \sim \mathcal{N}(0, \sigma^{2}), \, \mathrm{i.i.d.}$を仮定すると、$\displaystyle S = \sum_{k=1}^{d} x_{k}$について正規分布の再生性に基づいて下記が成立する1
$$
\large
\begin{align}
S \sim \mathcal{N}(0, d \sigma^{2})
\end{align}
$$

よって、$S$の標準偏差は$x_{k}$の標準偏差の$\sqrt{d}$倍になる。この結果から、「Transformerにおける温度スケーリングは計算される内積の標準偏差をトークンの次元$d$に依らず一定に保つ目的で導入された」と大まかに解釈できる。

期待校正誤差に基づくソフトマックス出力の校正

DeepLearningの確信度

入力$\mathbf{x}$に対しDeepLearningのソフトマックス演算後の出力を$p(\mathcal{C}_{k}|\mathbf{x})$とおく。この$p(\mathcal{C}_{k}|\mathbf{x})$を「DeepLearningの推論における確信度(confidence)」というが、入力$\mathbf{x}$が得られた際の事後確率と解釈することもできる。

DeepLearningではこの確信度(confidence)は正答率(accuracy)に対し過剰になる場合が多いので注意が必要である。このような場合に用いられる「確信度と正答率が概ね一致するかを確認する指標」の$1$つに期待校正誤差(ECE; Expected Calibration Error)という尺度がある。

以下、期待校正誤差の定義について確認を行う。まず、$N$個のテストサンプルを$M$分割した確信度の範囲$[0,1]$の各区間に分類を行うとき、$M$分割した区間は下記のように表される。
$$
\large
\begin{align}
m = 1 &: \left[ 0, \frac{1}{M} \right] \\
m = 2 &: \left[ \frac{1}{M}, \frac{2}{M} \right] \\
m = 3 &: \left[ \frac{2}{M}, \frac{3}{M} \right] \\
& \vdots \\
m = M-1 &: \left[ \frac{M-2}{M}, \frac{M-1}{M} \right] \\
m = M &: \left[ \frac{M-1}{M}, 1 \right]
\end{align}
$$

このとき$[0,1]$を$M$分割した中の$m$番目のビンに含まれるサンプル集合を$B_{m}$、$B_{m}$に含まれるサンプルの確信度を$\mathrm{conf}(B_{m})$とおくと、$\mathrm{conf}(B_{m})$は大まかに下記の式で近似することが可能である。
$$
\large
\begin{align}
\mathrm{conf}(B_{m}) \simeq \frac{m \, – \, 1/2}{M}
\end{align}
$$

たとえば、$M=10$のとき$m=1, \cdots , 10$について$\mathrm{conf}(B_{m})$の近似値は下記のように計算できる。

$m$$\mathrm{conf}(B_{m})$の近似値
$m=1$$\displaystyle \frac{1 \, – \, 1/2}{10} = 0.05$
$m=2$$\displaystyle \frac{2 \, – \, 1/2}{10} = 0.15$
$m=3$$\displaystyle \frac{3 \, – \, 1/2}{10} = 0.25$
$\vdots$$\vdots$
$m=9$$\displaystyle \frac{9 \, – \, 1/2}{10} = 0.85$
$m=10$$\displaystyle \frac{10 \, – \, 1/2}{10} = 0.95$

$\mathrm{conf}(B_{m})$の値は上記のように近似することが可能だが、各テストサンプルの確信度の平均を計算しても良い。ここでは「深層学習 第$2$版」の内容に基づいて近似式を詳しく確認した。

期待校正誤差の定義

「DeepLearningの確信度」で定義したビン$m$に含まれるサンプル集合$B_{m}$の正答率を$\mathrm{acc}(B_{m})$、確信度を$\mathrm{conf}(B_{m})$のようにおく。このとき期待校正誤差$\mathrm{ECE}$は下記のように定義される。
$$
\large
\begin{align}
\mathrm{ECE} = \sum_{m=1}^{M} \frac{|B_{m}|}{N} \left| \mathrm{acc}(B_{m}) \, – \, \mathrm{conf}(B_{m}) \right| \quad (2)
\end{align}
$$

上記の$|B_{m}|$はビン$m$に含まれるテストサンプルの数、$N$は全テストサンプルの数にそれぞれ対応する。

温度スケーリングを用いたソフトマックス出力の調整

バリデーション2データ上で$(2)$式を計算し、$\mathrm{ECE}$の値が最小になるように下記の$T$を調整し、学習を行うことで正答率に対し確信度が過剰になることを防ぐことができる。
$$
\large
\begin{align}
\mathrm{Softmax}(u_k) = \frac{\exp{(u_k/T)}}{\displaystyle \sum_{j=1}^{K} \exp{(u_j/T)}}
\end{align}
$$

  1. 確率分布からのサンプリングを取り扱う際には「確率変数」か「観測値」かを区別して取り扱うことが多いが、ここでは内容の簡易化にあたって厳密な議論は省略した。また、確率分布には正規分布を仮定したが、再生性が成立するならば他の分布を仮定しても同様な議論が成立する。 ↩︎
  2. 学習時に「学習に用いないサンプルの正答率の計算」に用いるサンプルをバリデーション(validation)データ、学習後に「正答率の計算」に用いるサンプルをテストデータという。双方が同じ場合もあるが、バリデーションとある場合は使い分けることが多い。 ↩︎