ブログ

【CLIP論文まとめ】対照学習を用いたmulti-modal embeddingの取得

SimCLRのような対照学習(Contrastive Learning)の枠組みを用いることで画像のベクトル表現の抽出が可能です。当記事では画像とテキストのベクトル表現の抽出をmulti-modalに行った研究であるCLIPについて取りまとめを作成しました。
CLIPの論文である「Learning Transferable Visual Models From Natural Language Supervision」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

SimCLR

SimCLRについては下記で詳しく取りまとめました。

CLIP

CLIPの処理の概要と対照学習

CLIP(Contrastive Language-Image Pre-training)論文の処理の全体像は下図を確認すると理解しやすいです。

CLIP論文 Figure$\, 1$

上図の$(1)$は対照学習(Contrastive Learning)に基づく学習プロセス、$(2)$と$(3)$は推論プロセスにそれぞれ対応します。

対照学習を用いた学習プロセスでは$N$個の画像とテキストの組を用意し、$N \times N$の対応について類似度計算を行うようにします。このとき、それぞれの行について対角成分のみを正解と見なし学習を行います。基本的な学習の流れはCLIP論文のFigure$\, 3$の擬似コードを確認すると良いです。

CLIP論文 Figure$\, 3$

上記の計算はSimCLRの学習におけるlossの計算と基本的には同様な計算です。このような学習を行うことで、画像とテキストに関するマルチモーダルなベクトル表現の獲得が可能になります。

このような学習によって得られるベクトル表現の活用である$(2)$と$(3)$については次項で確認します。

CLIPの学習結果を用いたzero-shot分類

CLIPは画像とテキストが同じペアであるかの判定を行うことによって、zero-shot分類を行うことが可能です。

CLIP論文 Figure$\, 1$の右半分

たとえば上図のようにCLIPで学習を行ったText EncoderとImage Encoderを用いてEmbedding(ベクトル表現)をそれぞれ計算し、コサイン類似度を計算することで分類を行うことができます。$(2)$のText Encoderの使い方については「A photo of a plane.」の入力に対応するText Encoderの出力が$T_1$、「A photo of a car.」の入力に対応する出力が$T_2$などのようにテキストのベクトル表現$T_i$を生成し、$I_1$との類似度を計算すると理解すると良いです。

図の計算結果では「A photo of a car.」に対応する出力である$T_3$とImage Encoderの出力の$I_1$の類似度が高いことから、zero-shot分類の結果は「A photo of a dog.」が出力されます。

【3DETR論文まとめ】Transformerを用いた$3$D Object Detection

$3$D Object Detectionは点群の$3$D空間上の点に対してバウンディングボックス(bounding box)とそのクラスを予測するタスクです。当記事では$3$D Object DetectionにTransformerを導入した研究である$3$DETRについて取りまとめました。
DETRの論文である「End-to-End Object Detection with Transformers」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

GIoU

DETR

3DETR

処理の概要

$3$DETR(DEtection TRansformer)の処理の概要は下図より確認できます。

$3$DETR論文 Figure$\, 2 \, (\mathrm{left})$

基本的にはDETRのCNN backbone処理をPointNet++に基づくMLP処理に置き換えたと理解して良いと思います。PointNet++やPointNet++で用いられるFPS(Farthest Point Sampling)については下記で詳しく取り扱いました。

query embeddings

Faster-RCNNやSSDなどのCNNを用いたObject Detectionのネットワークではバウンディングボックスの予測にあたって、事前に設定したAnchor boxとネットワークの出力のオフセット項を元にバウンディングボックスの回帰を行います。

一方、Transformerを用いてObject Detectionを行ったDETRではTransformerのPositional Encodingを学習パラメータに設定し、Positional EncodingをAnchor boxの代わりに用います。

このPositional Encodingをquery embeddingsといい、$3$DETRでも同様な処理が用いられます。

Decoderのアウトプット

Location

Locationはバウンディングボックスの中心の位置の$(X,Y,Z)$値に対応し、$3$DETR論文では$\mathbf{c}$で表されます。query embeddingsに基づく$\mathbf{q}$とネットワークからの出力の補正項の$\Delta \mathbf{q}$を用いて$\mathbf{c}$の予測は$\hat{\mathbf{c}} = \mathbf{q} + \Delta \mathbf{q} \in \mathbb{R}^{3}$のように定義されます。

ここでground truthの「$\cdot$」に対して「$\hat{\cdot}$」は予測を表すにあたって用いられることも注意しておくと良いです。

Size

Locationはバウンディングボックスのサイズに対応し、$X,Y,Z$の三方向についてスカラー値を持ちます。Locationと同様にquery embeddingとネットワークの出力の補正項を元に$\hat{\mathbf{d}} \in \mathbb{R}^{3}$が定義されます。

Orientation

室内を取り扱った点群であるSUN RGB-Dなどでは重力方向に変換したものが用いられるので、角度の自由度は$1$つだけ取り扱えれば十分です。$3$DETRではVoteNetなどと同様に、$[0, 2 \pi)$の間から$12$個のクラスと補正項のresidualの予測を行います。

$12$個のカテゴリを$\mathbf{a}_{c}$、residualを$\mathbf{a}_{r}$と表し、角度に関する出力$\hat{\mathbf{a}}$を下記のように定義します。
$$
\large
\begin{align}
\hat{\mathbf{a}} = [\hat{\mathbf{a}}_{c}, \hat{\mathbf{a}}_{r}]
\end{align}
$$

Semantic Class

Semantic Classはバウンディングボックス内のクラス分類の結果に対応します。ground truthの$\mathbf{s}$が$K$カテゴリ+背景の$1-$hotベクトル、ネットワークの出力に基づく予測が$\hat{\mathbf{s}} \in [0, 1]^{K+1}$のベクトルでそれぞれ定義されます。

3DETRのloss

$3$DETRのlossの$\mathcal{L}_{3 \mathrm{DETR}}$は下記のように定義されます。
$$
\large
\begin{align}
\mathcal{L}_{3 \mathrm{DETR}} = \lambda_{c}|| \hat{\mathbf{c}}-\mathbf{c}||_{1} &+ \lambda_{d}|| \hat{\mathbf{d}}-\mathbf{d}||_{1} + \lambda_{ar}|| \hat{\mathbf{a}}_{r}-\mathbf{a}_{r}||_{\mathrm{huber}} \\
&- \lambda_{ac} \mathbf{a}_{c}^{\mathrm{T}} \log{\hat{\mathbf{a}}_{c}} \, – \, \lambda_{s} \mathbf{s}^{\mathrm{T}} \log{\hat{\mathbf{s}}}
\end{align}
$$

