自動微分の理論と応用⑤:TransformerとDot Product Attention

自動微分(Automatic Differentiation)は大規模なニューラルネットワークであるDeepLearningの学習における誤差逆伝播などに用いられる手法です。当記事ではDot Product Attentionに基づくTransformerの簡易版の実装を行いました。
作成にあたっては「ゼロから作るDeep Learning②」の第$5$章「リカレントニューラルネットワーク」の内容を主に参照しました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

・直感的に理解するTransformerの仕組み

Dot Product Attention

ソフトマックス関数

def softmax(a):    # a is 2D-Array
    c = np.max(a, axis=1)
    exp_a = np.exp(a.T-c)
    sum_exp_a = np.sum(exp_a, axis=0)
    y = (exp_a / sum_exp_a).T
    return y

def softmax_3D(a):    # a is 3D-Array
    N, L, H = a.shape
    c = np.max(a, axis=2)
    exp_a = np.exp(a-c.reshape([N, L, 1]).repeat(H, axis=2))
    sum_exp_a = np.sum(exp_a, axis=2)
    y = exp_a / sum_exp_a.reshape([N, L, 1]).repeat(H, axis=2)
    return y

class Softmax:
    def __init__(self):
        self.loss = None
        self.y = None

    def forward(self, x):
        self.y = softmax_3D(x)
        return self.y

    def backward(self, dout):
        dx = dout * self.y * (1-self.y)
        return dx

内積の計算

Transformerの論文ではDot Product Attentionの計算を下記のように定義する。
$$
\large
\begin{align}
\mathrm{Attention}(K, Q, V) = \mathrm{softmax} \left( \frac{Q K^{\mathrm{T}}}{d_k} \right) V
\end{align}
$$

上記の$\displaystyle \mathrm{softmax} \left( \frac{Q K^{\mathrm{T}}}{d_k} \right)$の演算は下記のように実装することができる。

class CalcAttentionWeight:
    def __init__(self):
        self.h = None
        self.graph = None
        self.softmax = Softmax()

    def forward(self, h):
        self.h = h
        N, L, H = self.h.shape
        scaled_dp = np.zeros([N, L, L])
        for i in range(N):
            scaled_dp[i, :, :] = np.dot(h[i, :, :], h[i, :, :].T) / np.sqrt(H)

        self.graph = self.softmax.forward(scaled_dp)

        return self.graph

    def backward(self, dgraph):
        N, L, H = self.h.shape
        ds = self.softmax.backward(dgraph)
        dh = np.zeros_like(self.h)
        for i in range(N):
            dh[i, :, :] = 2*np.dot(dgraph[i, :, :], self.h[i, :, :]) / np.sqrt(H)
        return dh

上記のbackwardメソッドの計算にあたっては、行列$X$について下記のような行列微分が成立することを用いた。
$$
\large
\begin{align}
\frac{\partial}{\partial X} (X X^{\mathrm{T}}) = 2 X = \frac{\partial}{\partial X} (X^{\mathrm{T}} X)
\end{align}
$$

TransformerのDot Product Attentionでは行列の$Q$と$K$にどちらも隠れ層を用いるので、上記の計算が基本的に対応する。

Dot Product Attention

class MessagePassing3D:
    def __init__(self, adj_mat_3D):
        self.params, self.grads = [], []
        self.graph = adj_mat_3D
        self.h = None

    def forward(self, h):
        self.h = h
        N, L, H = h.shape
        m_t = np.zeros_like(h)
        for i in range(self.graph.shape[1]):
            ar = self.graph[:, i,:].reshape(N, L, 1).repeat(H, axis=2)
            t = h * ar
            m_t[:, i, :] = np.sum(t, axis=1)
        return m_t

    def backward(self, dm_t):
        N, L, H = self.h.shape
        dh = np.zeros_like(self.h)
        dar = np.zeros_like(self.graph)
        for i in range(self.graph.shape[1]):
            ar = self.graph[:, i, :].reshape(N, L, 1).repeat(H, axis=2)
            dh[:, i, :] = np.sum(dm_t * ar, axis=1)
        for i in range(self.graph.shape[0]):
            dar[i, :, :] = np.dot(dm_t[i, :, :], self.h[i, :, :].T)
        return dh, dar

基本的にはグラフニューラルネットワークと同様の実装を行なったが、内積による類似度計算を反映できるようにグラフの隣接行列に対応する配列を(L,L)から(N,L,L)に変えた。

Transformer

Transformerレイヤーの計算

Affineクラス

class Affine:
    def __init__(self, W, b):
        self.W = W
        self.b = b
        self.x = None
        self.dW = None
        self.db = None

    def forward(self, x):
        self.x = x
        out = np.dot(x, self.W) + self.b
        return out

    def backward(self, dout):
        dx = np.dot(dout, self.W.T)
        self.dW = np.dot(self.x.T, dout)
        self.db = np.sum(dout, axis=0)
        return dx

