ブログ

【DETR論文まとめ】Transformerを用いたObject Detection

Object Detectionタスクには従来VGGNetやResNetなどのCNNをbackboneに持つネットワークを用いることが主流であった一方で、近年Transformerの導入も行われています。当記事ではObject DetectionにTransformerを導入した研究であるDETRについて取りまとめました。
DETRの論文である「End-to-End Object Detection with Transformers」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

指示関数

指示関数(indicator function)の$\mathbb{1}_{[k \neq i]} \in \{ 0, 1 \}$は下記のように定義されます。
$$
\large
\mathbb{1}{[k \neq i]} =
\begin{cases}
1 \quad \mathrm{if} \quad k \neq i \\
0 \quad \mathrm{otherwise}
\end{cases}
$$

Hungarian algorithm

・Hungarian algorithm(Wikipedia)

DETR

処理の概要

DETR(DEtection TRansformer)の処理の全体は下図より確認できます。

DETR論文 Figure$\, 2$

図より、DETRは主に下記の三つのプロセスによって構成されることが確認できます。

・backboneネットワーク
・Transformer
・FFN

以下、それぞれのプロセスの詳細やDETRのlossについて確認を行います。まとめの都合上、「backbone+Transformerのencoder」と「Transformerのdecoder+FFN」に分けて確認を行いました。

backboneとTransformer encoder

DETRではbackboneネットワークにCNNを用います。
$$
\large
\begin{align}
X_{\mathrm{img}} \in \mathbb{R}^{H \times W \times 3}
\end{align}
$$

上記の入力に対しCNNを作用させ、下記のようなFeature mapを取得します。
$$
\large
\begin{align}
f \in \mathbb{R}^{H \times W \times C}
\end{align}
$$

上記のFeature mapに対し、$1 \times 1$のConvolutionを行いReshapeを行った結果の$Z_{0} \in \mathbb{R}^{HW \times d}$をTransformer encoderに入力します。ここで$d$は$d < C$になるように値の設定を行います。

DETRのTransformer encoderはMultiHead AttentionやFFNによって構成されるオーソドックスなTransformerの処理を基本的にそのまま用います。

Transformer decoder: object queriesとbounding boxの生成

Transformer decoderでは機械翻訳では翻訳文が入力されるTransformerのクエリにobject queriesという入力を用います。

このobject queryに基づいてObject Detectionにおけるバウンディングボックスの生成やクラスの予測が行われることから、object queryはFaster-RCNNなどにおけるAnchor boxと同様な役割であると理解すると良いと思います。

ここでobject queryは学習可能なパラメータであり、Faster-RCNN・SSDなど多くのObject Detectionで用いられるAnchor boxをDETRでは自動学習させることができると解釈することもできます。

DETRのloss

$y = \{ y_1, \cdots y_{n} \}, \, y_{n+1} = y_{n+2} \cdots = y_{N-1} = y_{N} = \varnothing$をObject集合の正解、$\hat{y} = \{ \hat{y}_{i} \}_{i=1}^{N} = \{ \hat{y}_{1}, \cdots , \hat{y}_{N} \}$をObjectの予測の集合とおくとき、$y$と$\hat{y}$の最もコストの低い対応を表すインデックスの順列を表す$\hat{\sigma} \in \mathfrak{S}_{N}$は下記のように定義されます。
$$
\large
\begin{align}
\hat{\sigma} &= \mathrm{arg}\min_{\sigma \in \mathfrak{S}_{N}} \sum_{i=1}^{N} \mathcal{L}_{\mathrm{match}}(y_{i}, \hat{y}_{\sigma(i)}) \quad (1) \\
\mathcal{L}_{\mathrm{match}}(y_{i}, \hat{y}_{\sigma(i)}) &= \mathbb{1}_{[c_i \neq \varnothing]} \hat{p}_{\sigma(i)} + \mathbb{1}_{[c_i \neq \varnothing]} \mathcal{L}_{\mathrm{box}}(b_i, \hat{b}_{\sigma(i)})
\end{align}
$$

ここで上記の式における$\mathfrak{S}_{N}$はインデックスの順列の全パターンの集合、$N$は予測の数、$n$は正解のObjectの数、$\mathbb{1}$は指示関数、$\varnothing$は$y_{n+1}$〜$y_{N}$のObjectが存在しないことをそれぞれ表します。

$(1)$式の$\hat{\sigma}$は組み合わせ最適化問題であり、Hungarian algorithmを用いることで得ることができます。ここで得た$\hat{\sigma}$を元にHungarian lossの$\mathcal{L}_{\mathrm{Hungarian}}(y,\hat{y})$は下記のように定義されます。
$$
\large
\begin{align}
\mathcal{L}_{\mathrm{Hungarian}}(y,\hat{y}) &= \sum_{i=1}^{N} \left[ -\log{\hat{p}_{\hat{\sigma}(i)}(c_i)} + \mathbb{1}_{[c_i \neq \varnothing]} \mathcal{L}_{\mathrm{box}}(b_i, \hat{b}_{\sigma(i)}) \right] \quad (2) \\
\mathcal{L}_{\mathrm{box}}(b_i, \hat{b}_{\sigma(i)}) &= \lambda_{\mathrm{IoU}} \mathcal{L}_{\mathrm{IoU}}(b_i, \hat{b}_{\sigma(i)}) + \lambda_{L1}|| b_i-\hat{b}_{\sigma(i)} ||_{1} \\
\lambda_{\mathrm{IoU}}, \lambda_{L1} & \in \mathbb{R}
\end{align}
$$

上記の$c_i, b_i$は$y_i=(c_i,b_i)$によって定義される、Objectの予測のクラスとバウンディングボックスにそれぞれ対応します1。また、$|| \cdot ||_{1}$は$L1$ノルムを表します。

$(2)$式は、「Hungarian algorithmによって得られた$\hat{\sigma}$について、クラス分類に関するCross EntropyとBounding Boxのズレに基づいてlossを定義する」のように解釈することができます。

参考

・Transformer論文:Attention is All you need[$2017$]
・DETR論文

  1. クラス間不均衡(Class imbalance)に対処するにあたって、実際には$c_i=\varnothing$の場合に$-\log{\hat{p}_{\hat{\sigma}(i)}(c_i)}$をdown-weightします。 ↩︎

【Pointformer論文まとめ】点群の処理へのTransformerの導入

PointNet++を用いた点群の処理はPointNetに階層型のプーリングを導入することで改良にはなったものの、局所領域における点間の相関を取り扱えないなどの課題があります。当記事ではこの課題の解決にあたって点群の処理にTransformerを導入したPointformerについて取りまとめました。
Pointformerの論文である「$3$D Object Detection with Pointformer」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

PointNet

PointNet++

Pointformer

処理の概要

Pointformerの処理の全体像は下図を元に掴むと良いです。

Pointformer論文 Figure$\, 2$

上図から確認できるように、Pointformerは「Pointformer Block」に基づいて基本的に構成されています。

また、同時に「Pointformer Block」が「Local Transformer」、「Local-Global Transformer」、「Global Transformer」の三つの処理から構成されていることも確認することができます。以下、三つの処理について詳しく確認を行います。

Local Transformer

入力される点群の集合を$\mathcal{P} = \{ x_1, \cdots , x_{N} \}$とおくとき、Local TransformerではPointNet++と同様にFPS(Furthest Point Sampling)を用いて下記のような中心点の集合を取得します。
$$
\large
\begin{align}
\{ x_{c_1}, \cdots , x_{c_{N’}} \}, \quad N’ < N
\end{align}
$$