上記の$||\cdot||_{1}$は$L1$ノルム、$||\cdot||_{\mathrm{huber}}$はHuber lossと同様の計算に基づくノルムをそれぞれ表します。また、$\mathbf{a}_{c}$や$\mathbf{s}$はそれぞれベクトルであることから、上記のようにクロスエントロピーによる分類のlossが計算されます。

Bipartite Matching

前項で取り扱った$3$DETRのlossの定義にあたっては、「どの予測とどのバウンディングボックスを対応させるか」という前段階の処理が必要です。このBipartite Matchingの解決にあたってはHungarian algorithmが用いられます。

参考

・Transformer論文
・DETR論文
・$3$DETR論文

GIoU(Generalized IoU)の数式と指標の解釈

Object Detectionタスクなどにおけるバウンディングボックスの予測にあたっては予測結果とground truthとの当てはまりの指標が必要でこの際にIoU(Intersection over Union)が一般的に用いられます。当記事ではIoUを改良した指標であるGIoUについてまとめました。
GIoUの論文である「Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

IoUの概要と数式

IoU(Intersection over Union)は「二つの図形がどのくらい類似しているかの指標」であり、Object Detectionタスクにおけるバウンディングボックスの予測結果の評価などにあたってよく用いられます。

領域$X$と領域$Y$のIoUを$\mathrm{IoU}(X,Y)$とおくと、$\mathrm{IoU}(X,Y)$は$X$と$Y$の和集合$X \cup Y$と$X$と$Y$の積集合$X \cap Y$を用いて下記のような式で表されます。
$$
\large
\begin{align}
\mathrm{IoU}(X,Y) = \frac{X \cap Y}{X \cup Y}
\end{align}
$$

たとえば$X$と$Y$が一致している場合$X \cup Y = X \cap Y$であるので$\mathrm{IoU}(X,Y) = 1$、$X$と$Y$が共通部分を持たない場合は$X \cap Y = 0$であるので$\mathrm{IoU}(X,Y) = 0$のように計算できます。

IoUの課題

「二つの図形の類似度を計算を行う」際にIoUは有力な手法である一方で、「$X \cap Y = 0$の際に$X$と$Y$がどのくらい離れているかの指標にはならない」という点で課題があります。

たとえばObject Detectionタスクでは多くの予測がされるので、ground truthと予測のバウンディングボックスが共通部分を持たない可能性もあり、このような場合にIoUをそのまま用いるとどのくらい離れているかに基づいて学習を行うことができません。

GIoU(Generalized Intersection over Union)はこの課題の解決にあたって導入される指標です。GIoUについて詳しくは次節で確認します。

GIoU

GIoUの定義式

集合$X$と集合$Y$を全て含む最小の集合を$C$とおくとき、$X$と$Y$のGIoUは下記のように計算することができます。
$$
\large
\begin{align}
\mathrm{IoU}(X,Y) &= \frac{X \cap Y}{X \cup Y} \\
\mathrm{GIoU}(X,Y) &= \mathrm{IoU}(X,Y) \, – \, \frac{|C-(X \cup Y)|}{|C|} \quad (1)
\end{align}
$$

上記の式は、$|C|$が$|X \cup Y|$に対して大きくなればなるほど$\mathrm{GIoU}(X,Y)$は小さくなると理解すると良いです。また、$\mathrm{GIoU}(X,Y)$の最小値は$\mathrm{IoU}(X,Y)=0$、$\frac{|X \cup Y|}{|C|} \to 0$のとき漸近的に下記のように得られます。
$$
\large
\begin{align}
\lim_{\frac{|X \cup Y|}{|C|} \to 0} \mathrm{GIoU}(X,Y) = -1
\end{align}
$$

GIoUの論文では$(1)$式が下記のように$\mathrm{Algorithm} \, 1$で記載されています。

GIoU論文 $\mathrm{Algorithm} \, 1$

バウンディングボックス回帰におけるGIoUの計算

バウンディングボックス回帰におけるGIoUの計算はGIoU論文の$\mathrm{Algorithm} \, 2$で取り扱われています。

GIoU論文 $\mathrm{Algorithm} \, 2$

上記は一見複雑に見えるかもしれませんが、predictionとground truthの積集合(Intersection)を$\mathcal{I}$、和集合(Union)を$\mathcal{U}$、predictionとground truthを含む最小の領域を$A^{c}$で表すことから逆に計算を辿ると理解しやすいと思います。

【SegFormer】Transformerを用いたシンプルかつ効率的なセグメンテーション

局所的な特徴量の抽出に適したCNNに対して、大域的な特徴量の抽出に適したTransformerはViT以降、多くのComputer Visionのタスクに用いられます。当記事ではTransformerを用いてシンプルかつ効率的なセグメンテーションを実現したSegFormerについて取りまとめました。
SegFormerの論文である「SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

ViT

SegFormer

処理の概要

SegFormerの処理の全体は下図を元に掴むと良いです。

SegFormer論文 Figure$\, 2$

Hierarchical Feature Representation

$H \times W \times 3$の入力に対し、ステージ$i$におけるFeature mapを$F_{i}$とおくと、$F_{i}$のサイズは下記のように定義されます。
$$
\large
\begin{align}
F_{i} & \in \mathbb{R}^{\frac{H}{2^{i+1}} \times \frac{W}{2^{i+1}} \times C_i} \quad (1) \\
i &= \{ 1, 2, 3, 4 \}
\end{align}
$$

$i=1$のとき$2^{i+1}=2^{2}=4$、$i=2$のとき$2^{i+1}=2^{3}=8$なので、$(1)$式はSegFormer論文のFigure$\, 2$と対応することが確認できます。

また、$i$が大きくなるにつれて$C_i$は大きくなるので$C_{i} < C_{i+1}$が成立します。ここでの処理はVGGNetやResNetのbackboneネットワークによるFeature mapの作成と同様なものであると理解しておくと良いです。

Overlapped Patch Merging

パッチの特徴量の作成にあたっては、「パッチに含まれるピクセルや特徴量の値をそのまま用いる」というのが基本的である一方で、このようにパッチ特徴量の作成を行うと「パッチの境界における相関」をうまく取り扱うことができません1

この解決にあたってSegFormerではOverlapped Patch Mergingが導入されます。Overlapped Patch Mergingはパッチ作成時にフィルタの大きなCNNを用い、パッチ作成にあたって用いる領域を重複させる手法です。

このようにOverlapped Patch Mergingを用いることでパッチ間の境界領域の相関も特徴量抽出にうまく反映させることができ、パフォーマンスの向上に役立ちます。

Efficient Self-Attention

Efficient Self-AttentionではPyramid Vision Transformer(PVT)のSpatial Reduction Attention(SRA)と同様の処理を行うことで、ViTのボトルネックである計算量の改良を実現します。

Fix-FFN

