【SimCLR】対照学習(Contrastive Learning)に基づくベクトル表現の取得①

SimCLR(Simple Framework for Contrastive Learning of Visual Representations)は対照学習(Contrastive Learning)を用いて画像のベクトル表現を抽出する手法です。当記事ではSimCLRの一連の学習手順について取りまとめを行いました。
SimCLRの論文の「A Simple Framework for Contrastive Learning of Visual Representations」を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

指示関数

指示関数(indicator function)の$\mathbb{1}_{[k \neq i]} \in \{ 0, 1 \}$は下記のように定義される。
$$
\large
\mathbb{1}_{[k \neq i]} =
\begin{cases}
1 \quad \mathrm{if} \quad k \neq i \\
0 \quad \mathrm{otherwise}
\end{cases}
$$

対照学習

SimCLR

SimCLRの全容

SimCLRは下記の$4$つの主要要素によって構成される。
・確率的データ拡張(A stochastic data augmentation)
・エンコーダ(A neural network base encoder $f(·)$)
・A small neural network projection head $g(·)$
・対照損失関数(A contrastive loss function)

上記の$4$つの主要要素は下図を元に抑えておくと良い。

SimCLR論文 Figure$\, 2$

上図の$\mathbf{x}$から$\tilde{\mathbf{x}}_{i}$や$\tilde{\mathbf{x}}_{j}$を作成するプロセスが確率的データ拡張(A stochastic data augmentation)、$\tilde{\mathbf{x}}_{i}, \, \tilde{\mathbf{x}}_{j}$から$\mathbf{h}_{i}, \, \mathbf{h}_{j}$を計算するプロセスがエンコーダ、$\mathbf{h}_{i}, \, \mathbf{h}_{j}$から$\mathbf{z}_{i}, \, \mathbf{z}_{j}$を計算するプロセスがprojection head、$\mathbf{z}_{i}, \, \mathbf{z}_{j}$を用いて定義される損失関数が対照損失関数(A contrastive loss function)にそれぞれ対応する。

以下、$4$つの主要要素についてそれぞれ確認を行う。

stochastic data augmentation

データ拡張(data augmentation)の主要な手法は下図に基づいて把握することができる。

SimCLR論文 Figure$\, 4$

SimCLRでは$(\mathrm{c})$の「crop(切り抜き)+リサイズ+反転(flip)」、$(\mathrm{d}), \, (\mathrm{e})$の「color distortion」、$(\mathrm{i})$の「Gaussian blur」の$3$つが用いられ、cropとcolor distortionが有効であったと報告されている。

$(\mathrm{c})$の切り抜きにあたっては「切り抜く場所」や「切り抜くサイズ」にランダム性があることから、SimCLRにおけるデータ拡張は確率的データ拡張(stochastic data augmentation)と表されていることも合わせて抑えておくと良い。

ニューラルネットの構成①:encoder

encoderの$f(\cdot)$は入力の$\tilde{\mathbf{x}}_{i}$や$\tilde{\mathbf{x}}_{j}$からベクトル表現(Visual Representation)を抽出する関数に対応する。SimCLRでは下記の数式で表されるようにResNetが用いられる。
$$
\large
\begin{align}
\mathbf{h}_{i} &= f(\mathbf{x}_{i}) = \mathrm{ResNet}(\mathbf{x}_{i}) \\
\mathbf{h}_{i} & \in \mathbb{R}^{d}
\end{align}
$$

$d$は抽出する画像のベクトル表現の次元に対応する。

ニューラルネットの構成②:projection head

projection headの$g(\cdot)$は抽出したベクトル表現の$\mathbf{h}_{i}$を対照損失(contrastive loss)の計算用に変換する処理に対応する。原理的には$\mathbf{h}_{i}$をそのまま用いてlossの計算を行うことは可能であるが、SimCLRの論文では$g$を用いることがbeneficialとされる。

SimCLRでは$g$に二層のMLP(Multi Layer Perceptron)が用いられている。このMLPは下記のような数式で表される。
$$
\large
\begin{align}
\mathbf{z}_{i} &= g(\mathbf{h}_{i}) = W^{(2)} \mathrm{ReLU}(W^{(1)} \mathbf{h}_{i}) \\
\mathrm{ReLU}(x) &= \max(0, x)
\end{align}
$$

loss function

$(i,j)$を正例の組とおき、$k$組用意するとき、サンプルは$2k$個となる。このとき、$(i,j)$に関するloss functionを下記のように定義する。
$$
\large
\begin{align}
l_{i,j} &= -\log{\left[ \frac{\exp{(\mathrm{sim}(\mathbf{z}_{i},\mathbf{z}_{j}))/\tau}}{\sum_{k=1}^{2N} \mathbb{1}_{k \neq i} \exp{(\mathrm{sim}(\mathbf{z}_{i},\mathbf{z}_{k}))/\tau}} \right]} \\
\mathrm{sim}(\mathbf{u},\mathbf{v}) &= \frac{\mathbf{u}^{\mathrm{T}} \mathbf{v}}{||\mathbf{u}|| ||\mathbf{v}||}
\end{align}
$$

$\tau$は温度パラメータである。上記の式に基づいてSimCLRの学習を行う。

参考

・SimCLR論文

「【SimCLR】対照学習(Contrastive Learning)に基づくベクトル表現の取得①」への2件のフィードバック

コメントは受け付けていません。