このとき、上記の中心点の周辺の半径に基づいたball queryを用いることで$j$番目の中心点に対して$K_t(t=1, \cdots , N’)$個の点を獲得し、Transformerへ入力を行います。

Pointformer論文 Figure$\, 3$

FPSは上図における左側の青丸から赤丸を取得する処理、ball queryを用いた処理は赤丸の中心に基づいて円を描き、円の内部の点をグルーピングする処理にそれぞれ対応します。

$t$番目の中心点の周囲の局所領域における$i$番目の点について位置情報と特徴量を$\{ x_{i}, f_{i} \}_{t}, \, x_{i} \in \mathbb{R}^{1 \times 3}, \, f_{i} \in \mathbb{R}^{1 \times C}$のように表すとき、$t$番目の領域におけるTransformer処理は下記のように表すことができます。
$$
\large
\begin{align}
f_{i}^{(0)} &= \mathrm{FFN}(f_{i}), \quad {}^{\forall} i \in N(x_{c_{t}}) \\
F^{(l+1)} &= \mathrm{Transformer Block}(F^{(l)}, \mathrm{PE}(X)) \\
F^{(l)} &= \left( \begin{array}{c} f_{1}^{(l)} \\ \vdots \\ f_{K_t}^{(l)} \end{array} \right) , \quad X = \left( \begin{array}{c} x_{1} \\ \vdots \\ x_{K_t} \end{array} \right)
\end{align}
$$

ここで$N(x_{c_{t}})$の$N$は中心点$x_{c_{t}}$と同じ領域にグルーピングされた点のインデックスの集合を表します1。$\mathrm{Transformer Block}$は一般的なTransformerのそれぞれの層における計算に対応します。また、論文では$F^{(l)},X$がやや唐突に出てきたので上記の$3$式目で具体化しました。

ここまでがLocal Transformerの基本処理です。一方で、Local Transformerには基本処理に加えてFPSによってサンプリングされた中心点の位置の補正を行うCoordinate Refinementも合わせて用いられます。以下、Coordinate Refinementについて詳しく確認します。

Coordinate RefinementはLocal Transformerで計算を行ったAttention Matrix2に基づいて中心点の位置の調整を行う手法です。Attention Matrixの$m$番目のheadのAttention Matrixを$A^{(m)}$、$t$番目の中心点に対応する行を$A_{c}^{(m)} \in \mathbb{R}^{1 \times K_t}$とおくとき、行ベクトル$\mathbf{w} \in \mathbb{R}^{1 \times K_t}$を下記のように定義します。
$$
\large
\begin{align}
\mathbf{w} = \frac{1}{M} \sum_{m=1}^{M} A_{c}^{(m)}
\end{align}
$$

上記の$M$はMultiHead Attentionにおけるheadの数を表します。このとき、行ベクトル$\mathbf{w}$の$k$番目の要素を$w_{k}$のように表すとき、中心点$x_{c_{t}}$を下記のような式に基づいて得られる$x’_{c_{t}}$に移動させます3
$$
\large
\begin{align}
x_{c_{t}}’ &= \sum_{k=1}^{K_{t}} w_{k} x_{k} \\
w_{k} & \in \mathbb{R}, \, x_{c_{t}}’ \in \mathbb{R}^{1 \times 3}, \, x_{k} \in \mathbb{R}^{1 \times 3}
\end{align}
$$

Local-Global Transformer

Pointformer論文 Figure$\, 2$の一部

Local-Transformerでは中心点に基づく局所領域を元に点群のダウンサンプリングを行います。この処理では局所領域におけるattentionしか計算しないので、ダウンサンプリング後のそれぞれの点に対して大域的な情報を反映させられることが望ましいです。

Local-Global Transformerではダウンサンプリング後の各点に対して大域的な情報を反映させるにあたって、Local Transformerの出力をquery、Pointformer blockへの入力をkey・valueとするcross-attention処理を行います。Pointformer blockへの入力は高解像度なので点のインデックスの集合を$\mathcal{P}^{h}$、Local Transformerの出力低解像度なので点のインデックスの集合を$\mathcal{P}^{l}$とおくと、Local-Global Transformerの演算を下記のように表すことができます。
$$
\large
\begin{align}
f_{i}^{(0)} &= \mathrm{FFN}(f_{i}), \quad {}^{\forall} i \in \mathcal{P}^{l} \\
f_{j}’ &= \mathrm{FFN}(f_{j}), \quad {}^{\forall} j \in \mathcal{P}^{h} \\
F^{(l+1)} &= \mathrm{Transformer Block}(F^{(l)}, F’, \mathrm{PE}(X)) \\
F^{(l)} &= \left( \begin{array}{c} f_{1}^{(l)} \\ \vdots \\ f_{K_t}^{(l)} \end{array} \right), \quad X = \left( \begin{array}{c} x_{1} \\ \vdots \\ x_{K_t} \end{array} \right) \\
F’ &= \left( \begin{array}{c} f_{1}’ \\ \vdots \\ f_{N}’ \end{array} \right)
\end{align}
$$

上記の式の$\mathrm{Transformer Block}$はself-attentionではなくcross-attentionの式であることに注意が必要です。

Global Transformer

Pointformer論文 Figure$\, 2$の一部

Global TransformerではLocal-Global Transformerの出力の全ての点のインデックスの集合$\mathcal{P}$についてTransformer処理を行います。
$$
\large
\begin{align}
f_{i}^{(0)} &= \mathrm{FFN}(f_{i}), \quad {}^{\forall} i \in \mathcal{P} \\
F^{(l+1)} &= \mathrm{Transformer Block}(F^{(l)}, \mathrm{PE}(X))
\end{align}
$$

Global Transformerの式表記はLocal Transformerの$N(x_{c_{t}})$を$\mathcal{P}$に変えることで表されます。

Pointformerの構成

indoorデータセットとKITTIに対するPointformerの構成は下記の表より確認することができます。

indoorタスク用のPointformer:Pointformer論文 Table$\, 9$
KITTIのタスク用のPointformer:Pointformer論文 Table$\, 10$

参考

・Transformer論文:Attention is All you need[$2017$]
・Pointformer論文

  1. グラフ理論ではノード$v$に隣接するノードの集合を$N(v)$のように表しますが、同様な表現と見なして良いです。論文では$N$ではなく$\mathcal{N}$が用いられていますが、$\mathcal{N}$は正規分布に用いられることが多いので、グラフ理論の表記を参考に当記事では$N$を用いました。 ↩︎
  2. ここではSoftmax関数などで行方向に正規化されたAttention Matrixを用います。 ↩︎
  3. $x’{c{t}}$の特徴量の計算についての記載が見当たらなかったので要確認。$x’{c{t}}$と同じく加重平均で計算すること自体はおそらく可能。 ↩︎

【PoinNet++論文まとめ】階層化グルーピングによるPointNetの改良

点群にDeepLearningを導入したPointNetは有力な手法である一方で、max poolingを一度しか行わないことで局所的な構造をなかなか抽出できないという課題があります。当記事ではこの解決にあたって階層化グルーピングを用いた研究であるPointNet++について取りまとめを行いました。
PointNet++の論文である「PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

PointNet

PointNet++

処理の概要