Mix-FFNでは下記のような式に基づいてFeed Forward Network(FFN)処理が行われます。
$$
\large
\begin{align}
X_{out} &= \mathrm{MLP}(\mathrm{GELU}(\mathrm{Conv}_{3 \times 3}(\mathrm{MLP}(X_{in})))) + X_{in} \\
\mathrm{GELU}(x) &= x \Phi(x) \\
\Phi(x) &= \int_{-\infty}^{x} \frac{1}{\sqrt{2 \pi}} \exp{ \left[ -\frac{t^{2}}{2} \right] } dt
\end{align}
$$

Mix-FFNでは上記の式に基づいて、$3 \times 3$の「Depth-wise Convolution」処理が途中で実行されます。また、Mix-FFNに畳み込み処理を導入することで位置情報を反映できるので、セグメンテーションタスクではPositional Encodingは必ずしも必要ではないとSegFormer論文に記載があります(We argue that positional encoding is actually not necessary for semantic segmentation.2)。

参考

・Transformer論文:Attention is All you need[$2017$]
・SegFormer論文

  1. もちろん「パッチの境界における相関」を取り扱うにあたってViTの$16 \times 16$サイズのパッチを階層型のViTでは$4 \times 4$まで抑えることで改良させることは可能ですが、あくまで改良に過ぎず直接的な解決にはならないことに注意が必要です。 ↩︎
  2. SegFormer論文 $3.1$ Hierarchical Transformer Encoder Mix-FFN.の一文 ↩︎

【DETR論文まとめ】Transformerを用いたObject Detection

Object Detectionタスクには従来VGGNetやResNetなどのCNNをbackboneに持つネットワークを用いることが主流であった一方で、近年Transformerの導入も行われています。当記事ではObject DetectionにTransformerを導入した研究であるDETRについて取りまとめました。
DETRの論文である「End-to-End Object Detection with Transformers」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

指示関数

指示関数(indicator function)の$\mathbb{1}_{[k \neq i]} \in \{ 0, 1 \}$は下記のように定義されます。
$$
\large
\mathbb{1}{[k \neq i]} =
\begin{cases}
1 \quad \mathrm{if} \quad k \neq i \\
0 \quad \mathrm{otherwise}
\end{cases}
$$

Hungarian algorithm

・Hungarian algorithm(Wikipedia)

DETR

処理の概要

DETR(DEtection TRansformer)の処理の全体は下図より確認できます。

DETR論文 Figure$\, 2$

図より、DETRは主に下記の三つのプロセスによって構成されることが確認できます。

・backboneネットワーク
・Transformer
・FFN

以下、それぞれのプロセスの詳細やDETRのlossについて確認を行います。まとめの都合上、「backbone+Transformerのencoder」と「Transformerのdecoder+FFN」に分けて確認を行いました。

backboneとTransformer encoder

DETRではbackboneネットワークにCNNを用います。
$$
\large
\begin{align}
X_{\mathrm{img}} \in \mathbb{R}^{H \times W \times 3}
\end{align}
$$

上記の入力に対しCNNを作用させ、下記のようなFeature mapを取得します。
$$
\large
\begin{align}
f \in \mathbb{R}^{H \times W \times C}
\end{align}
$$

上記のFeature mapに対し、$1 \times 1$のConvolutionを行いReshapeを行った結果の$Z_{0} \in \mathbb{R}^{HW \times d}$をTransformer encoderに入力します。ここで$d$は$d < C$になるように値の設定を行います。

DETRのTransformer encoderはMultiHead AttentionやFFNによって構成されるオーソドックスなTransformerの処理を基本的にそのまま用います。

Transformer decoder: object queriesとbounding boxの生成

Transformer decoderでは機械翻訳では翻訳文が入力されるTransformerのクエリにobject queriesという入力を用います。

このobject queryに基づいてObject Detectionにおけるバウンディングボックスの生成やクラスの予測が行われることから、object queryはFaster-RCNNなどにおけるAnchor boxと同様な役割であると理解すると良いと思います。

ここでobject queryは学習可能なパラメータであり、Faster-RCNN・SSDなど多くのObject Detectionで用いられるAnchor boxをDETRでは自動学習させることができると解釈することもできます。

DETRのloss

$y = \{ y_1, \cdots y_{n} \}, \, y_{n+1} = y_{n+2} \cdots = y_{N-1} = y_{N} = \varnothing$をObject集合の正解、$\hat{y} = \{ \hat{y}_{i} \}_{i=1}^{N} = \{ \hat{y}_{1}, \cdots , \hat{y}_{N} \}$をObjectの予測の集合とおくとき、$y$と$\hat{y}$の最もコストの低い対応を表すインデックスの順列を表す$\hat{\sigma} \in \mathfrak{S}_{N}$は下記のように定義されます。
$$
\large
\begin{align}
\hat{\sigma} &= \mathrm{arg}\min_{\sigma \in \mathfrak{S}_{N}} \sum_{i=1}^{N} \mathcal{L}_{\mathrm{match}}(y_{i}, \hat{y}_{\sigma(i)}) \quad (1) \\
\mathcal{L}_{\mathrm{match}}(y_{i}, \hat{y}_{\sigma(i)}) &= \mathbb{1}_{[c_i \neq \varnothing]} \hat{p}_{\sigma(i)} + \mathbb{1}_{[c_i \neq \varnothing]} \mathcal{L}_{\mathrm{box}}(b_i, \hat{b}_{\sigma(i)})
\end{align}
$$

ここで上記の式における$\mathfrak{S}_{N}$はインデックスの順列の全パターンの集合、$N$は予測の数、$n$は正解のObjectの数、$\mathbb{1}$は指示関数、$\varnothing$は$y_{n+1}$〜$y_{N}$のObjectが存在しないことをそれぞれ表します。

$(1)$式の$\hat{\sigma}$は組み合わせ最適化問題であり、Hungarian algorithmを用いることで得ることができます。ここで得た$\hat{\sigma}$を元にHungarian lossの$\mathcal{L}_{\mathrm{Hungarian}}(y,\hat{y})$は下記のように定義されます。
$$
\large
\begin{align}
\mathcal{L}_{\mathrm{Hungarian}}(y,\hat{y}) &= \sum_{i=1}^{N} \left[ -\log{\hat{p}_{\hat{\sigma}(i)}(c_i)} + \mathbb{1}_{[c_i \neq \varnothing]} \mathcal{L}_{\mathrm{box}}(b_i, \hat{b}_{\sigma(i)}) \right] \quad (2) \\
\mathcal{L}_{\mathrm{box}}(b_i, \hat{b}_{\sigma(i)}) &= \lambda_{\mathrm{IoU}} \mathcal{L}_{\mathrm{IoU}}(b_i, \hat{b}_{\sigma(i)}) + \lambda_{L1}|| b_i-\hat{b}_{\sigma(i)} ||_{1} \\
\lambda_{\mathrm{IoU}}, \lambda_{L1} & \in \mathbb{R}
\end{align}
$$

