SPMDのメモリ制約とMesh-TensorFlowを用いたModel-Parallel処理の実装

DeepLearningにおける分散処理ではSPMDに基づいてバッチ分割を行うことが多い一方で、大規模モデルを取り扱うにあたってはメモリの制約などの課題があります。当記事では上記の解決にあたって用いられるModel-Parallel処理の原理やMesh-TensorFlowライブラリの概要について取り扱いました。
Mesh-TensorFlowの論文である”Mesh-TensorFlow: Deep Learning for Supercomputers”やSwitch Transformer論文などを参考に取りまとめを行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

・Mesh-TensorFlow論文
・Switch Transformer論文

Mesh-TensorFlowの概要

DeepLearningとSPMD

SPMD(Single-Program-Multiple-Data)は分散処理を行う際の基本的な方針であり、「同じプログラムを複数のデータに作用させる」ということを意味する。DeepLearningではミニバッチを分割し複数のGPUやTPUで処理することに対応する。

このようなGPU/TPUを用いた分散処理を用いることで学習の高速化が可能になる。ResNetなどに基づく画像処理ではFoward処理とBackProp処理を分散して行い、勾配を用いたパラメータのUpdateをまとめて行うことで高速化を実現できる。

このようにDeepLearningではSPMD(Single-Program-Multiple-Data)に基づいて複数のGPU/TPUを用いて分散処理を行うことで学習の高速化を行うことができる。

Mesh-TensorFlowの概要

Mesh-TensorFlowはSPMDに基づくバッチ分割(batch splitting)以外の分散処理も行えるように実装されたライブラリである。詳しくは次節の「Mesh-TensorFlowの設定と活用」で取り扱った。

Mesh-TensorFlowの使用例:2層MLP

入力層に対応する行列$x \in \mathbb{R}^{b \times d_{io}}$、中間層に対応する行列$h \in \mathbb{R}^{b \times d_{h}}$、出力層に対応する行列$y \in \mathbb{R}^{b \times d_{io}}$を元に下記のように$2$層MLPを定義する。
$$
\large
\begin{align}
y &= \mathrm{ReLU}(xw + \mathrm{bias}) v \quad (1) \\
w & \in \mathbb{R}^{d_{io} \times d_{h}}, \, v \in \mathbb{R}^{d_{h} \times d_{io}}, \, \mathrm{bias} \in \mathbb{R}^{b \times d_{h}}
\end{align}
$$

上記の$b$はバッチサイズ、$d_{io}$は$2$層MLPの入力層と出力層のベクトルの要素数、$d_{h}$は$2$層MLPの隠れ層のベクトルの要素数に対応する。このとき$(1)$式の演算をMesh-TensorFlowでは下記のように表現する。

batch = mtf.Dimension("batch", b)
io = mtf.Dimension("io", d_io)
hidden = mtf.Dimension("hidden", d_h)
# x.shape == [batch, io]
w = mtf.get_variable("w", shape=[io, hidden])
bias = mtf.get_variable("bias", shape=[hidden])
v = mtf.get_variable("v", shape=[hidden, io])
h = mtf.relu(mtf.einsum(x, w, output_shape=[batch, hidden]) + bias)
y = mtf.einsum(h, v, output_shape=[batch, io])

8行目と9行目で隠れ層の計算、出力層の計算がそれぞれ行われることに注意して上記は確認すると良い。

Mesh-TensorFlowの設定と活用

Data-Parallel Layout

Data-Parallel LayoutはMesh-TensorFlowでSPMD(Single-Program-Multiple-Data)処理を行う際のLayoutに対応する。Mesh-TensorFlowでは下記のような表記でData-Parallel Layoutを表す。

mesh_shape = [("all", n)]
computation_layout = [("batch", "all")]

Mesh-TensorFlowでは上記のcomputation_layoutbatchを指定することでData-Parallel処理を表す。

Model-Parallel Layout

Mesh-TensorFlowはバッチ分割ではなく、MLP(Multi Layer Perceptron)の演算の分割も行うことができる。前節で取り扱った$2$層MLPの例では入力層の$x$と出力層の$y$を全てのGPU/TPUに載せ、隠れ層の演算のみを分割すると分散処理が行える。

たとえば入力層・出力層の次元を$50$、隠れ層の次元を$100$に設定し、$5$つの演算ノードで分散処理を行う場合、通常では$50 \times 100$と$100 \times 50$の行列を用いて行列演算を行うが、$50 \times 20$と$20 \times 50$の行列演算を$5$つ行うことで全体の演算を実現できる。

このような処理を行うことで、MLP処理における隠れ層の次元が大きい場合も分割して処理を行うことができ、大規模モデルの構築も無理なく行うことができる。たとえば近年注目を集めるGPT$3$やPaLMなどのLLMでは$10{,}000$次元以上が用いられることがあるので、このような演算は有効な手段になり得る。

mesh_shape = [("all", n)]
computation_layout = [("hidden", "all")]

Mesh-TensorFlowでは上記のcomputation_layouthiddenを指定することでData-Parallel処理を表す。

前項のbatchと当項hiddenの対応に注意しておくと良い。

Data-Parallel, Model-Parallel Layouts

DataとModelの方向にそれぞれ対応する$r \times c$の$2$次元でメッシュ化を行う場合、Mesh-TensorFlowでは下記のように処理を表現する。

mesh_shape = [("rows", r), ("cols", c)]
computation_layout = [("batch", "rows"), ("hidden", "cols")]

上記はmesh_shaperowsrcolscを設定し、computation_layoutでそれぞれ辞書オブジェクトのように指定すると理解すればよい。

Mesh-TensorFlowとTransformer

Mesh-TensorFlowはTransformerにも活用することができる。

mesh_shape = [("all", n)]
computation_layout = [
  ("vocab", "all"), ("d_ff", "all"), ("heads", "all")]

上記はTransformerにModel-Parallel Layoutsを用いる際の表記である。n個のメッシュを用意し、語彙のサイズ$d_{model}$、隠れ層のサイズ$d_{ff}$、Attention_Headsの数にnを対応させる。

Model-Parallel LayoutsだけでなくData-Parallel Layoutも用いる場合は、下記のように$r \times c$個のメッシュを設定する。

mesh_shape = [("rows", r), ("cols", c")]
computation_layout = [("batch", "rows"), ("vocab", "cols"),
                      ("d_ff", "cols"), ("heads", "cols")]

Mesh-TensorFlowの活用①:Switch Transformer

Switch Transformerでは、Mesh-TensorFlowを用いる際の項を下記のように定義する。

Switch Transformer論文 Section.$5$より

Mesh-TensorFlowの活用②:Pathways・PaLM