PointNet++の処理の全体像は下図を元に掴むと良いです。

PointNet++論文 Figure$\, 2$

上図から確認できるように、PointNet++はVGGNetやResNetにおけるbackbone処理とタスク特化処理の二つによって大まかに構成されます。

backbone処理ではSet Abstraction、SegmentationタスクではFeature Propagationをそれぞれ理解すると良いです。Set Abstractionは次項、Feature Propagationについては次々項でそれぞれ詳しく取り扱いました。

Set Abstraction

PointNet++論文 Figure$\, 2$の一部

上図から確認できるように、Set Abstractionは以下の$3$つのkey layerで構成されます。

・Sampling layer
・Grouping layer
・PointNet layer

上記の$3$つのlayerは、「Sampling layer」で局所領域における中心点(centroid)を決め、「Grouping layer」で局所領域における中心点の近傍にある点をグルーピングし、「PointNet layer」でグルーピングした点の集合から特徴量の抽出を行う、とそれぞれ解釈すると良いです。PointNet論文ではこの$3$つのkey layerの処理をまとめてset abstractionと表しており、Figure$\, 2$でもset abstractionの記載が確認できます。

以下、「Sampling layer」、「Grouping layer」、「PointNet layer」のそれぞれについて詳しく確認を行います。

Sampling layer

Sampling layerでは$N$個の点集合$\{ x_1, \cdots , x_{N} \}$に対しFPS(farthest point sampling)を行うことで、$N’$個の点集合$\{ x_{i_{1}}, \cdots , x_{i_{N’}} \}$をの抽出を行います。

ここで$i_{j}$は順序$i$1の$j$番目に対応するインデックスを表します2。ここでFPSは$x_{i_{j}}$の選定にあたって、$j-1$の点集合$\{ x_{i_{1}}, \cdots , x_{i_{j-1}} \}$から最も遠い点を選択する手法であると理解すると良いです。

中心点(centroid)の選択にあたってはランダムサンプリングも可能ですが、定数$K$が固定されるときFPSを用いることでより全体の点集合をカバーすることができるなど有用です。FPSを用いたこのような処理はCNNにおけるreceptive fieldと対応させて解釈すると良いです。

Grouping layer

$N$個の点群のそれぞれの点の位置情報が$d$次元、特徴量が$C$次元で表されるとき、これらをまとめて$x_{i} \in \mathbb{R}^{1 \times (d+C)}$のように表せると仮定します。このとき、下記が成立します。
$$
\large
\begin{align}
\left( \begin{array}{c} x_1 \\ \vdots \\ x_N \end{array} \right) & \in \mathbb{R}^{N \times (d+C)} \\
\left( \begin{array}{c} x_{i_{1}} \\ \vdots \\ x_{i_{N’}} \end{array} \right) & \in \mathbb{R}^{N’ \times (d+C)}
\end{align}
$$

ここで$j$番目の中心点(centroid)である$x_{i_{j}}$から半径$R$以内の点のインデックスの集合を$\{ \sigma_{1}, \sigma_{2}, \cdots , \sigma_{K_{j}} \}$とおくと、このインデックスに基づく行列のサイズは下記のように表すことができます3
$$
\large
\begin{align}
\left( \begin{array}{c} x_{\sigma_{1}} \\ \vdots \\ x_{\sigma_{K_{j}}} \end{array} \right) & \in \mathbb{R}^{K_{j} \times (d+C)} \quad (1)
\end{align}
$$

Grouping layerでは$j$番目の中心点に対し、半径$R$の点を$(1)$式のように抽出を行います。ここで注意が必要なのが「半径$R$」を定義するには計量空間(metric space)が定義される必要があるということです。

計量空間(metric space)には「$\mathbb{R}^{n}$の標準内積」に基づいて定義される計量空間を用いることが一般的で、論文のSection$4.1$のEuclidean Metric Spaceに対応します。

一方でSection$4.3$のようにNon-Euclidean Metric Spaceを用いる場合もあります。

PointNet layer

PointNet layerではそれぞれのグループに対し、PointNetの処理が実行されます。PointNetに$(1)$式を入力した際に$x’_{j} \in \mathbb{R}^{1 \times (d+C’)}$が出力されます。$N’$個の全てのグループについての特徴量は下記のような行列で表されます。
$$
\large
\begin{align}
\left( \begin{array}{c} x’_1 \\ \vdots \\ x’_j \\ \vdots \\ x’_N \end{array} \right) & \in \mathbb{R}^{N’ \times (d+C’)}
\end{align}
$$

Grouping layerで各中心点ごとにサンプル数が可変長になったものをPointNet layerでmax poolingを行うことで固定長に戻すことができることは注意して抑えておくと良いと思います。

Feature Propagation

Feature PropagationはSegmentationタスクを解くにあたって用いられる処理の仕組みです。

PointNet++論文 Figure$\, 2$

上図から確認できるように、プーリングによってダウンサンプリングされた結果をアップサンプリングすることでセグメンテーションタスクの予測が行われます。アップサンプリング時の点の復元にあたっては、$N_{1} \geq N_{2} \geq \cdots N_{l-1} \geq N_{l} \geq \cdots$のようにダウンサンプリングした点を逆に辿って復元を行います。$N_{l}$から$N_{l-1}$の復元にあたっては位置情報を持つ$d$次元はダウンサンプリング時のものをそのまま保持しておき、特徴量の値は補間(interpolation)処理に基づいて取得します。

PointNet++で用いられる補間処理は下記で表した$k$近傍($k$-nearest neighbors)の加重平均の式で定義されます。
$$
\large
\begin{align}
f^{(j)}(x) &= \frac{\sum_{i=1}^{k} w_{i}(x) f_{i}^{(j)}}{\sum_{i=1}^{k} w_{i}(x)}, \, j=1, \cdots , C \quad (2) \\
w_{i}(x) &= \frac{1}{d(x, x_i)^{p}}
\end{align}
$$

上記の式は点$x$の$j$番目の特徴量を$^{(j)}(x)$とおくとき、Centroidの$x_i$との近さに基づいて特徴量を加重平均によって得ると解釈すれば良いです。ここでは論文の表記に合わせて$j$番目の要素について式定義した一方で、$C$次元のベクトルと見なして定義しても良いと思います。また、$d$は$x$と$x_i$のdistanceを意味しており、$p$は次元を表します。補間処理に用いる近傍点の数$k$と次元の$p$については$k=3$、$p=2$がデフォルトで用いられます。

次に、上記のように行う補間処理は一度ダウンサンプリングしたものを用いると高周波成分が消失する場合があり得るので、U-Netのように対応するSet Abstractionのレイヤーの特徴量の反映処理が行われます。この処理はskip link concatenationのように表されます。skip link concatenationについては基本的にはU-Netと同じ仕組みだと理解すると良いと思います。このような処理が行われることから、Feature Propagationの数はSet Abstractionの数と原則一致することも合わせて抑えておくと良いです。

各領域における点の密度の取り扱い

点群における各点は一様分布に基づいていないので、領域ごとに密度の偏りがあります。密(dense)な領域に基づいて学習した結果が疎(sparse)な領域にそのまま一般化することはできないので注意が必要です。

各領域の点の密度の取り扱いにおける「半径$R$を用いる」アプローチの改良にあたって、PointNet++では「MSG(Multi-scale grouping)」と「MRG(Multi-resolution grouping)」の二つの手法が紹介されています。以下それぞれについて詳しく確認します。