上記の$c_i, b_i$は$y_i=(c_i,b_i)$によって定義される、Objectの予測のクラスとバウンディングボックスにそれぞれ対応します1。また、$|| \cdot ||_{1}$は$L1$ノルムを表します。

$(2)$式は、「Hungarian algorithmによって得られた$\hat{\sigma}$について、クラス分類に関するCross EntropyとBounding Boxのズレに基づいてlossを定義する」のように解釈することができます。

参考

・Transformer論文:Attention is All you need[$2017$]
・DETR論文

  1. クラス間不均衡(Class imbalance)に対処するにあたって、実際には$c_i=\varnothing$の場合に$-\log{\hat{p}_{\hat{\sigma}(i)}(c_i)}$をdown-weightします。 ↩︎

【Pointformer論文まとめ】点群の処理へのTransformerの導入

PointNet++を用いた点群の処理はPointNetに階層型のプーリングを導入することで改良にはなったものの、局所領域における点間の相関を取り扱えないなどの課題があります。当記事ではこの課題の解決にあたって点群の処理にTransformerを導入したPointformerについて取りまとめました。
Pointformerの論文である「$3$D Object Detection with Pointformer」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

PointNet

PointNet++

Pointformer

処理の概要

Pointformerの処理の全体像は下図を元に掴むと良いです。

Pointformer論文 Figure$\, 2$

上図から確認できるように、Pointformerは「Pointformer Block」に基づいて基本的に構成されています。

また、同時に「Pointformer Block」が「Local Transformer」、「Local-Global Transformer」、「Global Transformer」の三つの処理から構成されていることも確認することができます。以下、三つの処理について詳しく確認を行います。

Local Transformer

入力される点群の集合を$\mathcal{P} = \{ x_1, \cdots , x_{N} \}$とおくとき、Local TransformerではPointNet++と同様にFPS(Furthest Point Sampling)を用いて下記のような中心点の集合を取得します。
$$
\large
\begin{align}
\{ x_{c_1}, \cdots , x_{c_{N’}} \}, \quad N’ < N
\end{align}
$$

このとき、上記の中心点の周辺の半径に基づいたball queryを用いることで$j$番目の中心点に対して$K_t(t=1, \cdots , N’)$個の点を獲得し、Transformerへ入力を行います。

Pointformer論文 Figure$\, 3$

FPSは上図における左側の青丸から赤丸を取得する処理、ball queryを用いた処理は赤丸の中心に基づいて円を描き、円の内部の点をグルーピングする処理にそれぞれ対応します。

$t$番目の中心点の周囲の局所領域における$i$番目の点について位置情報と特徴量を$\{ x_{i}, f_{i} \}_{t}, \, x_{i} \in \mathbb{R}^{1 \times 3}, \, f_{i} \in \mathbb{R}^{1 \times C}$のように表すとき、$t$番目の領域におけるTransformer処理は下記のように表すことができます。
$$
\large
\begin{align}
f_{i}^{(0)} &= \mathrm{FFN}(f_{i}), \quad {}^{\forall} i \in N(x_{c_{t}}) \\
F^{(l+1)} &= \mathrm{Transformer Block}(F^{(l)}, \mathrm{PE}(X)) \\
F^{(l)} &= \left( \begin{array}{c} f_{1}^{(l)} \\ \vdots \\ f_{K_t}^{(l)} \end{array} \right) , \quad X = \left( \begin{array}{c} x_{1} \\ \vdots \\ x_{K_t} \end{array} \right)
\end{align}
$$

ここで$N(x_{c_{t}})$の$N$は中心点$x_{c_{t}}$と同じ領域にグルーピングされた点のインデックスの集合を表します1。$\mathrm{Transformer Block}$は一般的なTransformerのそれぞれの層における計算に対応します。また、論文では$F^{(l)},X$がやや唐突に出てきたので上記の$3$式目で具体化しました。

ここまでがLocal Transformerの基本処理です。一方で、Local Transformerには基本処理に加えてFPSによってサンプリングされた中心点の位置の補正を行うCoordinate Refinementも合わせて用いられます。以下、Coordinate Refinementについて詳しく確認します。

Coordinate RefinementはLocal Transformerで計算を行ったAttention Matrix2に基づいて中心点の位置の調整を行う手法です。Attention Matrixの$m$番目のheadのAttention Matrixを$A^{(m)}$、$t$番目の中心点に対応する行を$A_{c}^{(m)} \in \mathbb{R}^{1 \times K_t}$とおくとき、行ベクトル$\mathbf{w} \in \mathbb{R}^{1 \times K_t}$を下記のように定義します。
$$
\large
\begin{align}
\mathbf{w} = \frac{1}{M} \sum_{m=1}^{M} A_{c}^{(m)}
\end{align}
$$

上記の$M$はMultiHead Attentionにおけるheadの数を表します。このとき、行ベクトル$\mathbf{w}$の$k$番目の要素を$w_{k}$のように表すとき、中心点$x_{c_{t}}$を下記のような式に基づいて得られる$x’_{c_{t}}$に移動させます3
$$
\large
\begin{align}
x_{c_{t}}’ &= \sum_{k=1}^{K_{t}} w_{k} x_{k} \\
w_{k} & \in \mathbb{R}, \, x_{c_{t}}’ \in \mathbb{R}^{1 \times 3}, \, x_{k} \in \mathbb{R}^{1 \times 3}
\end{align}
$$

Local-Global Transformer

Pointformer論文 Figure$\, 2$の一部

Local-Transformerでは中心点に基づく局所領域を元に点群のダウンサンプリングを行います。この処理では局所領域におけるattentionしか計算しないので、ダウンサンプリング後のそれぞれの点に対して大域的な情報を反映させられることが望ましいです。

Local-Global Transformerではダウンサンプリング後の各点に対して大域的な情報を反映させるにあたって、Local Transformerの出力をquery、Pointformer blockへの入力をkey・valueとするcross-attention処理を行います。Pointformer blockへの入力は高解像度なので点のインデックスの集合を$\mathcal{P}^{h}$、Local Transformerの出力低解像度なので点のインデックスの集合を$\mathcal{P}^{l}$とおくと、Local-Global Transformerの演算を下記のように表すことができます。
$$
\large
\begin{align}
f_{i}^{(0)} &= \mathrm{FFN}(f_{i}), \quad {}^{\forall} i \in \mathcal{P}^{l} \\
f_{j}’ &= \mathrm{FFN}(f_{j}), \quad {}^{\forall} j \in \mathcal{P}^{h} \\
F^{(l+1)} &= \mathrm{Transformer Block}(F^{(l)}, F’, \mathrm{PE}(X)) \\
F^{(l)} &= \left( \begin{array}{c} f_{1}^{(l)} \\ \vdots \\ f_{K_t}^{(l)} \end{array} \right), \quad X = \left( \begin{array}{c} x_{1} \\ \vdots \\ x_{K_t} \end{array} \right) \\
F’ &= \left( \begin{array}{c} f_{1}’ \\ \vdots \\ f_{N}’ \end{array} \right)
\end{align}
$$