Affineクラスについては下記で詳しく取り扱った。

Transformer_Layerクラス

class Transformer_Layer:
    def __init__(self, W, b):
        self.W = W
        self.b = b
        self.dW = np.zeros_like(W)
        self.db = np.zeros_like(b)
        self.h = None
        self.affines = []
        self.graph = None
        self.calc_mp = None
        self.calc_weight = CalcAttentionWeight()

    def forward(self, h):
        self.h = h
        self.graph = self.calc_weight.forward(self.h)
        self.calc_mp = MessagePassing3D(self.graph)
        m_t = self.calc_mp.forward(h)
        h_next = np.zeros_like(h)
        for i in range(self.graph.shape[1]):
            affine = Affine(self.W, self.b)
            h_next[:, i, :] = affine.forward(m_t[:, i, :])
            self.affines.append(affine)
            
        return h_next

    def backward(self, dh_next):
        N, T, H = self.h.shape
        
        dh_affine = np.zeros_like(self.h)
        for i in range(self.graph.shape[1]):
            dh_affine[:, i, :] = self.affines[i].backward(dh_next[:, i, :])
            self.dW += self.affines[i].dW
            self.db += self.affines[i].db
            
        dh_mp, dgraph = self.calc_mp.backward(dh_affine)
        dh = self.calc_weight.backward(dgraph)
        return dh_mp+dh

$2$層Transformer

出力層

class Aggregate:
    def __init__(self):
        self.h = None
        self.y = None

    def forward(self, h):
        self.h = h
        self.y = np.sum(h, axis=1)
        return self.y

    def backward(self, dy):
        N, T, H = self.h.shape
        dh = dy.reshape(N, 1, H).repeat(T, axis=1)
        return dh

def cross_entropy_error(y,t):
    delta = 1e-7
    return -np.sum(t * np.log(y+delta))

class SoftmaxWithLoss:
    def __init__(self):
        self.loss = None
        self.y = None
        self.t = None

    def forward(self, x, t):
        self.t = t
        self.y = softmax(x)
        self.loss = cross_entropy_error(self.y, self.t)
        return self.loss

    def backward(self, dout=1.):
        batch_size = self.t.shape[0]
        dx = (self.y - self.t) / batch_size
        return dx

$2$層Transformer

from collections import OrderedDict

class TwoLayerTransformer:
    def __init__(self, input_size, hidden_size, output_size, weight_init_std=0.01):
        self.params = {}
        self.params["W1"] = weight_init_std * np.random.randn(input_size, hidden_size)
        self.params["b1"] = np.zeros(hidden_size)
        self.params["W2"] = weight_init_std * np.random.randn(hidden_size, output_size)
        self.params["b2"] = np.zeros(output_size)

        # generate layers
        self.layers = OrderedDict()
        self.layers["Transformer_layer1"] = Transformer_Layer(self.params["W1"], self.params["b1"])
        self.layers["Softmax1"] = Softmax()
        self.layers["Transformer_layer2"] = Transformer_Layer(self.params["W2"], self.params["b2"])
        self.layers["Aggregate"] = Aggregate()
        self.lastLayer = SoftmaxWithLoss()

    def predict(self, x):
        for layer in self.layers.values():
            x = layer.forward(x)
        return x

    def loss(self, x, t):
        y = self.predict(x)
        return self.lastLayer.forward(y, t)

    def calc_gradient(self, x, t):
        # forward
        self.loss(x, t)

        # backward
        dout = self.lastLayer.backward(1.)
        layers = list(self.layers.values())
        layers.reverse()
        for layer in layers:
            dout = layer.backward(dout)

        # output
        grads = {}
        grads["W1"] = self.layers["Transformer_layer1"].dW
        grads["b1"] = self.layers["Transformer_layer1"].db
        grads["W2"] = self.layers["Transformer_layer2"].dW
        grads["b2"] = self.layers["Transformer_layer2"].db

        return grads

実行例

np.random.seed(0)

alpha = 0.1

x = np.ones([2, 5, 2])
x[0, :, :1] = -1.
t = np.array([[1., 0.], [0., 1.]])

network = TwoLayerTransformer(2, 2, 2)

for i in range(50):
    grad = network.calc_gradient(x, t)

    for key in ("W1", "b1", "W2", "b2"):
        network.params[key] -= alpha * grad[key]

    if (i+1)%10==0:
        print(softmax(network.predict(x)))

・実行結果

[[0.46977806 0.53022194]
 [0.46352429 0.53647571]]
[[0.73811463 0.26188537]
 [0.30103311 0.69896689]]
[[0.96535871 0.03464129]
 [0.03273874 0.96726126]]
[[0.99548783 0.00451217]
 [0.00521023 0.99478977]]
[[9.99530595e-01 4.69404946e-04]
 [1.19280865e-03 9.98807191e-01]]