Multi-scale grouping

MSG(Multi-scale grouping)は同一の中心点(centroid)に対し、複数のスケールを用いてグルーピングを行う手法です。

PointNet++論文 Figure$\, 3 \, (\mathrm{a})$

上図には三つのグレースケールがありますが、濃淡と半径$R$の大きさが対応します。計算の詳細についてはPointNet++論文のAppendix.B.$1$のNetwork Architecturesを確認すると良いです。

PointNet++論文 Appendix.B.$1$ Network Architecturesより

上記には同じ$R$のみを用いるSSG(Single-Scale Grouping)と複数の$R$を用いるMSG(Multi-scale grouping)の$2$つのネットワーク構造について記載があります。$\mathrm{SA}$、$\mathrm{FC}$やそれぞれのハイパーパラメータについては下記にまとめました。

記号意味
$\mathrm{SA}(K, r, [l_{1}, \cdots , l_{d}])$SSGにおけるset abstractionの演算。$K$は「中心点/局所領域」の数、$r$はグルーピングの際の半径、$d$はPointNetの層の数、$l_i$は$i$番目の層における出力層の次元を表す。
$\mathrm{SA}([l_{1}, \cdots , l_{d}])$max poolingを伴うglobal set abstractionの演算。
$\mathrm{SA}(K, [r^{(1)}, \cdots , r^{(m)}], [[l_{1}^{(1)}, \cdots , l_{d}^{(1)}], \cdots , [l_{1}^{(m)}, \cdots , l_{d}^{(m)}]])$MSGにおけるset abstractionの演算。$m$は半径のスケールの数、$r^{(j)}$は$j$番目のスケールにおける半径、$l_{i}^{(j)}$は$j$番目のスケールの$i$番目の層における出力層の次元を表す。
$\mathrm{FC}(l, dp)$全結合(Fully Connected)層の演算、MLPと同義。$l$は出力層の次元、$dp$はDropoutの比率を表す。
$\mathrm{FP}(l_1, \cdots , l_{d})$feature propagation層

上記の表を元に「PointNet++論文 Appendix.B.$1$ Network Architectures」を確認すると、SSGによる分類では$K=512, r=0.2$で$64 \to 64 \to 128$のMLP演算が行われたのちに、$K=128, r=0.4$で$128 \to 128 \to 256$のMLP演算が行われることが読み取れます。さらにこの後Globalにmax poolingを実行し、$256 \to 256 \to 512$でMLP処理を行い、$512$次元のベクトルに対し全結合層の計算を何度か実行したのちに$K$クラス分類問題に用いる$K$次元のベクトルを出力することが読み取れます。

上記と同様にMSGの処理も読み取ることができます。

Multi-resolution grouping

PointNet++論文 Figure$\, 3 \, (\mathrm{b})$

参考

・PointNet++論文

  1. たとえば$N=5$、$i_{1}=2, i_{2}=5, i_{3}=4, i_{4}=1, i_{5}=3$のとき、$N’=3$であれば$\{ x_{i_{1}}, \cdots , x_{j_{N’}} \}$は$\{ x_2, x_5, x_4 \}$に対応します。 ↩︎
  2. 文字$i$はインデックスに用いられることが多いので紛らわしいですが、PointNet++論文で$i$が用いられているのでそのまま用いました。また、論文の$n, m$についてはGrouping layerと対応させるにあたって当記事では$N, N’$で置き換えました。 ↩︎
  3. PointNet++の論文では$N’$個の中心点に対し、$N’ \times K \times (d+C)$のように定義されますが、$K$が中心点のインデックスに対し定数ではなく可変であり紛らわしいので、当記事では$j$番目の中心点について式定義を行いました。 ↩︎

【ViT論文まとめ】Computer Vision分野へのTransformerの応用

Transformerは元々機械翻訳タスクに対して考案された一方で、大域的な特徴量を取り扱うことのできる強力なモジュールであることから様々なタスクに応用されます。当記事ではTransformerを画像処理に応用した初期の研究であるViT(Vision Transformer)について取りまとめました。
ViTの論文である「An image is worth 16×16 words: Transformers for image recognition at scale.」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

ViT

処理の概要

Vision Transformer(ViT)の処理の全体像は下図を元に掴むと良いです。

Vision Transformer論文 Figure$\, 1$

Vision Transformerでは入力画像を$16 \times 16$ピクセルのように一定サイズで区切りパッチ(patch)を作成し、それぞれのパッチをトークンと見なしてTransformerの処理を実行します。上図では左下の入力画像を$9$つに分割し、Transformer Encoderに入力されます。

この処理における「パッチのトークン化」「パッチへのPositional Embeddingの追加」については次項以降で詳しく確認します。

パッチのトークン化

ViTではまず入力画像を$X \in \mathbb{R}^{H \times W \times C}$をパッチ集合に対応する行列$X_{p} \in \mathbb{R}^{N \times (P^{2} \cdot C)}$へリサイズを行います。ここで$N$は下記のように計算されます。
$$
\large
\begin{align}
H \times W \times C &= N \times (P^{2} \cdot C) \\
N &= \frac{HW \cancel{C}}{P^{2} \cancel{C}} \\
&= \frac{HW}{P^{2}}
\end{align}
$$

上記の$\displaystyle N = \frac{HW}{P^{2}}$は「パッチの数はグレースケールにおける全ピクセル数($HW$)を各パッチのピクセル数$P^{2}$で割った数に一致する」と解釈するとわかりやすいと思います1

ここでリサイズによって得た$X_{p}$の$i$行目を$\mathbf{x}_{p}^{i} \in \mathbb{R}^{1 \times (P^{2} C)}$のようにおくと、$i$番目のパッチの$D$次元のEmbeddingは下記のように定義されます。
$$
\large
\begin{align}
\mathbf{x}_{p}^{i} E & \in \mathbb{R}^{1 \times D} \\
\mathbf{x}_{p}^{i} & \in \mathbb{R}^{1 \times (P^{2} C)}, \, E \in \mathbb{R}^{(P^{2} C) \times D}
\end{align}
$$

ここでBERTと同様にTransformer演算から全体の特徴量を抽出する用のパッチ$\mathbf{x}_{cls} \in \mathbb{R}^{1 \times D}$とPositional Encodeingの$E_{pos} \in \mathbb{R}^{(N+1) \times D}$を元に下記のように$\mathbf{z}_{0}$を定義します。
$$
\large
\begin{align}
\mathbf{z}_{0} = \left( \begin{array}{c} \mathbf{x}_{cls} \\ \mathbf{x}_{p}^{1} E \\ \mathbf{x}_{p}^{2} E \\ \vdots \\ \mathbf{x}_{p}^{N-1} E \\ \mathbf{x}_{p}^{N} E \end{array} \right) + E_{pos} \quad (1)
\end{align}
$$

上記はViT論文の$(1)$式に対応します。また、$\mathbf{x}_{cls}, \, E, \, E_{pos}$はどれも学習時に同時に学習されるパラメータです。ViTでは$(1)$式で定義される$\mathbf{z}_{0} \in \mathbb{R}^{(N+1) \times D}$に対し、Transformerの処理を適用し、特徴量の抽出を行います。

パッチのPositonal Embedding

前項で取り扱ったようにViTでは学習によってPositonal Embeddingの値を取得します。このようなPositonal Embeddingの取得はBERTやGPTなどでも用いられています。