上記の式の$\mathrm{Transformer Block}$はself-attentionではなくcross-attentionの式であることに注意が必要です。

Global Transformer

Pointformer論文 Figure$\, 2$の一部

Global TransformerではLocal-Global Transformerの出力の全ての点のインデックスの集合$\mathcal{P}$についてTransformer処理を行います。
$$
\large
\begin{align}
f_{i}^{(0)} &= \mathrm{FFN}(f_{i}), \quad {}^{\forall} i \in \mathcal{P} \\
F^{(l+1)} &= \mathrm{Transformer Block}(F^{(l)}, \mathrm{PE}(X))
\end{align}
$$

Global Transformerの式表記はLocal Transformerの$N(x_{c_{t}})$を$\mathcal{P}$に変えることで表されます。

Pointformerの構成

indoorデータセットとKITTIに対するPointformerの構成は下記の表より確認することができます。

indoorタスク用のPointformer:Pointformer論文 Table$\, 9$
KITTIのタスク用のPointformer:Pointformer論文 Table$\, 10$

参考

・Transformer論文:Attention is All you need[$2017$]
・Pointformer論文

  1. グラフ理論ではノード$v$に隣接するノードの集合を$N(v)$のように表しますが、同様な表現と見なして良いです。論文では$N$ではなく$\mathcal{N}$が用いられていますが、$\mathcal{N}$は正規分布に用いられることが多いので、グラフ理論の表記を参考に当記事では$N$を用いました。 ↩︎
  2. ここではSoftmax関数などで行方向に正規化されたAttention Matrixを用います。 ↩︎
  3. $x’{c{t}}$の特徴量の計算についての記載が見当たらなかったので要確認。$x’{c{t}}$と同じく加重平均で計算すること自体はおそらく可能。 ↩︎

【PoinNet++論文まとめ】階層化グルーピングによるPointNetの改良

点群にDeepLearningを導入したPointNetは有力な手法である一方で、max poolingを一度しか行わないことで局所的な構造をなかなか抽出できないという課題があります。当記事ではこの解決にあたって階層化グルーピングを用いた研究であるPointNet++について取りまとめを行いました。
PointNet++の論文である「PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

PointNet

PointNet++

処理の概要

PointNet++の処理の全体像は下図を元に掴むと良いです。

PointNet++論文 Figure$\, 2$

上図から確認できるように、PointNet++はVGGNetやResNetにおけるbackbone処理とタスク特化処理の二つによって大まかに構成されます。

backbone処理ではSet Abstraction、SegmentationタスクではFeature Propagationをそれぞれ理解すると良いです。Set Abstractionは次項、Feature Propagationについては次々項でそれぞれ詳しく取り扱いました。

Set Abstraction

PointNet++論文 Figure$\, 2$の一部

上図から確認できるように、Set Abstractionは以下の$3$つのkey layerで構成されます。

・Sampling layer
・Grouping layer
・PointNet layer

上記の$3$つのlayerは、「Sampling layer」で局所領域における中心点(centroid)を決め、「Grouping layer」で局所領域における中心点の近傍にある点をグルーピングし、「PointNet layer」でグルーピングした点の集合から特徴量の抽出を行う、とそれぞれ解釈すると良いです。PointNet論文ではこの$3$つのkey layerの処理をまとめてset abstractionと表しており、Figure$\, 2$でもset abstractionの記載が確認できます。

以下、「Sampling layer」、「Grouping layer」、「PointNet layer」のそれぞれについて詳しく確認を行います。

Sampling layer

Sampling layerでは$N$個の点集合$\{ x_1, \cdots , x_{N} \}$に対しFPS(farthest point sampling)を行うことで、$N’$個の点集合$\{ x_{i_{1}}, \cdots , x_{i_{N’}} \}$をの抽出を行います。

ここで$i_{j}$は順序$i$1の$j$番目に対応するインデックスを表します2。ここでFPSは$x_{i_{j}}$の選定にあたって、$j-1$の点集合$\{ x_{i_{1}}, \cdots , x_{i_{j-1}} \}$から最も遠い点を選択する手法であると理解すると良いです。

中心点(centroid)の選択にあたってはランダムサンプリングも可能ですが、定数$K$が固定されるときFPSを用いることでより全体の点集合をカバーすることができるなど有用です。FPSを用いたこのような処理はCNNにおけるreceptive fieldと対応させて解釈すると良いです。

Grouping layer

$N$個の点群のそれぞれの点の位置情報が$d$次元、特徴量が$C$次元で表されるとき、これらをまとめて$x_{i} \in \mathbb{R}^{1 \times (d+C)}$のように表せると仮定します。このとき、下記が成立します。
$$
\large
\begin{align}
\left( \begin{array}{c} x_1 \\ \vdots \\ x_N \end{array} \right) & \in \mathbb{R}^{N \times (d+C)} \\
\left( \begin{array}{c} x_{i_{1}} \\ \vdots \\ x_{i_{N’}} \end{array} \right) & \in \mathbb{R}^{N’ \times (d+C)}
\end{align}
$$

ここで$j$番目の中心点(centroid)である$x_{i_{j}}$から半径$R$以内の点のインデックスの集合を$\{ \sigma_{1}, \sigma_{2}, \cdots , \sigma_{K_{j}} \}$とおくと、このインデックスに基づく行列のサイズは下記のように表すことができます3
$$
\large
\begin{align}
\left( \begin{array}{c} x_{\sigma_{1}} \\ \vdots \\ x_{\sigma_{K_{j}}} \end{array} \right) & \in \mathbb{R}^{K_{j} \times (d+C)} \quad (1)
\end{align}
$$

Grouping layerでは$j$番目の中心点に対し、半径$R$の点を$(1)$式のように抽出を行います。ここで注意が必要なのが「半径$R$」を定義するには計量空間(metric space)が定義される必要があるということです。

計量空間(metric space)には「$\mathbb{R}^{n}$の標準内積」に基づいて定義される計量空間を用いることが一般的で、論文のSection$4.1$のEuclidean Metric Spaceに対応します。

一方でSection$4.3$のようにNon-Euclidean Metric Spaceを用いる場合もあります。

PointNet layer

PointNet layerではそれぞれのグループに対し、PointNetの処理が実行されます。PointNetに$(1)$式を入力した際に$x’_{j} \in \mathbb{R}^{1 \times (d+C’)}$が出力されます。$N’$個の全てのグループについての特徴量は下記のような行列で表されます。
$$
\large
\begin{align}
\left( \begin{array}{c} x’_1 \\ \vdots \\ x’_j \\ \vdots \\ x’_N \end{array} \right) & \in \mathbb{R}^{N’ \times (d+C’)}
\end{align}
$$

Grouping layerで各中心点ごとにサンプル数が可変長になったものをPointNet layerでmax poolingを行うことで固定長に戻すことができることは注意して抑えておくと良いと思います。

Feature Propagation

Feature PropagationはSegmentationタスクを解くにあたって用いられる処理の仕組みです。

PointNet++論文 Figure$\, 2$

上図から確認できるように、プーリングによってダウンサンプリングされた結果をアップサンプリングすることでセグメンテーションタスクの予測が行われます。アップサンプリング時の点の復元にあたっては、$N_{1} \geq N_{2} \geq \cdots N_{l-1} \geq N_{l} \geq \cdots$のようにダウンサンプリングした点を逆に辿って復元を行います。$N_{l}$から$N_{l-1}$の復元にあたっては位置情報を持つ$d$次元はダウンサンプリング時のものをそのまま保持しておき、特徴量の値は補間(interpolation)処理に基づいて取得します。

PointNet++で用いられる補間処理は下記で表した$k$近傍($k$-nearest neighbors)の加重平均の式で定義されます。
$$
\large
\begin{align}
f^{(j)}(x) &= \frac{\sum_{i=1}^{k} w_{i}(x) f_{i}^{(j)}}{\sum_{i=1}^{k} w_{i}(x)}, \, j=1, \cdots , C \quad (2) \\
w_{i}(x) &= \frac{1}{d(x, x_i)^{p}}
\end{align}
$$

上記の式は点$x$の$j$番目の特徴量を$^{(j)}(x)$とおくとき、Centroidの$x_i$との近さに基づいて特徴量を加重平均によって得ると解釈すれば良いです。ここでは論文の表記に合わせて$j$番目の要素について式定義した一方で、$C$次元のベクトルと見なして定義しても良いと思います。また、$d$は$x$と$x_i$のdistanceを意味しており、$p$は次元を表します。補間処理に用いる近傍点の数$k$と次元の$p$については$k=3$、$p=2$がデフォルトで用いられます。

次に、上記のように行う補間処理は一度ダウンサンプリングしたものを用いると高周波成分が消失する場合があり得るので、U-Netのように対応するSet Abstractionのレイヤーの特徴量の反映処理が行われます。この処理はskip link concatenationのように表されます。skip link concatenationについては基本的にはU-Netと同じ仕組みだと理解すると良いと思います。このような処理が行われることから、Feature Propagationの数はSet Abstractionの数と原則一致することも合わせて抑えておくと良いです。

各領域における点の密度の取り扱い

点群における各点は一様分布に基づいていないので、領域ごとに密度の偏りがあります。密(dense)な領域に基づいて学習した結果が疎(sparse)な領域にそのまま一般化することはできないので注意が必要です。

各領域の点の密度の取り扱いにおける「半径$R$を用いる」アプローチの改良にあたって、PointNet++では「MSG(Multi-scale grouping)」と「MRG(Multi-resolution grouping)」の二つの手法が紹介されています。以下それぞれについて詳しく確認します。

Multi-scale grouping

MSG(Multi-scale grouping)は同一の中心点(centroid)に対し、複数のスケールを用いてグルーピングを行う手法です。

PointNet++論文 Figure$\, 3 \, (\mathrm{a})$

上図には三つのグレースケールがありますが、濃淡と半径$R$の大きさが対応します。計算の詳細についてはPointNet++論文のAppendix.B.$1$のNetwork Architecturesを確認すると良いです。

PointNet++論文 Appendix.B.$1$ Network Architecturesより

上記には同じ$R$のみを用いるSSG(Single-Scale Grouping)と複数の$R$を用いるMSG(Multi-scale grouping)の$2$つのネットワーク構造について記載があります。$\mathrm{SA}$、$\mathrm{FC}$やそれぞれのハイパーパラメータについては下記にまとめました。

記号意味
$\mathrm{SA}(K, r, [l_{1}, \cdots , l_{d}])$SSGにおけるset abstractionの演算。$K$は「中心点/局所領域」の数、$r$はグルーピングの際の半径、$d$はPointNetの層の数、$l_i$は$i$番目の層における出力層の次元を表す。
$\mathrm{SA}([l_{1}, \cdots , l_{d}])$max poolingを伴うglobal set abstractionの演算。
$\mathrm{SA}(K, [r^{(1)}, \cdots , r^{(m)}], [[l_{1}^{(1)}, \cdots , l_{d}^{(1)}], \cdots , [l_{1}^{(m)}, \cdots , l_{d}^{(m)}]])$MSGにおけるset abstractionの演算。$m$は半径のスケールの数、$r^{(j)}$は$j$番目のスケールにおける半径、$l_{i}^{(j)}$は$j$番目のスケールの$i$番目の層における出力層の次元を表す。
$\mathrm{FC}(l, dp)$全結合(Fully Connected)層の演算、MLPと同義。$l$は出力層の次元、$dp$はDropoutの比率を表す。
$\mathrm{FP}(l_1, \cdots , l_{d})$feature propagation層

上記の表を元に「PointNet++論文 Appendix.B.$1$ Network Architectures」を確認すると、SSGによる分類では$K=512, r=0.2$で$64 \to 64 \to 128$のMLP演算が行われたのちに、$K=128, r=0.4$で$128 \to 128 \to 256$のMLP演算が行われることが読み取れます。さらにこの後Globalにmax poolingを実行し、$256 \to 256 \to 512$でMLP処理を行い、$512$次元のベクトルに対し全結合層の計算を何度か実行したのちに$K$クラス分類問題に用いる$K$次元のベクトルを出力することが読み取れます。

上記と同様にMSGの処理も読み取ることができます。

Multi-resolution grouping

PointNet++論文 Figure$\, 3 \, (\mathrm{b})$

参考

・PointNet++論文

  1. たとえば$N=5$、$i_{1}=2, i_{2}=5, i_{3}=4, i_{4}=1, i_{5}=3$のとき、$N’=3$であれば$\{ x_{i_{1}}, \cdots , x_{j_{N’}} \}$は$\{ x_2, x_5, x_4 \}$に対応します。 ↩︎
  2. 文字$i$はインデックスに用いられることが多いので紛らわしいですが、PointNet++論文で$i$が用いられているのでそのまま用いました。また、論文の$n, m$についてはGrouping layerと対応させるにあたって当記事では$N, N’$で置き換えました。 ↩︎
  3. PointNet++の論文では$N’$個の中心点に対し、$N’ \times K \times (d+C)$のように定義されますが、$K$が中心点のインデックスに対し定数ではなく可変であり紛らわしいので、当記事では$j$番目の中心点について式定義を行いました。 ↩︎