ViT-Base/ViT-Large/ViT-Huge

ViT論文ではViT-Base/ViT-Large/ViT-Hugeの三つが主に用いられます。それぞれのハイパーパラメータは下記の表より確認できます。

Vision Transformer論文 Table$\, 1$

BERTと同様な規模感で把握しておくと良いと思います。パッチサイズは基本的に$P=16$が用いられます。

参考

・Transformer論文:Attention is All you need[$2017$]
・ViT論文

  1. RGBであれば$3HW$と$3P^{2}$、チャネル数$C$の場合は$HWC$と$P^{2} C$がそれぞれ対応します。解釈のわかりやすさを優先するにあたって、グレースケールを例に出しました。 ↩︎

【PointNet論文まとめ】DeepLearningを用いた点群の分類・セグメンテーション

PointNetは点群(point clouds)の分類(classification)や点単位のセグメンテーション(segmentation)にMLP(Multi Layer Perceptron)を導入した研究です。当記事ではPointNetの論文を元にPointNetの処理手順を取りまとめました。
PointNetの論文である「PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation」や「深層学習 第$2$版」第$7$章の「集合・グラフのためのネットワークと注意機構」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

点群とInvariance

点群は並べ替えに不変であるという特徴を持ちます。同変性を表すequivarenceと不変性を表すinvarianceについては下記で詳しく取り扱ったので当記事では省略します。

点群の取り扱いにあたっては並び替え操作に関する不変性を念頭に置いた上で処理を組む必要があり、PointNetも例外ではないことに注意しておくと良いです。

PointNet

PointNetの処理概要

PointNetの処理の全体像は下図を元に掴むことができます。

PointNet論文 Figure$\, 2$

基本的には入力された$n$個の点ごとにMLP(Multi Layer Perceptron)処理を行うことに注意して上図は確認を行うと良いです。図のMLP演算の所に「shared」と記載があるのはMLPの計算に用いるパラメータを共有したと解釈するのが適切だと思われます1

また、$n \times 1024$にmax poolingを施した結果、$1024$次元の特徴量が得られ、$512$次元$\longrightarrow$$256$次元$\longrightarrow$$k$次元とMLPによって計算され、$k$クラス分類問題に用いられる一連の流れを確認することもできます2

Max PoolingとInvariance

前節で取り扱ったように点群は並べ替えに対して不変である必要があるので、CNNにおけるダウンサンプリングのように近傍の点と段階的に統合することはできません。したがって、PointNetでは$n \times 1024$の特徴量を点方向にmax poolingすることで大域特徴量(Global Feature)の抽出を行います。

Local and Global Information Aggregation

分類(Classification)タスクを解くにあたってはmax poolingを行うことで得たGlobal Featureをそのまま用いれば良い一方で、セグメンテーション(Segmentation)タスクでは点単位での分類が必要になります。

したがってPointNetでは$n \times 64$の点ごとの特徴量にそれぞれ$1024$次元のGlobal Featureを連結する(concatenate)ことで$n \times 1088$とし、その後の処理を行います。MLPにより、$n \times 1088$は下記のように推移します。
$$
\large
\begin{align}
n \times 1088 & \longrightarrow n \times 512 \longrightarrow n \times 256 \\
& \longrightarrow n \times 128 \longrightarrow n \times 128 \longrightarrow n \times m
\end{align}
$$

上記に基づいて点ごとに$m$クラス分類を行うことで、セグメンテーションを実現することができます。

Joint Alignment Network

点群はrigid transformationのように特定のgeometric transformationsに対して不変(invariant)である必要があります。PointNetではこの対応にあたってアフィン変換を実現する小さなニューラルネットワークであるT-netが導入されます。

参考

・PointNet論文

  1. 論文内に該当の記載が見つからないので要出典。 ↩︎
  2. $(512,256,k)$のmlpの四角形の高さが紛らわしいですが、どれもglobal featureであるので単に$1024 \to 512 \to 256 \to k$のように推移すると理解すれば良いです。 ↩︎

【CvT論文まとめ】畳み込みとMultiHead Attentionの対応

ViTなどのComputer Vision分野へのTransformerの導入は強力なアプローチである一方で、Transformerをそのまま用いる場合は局所相関を生かせないなどの課題があります。当記事ではViTに畳み込みを導入した手法であるCvTを題材に畳み込みとMultiHead Attentionの対応について取りまとめました。
CvTの論文である「CvT: Introducing Convolutions to Vision Transformers」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

ViT

Convolutional vision Transformer

処理の概要

Convolutional vision Transformer(CvT)の処理の全体像は下図を元に掴むと良いです。

Convolutional Vision Transformer論文 Figure$\, 2$

CvTは主に「Convolutional Token Embedding」と「Convolutional Transformer Block」の二つによって構成され、どちらの処理にも畳み込みが用いられます。

「Convolutional Transformer Block」は「Convolutional Projection」という手法を用いてTransformerモジュールの入力であるQuery、Key、Valueの作成を行い、Transformer計算を行います。上図の$(\mathrm{b})$で示されるように、「Convolutional Projection」以外の「Convolutional Transformer Block」の処理はオーソドックスなTransformerであるので、以下では「Convolutional Token Embedding」と「Convolutional Projection」について詳しく取り扱います。

Convolutional Token Embedding

ステージ$i-1$の出力を$X_{i-1} \in \mathbb{R}^{H_{i-1} \times W_{i-1} \times C_{i-1}}$とおき、下記が成立するようにチャネル数$C_i$、フィルタのサイズが$s \times s$の畳み込み演算に対応する関数$f$を定義します。
$$
\large
\begin{align}
f(X_{i-1}) \in \mathbb{R}^{H_{i} \times W_{i} \times C_{i}}
\end{align}
$$

このとき、上記で表される関数の$f$が「Convolutional Token Embedding」に対応します。ステージ毎のフィルタサイズ$s \times s$は下記の表から確認できます。

Convolutional Vision Transformer論文 Table$\, 2$を改変

赤枠で囲った部分がフィルタサイズに対応します。また、「Convolutional Token Embedding」がステージ毎に$1$度であるのに対し、「Convolutional Transformer Block」は複数回繰り返される場合があることも合わせて抑えておくと良いです。

Convolutional Projection

Convolutional Vision Transformer論文 Figure$\, 3$

「Convolutional Projection」の処理概要は上図の$(\mathrm{b})$を元に理解すると良いです。パッチを$2D$に並べ、畳み込みを行なうことでQuery、Key、Valueを計算します。ここでの畳み込みではオーソドックスな畳み込みではなくDepth-wise Convolution」と「Point-wise Convolution」を組み合わせた「Depth-wise Separable Convolution」が用いられます。「Depth-wise Separable Convolution」については下記で詳しく取り扱ったので当記事では省略します。

また、処理概要を掴むにあたっては図の$(\mathrm{b})$が適している一方で、実際には$(\mathrm{c})$のSqueezed convolutional projectionが用いられます。Squeezed convolutional projectionはKeyとValueの解像度を低くすることでSpatial Reduction Attention(SRA)と同様に計算の軽量化を行なったと解釈すると良いと思います。Spatial Reduction Attentionについては下記で詳しく取り扱ったので当記事では省略します。

畳み込みとMultiHead Attentionの対応