【ViT論文まとめ】Computer Vision分野へのTransformerの応用

Transformerは元々機械翻訳タスクに対して考案された一方で、大域的な特徴量を取り扱うことのできる強力なモジュールであることから様々なタスクに応用されます。当記事ではTransformerを画像処理に応用した初期の研究であるViT(Vision Transformer)について取りまとめました。
ViTの論文である「An image is worth 16×16 words: Transformers for image recognition at scale.」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

ViT

処理の概要

Vision Transformer(ViT)の処理の全体像は下図を元に掴むと良いです。

Vision Transformer論文 Figure$\, 1$

Vision Transformerでは入力画像を$16 \times 16$ピクセルのように一定サイズで区切りパッチ(patch)を作成し、それぞれのパッチをトークンと見なしてTransformerの処理を実行します。上図では左下の入力画像を$9$つに分割し、Transformer Encoderに入力されます。

この処理における「パッチのトークン化」「パッチへのPositional Embeddingの追加」については次項以降で詳しく確認します。

パッチのトークン化

ViTではまず入力画像を$X \in \mathbb{R}^{H \times W \times C}$をパッチ集合に対応する行列$X_{p} \in \mathbb{R}^{N \times (P^{2} \cdot C)}$へリサイズを行います。ここで$N$は下記のように計算されます。
$$
\large
\begin{align}
H \times W \times C &= N \times (P^{2} \cdot C) \\
N &= \frac{HW \cancel{C}}{P^{2} \cancel{C}} \\
&= \frac{HW}{P^{2}}
\end{align}
$$

上記の$\displaystyle N = \frac{HW}{P^{2}}$は「パッチの数はグレースケールにおける全ピクセル数($HW$)を各パッチのピクセル数$P^{2}$で割った数に一致する」と解釈するとわかりやすいと思います1

ここでリサイズによって得た$X_{p}$の$i$行目を$\mathbf{x}_{p}^{i} \in \mathbb{R}^{1 \times (P^{2} C)}$のようにおくと、$i$番目のパッチの$D$次元のEmbeddingは下記のように定義されます。
$$
\large
\begin{align}
\mathbf{x}_{p}^{i} E & \in \mathbb{R}^{1 \times D} \\
\mathbf{x}_{p}^{i} & \in \mathbb{R}^{1 \times (P^{2} C)}, \, E \in \mathbb{R}^{(P^{2} C) \times D}
\end{align}
$$

ここでBERTと同様にTransformer演算から全体の特徴量を抽出する用のパッチ$\mathbf{x}_{cls} \in \mathbb{R}^{1 \times D}$とPositional Encodeingの$E_{pos} \in \mathbb{R}^{(N+1) \times D}$を元に下記のように$\mathbf{z}_{0}$を定義します。
$$
\large
\begin{align}
\mathbf{z}_{0} = \left( \begin{array}{c} \mathbf{x}_{cls} \\ \mathbf{x}_{p}^{1} E \\ \mathbf{x}_{p}^{2} E \\ \vdots \\ \mathbf{x}_{p}^{N-1} E \\ \mathbf{x}_{p}^{N} E \end{array} \right) + E_{pos} \quad (1)
\end{align}
$$

上記はViT論文の$(1)$式に対応します。また、$\mathbf{x}_{cls}, \, E, \, E_{pos}$はどれも学習時に同時に学習されるパラメータです。ViTでは$(1)$式で定義される$\mathbf{z}_{0} \in \mathbb{R}^{(N+1) \times D}$に対し、Transformerの処理を適用し、特徴量の抽出を行います。

パッチのPositonal Embedding

前項で取り扱ったようにViTでは学習によってPositonal Embeddingの値を取得します。このようなPositonal Embeddingの取得はBERTやGPTなどでも用いられています。

ViT-Base/ViT-Large/ViT-Huge

ViT論文ではViT-Base/ViT-Large/ViT-Hugeの三つが主に用いられます。それぞれのハイパーパラメータは下記の表より確認できます。

Vision Transformer論文 Table$\, 1$

BERTと同様な規模感で把握しておくと良いと思います。パッチサイズは基本的に$P=16$が用いられます。

参考

・Transformer論文:Attention is All you need[$2017$]
・ViT論文

  1. RGBであれば$3HW$と$3P^{2}$、チャネル数$C$の場合は$HWC$と$P^{2} C$がそれぞれ対応します。解釈のわかりやすさを優先するにあたって、グレースケールを例に出しました。 ↩︎

【PointNet論文まとめ】DeepLearningを用いた点群の分類・セグメンテーション

PointNetは点群(point clouds)の分類(classification)や点単位のセグメンテーション(segmentation)にMLP(Multi Layer Perceptron)を導入した研究です。当記事ではPointNetの論文を元にPointNetの処理手順を取りまとめました。
PointNetの論文である「PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation」や「深層学習 第$2$版」第$7$章の「集合・グラフのためのネットワークと注意機構」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

点群とInvariance

点群は並べ替えに不変であるという特徴を持ちます。同変性を表すequivarenceと不変性を表すinvarianceについては下記で詳しく取り扱ったので当記事では省略します。

点群の取り扱いにあたっては並び替え操作に関する不変性を念頭に置いた上で処理を組む必要があり、PointNetも例外ではないことに注意しておくと良いです。

PointNet

PointNetの処理概要

PointNetの処理の全体像は下図を元に掴むことができます。

PointNet論文 Figure$\, 2$

基本的には入力された$n$個の点ごとにMLP(Multi Layer Perceptron)処理を行うことに注意して上図は確認を行うと良いです。図のMLP演算の所に「shared」と記載があるのはMLPの計算に用いるパラメータを共有したと解釈するのが適切だと思われます1

また、$n \times 1024$にmax poolingを施した結果、$1024$次元の特徴量が得られ、$512$次元$\longrightarrow$$256$次元$\longrightarrow$$k$次元とMLPによって計算され、$k$クラス分類問題に用いられる一連の流れを確認することもできます2

Max PoolingとInvariance

前節で取り扱ったように点群は並べ替えに対して不変である必要があるので、CNNにおけるダウンサンプリングのように近傍の点と段階的に統合することはできません。したがって、PointNetでは$n \times 1024$の特徴量を点方向にmax poolingすることで大域特徴量(Global Feature)の抽出を行います。

Local and Global Information Aggregation

分類(Classification)タスクを解くにあたってはmax poolingを行うことで得たGlobal Featureをそのまま用いれば良い一方で、セグメンテーション(Segmentation)タスクでは点単位での分類が必要になります。

したがってPointNetでは$n \times 64$の点ごとの特徴量にそれぞれ$1024$次元のGlobal Featureを連結する(concatenate)ことで$n \times 1088$とし、その後の処理を行います。MLPにより、$n \times 1088$は下記のように推移します。
$$
\large
\begin{align}
n \times 1088 & \longrightarrow n \times 512 \longrightarrow n \times 256 \\
& \longrightarrow n \times 128 \longrightarrow n \times 128 \longrightarrow n \times m
\end{align}
$$

上記に基づいて点ごとに$m$クラス分類を行うことで、セグメンテーションを実現することができます。

Joint Alignment Network

点群はrigid transformationのように特定のgeometric transformationsに対して不変(invariant)である必要があります。PointNetではこの対応にあたってアフィン変換を実現する小さなニューラルネットワークであるT-netが導入されます。

参考

・PointNet論文

  1. 論文内に該当の記載が見つからないので要出典。 ↩︎
  2. $(512,256,k)$のmlpの四角形の高さが紛らわしいですが、どれもglobal featureであるので単に$1024 \to 512 \to 256 \to k$のように推移すると理解すれば良いです。 ↩︎

【CvT論文まとめ】畳み込みとMultiHead Attentionの対応

ViTなどのComputer Vision分野へのTransformerの導入は強力なアプローチである一方で、Transformerをそのまま用いる場合は局所相関を生かせないなどの課題があります。当記事ではViTに畳み込みを導入した手法であるCvTを題材に畳み込みとMultiHead Attentionの対応について取りまとめました。
CvTの論文である「CvT: Introducing Convolutions to Vision Transformers」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformerの概要

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

ViT

Convolutional vision Transformer

処理の概要

Convolutional vision Transformer(CvT)の処理の全体像は下図を元に掴むと良いです。

Convolutional Vision Transformer論文 Figure$\, 2$

CvTは主に「Convolutional Token Embedding」と「Convolutional Transformer Block」の二つによって構成され、どちらの処理にも畳み込みが用いられます。

「Convolutional Transformer Block」は「Convolutional Projection」という手法を用いてTransformerモジュールの入力であるQuery、Key、Valueの作成を行い、Transformer計算を行います。上図の$(\mathrm{b})$で示されるように、「Convolutional Projection」以外の「Convolutional Transformer Block」の処理はオーソドックスなTransformerであるので、以下では「Convolutional Token Embedding」と「Convolutional Projection」について詳しく取り扱います。

Convolutional Token Embedding

ステージ$i-1$の出力を$X_{i-1} \in \mathbb{R}^{H_{i-1} \times W_{i-1} \times C_{i-1}}$とおき、下記が成立するようにチャネル数$C_i$、フィルタのサイズが$s \times s$の畳み込み演算に対応する関数$f$を定義します。
$$
\large
\begin{align}
f(X_{i-1}) \in \mathbb{R}^{H_{i} \times W_{i} \times C_{i}}
\end{align}
$$

このとき、上記で表される関数の$f$が「Convolutional Token Embedding」に対応します。ステージ毎のフィルタサイズ$s \times s$は下記の表から確認できます。

Convolutional Vision Transformer論文 Table$\, 2$を改変

赤枠で囲った部分がフィルタサイズに対応します。また、「Convolutional Token Embedding」がステージ毎に$1$度であるのに対し、「Convolutional Transformer Block」は複数回繰り返される場合があることも合わせて抑えておくと良いです。

Convolutional Projection

Convolutional Vision Transformer論文 Figure$\, 3$

「Convolutional Projection」の処理概要は上図の$(\mathrm{b})$を元に理解すると良いです。パッチを$2D$に並べ、畳み込みを行なうことでQuery、Key、Valueを計算します。ここでの畳み込みではオーソドックスな畳み込みではなくDepth-wise Convolution」と「Point-wise Convolution」を組み合わせた「Depth-wise Separable Convolution」が用いられます。「Depth-wise Separable Convolution」については下記で詳しく取り扱ったので当記事では省略します。

また、処理概要を掴むにあたっては図の$(\mathrm{b})$が適している一方で、実際には$(\mathrm{c})$のSqueezed convolutional projectionが用いられます。Squeezed convolutional projectionはKeyとValueの解像度を低くすることでSpatial Reduction Attention(SRA)と同様に計算の軽量化を行なったと解釈すると良いと思います。Spatial Reduction Attentionについては下記で詳しく取り扱ったので当記事では省略します。

畳み込みとMultiHead Attentionの対応

前項では「Convolutional Transformer Block」におけるTransformerモジュールへの入力となるQuery、Key、Valueの作成を行う「Convolutional Projection」について確認を行いました。当項ではこの「Convolutional Projection」の解釈についてMultiHead Attentionの処理と対応させることで確認を行います。まず、Point-wise Convolutionに一致する$1 \times 1$の畳み込みはMultiHead Attentionにおけるlinear projectionと同じ処理を表すことを抑えておくと良いです。

「Point-wise Convolution」と「MultiHead Attentionにおけるlinear projection」の対応について詳しく確認するにあたって、Feature Mapの$X \in \mathbb{R}^{H \times W \times C_1}$とフィルタの$w_{i} \in \mathbb{R}^{C_1}$を定義します。

このとき、$X$の空間の$2D$を$1D$に直す(Flatten)ことで$X’ \in \mathbb{R}^{HW \times C_1}$を得ることができます。ここで入力のチャネル数の$C_1$に対応する出力のチャネル数を$C_2$とおき、下記のようにフィルタのパラメータを行列$\mathcal{W}$で定義します。
$$
\large
\begin{align}
\mathcal{W} = \left( \begin{array}{ccccc} w_1 & \cdots & w_i & \cdots & w_{C_2} \end{array} \right)
\end{align}
$$

上記を元に$X’ \mathcal{W}$を計算し、$1D$から$2D$に戻す処理はPoint-wise Convolutionの演算に一致します1

一方、$X’$を$Q, K, V$、$\mathcal{W}$を$W^{Q}, W^{K}, W^{V}$に読み換えた処理はMultiHead Attentionの「linear projection」の定義式そのものであることも同時に確認できます。したがって、「Point-wise Convolution」と「MultiHead Attentionにおけるlinear projection」は基本的に同一な処理であることが確認できます。

この処理の対応から、「Convolutional Projection」は新規の処理というよりはMultiHead Attentionの拡張であると解釈することができます。

参考

・Transformer論文:Attention is All you need[$2017$]
・Convolutional vision Transformer論文

  1. CNNの実装では畳み込みをそのまま実装するとループ計算が多くなるので、フィルタと行列演算ができるようにここで取り扱ったような空間方向に展開してから元に戻すような実装がよく用いられます。DeepLearningのフレームワークではim2colなどの名称でこのような機能が提供されることが多いです。 ↩︎