前項では「Convolutional Transformer Block」におけるTransformerモジュールへの入力となるQuery、Key、Valueの作成を行う「Convolutional Projection」について確認を行いました。当項ではこの「Convolutional Projection」の解釈についてMultiHead Attentionの処理と対応させることで確認を行います。まず、Point-wise Convolutionに一致する$1 \times 1$の畳み込みはMultiHead Attentionにおけるlinear projectionと同じ処理を表すことを抑えておくと良いです。

「Point-wise Convolution」と「MultiHead Attentionにおけるlinear projection」の対応について詳しく確認するにあたって、Feature Mapの$X \in \mathbb{R}^{H \times W \times C_1}$とフィルタの$w_{i} \in \mathbb{R}^{C_1}$を定義します。

このとき、$X$の空間の$2D$を$1D$に直す(Flatten)ことで$X’ \in \mathbb{R}^{HW \times C_1}$を得ることができます。ここで入力のチャネル数の$C_1$に対応する出力のチャネル数を$C_2$とおき、下記のようにフィルタのパラメータを行列$\mathcal{W}$で定義します。
$$
\large
\begin{align}
\mathcal{W} = \left( \begin{array}{ccccc} w_1 & \cdots & w_i & \cdots & w_{C_2} \end{array} \right)
\end{align}
$$

上記を元に$X’ \mathcal{W}$を計算し、$1D$から$2D$に戻す処理はPoint-wise Convolutionの演算に一致します1

一方、$X’$を$Q, K, V$、$\mathcal{W}$を$W^{Q}, W^{K}, W^{V}$に読み換えた処理はMultiHead Attentionの「linear projection」の定義式そのものであることも同時に確認できます。したがって、「Point-wise Convolution」と「MultiHead Attentionにおけるlinear projection」は基本的に同一な処理であることが確認できます。

この処理の対応から、「Convolutional Projection」は新規の処理というよりはMultiHead Attentionの拡張であると解釈することができます。

参考

・Transformer論文:Attention is All you need[$2017$]
・Convolutional vision Transformer論文

  1. CNNの実装では畳み込みをそのまま実装するとループ計算が多くなるので、フィルタと行列演算ができるようにここで取り扱ったような空間方向に展開してから元に戻すような実装がよく用いられます。DeepLearningのフレームワークではim2colなどの名称でこのような機能が提供されることが多いです。 ↩︎

Pyramid ViTとSpatial Reduction Attention

Transformerを用いてセグメンテーション(Segmentation)やObject DetectionのようなDense Predictionタスクを学習させるには解像度を高くする必要がある一方で、ViTでは解像度を高くすると計算量の問題が生じます。当記事ではこの解決にあたって考案されたSRA(Spatial Reduction Attention)に基づくPyramid Vision Transformer(PVT)の論文について取りまとめを行いました。PVTの論文である「Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

ViT

Pyramid Vision Transformer

処理の概要

Pyramid Vision Transformer(PVT)の処理の全体像は下図を元に掴むと良いです。

Pyramid Vision Transformer論文 Figure$\, 3$

基本的にはViTの処理と同じである一方で、$16 \times 16$のピクセルをパッチと見なすViTに対してPyramid Vision Transformerではまず$4 \times 4$のピクセルをパッチと見なし、ステージ(stage)単位で徐々に解像度を粗くします。

各ステージの処理は「Patch embedding」と「Transformer encoder」に分けられます。大まかには前処理とTransformer処理のように解釈しておくと良いと思います。

ハイパーパラメータの定義と基本的な設定

Pyramid Vision Transformerではステージ単位で処理を行います。ステージ$i$におけるハイパーパラメータはそれぞれ下記のように定義されます。

ハイパーパラメータ意味
$F_i$ステージ$i$におけるFeature Mapの縮小率。$F_1=4, F_2=8, F_3=16, F_4=32$である。
$P_i$ステージ$i$におけるパッチのサイズ。基本的には$P_1=4$、$i \geq 2$では$P_i=2$。ステージ$i-1$の出力をベースにパッチのサイズを定義することに注意。
$C_i$ステージ$i$の出力のチャネル数。解像度が低くなるにつれて$C_i$は大きくなる。
$L_i$ステージ$i$におけるTransformer encoderの層の数。特に$i=2, 3$のときにモデルが大きくなるにつれて$L_i$は大きくなる。
$R_i$ステージ$i$のSpatial Reduction Attentionにおけるreduction ratio。$R_i$が大きくなるにつれてAttentionにおける解像度が粗くなる。
$N_i$ステージ$i$におけるAttentionのヘッドの数。
$E_i$FFNレイヤーにおける隠れ層の倍率1。一般的なTransformerでは$4$が多い。

上記のハイパーパラメータと具体的なモデルの対応は下記を元に確認することができます。

Pyramid Vision Transformer論文 Table$\, 1$

Patch embedding

Pyramid Vision Transformer論文 Figure$\, 3$の一部

Patch embeddingでは入力画像や$i-1$番目のステージの出力を元に$i$番目のステージにおける入力の調整を行います。
$$
\large
\begin{align}
H_{i-1} \times W_{i-1} \times C_{i-1}
\end{align}
$$

ステージ$i-1$の出力のサイズは上記のように表されますが、ステージ$i$ではまず始めに近隣パッチを連結することにより、下記のようなサイズのFeature Mapを得ます。
$$
\large
\begin{align}
\frac{H_{i-1}}{P_{i}} \times \frac{W_{i-1}}{P_{i}} \times (P_{i}^{2} C_{i-1}) \quad (1)
\end{align}
$$

上記はステージ$i$ではステージ$i-1$の出力を$P_{i} \times P_{i}$単位でパッチの連結を行い、その分チャネル(TransformerのMLP層のサイズに対応)が増えたと理解できます。

Patch embedding層では$(1)$式における$P_{i}^{2} C_{i-1}$のチャネルに対しMLP処理を行うことで、チャネル数を$C_{i}$に変えます。
$$
\large
\begin{align}
\frac{H_{i-1}}{P_{i}} \times \frac{W_{i-1}}{P_{i}} \times C_{i} \quad (1)’
\end{align}
$$

さらにReshape処理を行うことで下記のようなサイズを持つTransformer encoder層の入力が作成されます。
$$
\large
\begin{align}
\frac{H_{i-1} W_{i-1}}{P_{i}^{2}} \times C_{i}
\end{align}
$$

また、$H_{i-1}, W_{i-1}$と$H_{i}, W_{i}$について下記が成立することも合わせて抑えておくと良いです。
$$
\large
\begin{align}
\frac{H_{i-1}}{P_{i}} &= H_{i} \\
\frac{W_{i-1}}{P_{i}} &= W_{i}
\end{align}
$$

Transformer encoder

Pyramid Vision Transformer論文 Figure$\, 3$の一部

Transformer encoderではPatch embeddingによって作成されたTransformerの入力を元にTransformerの演算が行われます。ステージ$i$では$1$度のPatch embedding処理に対し、$L_i$回のTransformer処理が行われることに注意が必要です。

Spatial Reduction Attention(SRA)はCNNと同様に画像の局所的な相関を活用することで処理の効率化を行う手法です。詳しくは次項で取り扱います。

Spatial Reduction Attention(SRA)

Spatial Reduction Attention(SRA)はAttentionの入力であるKeyとValueの解像度を低くすることで計算の効率化を実現する手法です。

Pyramid Vision Transformer論文 Figure$\, 4$

Spatial Reduction Attention処理は下記のような数式で表されます2
$$
\large
\begin{align}
\mathrm{SRA}(Q, K, V) &= \mathrm{Concat}(\mathrm{head}_{1}, \cdots , \mathrm{head}_{N_{i}})W^{O} \quad (2) \\
\mathrm{head}_{j} &= \mathrm{Attention}(Q W_{j}^{Q}, \mathrm{SR}(K) W_{j}^{K}, \mathrm{SR}(V) W_{j}^{V}) \quad (3) \\
W_{j}^{Q} & \in \mathbb{R}^{C_i \times d_{head}}, \, W_{j}^{K} \in \mathbb{R}^{C_i \times d_{head}}, \, W_{j}^{V} \in \mathbb{R}^{C_i \times d_{head}} \\
W^{O} & \in \mathbb{R}^{C_i \times C_i} \\
d_{head} &= \frac{C_i}{N_i} \quad (4)
\end{align}
$$

$(2)$式は一般的なMultiHead Attentionの計算と同じである一方で、$(3)$式の$\mathrm{head}_{j}$の計算過程に$\mathrm{SR}(K)$が出てくることに注意が必要です。行列$X$についてこの$\mathrm{SR}(X)$の計算は下記のように定義されます。
$$
\large
\begin{align}
\mathrm{SR}(X) &= \mathrm{Norm}(\mathrm{Reshape}(X, R_i)W^{S}) \quad (5) \\
X \in \mathbb{R}^{(H_i W_i) \times C_i}, \,\, & \mathrm{Reshape}(X, R_i) \in \mathbb{R}^{\frac{H_i W_i}{R_i^{2}} \times (R_i^{2} C_i)}, \,\, W^{S} \in \mathbb{R}^{(R_i^{2} C_i) \times C_i} \quad (6)
\end{align}
$$

$\displaystyle \mathrm{Reshape}(X, R_i) \in \mathbb{R}^{\frac{H_i W_i}{R_i^{2}} \times (R_i^{2} C_i)}$で表したように、$(5)$式のReshape処理はPatch embeddingと同様な処理が行われるので対応させて抑えておくと良いと思います。

$(5)$式と$(6)$式より、$\mathrm{SR}(X)$は$\displaystyle \mathrm{SR}(X) \in \mathbb{R}^{\frac{H_i W_i}{R_i^{2}} \times C_i}$で表される行列であり、$X \in \mathbb{R}^{(H_i W_i) \times C_i}$の行数が$\displaystyle \frac{1}{R_i^{2}}$になったことが確認できます。この処理は「Queryに対応して重み付け和を計算するValueの空間方向の解像度を低くすることで、CNNにおけるダウンサンプリングと同様な効果が得られる」と解釈することができます3

参考

・Transformer論文:Attention is All you need[$2017$]
・Pyramid Vision Transformer論文

  1. expansion ratioの厳密な説明が論文内に見当たらなかったので、要出典。 ↩︎
  2. 論文では$d_{head}$と表記がありますが、$(4)$式で定義されるようにステージ$i$毎に$d_{head}$の値が変わる可能性があることは注意しておくと良いです。 ↩︎
  3. TransformerにおけるAttentionでは$K=V$かつ$Q, K, V$の列数が一致する必要がある一方で、$Q$と$K=V$の行数が必ずしも一致しないことは抑えておくと良いと思います。Encoderでは$Q=K=V$であることが多い一方で、Decoderでは$Q \neq K$の計算も行われます。 ↩︎

【SimCLR】対照学習(Contrastive Learning)に基づくベクトル表現の取得①

SimCLR(Simple Framework for Contrastive Learning of Visual Representations)は対照学習(Contrastive Learning)を用いて画像のベクトル表現を抽出する手法です。当記事ではSimCLRの一連の学習手順について取りまとめを行いました。
SimCLRの論文の「A Simple Framework for Contrastive Learning of Visual Representations」を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

指示関数

指示関数(indicator function)の$\mathbb{1}_{[k \neq i]} \in \{ 0, 1 \}$は下記のように定義される。
$$
\large
\mathbb{1}_{[k \neq i]} =
\begin{cases}
1 \quad \mathrm{if} \quad k \neq i \\
0 \quad \mathrm{otherwise}
\end{cases}
$$

対照学習

SimCLR

SimCLRの全容

SimCLRは下記の$4$つの主要要素によって構成される。
・確率的データ拡張(A stochastic data augmentation)
・エンコーダ(A neural network base encoder $f(·)$)
・A small neural network projection head $g(·)$
・対照損失関数(A contrastive loss function)

上記の$4$つの主要要素は下図を元に抑えておくと良い。

SimCLR論文 Figure$\, 2$

上図の$\mathbf{x}$から$\tilde{\mathbf{x}}_{i}$や$\tilde{\mathbf{x}}_{j}$を作成するプロセスが確率的データ拡張(A stochastic data augmentation)、$\tilde{\mathbf{x}}_{i}, \, \tilde{\mathbf{x}}_{j}$から$\mathbf{h}_{i}, \, \mathbf{h}_{j}$を計算するプロセスがエンコーダ、$\mathbf{h}_{i}, \, \mathbf{h}_{j}$から$\mathbf{z}_{i}, \, \mathbf{z}_{j}$を計算するプロセスがprojection head、$\mathbf{z}_{i}, \, \mathbf{z}_{j}$を用いて定義される損失関数が対照損失関数(A contrastive loss function)にそれぞれ対応する。

以下、$4$つの主要要素についてそれぞれ確認を行う。

stochastic data augmentation

データ拡張(data augmentation)の主要な手法は下図に基づいて把握することができる。

SimCLR論文 Figure$\, 4$

SimCLRでは$(\mathrm{c})$の「crop(切り抜き)+リサイズ+反転(flip)」、$(\mathrm{d}), \, (\mathrm{e})$の「color distortion」、$(\mathrm{i})$の「Gaussian blur」の$3$つが用いられ、cropとcolor distortionが有効であったと報告されている。

$(\mathrm{c})$の切り抜きにあたっては「切り抜く場所」や「切り抜くサイズ」にランダム性があることから、SimCLRにおけるデータ拡張は確率的データ拡張(stochastic data augmentation)と表されていることも合わせて抑えておくと良い。

ニューラルネットの構成①:encoder

encoderの$f(\cdot)$は入力の$\tilde{\mathbf{x}}_{i}$や$\tilde{\mathbf{x}}_{j}$からベクトル表現(Visual Representation)を抽出する関数に対応する。SimCLRでは下記の数式で表されるようにResNetが用いられる。
$$
\large
\begin{align}
\mathbf{h}_{i} &= f(\mathbf{x}_{i}) = \mathrm{ResNet}(\mathbf{x}_{i}) \\
\mathbf{h}_{i} & \in \mathbb{R}^{d}
\end{align}
$$

$d$は抽出する画像のベクトル表現の次元に対応する。

ニューラルネットの構成②:projection head

projection headの$g(\cdot)$は抽出したベクトル表現の$\mathbf{h}_{i}$を対照損失(contrastive loss)の計算用に変換する処理に対応する。原理的には$\mathbf{h}_{i}$をそのまま用いてlossの計算を行うことは可能であるが、SimCLRの論文では$g$を用いることがbeneficialとされる。

SimCLRでは$g$に二層のMLP(Multi Layer Perceptron)が用いられている。このMLPは下記のような数式で表される。
$$
\large
\begin{align}
\mathbf{z}_{i} &= g(\mathbf{h}_{i}) = W^{(2)} \mathrm{ReLU}(W^{(1)} \mathbf{h}_{i}) \\
\mathrm{ReLU}(x) &= \max(0, x)
\end{align}
$$

loss function

$(i,j)$を正例の組とおき、$k$組用意するとき、サンプルは$2k$個となる。このとき、$(i,j)$に関するloss functionを下記のように定義する。
$$
\large
\begin{align}
l_{i,j} &= -\log{\left[ \frac{\exp{(\mathrm{sim}(\mathbf{z}_{i},\mathbf{z}_{j}))/\tau}}{\sum_{k=1}^{2N} \mathbb{1}_{k \neq i} \exp{(\mathrm{sim}(\mathbf{z}_{i},\mathbf{z}_{k}))/\tau}} \right]} \\
\mathrm{sim}(\mathbf{u},\mathbf{v}) &= \frac{\mathbf{u}^{\mathrm{T}} \mathbf{v}}{||\mathbf{u}|| ||\mathbf{v}||}
\end{align}
$$

$\tau$は温度パラメータである。上記の式に基づいてSimCLRの学習を行う。

参考

・SimCLR論文

Swin Transformer: 階層型Vision Transformer まとめ

Transformerの画像処理への応用にあたってはViT(Vision Transformer)などが有名である一方で、画像の局所特徴量の抽出の観点からは少々処理が非効率です。当記事では階層型のAttentionを用いることで改善を行なった研究であるSwin Transformerについて取りまとめを行いました。
「Swin Transformer: Hierarchical Vision Transformer using Shifted Windows」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

ViT

Swin Transformer

パッチの作成と連結

Swin Transformerでは$4 \times 4$のピクセル単位でパッチの作成を行います。基本的にはRGB値を持つカラー画像が入力であるので、それぞれのパッチが持つ要素数は$4 \times 4 \times 3 = 48$となります。

このように作成を行なったパッチを処理を行うにつれて隣接するパッチと連結を行い、Feature Mapの作成を行います。

この「パッチの連結によるFeature Mapの作成」はVGGNetやResNetにおけるpoolingを用いたFeature Mapの作成と対応づけて抑えておくと良いです。

Swin Transformerの処理概要

Swin Transformer論文 Figure$\, 3, (a)$

Swin Transformerの処理の全体は上図のように表されます。まず$H \times W \times 3$の入力を$4 \times 4$単位でパッチの作成を行います。

$4 \times 4$単位でそれぞれ$48$の要素を持つパッチの作成を行うことで、$H \times W \times 3$の入力を$\displaystyle \frac{H}{4} \times \frac{W}{4} \times 48$の表現に変換することができます。この処理がPatch Partitionに対応します。

Swin Transformer論文 Figure$\, 3, (b)$

次にLinear Embedding(MLP処理)を通して$48$次元を$C$次元に変換したのちにSwin Transformer Blockを用いて処理を行います。Swin Transformer Blockの処理は上図のように表されますが、Swin Transformer Blockの詳細は次項以降で取り扱います。

ここまでのLinear Embedding・Swin Transformer Block・Patch Mergingの処理の組み合わせによって、Swin TransformerではFeature Mapの作成が行われます。

W-MSA

Swin Transformer論文 Figure$\, 3, (b)$

W-MSAはWindowを用いたMultihead Self Attention処理の略です。Windowはパッチの集合を表しており、$M \times M$の場合は$M^2$個のパッチをまとめて取り扱います。

全てのパッチ同士で計算を行うViTに対して、W-MSAではWindowの内部のパッチだけを元にAttention処理を行います。このような処理を行うことで計算量の削減が可能になります。

SW-MSA

W-MSAを用いることで計算量の削減が可能になる一方で、W-MSAではWindowの境界におけるパッチ間の相関を取り扱うことができないという課題があります。この解決にあたってW-MSAと併用されるのがSW-MSAです。

W-MSAのみを用いる場合、上図の左のように毎回Windowの境界のパッチが同じパッチになります。このような問題の解決にあたってSwin Transformerでは、上図の右のようにWindowをずらした上で(Shifted Window)Attention処理を行うSW-MSAという処理が用いられます1

Swin Transformer論文 Figure$\, 3, (b)$

このようにWindowをずらすことでWindowの境界のパッチが毎回同じ位置にならないようにすることができます。

cyclic shifting

参考

・Transformer論文:Attention is All you need[$2017$]
・Swin Transformer論文

  1. Swin TransformerのSwinがShifted Windowの略であることも合わせて抑えておくと良いです。 ↩︎

行列式と置換⑤:互換と巡回置換(cyclic permutation)

線形代数の枠組みで$n$次正方行列の行列式(determinant)を取り扱うにあたっては置換(permutation)という概念を抑えておく必要があります。当記事では互換(transpositions)と巡回置換(cyclic permutation)について取りまとめを行いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$4$章「行列式」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

互換と巡回置換

互換

順列$1, 2, \cdots , n$について$i < j, \, i, j \in {1, 2, \cdots , n }$である$i$と$j$のみを入れ替え、他をそのままにする置換を互換という。$i$と$j$のみを入れ替える置換を$\sigma$とおくと、$\sigma$は下記のように表せる。
$$
\large
\begin{align}
\sigma(i) &= j \\
\sigma(j) &= i \\
\sigma(k) &= k, \quad k \neq i \cap k \neq j
\end{align}
$$

$i$と$j$を入れ替える互換$\sigma$は$\sigma=(i \quad j)$のようにも表記される。

巡回置換

順列$1, 2, \cdots , n$から$k$個の数字を抜き出し、それぞれ$i_1 < i_2 < \cdots < i_{k}$のように表す。このとき下記のような置換$\sigma$を巡回置換という。
$$
\large
\sigma :
\begin{cases}
i_1 & \longmapsto i_2 \\
i_2 & \longmapsto i_3 \\
\cdots \\
i_{k-1} & \longmapsto i_{k} \\
i_{k} & \longmapsto i_{1}
\end{cases}
$$

巡回置換$\sigma$は$\sigma=(i_1 \quad i_2 \quad \cdots \cdots \quad i_{k})$のようにも表記される。

例題の確認

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$057$

$$
\large
\begin{align}
(i_1 \quad i_2 \quad \cdots \cdots \quad i_{k}) = (i_1 \quad i_k)(i_1 \quad i_{k-1}) \cdots \cdots (i_1 \quad i_{2}) \quad (1)
\end{align}
$$

右辺の変換によって、$i_1 \, i_2 \, i_3 \, \cdots \cdots \, i_{k-1} \, i_k$は下記のように変換される。
$$
\large
\begin{align}
& i_1 \, i_2 \, i_3 \, \cdots \cdots \, i_{k-1} \, i_k \\
\longrightarrow & i_2 \, i_1 \, i_3 \, \cdots \cdots \, i_{k-1} \, i_k \\
\longrightarrow & i_2 \, i_3 \, i_1 \, \cdots \cdots \, i_{k-1} \, i_k \\
& \cdots \\
\longrightarrow & i_2 \, i_3 \, i_4 \, \cdots \cdots \, i_1 \, i_k \\
\longrightarrow & i_2 \, i_3 \, i_4 \, \cdots \cdots \, i_k \, i_1
\end{align}
$$

上記より$(1)$式が成立することが確認できる。