ブログ

【上級】データサイエンス 数学ストラテジスト 公式問題集 解答例まとめ Q.21〜30

「データサイエンス 数学ストラテジスト 上級」はデータサイエンスの基盤である、確率・統計、線形代数、微積分、機械学習、プログラミングなどを取り扱う資格試験です。当記事では「日本数学検定協会」作成の「公式問題集」の演習問題$21$〜$30$の解答例を取り扱いました。

・数学検定まとめ
https://www.hello-statisticians.com/math_certificate

演習問題

Q.21

$$
\large
\begin{align}
AB = 8, \, BC = 5, \, \angle{ABC} = 60^{\circ}
\end{align}
$$

余弦定理より下記が成立する。
$$
\large
\begin{align}
AC^{2} &= AB^{2} + BC^{2} – 2 \cdot AB \cdot BC \cdot \cos{60^{\circ}} \\
&= 8^{2} + 5^{2} – \cancel{2} \cdot 8 \cdot 5 \cdot \frac{1}{\cancel{2}} \\
&= 64 + 25 – 40 = 49
\end{align}
$$

$AC>0$より$AC=7$である。また、外接円の半径が$R$なので正弦定理より下記が成立する。
$$
\large
\begin{align}
\frac{AC}{\sin{60^{\circ}}} &= 2R \\
R &= 7 \cdot \frac{\cancel{2}}{\sqrt{3}} \cdot \frac{1}{\cancel{2}} \\
&= \frac{7}{\sqrt{3}}
\end{align}
$$

また、$\triangle{ABC}$の面積を$S$とおくと$S$は下記のように計算できる。
$$
\large
\begin{align}
S &= \frac{1}{2} AB \cdot BC \sin{60^{\circ}} \\
&= 20 \cdot \frac{\sqrt{3}}{2} = 10 \sqrt{3}
\end{align}
$$

ここで内接円の半径$r$は下記の式に基づいて得られる。
$$
\large
\begin{align}
S &= \frac{1}{2}(AB+BC+AC)r \\
10 \sqrt{3} &= \frac{1}{2}(8+5+7)r \\
r &= \sqrt{3}
\end{align}
$$

よって$\displaystyle \frac{R}{r}$は下記のように得られる。
$$
\large
\begin{align}
\frac{R}{r} &= \frac{7}{\sqrt{3}} \cdot \frac{1}{\sqrt{3}} \\
&= \frac{7}{3}
\end{align}
$$

Q.22

$p_{n}, \, n=0,1,2,3$は下記のような式で表すことができる。
$$
\large
\begin{align}
p_{n} &= \frac{{}_{5} C_{n} \cdot {}_{10} C_{3-n}}{{}_{15} C_{3}} \quad [1]
\end{align}
$$

$[1]$式より、$p_0, p_1, p_2, p_3$は下記のように計算できる。
$$
\large
\begin{align}
p_{0} &= \frac{{}_{5} C_{0} \cdot {}_{10} C_{3}}{{}_{15} C_{3}} \\
&= \frac{10 \cdot 9 \cdot 8}{15 \cdot 14 \cdot 13} \\
&= \frac{24}{91} \\
p_{1} &= \frac{{}_{5} C_{1} \cdot {}_{10} C_{2}}{{}_{15} C_{3}} \\
&= \frac{5 \cdot 10 \cdot 9}{2} \cdot \frac{3 \cdot 2 \cdot 1}{15 \cdot 14 \cdot 13} \\
&= \frac{45}{91} \\
p_{2} &= \frac{{}_{5} C_{2} \cdot {}_{10} C_{1}}{{}_{15} C_{3}} \\
&= \frac{5 \cdot 4 \cdot 10}{2} \cdot \frac{3 \cdot 2 \cdot 1}{15 \cdot 14 \cdot 13} \\
&= \frac{20}{91} \\
p_{3} &= \frac{{}_{5} C_{3} \cdot {}_{10} C_{0}}{{}_{15} C_{3}} \\
&= \frac{5 \cdot 4 \cdot 3}{15 \cdot 14 \cdot 13} \\
&= \frac{2}{91}
\end{align}
$$

上記より$p_3 < p_2 < p_0 < p_1$であるので$(1)$が正しい。

・解説
この問題は超幾何分布の確率関数の計算と対応するので、合わせて抑えておくと良いです。

Q.23

$$
\large
\begin{align}
\sin{\alpha} &= \frac{1}{3}, \, \sin{\beta} = \frac{2}{3} \\
0 < & \alpha < \frac{\pi}{2}, \, 0 < \beta < \frac{\pi}{2}
\end{align}
$$

上記より、$0 < \cos{\alpha}, \, 0 < \cos{\beta}$であるので、$\sin^{2}{\theta}+\cos^{2}{\theta}=1$に基づいて下記が得られる。
$$
\large
\begin{align}
\sin^{2}{\alpha} + \cos^{2}{\alpha} &= 1 \\
\cos^{2}{\alpha} &= 1-\frac{1}{3^2} \\
\cos{\alpha} &= \frac{2\sqrt{2}}{3} \\
\sin^{2}{\beta} + \cos^{2}{\beta} &= 1 \\
\cos^{2}{\beta} &= 1-\frac{2^2}{3^2} \\
\cos{\beta} &= \frac{\sqrt{5}}{3}
\end{align}
$$

よって$\sin{(\alpha+\beta)}$は加法定理に基づいて下記のように得られる。
$$
\large
\begin{align}
\sin{(\alpha+\beta)} &= \sin{\alpha} \cos{\beta} + \cos{\alpha} \sin{\beta} \\
&= \frac{1}{3} \cdot \frac{\sqrt{5}}{3} + \frac{2\sqrt{2}}{3} \cdot \frac{2}{3} \\
&= \frac{4 \sqrt{2} + \sqrt{5}}{9}
\end{align}
$$

上記より$(5)$が正しい。

Q.24

$$
\large
\begin{align}
x^{\log_{10}{x}} = 1000 \sqrt{x} \quad [1]
\end{align}
$$

$[1]$式の両辺に対し、底が$10$の対数を取ると下記のように方程式を解くことができる。
$$
\large
\begin{align}
\log_{10}{x^{\log_{10}{x}}} &= \log_{10}{(1000 \sqrt{x})} \quad [1]’ \\
\log_{10}{x} \cdot \log_{10}{x} &= \frac{1}{2} \log_{10}{x} + 3 \\
2 (\log_{10}{x})^{2} – \log_{10}{x} – 6 &= 0 \\
(2 \log_{10}{x} + 3)(\log_{10}{x} – 2) &= 0 \\
\log_{10}{x} &= -\frac{3}{2}, \, 2 \\
x &= \frac{\sqrt{10}}{100}, \, 100
\end{align}
$$

よって$(2)$が正しい。

Q.25

$a, b, c$が等比数列かつ$\displaystyle \frac{b}{a} = r$より、$b,c$は$a,r$を用いて下記のように表すことができる。
$$
\large
\begin{align}
b &= ar \\
c &= ar^2
\end{align}
$$

また、$a,b,c$の相加平均が$b+2$に等しいことから下記が成立する。
$$
\large
\begin{align}
\frac{1}{3}(a+b+c) &= b+2 \\
a + ar + ar^2 &= 3(ar+2) \\
a(1+r+r^2-3r) &= 6 \\
a(1-2r+r^2) &= 6 \\
a(r-1)^{2} &= 6 \quad [1]
\end{align}
$$

$[1]$式より$a$が正の整数で$\displaystyle \frac{b}{a} = r$が整数であることから、$a=6, (r-1)^{2}=1$が成立する。ここで$r \neq 0$より$r=2$が成立する。よって、与えられた式は下記のように計算できる。
$$
\large
\begin{align}
\frac{a^2+a-7}{a+1} + \frac{r^2+r-1}{r+3} &= \frac{6^2+6-7}{6+1} + \frac{2^2+2-1}{2+3} \\
&= 5 + 1 = 6
\end{align}
$$

上記より$(1)$が正しい。

Q.26

面積の和$S_1+S_2$は下記のように計算できる。
$$
\large
\begin{align}
S_1+S_2 &= \int_{0}^{5} -(x^2-5x) dx + \int_{5}^{6} (x^2-5x) dx \\
&= \left[ -\frac{1}{3}x^3 + \frac{5}{2}x^2 \right]_{0}^{5} + \left[ \frac{1}{3}x^3 – \frac{5}{2}x^2 \right]_{5}^{6} \\
&= 2 \left( -\frac{125}{3}+\frac{125}{2} \right) + \frac{6^3}{3} – \frac{5 \cdot 6^2}{2} \\
&= 2 \cdot \frac{125}{6} + 72 – 90 \\
&= \frac{125 – 3 \cdot 18}{3} \\
&= \frac{71}{3}
\end{align}
$$

上記より$(2)$が正しい。

Q.27

$$
\large
\begin{align}
f(x) &= ax^3 + bx^2 + cx + d \\
f'(x) &= 3ax^2 + 2bx + c
\end{align}
$$

$x=-2$のとき極大値$15$、$x=4$のとき極小値$-12$を取るには下記の必要条件が成立しなければならない。
$$
\large
\begin{align}
f'(-2) &= 12a – 4b + c = 0 \quad [1] \\
f(-2) &= -8a + 4b – 2c + d = 15 \quad [2] \\
f'(4) &= 48a + 8b + c = 0 \quad [3] \\
f(4) &= 64a + 16b + 4c + d = -12 \quad [4]
\end{align}
$$

$[2]-[1]$より下記が得られる。
$$
\large
\begin{align}
36a + 12b &= 0 \\
b &= -3a \quad [5]
\end{align}
$$

$[5]$式を$[1]$に代入することで下記が得られる。
$$
\large
\begin{align}
12a – 4 \cdot (-3a) + c &= 0 \quad [1]’ \\
c &= -24a \quad [6]
\end{align}
$$

$[4]-[2]$より下記が得られる。
$$
\large
\begin{align}
72a + 12b + 6c = -27 \quad [7]
\end{align}
$$

$[7]$式に$[5], [6]$を代入すると下記が得られる。
$$
\large
\begin{align}
72a + 12 \cdot (-3a) + 6 \cdot (-24a) &= -27 \quad [7]’ \\
-108a &= -27 \\
a &= \frac{1}{4} \quad [8]
\end{align}
$$

$[8]$を$[5], [6]$に代入することで下記が得られる。
$$
\large
\begin{align}
b &= -\frac{3}{4} \quad [9] \\
c &= -6 \quad [10]
\end{align}
$$

また、$[8], [9], [10]$を$[2]$に代入することで下記が得られる。
$$
\large
\begin{align}
-8 \cdot \frac{1}{4} + 4 \cdot \left( -\frac{3}{4} \right) – 2 \cdot (-6) + d &= 15 \quad [2]’ \\
-2 – 3 + 12 + d &= 15 \\
d &= 8 \quad [11]
\end{align}
$$

$[8], [9], [10], [11]$より$a+b+c+d$は下記のように計算できる。
$$
\large
\begin{align}
a + b + c + d &= \frac{1}{4} – \frac{3}{4} – 6 + 8 \\
&= -\frac{1}{2} + 2 \\
&= \frac{3}{2}
\end{align}
$$

上記より$(2)$が正しい。

Q.28

$\displaystyle X \sim \mathrm{Bin} \left( 10n, \frac{1}{2} \right)$であるので、$m=E[X], \sigma=\sqrt{V[X]}$は下記のように計算できる。
$$
\large
\begin{align}
m &= E[X] \\
&= 10n \cdot \frac{1}{2} \\
&= 5n \\
\sigma &= \sqrt{V[X]} \\
&= \sqrt{10n \cdot \frac{1}{2} \cdot \frac{1}{2}} \\
&= \frac{\sqrt{10n}}{2}
\end{align}
$$

上記より$(5)$が正しい。

・解説
二項分布の期待値$E[X]$と分散$V[X]$の式と導出は下記で取り扱ったので、合わせて抑えておくと良いです。

Q.29

・$0 \leq x \leq 1$
$$
\large
\begin{align}
f(x) = a(x-x^{2}) \quad [1]
\end{align}
$$

・$x < 0, 1 < x$
$$
\large
\begin{align}
f(x) = 0 \quad [2]
\end{align}
$$

$[1], [2]$式と確率密度関数の定義より下記が成立する。
$$
\large
\begin{align}
\int_{0}^{1} f(x) dx &= 1 \\
a \int_{0}^{1} (x-x^{2}) dx &= 1 \\
a \left[ \frac{1}{2}x^{2} – \frac{1}{3}x^{3} \right]_{0}^{1} &= 1 \\
a \left( \frac{1}{2} – \frac{1}{3} \right) &= 1 \\
a &= 6
\end{align}
$$

また、$E[X], E[X^{2}]$はそれぞれ下記のように計算できる。
$$
\large
\begin{align}
E[X] &= \int_{0}^{1} x f(x) dx = \int_{0}^{1} (x^{2}-x^{3}) dx \\
&= 6 \left[ \frac{1}{3}x^{3}-\frac{1}{4}x^{4} \right]{0}^{1} \\
&= 6 \cdot \frac{1}{12} = \frac{1}{2} E[X^{2}] \\
&= \int_{0}^{1} x^{2} f(x) dx = \int_{0}^{1} (x^{3}-x^{4}) dx \\
&= 6 \left[ \frac{1}{4}x^{3} – \frac{1}{5}x^{4}) \right]_{0}^{1} \\
&= 6 \cdot \frac{1}{20} = \frac{3}{10}
\end{align}
$$

よって分散$V[X]=E[X^{2}]-E[X]^{2}$は下記のように得られる。
$$
\large
\begin{align}
V[X] &= E[X^{2}] – E[X]^{2} \\
&= \frac{3}{10} – \left( \frac{1}{2} \right)^{2} \\
&= \frac{6-5}{20} = \frac{1}{20}
\end{align}
$$

よって$(2)$が正しい。

Q.30

$$
\large
\begin{align}
\vec{a} = \left( \begin{array}{c} 2 \\ -1 \end{array} \right), \, \vec{a} = \left( \begin{array}{c} -1 \\ 3 \end{array} \right)
\end{align}
$$

ベクトル$\vec{p}=k\vec{a}+\vec{b}$は下記のように表せる。
$$
\large
\begin{align}
\vec{p} &= k \vec{a} + \vec{b} = k \left( \begin{array}{c} 2 \\ -1 \end{array} \right) + \left( \begin{array}{c} -1 \\ 3 \end{array} \right) \\
&= \left( \begin{array}{c} 2k-1 \\ -k+3 \end{array} \right)
\end{align}
$$

上記より$|\vec{p}|$は下記のように計算できる。
$$
\large
\begin{align}
|\vec{p}| &= \sqrt{(2k-1)^{2} + (-k+3)^{2}} \\
&= \sqrt{(4k^{2}-4k+1) + (k^{2}-6k+9)} \\
&= \sqrt{5k^{2} – 10k + 10} \\
&= \sqrt{5(k^{2}-2k+1) + 5} \\
&= \sqrt{5(k-1)^{2} + 5}
\end{align}
$$

ここで$\displaystyle -1 \leq k \frac{3}{2}$であるので、$|\vec{p}|$は$k=1$のとき最小値$\sqrt{5}$を取る。よって$m^{2}=5$であるので$(3)$が正しい。

BPE(Byte Pair Encoding)アルゴリズムの仕組みとPythonプログラムの確認

Transformerに基づくLLMの学習にあたっては多くの文書を用いる一方で、単語をそのまま取り扱うとEmbedding処理のパラメータ数が増大します。当記事ではこの解決にあたって用いられる手法の$1$つであるBPE(Byte Pair Encoding)のアルゴリズムの仕組みとPythonプログラムの確認を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

BPEアルゴリズム

BPEの基本的な仕組み

BPE(Byte Pair Encoding)の基本的な仕組みは下図を元に理解することができる。

A New Algorithm for Data Compression(Philip Gage, $1994$)のFigure.$1$

上図では入力が”ABABCABCD”であり、”A”、”B”、”C”、”D”の$4$種類の文字で文字列が構成される。ここで”AB”の並びが最多の$3$回確認できるので”AB”を”H”で置き換えると、文字列は”HHCHCD”に変換できる。

次に”HHCHCD”の文字列では”HC”の並びが最多の$2$回観測されるので”G”で置き換えると文字列は”HGGD”に変換できる。BPEではこのように取得した”G”と”H”を加えた”A”、”B”、”C”、”D”、”G”、”H”で文字列を表す。

BPEの論文

BPE(Byte Pair Encoding)を用いる際に参照されることが多いのが「Neural Machine Translation of Rare Words with Subword Units(Sennrich et al., $2016$)」である。

上記はByte Pair Encodingを用いた機械翻訳の論文である一方で、Byte Pair Encodingについては”A New Algorithm for Data Compression(Gage, $1994$)”を参照している。

近年の論文ではBPEを示すにあたって新しいものが参照されることが多いので注意が必要である。当記事でも以下、近年の論文に倣い「Neural Machine Translation of Rare Words with Subword Units(Sennrich et al., $2016$)」を「BPE論文」と表す。

BPEのPythonプログラム

BPE論文のPythonプログラム

BPE論文のAlgorithm$1$では下記のようにPythonを用いたBPE処理が確認できる。

BPE論文のAlgorithm$1$
import re, collections

def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

vocab = {'l o w _' : 5, 'l o w e r _' : 2, 'n e w e s t _':6,'w i d e s t _':3}
num_merges = 10

for i in range(num_merges):
    pairs = get_stats(vocab)
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    print(best)

・実行結果

('e', 's')
('es', 't')
('est', '_')
('l', 'o')
('lo', 'w')
('n', 'e')
('ne', 'w')
('new', 'est_')
('low', '_')
('w', 'i')

上記の出力結果はfor文による繰り返し処理それぞれで連結する文字列の組である。たとえば”e”と”s”は”newest”と”widest”に一度ずつ出現するので、全部で$9$回観測され、全体の中で最多である。

処理の概要については次項で詳しく確認を行う。

処理の概要

以下のプログラムではモジュールの読み込みやget_stats関数とmerge_vocab関数の定義は基本的には省略する。

まず、get_stats(vocab)の実行結果の確認を行う。

vocab = {'l o w _' : 5, 'l o w e r _' : 2, 'n e w e s t _':6,'w i d e s t _':3}

print(get_stats(vocab))

・実行結果

defaultdict(int,
            {('l', 'o'): 7,
             ('o', 'w'): 7,
             ('w', '_'): 5,
             ('w', 'e'): 8,
             ('e', 'r'): 2,
             ('r', '_'): 2,
             ('n', 'e'): 6,
             ('e', 'w'): 6,
             ('e', 's'): 9,
             ('s', 't'): 9,
             ('t', '_'): 9,
             ('w', 'i'): 3,
             ('i', 'd'): 3,
             ('d', 'e'): 3})

上記より、「e-s」、「s-t」、「t-_」の出現回数が$9$で最多であることが確認できる。プログラムでは次に下記のような処理を行う。

vocab = {'l o w _' : 5, 'l o w e r _' : 2, 'n e w e s t _':6,'w i d e s t _':3}

pairs = get_stats(vocab)
best = max(pairs, key=pairs.get)
vocab = merge_vocab(best, vocab)

print(best)
print(vocab)

・実行結果

('e', 's')
{'l o w _': 5, 'l o w e r _': 2, 'n e w es t _': 6, 'w i d es t _': 3}

上記より、出現頻度が最多の組をbestに格納し、merge_vocab(best, vocab)を実行することで”e”と”s”の連結を行う流れが確認できる。この処理を10回繰り返すことで下記のような結果が得られる。

vocab = {'l o w _' : 5, 'l o w e r _' : 2, 'n e w e s t _':6,'w i d e s t _':3}
num_merges = 10

for i in range(num_merges):
    pairs = get_stats(vocab)
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    print(best)

print("===")
print(vocab)

・実行結果

('e', 's')
('es', 't')
('est', '_')
('l', 'o')
('lo', 'w')
('n', 'e')
('ne', 'w')
('new', 'est_')
('low', '_')
('w', 'i')
===
{'low_': 5, 'low e r _': 2, 'newest_': 6, 'wi d est_': 3}

また、下記のようにプログラムを改変すると10回繰り返した際の辞書を得ることができる。

vocab = {'l o w _' : 5, 'l o w e r _' : 2, 'n e w e s t _':6,'w i d e s t _':3}
num_merges = 10

bpe_dict = []
for i in range(num_merges):
    pairs = get_stats(vocab)
    best = max(pairs, key=pairs.get)
    bpe_dict.append("".join(best))
    vocab = merge_vocab(best, vocab)
    print(best)

print("===")
print(bpe_dict)

・実行結果

('e', 's')
('es', 't')
('est', '_')
('l', 'o')
('lo', 'w')
('n', 'e')
('ne', 'w')
('new', 'est_')
('low', '_')
('w', 'i')
===
['es', 'est', 'est_', 'lo', 'low', 'ne', 'new', 'newest_', 'low_', 'wi']

BPE(Byte Pair Encoding)では上記のように得られる辞書を元にsubwordにidを割り当て、以後の処理を行う。

【Transformer】LLM(Large Language Model)のパラメータ数の概算法

昨今LLM(Large Language Model)が大きな注目を集める一方で、パラメータ数がどのように決まるかについて抑えておくと理解に役立ちます。そこで当記事ではLLMの主要モジュールであるTransformerに用いられるパラメータの概算法について取りまとめを行いました。
Transformerの論文や筆者作成の『直感的に理解するTransformer』の内容などを元に取りまとめを行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

・Transformer論文
・直感的に理解するTransformer(運営者作成)

パラメータ数の概算

パラメータ数の単位

LLM(Large Language Model)関連の論文ではパラメータ数はMillionを表すMやBillionを表すBで略記されるので注意が必要です。Millionは$10^{6}$の$100$万、Billionは$10^{9}$の$10$億にそれぞれ対応します。

具体的な論文とパラメータ数の対応については、$110$Mと$340$MのBERTが$1.1$億と$3.4$億、$11$BのT$5$が$110$億、$175$BのGPT$3$が$1750$億にそれぞれ対応します。

Transformerの大まかな仕組み

LLMの基盤のアーキテクチャには基本的にTransformerが用いられます。よって、LLMのパラメータ数について解釈する際はTransformerの大まかな仕組みの理解が重要です。Transformerの大まかな仕組みについては下記で詳しくまとめました。

・直感的に理解するTransformer

パラメータ数の概算

Transformerにおけるパラメータは主にEmbedding、各層におけるMulti-Head Attention、FFN(Feed Forward Network)処理にそれぞれ用いられます。下図の赤枠がパラメータ処理に対応します。

Transformer論文のFigure$1$を改変

上記に基づいてTransformerに用いられるパラメータ数を大まかに概算することが可能です。以下、パラメータ数の計算を下記の三つにわけて概算します。

$1. \,$ Embedding処理
$2. \,$ Multi-Head Attention処理
$3. \,$ FFN処理

Embedding処理

Embedding処理は$1$-hotベクトルにEmbedding Matrixを左からかけることで得ることができます。このEmbedding Matrixのパラメータ数は語彙数$V$とTransformer処理における隠れ層の数$D$によって概算が可能です。

たとえば$1$万種類の単語に対し、Transformerのそれぞれの単語の隠れ層の数が$D=512$である場合、パラメータ数は下記のように概算できます。
$$
\large
\begin{align}
V \times D &= 10^{4} \times 512 \\
&= 5.12 \times 10^{6}
\end{align}
$$

上記は約$5$Mに対応します。同様に$10$万種類の単語を$D=1024$で取り扱う場合は下記のように概算できます。
$$
\large
\begin{align}
V \times D &= 10^{7} \times 1024 \\
&= 1.024 \times 10^{8}
\end{align}
$$

上記は約$100$Mに対応します。パラメータ数の概算にあたっては、桁数で大まかに把握できるので、$5.12 \times 10^{6}$や$1.024 \times 10^{8}$のようにパラメータ数を表しました。

トークンが単語単位の場合はWord$2$vecがEmbeddingに対応する一方で、トークンの種類が増大するLLMではBPE(Byte Pair Encoding)などを用いることで$3$万〜$7$万種程度の語彙(vocabulary)に集約させるのが一般的です。よって、LLMの学習時に取り扱う文章が増えても語彙数は数万程度に収まることが多いです。

Multi-Head Attention処理

$$
\large
\begin{align}
\mathrm{MultiHead}(Q,K,V) &= \mathrm{Concat}(\mathrm{head}_{1}, \cdots , \mathrm{head}_{h}) W^{O} \\
\mathrm{head}_{i} &= \mathrm{Attention}(QW_{i}^{Q}, KW_{i}^{K}, VW_{i}^{V}) \\
W_{i}^{O} \in \mathbb{R}^{hd_{v} \times d_{model}}, \, W_{i}^{Q} & \in \mathbb{R}^{d_{model} \times d_{k}}, \, W_{i}^{K} \in \mathbb{R}^{d_{model} \times d_{k}}, W_{i}^{V} \in \mathbb{R}^{d_{model} \times d_{v}}
\end{align}
$$

Multi-Head Attentionではアンサンブル学習と同様に各Headにおける計算の相関が低くなるようにパラメータ$W$を元に上記のような計算を行います。ここで$d_{model}$は各トークンの分散表現の次元数であり前項の$D$と同義です。また、$h$はヘッドの数を表します。このとき、パラメータ数は下記のように概算できます。
$$
\large
\begin{align}
N \times d_{model} \times (h \times (2d_{k} + d_{v}) + h d_{v}) = 2 N d_{model} h(d_{k}+d_{v}) \quad (1)
\end{align}
$$

$N=6, d_{model}=512, h=8, d_{k}=64, d_{v}=64$のとき、Multi-Head Attention処理のパラメータ数は$(1)$式を元に下記のように概算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 6 \times 512 \times 8 \times (64+64) \\
&= 6291456 = 6.29 \times 10^{6}
\end{align}
$$

上記は$6.29$Mに対応します。

FFN処理

FFN処理は単語ごとの隠れ層に対してMLP(Multi Layer Perceptron)を行うことに対応します。よって、単語数$L$、それぞれの単語の隠れ層の数が$D$である場合、$N$層のMLPにおけるパラメータ数は下記のように概算することができます。
$$
\large
\begin{align}
N \times L \times D^{2} = NLD^2
\end{align}
$$

たとえば、$N=6, L=512, D=512$の場合、パラメータ数は下記のように概算できます。
$$
\large
\begin{align}
6 \times 512 \times 512^2 &= 805306368 \\
& \simeq 8.53 \times 10^{8}
\end{align}
$$

上記は$800$Mに対応し、Transformerのパラメータ数の約$100$Mを大きく上回ります。Transformerでは一般的に同じ層の単語では同じパラメータを使うので、$L$はかけないことに注意が必要です。また、FFNの処理では「$D$次元$\to$$D$次元」ではなく、「$D$次元$\to$$4D$次元$\to$$D$次元」のような処理が行われます。中間層の$4D$は別途設定されることもありますが、$4D$が用いられることが多いです。ここまでの内容に基づいてFFN処理におけるパラメータは下記のように概算できます。
$$
\large
\begin{align}
2 \times N \times D \times 4D = 8ND^2 \quad (2)
\end{align}
$$

また、$N=6, D=512$の場合のパラメータ数は$(2)$式より下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 6 \times 512^{2} \\
&= 12582912 = 1.26 \times 10^{7}
\end{align}
$$

上記は$12.6$Mに対応します。

LLMのパラメータ数の概算

注意事項:Encoder-Decoderの場合

Transformer論文のFigure$1$を改変

前節では上図を元にEncoder部分のみを確認しましたが、論文によってEncoderとDecoderの双方を用いる場合があることに注意が必要です。この場合、Multi-Head Attentionのパラメータ数を$3$倍、FFNの処理のパラメータ数を$2$倍して概算する必要があります。

具体的にはTransformerの論文やT$5$の論文はEncoderとDecoderを用いており、BERTやGPT-$3$は片方のみが用いられます。

GPT-$3$論文のTable$D.1$より

上記はGPT-$3$の論文のTable$D.1$に対応しますが、「T$5$がencoder-decoder modelであるのでパラメータの半数のみがactive」というような注意書きが読み取れます。

このように論文毎にパラメータの概算方法が変わる場合があるので単にパラメータの総数だけでなく、大まかな概算法も合わせて抑えておくと良いです。

Transformer

Transformer論文のTable$3$より

Transformer論文のパラメータ設定は上記より確認できます。以下、Transformerのパラメータ数がEncoder-Decoderを前提に計算されることを元にbaseとbigの双方についてパラメータの概算を行います。

base

・Embedding
$V=37000, D=512$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 37000 \times 512 \\
&= 18{,}944{,}000
\end{align}
$$

・Multi-Head Attention
$N=6, d_{model}=512, h=8, d_{k}=64, d_{v}=64$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) \times 3 &= 2 \times 6 \times 512 \times 8 \times (64+64) \times 3 \\
&= 18{,}874{,}368
\end{align}
$$

・FFN
$N=6, D=512$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 \times 2 &= 8 \times 6 \times 512^{2} \times 2 \\
&= 25{,}165{,}824
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 18{,}944{,}000 + 18{,}874{,}368 + 25{,}165{,}824 \\
&= 62{,}984{,}192 = 6.3 \times 10^{7}
\end{align}
$$

上記は$63$Mなので、表の値と概ね一致することが確認できます。

big

・Embedding
$V=37000, D=1024$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 37000 \times 1024 \\
&= 37{,}888{,}000
\end{align}
$$

・Multi-Head Attention
$N=6, d_{model}=1024, h=16, d_{k}=64, d_{v}=64$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) \times 3 &= 2 \times 6 \times 1024 \times 16 \times (64+64) \times 3 \\
&= 75{,}497{,}472
\end{align}
$$

・FFN
$N=6, D=1024$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 \times 2 &= 8 \times 6 \times 1024^{2} \times 2 \\
&= 100{,}663{,}296
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 37{,}888{,}000 + 75{,}497{,}472 + 100{,}663{,}296 \\
&= 214{,}048{,}768 = 2.14 \times 10^{8}
\end{align}
$$

上記は$214$Mなので、表の値と概ね一致することが確認できます。

BERT

BERT論文のSection$3$より

BERT論文のパラメータ設定は上記より確認できます。以下、BERTのパラメータ数がEncoderのみを用いて計算されることを元にBASEとLARGEの双方についてパラメータの概算を行います。

BASE

・Embedding
$H=768$がhidden sizeであるので$V=30000, D=768$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 30000 \times 768 \\
&= 23{,}040{,}000
\end{align}
$$

・Multi-Head Attention
$N=12, d_{model}=768, h=12, d_{k}=64, d_{v}=64$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 12 \times 768 \times 12 \times (64+64) \\
&= 28{,}311{,}552
\end{align}
$$

・FFN
$N=12, D=768$を元にパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 12 \times 768^{2} \\
&= 56{,}623{,}104
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 23{,}040{,}000 + 28{,}311{,}552 + 56{,}623{,}104 \\
&= 107{,}974{,}656 = 1.08 \times 10^{8}
\end{align}
$$

上記は$108$Mなので、論文のパラメータ数と概ね一致することが確認できます。

LARGE

・Embedding
$H=1024$がhidden sizeであるので$V=30000, D=1024$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 30000 \times 1024 \\
&= 30{,}720{,}000
\end{align}
$$

・Multi-Head Attention
$N=24, d_{model}=1024, h=16, d_{k}=64, d_{v}=64$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 24 \times 1024 \times 16 \times (64+64) \\
&= 100{,}663{,}296
\end{align}
$$

・FFN
$N=24, D=1024$を元にパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 24 \times 1024^{2} \\
&= 201{,}326{,}592
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 30{,}720{,}000 + 100{,}663{,}296 + 201{,}326{,}592 \\
&= 332{,}709{,}888 = 3.33 \times 10^{8}
\end{align}
$$

上記は$333$Mなので、論文のパラメータ数と概ね一致することが確認できます。

T$5$

T$5$論文のSection$3.7$より

T$5$論文のパラメータ設定は上記より確認できます。以下、T$5$のパラメータ数がEncoder-Decoderを前提に計算されることを元に$3B$と$11B$の双方についてパラメータの概算を行います。

$3B$

GPT-$3$

GPT-$3$論文のTable$2.1$より

GPT-$3$のパラメータ設定は上記より確認できます。

GPT論文のFigure$1$より

GPT-$3$では上図のようにencoderを用いないTransformerであるTransformer decoderを用います。Transformer decoderの概要については下記で詳しく取り扱いました。

GPT論文のFigure$1$よりGPT-$3$のパラメータ数はencoderのみを用いるBERTと基本的には同様の概算を行えることが確認できます。以下、GPT-$3$のパラメータ数がEncoderのみを用いて計算されることを元に$6.7$B、$13$B、$175$Bについてパラメータの概算を行います。

$6.7$B

・Embedding
$V=40000, D=4096$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 40000 \times 4096 \\
&= 163{,}840{,}000
\end{align}
$$

$V=40000$はGPTの論文を参照しました。

・Multi-Head Attention
$N=32, d_{model}=4096, h=32, d_{k}=128, d_{v}=128$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 32 \times 4096 \times 32 \times (128+128) \\
&= 2{,}147{,}483{,}648
\end{align}
$$

・FFN
$N=32, D=4096$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 32 \times 4096^{2} \\
&= 4{,}294{,}967{,}296
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 163{,}840{,}000 + 2{,}147{,}483{,}648 + 4{,}294{,}967{,}296 \\
&= 6{,}606{,}290{,}944 = 6.6 \times 10^{9}
\end{align}
$$

上記は$6.6$Bなので、表の値と概ね一致することが確認できます。

$13$B

・Embedding
$V=40000, D=5140$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 40000 \times 5140 \\
&= 205{,}600{,}000
\end{align}
$$

$V=40000$はGPTの論文を参照しました。

・Multi-Head Attention
$N=40, d_{model}=5140, h=40, d_{k}=128, d_{v}=128$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 40 \times 5140 \times 40 \times (128+128) \\
&= 4{,}210{,}688{,}000
\end{align}
$$

・FFN
$N=40, D=5140$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 40 \times 5140^{2} \\
&= 8{,}454{,}272{,}000
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 205{,}600{,}000 + 4{,}210{,}688{,}000 + 8{,}454{,}272{,}000 \\
&= 12{,}870{,}560{,}000 = 1.29 \times 10^{10}
\end{align}
$$

上記は$12.9$Bなので、表の値と概ね一致することが確認できます。

$175$B

・Embedding
$V=40000, D=12288$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 40000 \times 12288 \\
&= 491{,}520{,}000
\end{align}
$$

$V=40000$はGPTの論文を参照しました。

・Multi-Head Attention
$N=96, d_{model}=12288, h=96, d_{k}=128, d_{v}=128$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 96 \times 12288 \times 96 \times (128+128) \\
&= 57{,}982{,}058{,}496
\end{align}
$$

・FFN
$N=96, D=12288$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 96 \times 12288^{2} \\
&= 115{,}964{,}116{,}992
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 491{,}520{,}000 + 57{,}982{,}058{,}496 + 115{,}964{,}116{,}992 \\
&= 174{,}437{,}695{,}488 = 1.74 \times 10^{11}
\end{align}
$$

上記は$174$Bなので、表の値と概ね一致することが確認できます。

Gopher

Gopher論文のパラメータ設定は上記より確認できます。基本的にはGPT$3$と同様にTransformer decoderの構成が用いられます。

$44$M

・Embedding
$V=32000, D=512$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 32000 \times 512 \\
&= 16{,}384{,}000
\end{align}
$$

・Multi-Head Attention
$N=8, d_{model}=512, h=16, d_{k}=32, d_{v}=32$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 8 \times 512 \times 16 \times (32+32) \\
&= 8{,}388{,}608
\end{align}
$$

・FFN
$N=8, D=512$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 8 \times 512^{2} \\
&= 16{,}777{,}216
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 16{,}384{,}000 + 8{,}388{,}608 + 16{,}777{,}216 \\
&= 41{,}549{,}824 = 4.2 \times 10^{7}
\end{align}
$$

上記は$42$Mなので、表の値と概ね一致することが確認できます。

$117$M

・Embedding
$V=32000, D=512$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 32000 \times 768 \\
&= 24{,}576{,}000
\end{align}
$$

・Multi-Head Attention
$N=12, d_{model}=768, h=12, d_{k}=64, d_{v}=64$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 12 \times 768 \times 12 \times (64+64) \\
&= 28{,}311{,}552
\end{align}
$$

・FFN
$N=12, D=768$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 12 \times 768^{2} \\
&= 56{,}623{,}104
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 24{,}576{,}000 + 28{,}311{,}552 + 56{,}623{,}104 \\
&= 109{,}510{,}656 = 1.1 \times 10^{8}
\end{align}
$$

上記は$110$Mなので、表の値と概ね一致することが確認できます。

$417$M

・Embedding
$V=32000, D=1536$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 32000 \times 1536 \\
&= 49{,}152{,}000
\end{align}
$$

・Multi-Head Attention
$N=12, d_{model}=1536, h=12, d_{k}=128, d_{v}=128$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 12 \times 1536 \times 12 \times (128+128) \\
&= 113{,}246{,}208
\end{align}
$$

・FFN
$N=12, D=1536$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 12 \times 1536^{2} \\
&= 226{,}492{,}416
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 49{,}152{,}000 + 113{,}246{,}208 + 226{,}492{,}416 \\
&= 388{,}890{,}624 = 3.89 \times 10^{8}
\end{align}
$$

上記は$389$Mなので、表の値と概ね一致することが確認できます。

$1.4$B

・Embedding
$V=32000, D=2048$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 32000 \times 2048 \\
&= 65{,}536{,}000
\end{align}
$$

・Multi-Head Attention
$N=24, d_{model}=2048, h=16, d_{k}=128, d_{v}=128$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 24 \times 2048 \times 16 \times (128+128) \\
&= 402{,}653{,}184
\end{align}
$$

・FFN
$N=24, D=2048$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 24 \times 2048^{2} \\
&= 805{,}306{,}368
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 65{,}536{,}000 + 402{,}653{,}184 + 805{,}306{,}368 \\
&= 1{,}273{,}495{,}552 = 1.27 \times 10^{9}
\end{align}
$$

上記は$1.27$Bなので、表の値と概ね一致することが確認できます。

$7.1$B

・Embedding
$V=32000, D=4096$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 32000 \times 4096 \\
&= 131{,}072{,}000
\end{align}
$$

・Multi-Head Attention
$N=32, d_{model}=4096, h=32, d_{k}=128, d_{v}=128$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 32 \times 4096 \times 32 \times (128+128) \\
&= 2{,}147{,}483{,}648
\end{align}
$$

・FFN
$N=32, D=4096$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 24 \times 2048^{2} \\
&= 4{,}294{,}967{,}296
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 131{,}072{,}000 + 2{,}147{,}483{,}648 + 4{,}294{,}967{,}296 \\
&= 6{,}573{,}522{,}944 = 6.57 \times 10^{9}
\end{align}
$$

上記は$6.57$Bなので、表の値と概ね一致することが確認できます。

$280B$

・Embedding
$V=32000, D=16384$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 32000 \times 16384 \\
&= 524{,}288{,}000
\end{align}
$$

・Multi-Head Attention
$N=80, d_{model}=16384, h=128, d_{k}=128, d_{v}=128$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 80 \times 16384 \times 128 \times (128+128) \\
&= 85{,}899{,}345{,}920
\end{align}
$$

・FFN
$N=32, D=16384$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 80 \times 16384^{2} \\
&= 171{,}798{,}691{,}840
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 524{,}288{,}000 + 85{,}899{,}345{,}920 + 171{,}798{,}691{,}840 \\
&= 258{,}222{,}325{,}760 = 2.58 \times 10^{11}
\end{align}
$$

上記は$258$Bなので、表の値と概ね一致することが確認できます。

GLaM

PaLM

PaLM論文 Table.$1$

PaLM論文のパラメータ設定は上記より確認できます。基本的にはGPT$3$と同様にTransformer decoderの構成が用いられます。

$8$B

・Embedding
$V=256000, D=4096$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 256000 \times 4096 \\
&= 1{,}048{,}576{,}000
\end{align}
$$

$V=256000$はPaLMの論文を参照しました。

・Multi-Head Attention
$N=32, d_{model}=4096, h=16, d_{k}=256, d_{v}=256$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 32 \times 4096 \times 16 \times (256+256) \\
&= 2{,}147{,}483{,}648
\end{align}
$$

・FFN
$N=32, D=4096$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 32 \times 4096^{2} \\
&= 4{,}294{,}967{,}296
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 1{,}048{,}576{,}000 + 2{,}147{,}483{,}648 + 4{,}294{,}967{,}296 \\
&= 7{,}491{,}026{,}944 = 7.49 \times 10^{9}
\end{align}
$$

上記は$7.49$Bなので、表の値$8$Bと概ね一致することが確認できます。

$62$B

・Embedding
$V=256000, D=8192$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 256000 \times 8192 \\
&= 2{,}097{,}152{,}000
\end{align}
$$

・Multi-Head Attention
$N=64, d_{model}=8192, h=32, d_{k}=256, d_{v}=256$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 32 \times 4096 \times 16 \times (256+256) \\
&= 17{,}179{,}869{,}184
\end{align}
$$

・FFN
$N=32, D=4096$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 64 \times 8192^{2} \\
&= 34{,}359{,}738{,}368
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 2{,}097{,}152{,}000 + 17{,}179{,}869{,}184 + 34{,}359{,}738{,}368 \\
&= 53{,}636{,}759{,}552 = 5.36 \times 10^{10}
\end{align}
$$

上記は$53.6$Bなので、表の値$62$Bと概ね一致することが確認できます。

$540$B

・Embedding
$V=256000, D=18432$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
V \times D &= 256000 \times 18432 \\
&= 4{,}718{,}592{,}000
\end{align}
$$

・Multi-Head Attention
$N=118, d_{model}=18432, h=72, d_{k}=256, d_{v}=256$を元にパラメータ数は下記のように計算できます。
$$
\large
\begin{align}
2 N d_{model} h(d_{k}+d_{v}) &= 2 \times 118 \times 18432 \times 72 \times (256+256) \\
&= 160{,}356{,}630{,}528
\end{align}
$$

・FFN
$N=118, D=18432$の場合のパラメータ数は下記のように概算できます。
$$
\large
\begin{align}
8ND^2 &= 8 \times 118 \times 18432^{2} \\
&= 320{,}713{,}261{,}056
\end{align}
$$

よって、パラメータの総数は下記のように概算できます。
$$
\large
\begin{align}
& 4{,}718{,}592{,}000 + 160{,}356{,}630{,}528 + 320{,}713{,}261{,}056 \\
&= 485{,}280{,}579{,}584 = 4.85 \times 10^{11}
\end{align}
$$

上記は$485$Bなので、表の値$540$Bと概ね一致することが確認できます。

参考

・Transformer論文:Attention is All you need$[2017]$
・BERT論文
・T$5$論文
・GPT$2$論文
・GPT$3$論文
・PaLM論文
・Transformer decoder論文
・直感的に理解するTransformer(運営者作成)

行列の列基本変形と基本行列(elementary matrix)による列基本変形の対応

行基本変形は基本行列(elementary matrix)の積による操作によって表すことができるなど、基本行列は
よく出てくるので抑えておくと良いです。当記事では列基本変形の概要と列基本変形と基本行列の対応について取り扱いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$3$節「行列の構造」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

列基本変形と基本行列

基本行列の定義

下記で詳しく取り扱った。

列基本変形

列基本変形と基本行列の対応

具体例の確認

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$041$

$[1]$
$$
\large
\begin{align}
\left(\begin{array}{cccc} 4 & -1 & 3 & 0 \\ 1 & 2 & -4 & 2 \end{array} \right)
\end{align}
$$

上記に対し、$1$列目と$2$列目を入れ替える列基本変形を行うと下記が得られる。
$$
\large
\begin{align}
\left(\begin{array}{cccc} 4 & -1 & 3 & 0 \\ 1 & 2 & -4 & 2 \end{array} \right) \longrightarrow \left(\begin{array}{cccc} -1 & 4 & 3 & 0 \\ 2 & 1 & -4 & 2 \end{array} \right)
\end{align}
$$

$[2]$
$$
\large
\begin{align}
\left(\begin{array}{ccc} -1 & 2 & 3 \\ 2 & 1 & 1 \end{array} \right)
\end{align}
$$

上記に対し、$1$列目の$2$倍を$2$列目に加える列基本変形を行うと下記が得られる。
$$
\large
\begin{align}
\left(\begin{array}{ccc} -1 & 2 & 3 \\ 2 & 1 & 1 \end{array} \right) \longrightarrow \left(\begin{array}{ccc} -1 & 0 & 3 \\ 2 & 5 & 1 \end{array} \right)
\end{align}
$$

$[3]$
$$
\large
\begin{align}
\left(\begin{array}{cccc} 1 & 0 & -2 & 1 \\ 1 & 1 & 1 & -1 \\ -1 & 3 & -5 & 0 \end{array} \right)
\end{align}
$$

上記に対し、$1$列目の$-1$倍を$2$列目に加え、$1$列目の$1$倍を$3$列目に加える列基本変形を行うと下記が得られる。
$$
\large
\begin{align}
\left(\begin{array}{cccc} 1 & 0 & -2 & 1 \\ 1 & 1 & 1 & -1 \\ -1 & 3 & -5 & 0 \end{array} \right) \longrightarrow \left(\begin{array}{cccc} 1 & -1 & -1 & 1 \\ 1 & 0 & 2 & -1 \\ -1 & 4 & -6 & 0 \end{array} \right)
\end{align}
$$

基本例題$042$

基本例題$043$

【Word2vecなどの出力層高速化】雑音対照推定(NCE)・負例サンプリング

分布仮説に基づくWord$2$vecなどの学習にあたっては、出力層が語彙の数に対応する分類問題に対応するので、そのまま取り扱うと巨大なソフトマックス関数の取り扱いが必要になります。当記事はNCEや負例サンプリング(Negative Sampling)を用いた解決策について取り扱いました。
「深層学習による自然言語処理」$4.3$節の「出力層の高速化」などを参考に当記事の作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

語彙が大きい場合のソフトマックス関数の取り扱い

Word$2$vecや翻訳・文書要約・対話などの文章の生成タスクなどの学習を行う際は、入力文$\mathbf{x} \in \mathcal{X}$と予測対象の単語$y \in \mathcal{Y}$を元に、下記のように交差エントロピー損失関数を定義する。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) = -\log{ \frac{\exp{(f_{\boldsymbol{\theta}}(\mathbf{x},y))}}{\displaystyle \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y}))}} } \quad (1.1)
\end{align}
$$

上記に対し、下記のように$s(y)$と$Z(\mathcal{Y})$を定義する。
$$
\large
\begin{align}
s(y) &= f_{\boldsymbol{\theta}}(\mathbf{x},y) \quad (1.2) \\
Z(\mathcal{Y}) &= \sum_{\tilde{y} \in \mathcal{Y}} \exp{(s(\tilde{y}))} = \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y})))} \quad (1.3)
\end{align}
$$

$(1.2), \, (1.3)$式を元に$(1.1)$式は下記のように表せる。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) &= -\log{ \frac{\exp{(f_{\boldsymbol{\theta}}(\mathbf{x},y))}}{\displaystyle \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y}))}} } \quad (1.1) \\
&= -s(y) + \log{Z(\mathcal{Y})} \quad (1.4)
\end{align}
$$

ここで$(1.4)$式の両辺をパラメータベクトル$\boldsymbol{\theta} \in \mathbb{R}^{p}$で方向微分すると下記が得られる。
$$
\large
\begin{align}
\nabla l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) = – \nabla s(y) + \nabla \log{Z(\mathcal{Y})} \quad (1.5)
\end{align}
$$

ここまでの詳しい式変形は下記で取り扱った。

$(1.5)$式の取り扱いの際に語彙数$|\mathcal{Y}|$が大きいと出力層の計算量が大きくなるので、何らかの解決策があると良いが、上記の記事では重点サンプリングを用いた手法について取り扱った。当記事ではNCEと負例サンプリングに基づく手法を次節で取り扱う。

NEC・負例サンプリング

学習にあたっての基本的な方針

上記の重点サンプリングでは分配関数$Z$の近似を行ったが、当節では$Z$を未知のパラメータとみなして学習させる方法について取り扱う。この学習にあたっては、通常の最尤推定ではない目的関数を設定し、パラメータの推定を行う。

具体的には「訓練データ」と「無作為に生成したノイズ」を区別するような目的関数を用いることで$Z$の推定を行う。

以下、上記に基づく手法である「雑音対照推定(NCE; Noise Contrasive Estimation)」、「負例サンプリング(Negative Sampling)」についてそれぞれ取り扱った。

NCE

NCE(Noise Contrasive Estimation)では「訓練データ」と「ノイズの分布$q$から生成された標本」の識別を行うような分類器の学習を行う。分類を取り扱うにあたって、確率変数$D$を用いて「訓練データ」を$D=1$、「ノイズからの標本」を$D=0$で定義する。

ここで$D=0$が$D=1$の$k$倍出現しやすいと仮定するとき、$D$と単語に対応する$Y$の同時確率は$D=1$と$D=0$に対してそれぞれ下記のように表せる。
$$
\large
\begin{align}
P(D=1, Y=y) &= \frac{1}{k+1} p(y) \quad (2.1) \\
P(D=0, Y=y) &= \frac{k}{k+1} q(y) \quad (2.2)
\end{align}
$$

このとき、単語に関する周辺確率$P(Y=y)$は下記のように表せる。
$$
\large
\begin{align}
P(Y=y) &= P(D=1, Y=y) + P(D=0, Y=y) \\
&= \frac{1}{k+1} p(y) + \frac{k}{k+1} q(y) \quad (2.3)
\end{align}
$$

上記に基づいて単語$Y=y$がノイズからの標本である確率$P(D=0|Y=y)$は条件付き確率の定義式を元に下記のように得られる。
$$
\large
\begin{align}
P(D=0|Y=y) &= \frac{P(D=0, Y=y)}{P(Y=y)} \\
&= \frac{\displaystyle \frac{k}{\cancel{k+1}} q(y)}{\displaystyle \frac{1}{\cancel{k+1}} p(y) + \frac{k}{\cancel{k+1}} q(y)} \\
&= \frac{k q(y)}{p(y) + k q(y)} \quad (2.4)
\end{align}
$$

同様に単語$Y=y$が訓練データである確率$P(D=1|Y=y)$は下記のように得られる。
$$
\large
\begin{align}
P(D=1|Y=y) &= \frac{P(D=1, Y=y)}{P(Y=y)} \\
&= \frac{\displaystyle \frac{1}{\cancel{k+1}} p(y)}{\displaystyle \frac{1}{\cancel{k+1}} p(y) + \frac{k}{\cancel{k+1}} q(y)} \\
&= \frac{p(y)}{p(y) + k q(y)} \quad (2.5)
\end{align}
$$

ここで$1$つの訓練データ$y$に対して、ノイズ分布$q$からの$k$個の標本$\tilde{\mathcal{D}}=(\tilde{y}_{1}, \cdots , \tilde{y}_{k})$の無作為抽出を行う。この$k+1$個の事例に対して下記の関数$l_{\boldsymbol{\theta}}^{\mathrm{NCE}}(y)$を定義する。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{NCE}}(y) &= -\log{P(D=1|Y=y)} – \sum_{\tilde{y} \in \mathcal{D}} \log{P(D=0|Y=\tilde{y})} \\
&= -\log{\frac{p(y)}{p(y) + k q(y)}} – \sum_{\tilde{y} \in \mathcal{D}} \log{\frac{k q(\tilde{y})}{p(\tilde{y}) + k q(\tilde{y})}} \quad (2.6)
\end{align}
$$

上記の$(2.6)$式は$k+1$個の事例に対する$D$の負の対数尤度であり、$\boldsymbol{\theta}$は関数$\displaystyle p(y)=\frac{\exp{(s(y))}}{Z}=\exp{(s(y)+c)}$で用いられるパラメータである。

$\displaystyle p(y)=\frac{\exp{(s(y))}}{Z}=\exp{(s(y)+c)}$の$c$は下記のように解釈することができる。
$$
\large
\begin{align}
\exp{(s(y)+c)} &= \exp{(s(y))} \exp{(c)} \\
&= \frac{\exp{(s(y))}}{\exp{(-c)}} = \frac{\exp{(s(y))}}{Z} \\
Z &= \exp{(-c)} \quad (2.7)
\end{align}
$$

$(2.6)$式を$\boldsymbol{\theta}$について最適化する際に、このように導入した$c$を同時に学習させることで分配関数$Z$の計算を省略することができる。また、$Z=\exp{(-c)}=1$とおいても結果が変わらないという実験結果もあり、この場合は下記の$(2.8)$式のような目的関数を用いる。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{NCE}}(y) &= -\log{\frac{p(y)}{p(y) + k q(y)}} – \sum_{\tilde{y} \in \mathcal{D}} \log{\frac{k q(\tilde{y})}{p(\tilde{y}) + k q(\tilde{y})}} \quad (2.6) \\
&= -\log{\frac{\exp{(s(y))}}{\exp{(s(y))} + k q(y)}} – \sum_{\tilde{y} \in \mathcal{D}} \log{\frac{k q(\tilde{y})}{\exp{(s(\tilde{y}))} + k q(\tilde{y})}} \quad (2.8)
\end{align}
$$

雑音対照推定(NCE)は上記の$(2.6)$式や$(2.8)$式を用いてパラメータの推定を行う手法である。

負例サンプリング

負例サンプリング(Negative Sampling)はNCEの損失関数の$(2.6)$式の確率にシグモイド関数を用いて簡略化を行った手法である。負例サンプリングの目的関数$l_{\boldsymbol{\theta}}^{\mathrm{NS}}(y)$は下記のように定義される。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{NS}}(y) &= -\log{[\mathrm{sigmoid}(s(y))]} – \sum_{\tilde{y} \in \mathcal{D}} \log{[1-\mathrm{sigmoid}(s(y))]} \\
&= \log{\frac{\exp{(s(y))}}{\exp{(s(y))}+1}} – \sum_{\tilde{y} \in \mathcal{D}} \log{\frac{1}{\exp{(s(\tilde{y}))}+1}} \quad (2.9) \\
\mathrm{sigmoid}(x) &= \frac{\exp{(x)}}{\exp{(x)}+1}
\end{align}
$$

$(2.9)$式の$\tilde{y}$が一様分布によってサンプリングされる場合は、$(2.8)$式で$k=|\mathcal{Y}|$かつ$\displaystyle q(\tilde{y}) = \frac{1}{|\mathcal{Y}|}$である場合の式に一致することは抑えておくと良い。

一方で、負例サンプリングの論文では、単語の出現頻度に比例した確率であるユニグラム確率の$\displaystyle \frac{3}{4}$を使用すると単純なユニグラム確率や一様分布を使用する場合以上に性能が上がるとされる。この辺の設定は実験の前提にもよるので、参考に確認しておくで十分であると思われる。

参考文献

・NCE論文
・負例サンプリング論文

複数の行基本変形と基本行列(elementary matrix)の積の対応

行基本変形は基本行列(elementary matrix)の積による操作によって表すことができるなど、基本行列はよく出てくるので抑えておくと良いです。当記事では複数の行基本変形と基本行列の積の対応について取り扱いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$3$節「行列の構造」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

複数の行基本変形と基本行列の積

基本行列の定義

下記で詳しく取り扱った。

複数の行基本変形と基本行列の積の対応

「複数の行基本変形を行うこと」は「対応する基本行列を左から次々に掛ける」ことに対応する。具体例は次節で取り扱った。

基本行列の具体例の確認

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$040$

・$[1]$
$2$行目に$3$を掛けたのちに$2$行目と$3$行目を入れ替える操作は行列$A$に$P_{23}P_{2}(3)$を左から掛ける演算に対応する。

・$[2]$
$2$行目に$3$行目の$-1$倍を加えたのちに$4$行目に$1$行目の$2$倍を加える操作は行列$A$に$P_{41}(2)P_{23}(-1)$を左から掛ける演算に対応する。

・$[3]$
$3$行目に$1$行目の$2$倍を加え、$2$行目と$4$行目を入れ替えたのちに$4$行目に$3$行目の$-3$倍を加える操作は行列$A$に$P_{43}(-3)P_{24}P_{31}(2)$を左から掛ける演算に対応する。

【Word2vecなどの出力層高速化】巨大なソフトマックス関数の課題と重点サンプリングによる解決

分布仮説(distributional hypothesis)に基づくWord$2$vecなどの学習にあたっては、出力層が語彙の数に対応する分類問題に対応するので、そのまま取り扱うと巨大なソフトマックス関数の取り扱いが必要になります。当記事はこの課題の重点サンプリングを用いた解決策について取りまとめました。
「深層学習による自然言語処理」$4.3$節の「出力層の高速化」などを参考に当記事の作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

ソフトマックス関数の巨大化

交差エントロピー損失関数とソフトマックス関数

Word$2$vecや翻訳・文書要約・対話などの文章の生成タスクの文章の生成タスクの学習を行う際は、入力文$\mathbf{x} \in \mathcal{X}$と予測対象の単語$y \in \mathcal{Y}$を元に、下記のように交差エントロピー損失関数を定義する。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) = -\log{ \frac{\exp{(f_{\boldsymbol{\theta}}(\mathbf{x},y))}}{\displaystyle \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y}))}} } \quad (1.1)
\end{align}
$$

文章の全体の生成における学習については$(1.1)$式を各単語について和を計算すれば良いが、式が複雑になるので以下では$1$つの単語の予測のみを取り扱う。ここで$(1.1)$式は条件付き確率分布$P(y|\mathbf{x})$やソフトマックス関数$\mathrm{softmax}(x)$を用いて下記のように表すこともできる。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) &= -\log{P(y|\mathbf{x})} \quad (1.2) \\
P(y|\mathbf{x}) &= \mathrm{softmax}(f_{\boldsymbol{\theta}}(\mathbf{x},y)) = \frac{\exp{(f_{\boldsymbol{\theta}}(\mathbf{x},y))}}{\displaystyle \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y}))}} \quad (1.3)
\end{align}
$$

$(1.2)$式は負の対数尤度に一致するので、損失関数の最大化は対数尤度の最大化に対応する。また、$(1.3)$式はソフトマックス関数の定義式を含む。

分配関数の定義

前項「交差エントロピー損失関数とソフトマックス関数」$(1.1)$式の$y$の取り扱いにあたって、下記のように関数$s(y)$を定義する。
$$
\large
\begin{align}
s(y) = f_{\boldsymbol{\theta}}(\mathbf{x},y) \quad (1.4)
\end{align}
$$

また、下記のように分配関数(partition function)の$Z(\mathcal{Y})$を定義する。
$$
\large
\begin{align}
Z(\mathcal{Y}) = \sum_{\tilde{y} \in \mathcal{Y}} \exp{(s(\tilde{y}))} = \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y})))} \quad (1.5)
\end{align}
$$

このとき$(1.1)$式は$(1.4), \, (1.5)$式を用いて下記のように表すことができる。
$$
\large
\begin{align}
l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) &= -\log{ \frac{\exp{(f_{\boldsymbol{\theta}}(\mathbf{x},y))}}{\displaystyle \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y}))}} } \quad (1.1) \\
&= -\log{ \frac{\exp{(s(y))}}{Z(\mathcal{Y})} } \\
&= -s(y) + \log{Z(\mathcal{Y})} \quad (1.6)
\end{align}
$$

パラメータに関する勾配

$(1.6)$式の両辺をパラメータベクトル$\boldsymbol{\theta} \in \mathbb{R}^{p}$で方向微分すると下記が得られる。
$$
\large
\begin{align}
\nabla l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) &= – \nabla s(y) + \nabla \log{Z(\mathcal{Y})} \quad (1.6)’ \\
\nabla &= \frac{\partial}{\partial \boldsymbol{\theta}} = \left(\begin{array}{c} \displaystyle \frac{\partial}{\partial \theta_{1}} \\ \vdots \\ \displaystyle \frac{\partial}{\partial \theta_{p}} \end{array} \right)
\end{align}
$$

ここで$(1.6)’$式の第$2$項の$\nabla \log{Z(\mathcal{Y})}$は合成関数の微分の公式を元に下記のように変形できる。
$$
\large
\begin{align}
\nabla \log{Z(\mathcal{Y})} &= \sum_{\tilde{y} \in \mathcal{Y}} \frac{\exp{(s(\tilde{y}))}}{Z(\mathcal{Y})} s'(\tilde{y}) \\
&= \sum_{\tilde{y} \in \mathcal{Y}} p(\tilde{y}) s'(\tilde(y)) \\
&= \mathbb{E}_{Y \sim p}[s'(Y)] \quad (1.7)
\end{align}
$$

$(1.7)$式を元に$(1.6)’$式は下記のように表せる。
$$
\large
\begin{align}
\nabla l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) &= – \nabla s(y) + \nabla \log{Z(\mathcal{Y})} \quad (1.6)’ \\
&= – \nabla s(y) + \mathbb{E}_{Y \sim p}[s'(Y)] \quad (1.8)
\end{align}
$$

$(1.8)$式の計算にあたって$\mathcal{Y}$が巨大な場合は、$\mathbb{E}_{Y \sim p}[s'(Y)]$は全ての語彙に対して計算する必要があることで計算量が大きくなる。この際に$\mathcal{Y}$に比例しない計算量で近似を行う手法にはいくつかあるが、次節では重点サンプリングを用いた手法を確認する。

重点サンプリングを用いた出力層の高速化

モンテカルロ法

$(1.8)$式における$\mathbb{E}_{Y \sim p}[s'(Y)]$のモンテカルロ法を用いた近似式は下記のように表せる。
$$
\large
\begin{align}
\mathbb{E}_{Y \sim p}[s'(Y)] & \simeq \frac{1}{n} \sum_{i=1}^{n} s'(\tilde{y}_{i}) \quad (2.1) \\
\tilde{y}_{1}, \, \cdots , & \, \tilde{y}_{n} \sim p \quad (2.2)
\end{align}
$$

ここで$(2.2)$式の$p$は$(1.3)$式に対応するので、ここでの$\tilde{y}_{i}$のサンプリングを行うにあたっては分配関数$Z$の計算が必要である。よって、次項では$p$を用いずにサンプリングを行う重点サンプリングを確認する。

重点サンプリング

対象の分布からの無作為抽出が難しい場合に重点サンプリング(importance sampling)はよく用いられる。ここでは$(1.3)$式で表される$p$の代わりに提案分布$q$を用いて重点サンプリングを行うことを考える。このとき$\mathbb{E}_{Y \sim p}[s'(Y)]$は下記のように表すことができる。
$$
\large
\begin{align}
\mathbb{E}_{Y \sim p}[s'(Y)] &= \sum_{\tilde{y} \in \mathcal{Y}} s'(\tilde{y})p(\tilde{y}) \\
&= \sum_{\tilde{y} \in \mathcal{Y}} s'(\tilde{y}) \frac{p(\tilde{y})}{q(\tilde{y})} q(\tilde{y}) \\
&= \mathbb{E}_{Y’ \sim q} \left[ s'(Y) \frac{p(Y’)}{q(Y’)} \right] \quad (2.3)
\end{align}
$$

重点サンプリングの数式表記と計算例については下記でも取り扱った。

ここで$(2.3)$式に基づくモンテカルロ近似の式は下記のように表せる。
$$
\large
\begin{align}
\mathbb{E}_{Y \sim p}[s'(Y)] = \mathbb{E}_{Y’ \sim q} \left[ s'(Y) \frac{p(Y’)}{q(Y’)} \right] \simeq \frac{1}{n} \sum_{i=1}^{n} s'(\tilde{y}_{i}) \frac{p(\tilde{y}_{i})}{q(\tilde{y}_{i})} \quad (2.4)
\end{align}
$$

$(2.4)$式を用いることで提案分布$q$に基づくサンプリングを行える一方で、$(2.4)$式の$p(\tilde{y}_{i})$の計算には分配関数$Z$の計算が必要である。そこで次項では分配関数$Z$の近似値の取得について取り扱う。

分配関数の近似

単語の集合を$\mathcal{Y}$で表したので、語彙数は$|\mathcal{Y}|$で表せる。ここで下記のような一様分布$u$を定義する。
$$
\large
\begin{align}
u(x) = \frac{1}{|\mathcal{Y}|} \quad (2.5)
\end{align}
$$

$(2.5)$式で表した一様分布$u$に従う確率変数$X \sim u$について$\exp{(s(X))}$の期待値は下記のように表すことができる。
$$
\large
\begin{align}
\mathbb{E}_{X \sim u}[\exp{(s(X))}] = \frac{1}{|\mathcal{Y}|} \sum_{\tilde{x} \in \mathcal{Y}} \exp{(s(\tilde{x}))} \quad (2.6)
\end{align}
$$

ここで$(2.6)$式に$(1.5)$式を用いることで、下記のような変形を行うことができる。
$$
\large
\begin{align}
\mathbb{E}_{X \sim u}[\exp{(s(X))}] &= \frac{1}{|\mathcal{Y}|} \sum_{\tilde{x} \in \mathcal{Y}} \exp{(s(\tilde{x}))} \quad (2.6) \\
\mathbb{E}_{X \sim u}[\exp{(s(X))}] &= \frac{1}{|\mathcal{Y}|} Z(\mathcal{Y}) \\
Z(\mathcal{Y}) &= |\mathcal{Y}| \mathbb{E}_{X \sim u}[\exp{(s(X))}] \quad (2.7)
\end{align}
$$

上記の$(2.7)$式に対し、下記の導出を元に提案分布$q$を用いて重点サンプリングを行う。
$$
\large
\begin{align}
Z(\mathcal{Y}) &= |\mathcal{Y}| \mathbb{E}_{X \sim u}[\exp{(s(X))}] \quad (2.7) \\
&= |\mathcal{Y}| \mathbb{E}_{Y’ \sim q} \left[\exp{(s(Y’))} \frac{u(Y’)}{q(Y’)} \right] \\
&= |\mathcal{Y}| \mathbb{E}_{Y’ \sim q} \left[ \exp{(s(Y’))} \frac{1}{|\mathcal{Y}| q(Y’)} \right] \\
&= \frac{\cancel{|\mathcal{Y}|}}{\cancel{|\mathcal{Y}|}} \mathbb{E}_{Y’ \sim q} \left[ \frac{\exp{(s(Y’))}}{q(Y’)} \right] \\
&= \mathbb{E}_{Y’ \sim q} \left[ \frac{\exp{(s(Y’))}}{q(Y’)} \right] \\
& \simeq \frac{1}{n} \sum_{i=1}^{n} \frac{\exp{(s(\tilde{y}_{i}))}}{q(\tilde{y}_{i})} \\
&= \hat{Z} \quad (2.8)
\end{align}
$$

勾配の式の導出

$(1.3)$式、$(1.5)$式、$(2.8)$式に基づいて下記が成立する。
$$
\large
\begin{align}
p(y) &= \frac{\exp{(f_{\boldsymbol{\theta}}(\mathbf{x},y))}}{\displaystyle \sum_{\tilde{y} \in \mathcal{Y}} \exp{(f_{\boldsymbol{\theta}}(\mathbf{x},\tilde{y}))}} \quad (1.3)’ \\
&= \frac{\exp{(s(y))}}{Z(\mathcal{Y})} \quad (1.5)’ \\
& \simeq \frac{\exp{(s(y))}}{\hat{Z}} \quad (2.9)
\end{align}
$$

$(2.9)$式を元に下記が成立する。
$$
\large
\begin{align}
s'(y) \frac{p(y)}{q(y)} & \simeq s'(y) \frac{\exp{(s(y))}/\hat{Z}}{q(y)} \\
&= s'(y) \frac{\exp{(s(y))}/q(y)}{\hat{Z}} \\
&= s'(y) \frac{\displaystyle \frac{\exp{(s(y))}}{q(y)}}{\displaystyle \frac{1}{n} \sum_{i=1}^{n} \frac{\exp{(s(\tilde{y}_{i}))}}{q(\tilde{y}_{i})}} \quad (2.10)
\end{align}
$$

$(2.3), \, (2.4)$式の重点サンプリングの式に$(2.10)$式を適用すると、サンプル$\tilde{q}_{1}, \cdots , \tilde{q}_{n} \sim q$を元に下記が得られる。
$$
\large
\begin{align}
\mathbb{E}_{Y \sim p}[s'(Y)] &= \mathbb{E}_{Y’ \sim q} \left[ s'(Y) \frac{p(Y’)}{q(Y’)} \right] \quad (2.3) \\
& \simeq \frac{1}{n} \sum_{i=1}^{n} s'(\tilde{y}_{i}) \frac{p(\tilde{y}_{i})}{q(\tilde{y}_{i})} \times \left( \frac{1}{n} \sum_{i=1}^{n} \frac{\exp{(s(\tilde{y}_{i}))}}{q(\tilde{y}_{i})} \right)^{-1} \\
&= \frac{\displaystyle \sum_{i=1}^{n} s'(\tilde{y}_{i}) \frac{p(\tilde{y}_{i})}{q(\tilde{y}_{i})}}{\displaystyle \sum_{i=1}^{n} \frac{\exp{(s(\tilde{y}_{i}))}}{q(\tilde{y}_{i})}} \quad (2.11)
\end{align}
$$

$(2.11)$式を$(1.8)$式に代入することで下記が得られる。
$$
\large
\begin{align}
\nabla l_{\boldsymbol{\theta}}^{\mathrm{softmax}}(\mathbf{x},y) &= – \nabla s(y) + \mathbb{E}_{Y \sim p}[s'(Y)] \quad (1.8) \\
& \simeq – \nabla s(y) + \frac{\displaystyle \sum_{i=1}^{n} s'(\tilde{y}_{i}) \frac{p(\tilde{y}_{i})}{q(\tilde{y}_{i})}}{\displaystyle \sum_{i=1}^{n} \frac{\exp{(s(\tilde{y}_{i}))}}{q(\tilde{y}_{i})}} \quad (2.12)
\end{align}
$$

$(2.12)$式を用いることで、語彙数$|\mathcal{Y}|$の大きさに関わらず勾配の近似値の計算を行うことができる。

式の解釈

当記事で確認を行った導出は複雑かつ難解であるので、「語彙数$|\mathcal{Y}|$が大きくなる際の勾配計算の高速化」が主目的であることは常に念頭に置く必要がある。

語彙数$|\mathcal{Y}|$が大きくなると分配関数$Z(\mathcal{Y})$の計算量が増加するが、分配関数$Z(\mathcal{Y})$の計算なしでは$(2.1), \, (2.2)$式に基づいてモンテカルロ法をそのまま適用することができない。

上記に対し、「重点サンプリング」と「分配関数の近似」を用いて勾配の近似を行うというのが当節における導出の流れである。特に分配関数の近似にあたっては、ニューラルネットワークの出力層に対応する$s(\tilde{y}_{i})$を提案分布$q$からサンプリングされた$\tilde{y}_{i}$を元にいくつか計算することで、分配関数の大体の大きさを推測することができると解釈しておくとよい。ソフトマックス関数における分配関数$Z(\mathcal{Y})$の役割は値の正規化であり、$(2.8)$式のような近似で大まかな値を得ることができる。

また、提案分布$q$に$u$と同じく一様分布を用いる場合は$(2.8)$が下記のように表される。
$$
\large
\begin{align}
\hat{Z} \simeq \frac{1}{n} \sum_{i=1}^{n} |\mathcal{Y}| \exp{(s(\tilde{y}_{i}))} \quad (2.8)’
\end{align}
$$

上記を$1$つのサンプルを語彙数倍しサンプル数で割ったと解釈すると、サンプリング結果に基づく分配関数の近似式であることが理解しやすい。

基本行列(elementary matrix)の判定と基本行列による行基本操作の確認

行基本変形は基本行列(elementary matrix)の積による操作によって表すことができるなど、基本行列はよく出てくるので抑えておくと良いです。当記事では基本行列の定義や基本行列かどうかの判定、基本行列と行基本変形の対応について取り扱いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$3$節「行列の構造」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

基本行列の概要

基本行列の定義

基本行列を下記の$[1]$〜$[3]$で定義する。

・$[1]$
基本行列$P_{ij}, \, i \neq j$を下記のように定義する。
$$
\large
\begin{align}
P_{ij} &= \left(\begin{array}{ccccccccccc} 1 & & & & & & & & & & \\ & \ddots & & & & & & & & & \\ & & 1 & & & & & & & & \\ & & & 0 & \cdots & \cdots & \cdots & p_{ij} & & & \\ & & & \vdots & 1 & & & \vdots & & & \\ & & & \vdots & & \ddots & & \vdots & & & \\ & & & \vdots & & & 1 & \vdots & & & \\ & & & p_{ji} & \cdots & \cdots & \cdots & 0 & & & \\ & & & & & & & & 1 & & \\ & & & & & & & & & \ddots & \\ & & & & & & & & & & 1 \end{array} \right) \quad (1) \\
p_{ij} &= p_{ji} = 1
\end{align}
$$

$(1)$式で$1$や$p_{ij}, p_{ji}$で表さなかった成分は全て$0$である。

・$[2]$
基本行列$P_i(c), \, c \neq 0$を下記のように定義する。
$$
\large
\begin{align}
P_i(c) &= \left(\begin{array}{ccccccc} 1 & & & & & & \\ & \ddots & & & & & \\ & & 1 & & & & \\ & & & p_{ii} & & & \\ & & & & 1 & & \\ & & & & & \ddots & \\ & & & & & & 1 \end{array} \right) \quad (2) \\
p_{ii} &= c
\end{align}
$$

$(2)$式で$1$や$p_{ii}$で表さなかった成分は全て$0$である。

・$[3]$
基本行列$P_{ij}(a), \, i \neq j, \, a \neq 0$を下記のように定義する。
$$
\large
\begin{align}
P_{ij}(a) &= \left(\begin{array}{ccccccc} 1 & & & & & & \\ & \ddots & & & & & \\ & & 1 & \cdots & p_{ij} & & \\ & & & \ddots & \vdots & & \\ & & & & 1 & & \\ & & & & & \ddots & \\ & & & & & & 1 \end{array} \right) \quad (3) \\
p_{ij} &= a
\end{align}
$$

$(2)$式で$1$や$p_{ij}$で表さなかった成分は全て$0$である。

基本行列であるかの判定

前項の「基本行列の定義」に合致するかで判定を行えば良い。

基本行列と行基本変形の対応

$[1] \,$ $m \times n$行列$A$の$i$行と$j$列を入れ替える操作は$m$次の基本行列$P_{ij}$を用いて$P_{ij}A$を計算することに一致する。
$[2] \,$ $m \times n$行列$A$の$i$行を$c$倍する操作は$m$次の基本行列$P_{ii}(c)$を用いて$P_{ii}(c)A$を計算することに一致する。
$[3] \,$ $m \times n$行列$A$の$i$行に$j$行の$a$倍を加える操作は$m$次の基本行列$P_{ij}(a)$を用いて$P_{ij}(a)A$を計算することに一致する。

基本行列の具体例の確認

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$038$

・$[1]$
$$
\large
\begin{align}
\left(\begin{array}{cc} 0 & 1 \\ 1 & 0 \end{array} \right)
\end{align}
$$

上記は基本行列$P_{12}$に対応する。

・$[2]$
$$
\large
\begin{align}
\left(\begin{array}{ccc} 1 & 0 & 0 \\ 1 & 1 & 0 \\ 0 & 0 & 1 \end{array} \right)
\end{align}
$$

上記は基本行列$P_{21}(1)$に対応する。

・$[3]$
$$
\large
\begin{align}
\left(\begin{array}{cccc} 2 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{array} \right)
\end{align}
$$

上記は基本行列$P_{11}(2)$に対応する。

・$[4]$
$$
\large
\begin{align}
\left(\begin{array}{cccc} 0 & 0 & 0 & 0 \\ 0 & 1 & 0 & 2 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{array} \right)
\end{align}
$$

上記は基本行列ではない。

基本例題$039$

$[1]$
$$
\large
\begin{align}
X &= \left(\begin{array}{cccc} 4 & -1 & 3 & 0 \\ 1 & 2 & -4 & 2 \end{array} \right) \\
P_{12} &= \left(\begin{array}{cc} 0 & 1 \\ 1 & 0 \end{array} \right)
\end{align}
$$

上記より$P_{12} X$は下記のように計算できる。
$$
\large
\begin{align}
P_{12} X &= \left(\begin{array}{cc} 0 & 1 \\ 1 & 0 \end{array} \right) \left(\begin{array}{cccc} 4 & -1 & 3 & 0 \\ 1 & 2 & -4 & 2 \end{array} \right) \\
&= \left(\begin{array}{cccc} 1 & 2 & -4 & 2 \\ 4 & -1 & 3 & 0 \end{array} \right)
\end{align}
$$

「行列の$1$行目と$2$行目を入れ替えることが、行列に基本行列$P_{12}$を左からかけることに対応する」ことが上記の例では確認できる。

列基本変形を用いた行列の簡約階段形(reduced echelon form)から標準形への変換

行列の標準形は階段形から行基本変形を行なって導出した簡約階段形(reduced echelon form)に列基本変形を行うことで得ることができます。当記事では列基本変形を用いた簡約階段形から標準形への変換などについて、概要と具体例を取り扱いました。
作成にあたっては「チャート式シリーズ 大学教養 線形代数」の第$3$節「行列の構造」を主に参考にしました。

・数学まとめ
https://www.hello-statisticians.com/math_basic

簡約階段形の概要

標準形の定義

$m \times n$行列$A$について$r=\mathrm{rank}A$であるとき、行列$A$に行基本変形と列基本変形を行うことで下記の行列$X$に変形することができる。
$$
\large
\begin{align}
X &= \left(\begin{array}{ccccccc} x_{11} & 0 & \cdots & 0 & 0 & \cdots & 0 \\ 0 & x_{22} & \cdots & 0 & 0 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots & \vdots & \vdots & \vdots \\ 0 & 0 & \cdots & x_{rr} & 0 & \cdots & 0 \\ 0 & 0 & \cdots & 0 & 0 & \cdots & 0 \\ 0 & 0 & \cdots & 0 & 0 & \ddots & 0 \\ 0 & 0 & \cdots & 0 & 0 & \cdots & 0 \end{array} \right) \\
x_{ii} &= 1, \, i \leq r
\end{align}
$$

$A$に行基本変形と列基本変形を行うことで得られる上記の$X$の形式の行列を行列$A$の標準形という。

行列の標準形への変形

標準形の取得にあたっては行基本変形によって簡約階段形に変形し、簡約階段形に列基本変形を行えばよい。簡約階段形については下記で詳しく取り扱った。

簡約階段形から標準形を得るにあたっては列の入れ替えによって、対角成分に$r=\mathrm{rank}A$個の$1$を並べ、それ以外の列は掃き出し法の要領で全ての要素が$0$になるように変形すればよい。

簡約階段形の判定法と簡約階段化の手順の具体例の確認

以下、「チャート式シリーズ 大学教養 線形代数」の例題の確認を行う。

基本例題$044$

$(1)$
$$
\large
\begin{align}
\left(\begin{array}{cccc} 4 & -1 & 3 & 0 \\ 1 & 2 & -4 & 2 \end{array} \right)
\end{align}
$$

上記の行列は下記のように行基本変形を元に簡約階段化できる。
$$
\large
\begin{align}
\left(\begin{array}{cccc} 4 & -1 & 3 & 0 \\ 1 & 2 & -4 & 2 \end{array} \right) & \rightarrow \left(\begin{array}{cccc} 1 & 2 & -4 & 2 \\ 4 & -1 & 3 & 0 \end{array} \right) \rightarrow \left(\begin{array}{cccc} 1 & 2 & -4 & 2 \\ 0 & -9 & 19 & -8 \end{array} \right) \\
& \rightarrow \left(\begin{array}{cccc} 1 & 2 & -4 & 2 \\ 0 & 1 & \displaystyle -\frac{19}{9} & \displaystyle \frac{8}{9} \end{array} \right) \rightarrow \left(\begin{array}{cccc} 1 & 0 & \displaystyle \frac{2}{9} & \displaystyle \frac{2}{9} \\ 0 & 1 & \displaystyle -\frac{19}{9} & \displaystyle \frac{8}{9} \end{array} \right)
\end{align}
$$

上記について列基本変形を行うことで下記のような標準形が得られる。
$$
\large
\begin{align}
\left(\begin{array}{cccc} 1 & 0 & \displaystyle \frac{2}{9} & \displaystyle \frac{2}{9} \\ 0 & 1 & \displaystyle -\frac{19}{9} & \displaystyle \frac{8}{9} \end{array} \right) \rightarrow \left(\begin{array}{cccc} 1 & 0 & 0 & 0 \\ 0 & 1 & \displaystyle -\frac{19}{9} & \displaystyle \frac{8}{9} \end{array} \right) \rightarrow \left(\begin{array}{cccc} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \end{array} \right)
\end{align}
$$

重要例題$017$

InstructGPTの概要まとめ 〜GPT3、RLHF、RewardModel〜

近年大きな注目を集めるChatGPTの学習にあたっては、強化学習に基づくRLHF(Reinforcement Learning from Human Feedback)がfinetuningに用いられます。当記事では同様の枠組みを取り扱ったInstructGPTの概要をまとめました。
作成にあたってはInstructGPTの論文である「Training language models to follow instructions with human feedback」を主に参考にしました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

・InstructGPT論文
・仕組みから理解するChatGPT(筆者作成)

前提の確認

Transformer

下記で詳しく取り扱った。
・直感的に理解するTransformer

GPT-3

下記で詳しく取り扱った。

PPOを用いた強化学習

PPO(Proximal Policy Optimization)は方策勾配法(Policy Gradient)の学習の安定化にあたって、繰り返し演算におけるパラメータの修正幅を制限する手法である。詳しくは下記で取り扱った。

InstructGPT

大まかな流れ

大まかな流れは下記のInstructGPT論文Figure$2$を元に理解すると良い。

InstructGPT論文 Figure$2$

上図より、InstructGPTの大まかな流れは下記の$3$つのステップで表される。

$1. \,$ プロンプトの入力例に対し人間が正解例を作成し、教師あり学習の形式でGPT-$3$のfinetuningを行う。
$2. \,$ 複数のプロンプトの出力結果に対し、人間がランク付けを行い、RewardModelを学習させる。
$3. \,$ 学習させたRewardModelを元にプロンプトの出力結果に対しRewardを出力し、この値に応じて文の生成における方策をPPOを用いて強化学習させる。

上記の『プロンプト』はGPT-$3$に対しfinetuningと方策の最適化を行なったInstructGPTの入出力に対応することに注意が必要である。以下では$1.$〜$3.$についてそれぞれ詳しく確認を行う。

$1.$ Supervised Fine-Tuning

InstructGPTでは学習済みのGPT-$3$に対し、教師ありFine-Tuning(SFT; Supervised Fine-Tuning)を行う。具体的には人間(labeler)がプロンプトの入力例に対し回答を作成し、その内容に基づいてFine-Tuningを行う。

SFTでは基本的に教師あり学習と同様の手順で学習を行うが、学習済みのGPT-$3$を用いることから教師なし学習を十分に行ったのちの処理であることは注意して抑えておくと良い。

$2.$ RewardModel

RewardModel論文

InstructGPTのRewardModelには「Learning to summarize from human feedback」のRewardModelと同様なものが用いられる。よって以下ではこのRewardModel論文を元に取りまとめを行う。

RewardModel論文 Figure$2$

上記はRewardModel論文のFigure$2$であるが、$2.$と$3.$はInstructGPT論文の図と概ね同様であることが確認できる。一方で、RewardModel論文の$2.$には下記の数式でRewardModelのlossの記載があることにも注意しておくと良い。
$$
\large
\begin{align}
\mathrm{loss} = \log{[\sigma(r_j-r_k)]} \quad (1)
\end{align}
$$

$(1)$式における$r_j$は人間が選んだもの、$r_k$はそうでないものがそれぞれ対応する。論文の本文では$(1)$式と同じlossが下記のように表される。
$$
\large
\begin{align}
\mathrm{loss}(r_{\theta}) = -\mathbb{E}_{(x,y_0,y_1,i) \sim D} \left[ \log{(\sigma[r_{\theta}(x,y_i)-r_{\theta}(x,y_{1-i})])} \right] \quad (1)’
\end{align}
$$

RewardModelのlossの解釈

当項では以下、前項の「RewardModel」のlossである$(1)$式がクロスエントロピーに対応することに関して確認を行う。二値分類におけるクロスエントロピー誤差関数は下記で表すベルヌーイ分布$\mathrm{Bern}(p)$の確率関数$f(x)$から導出できる。
$$
\large
\begin{align}
f(x) = p^{x} (1-p)^{1-x}
\end{align}
$$

上記をパラメータ$p$に関する尤度$L(p)$と見なすと、$-\log{L(p)}$は下記のように表せる。
$$
\large
\begin{align}
-\log{L(p)} &= -\log{(p^{x} (1-p)^{1-x})} \\
&= – x \log{p} \, – \, (1-x) \log{(1-p)} \quad (2)
\end{align}
$$

$(2)$式は二値分類におけるクロスエントロピー誤差関数に一致する。ここで$x=1$が観測されたと仮定すると$1-x=0$であるので、$(2)$式は下記のように表すことができる。
$$
\large
\begin{align}
-\log{L(p)} &= – 1 \cdot \log{p} \, – \, (1-1) \log{(1-p)} \quad (2)’ \\
&= -\log{p} \quad (3)
\end{align}
$$

ここで一般化線形モデル(GLM; Generalized Linear Model)と同様の要領で、$p$をニューラルネットワークの出力と対応させることを考える。RewardModel論文では下記のような式に基づいて確率パラメータ$p$の予測を行う。
$$
\large
\begin{align}
p &= \sigma(r_j-r_k) \quad (4) \\
\sigma(x) &= \frac{1}{1+\exp(-x)}
\end{align}
$$

$\sigma(x)$はシグモイド関数に対応する。$(4)$式の解釈にあたって、シグモイド関数の定義に基づいて下記のような変換を行う。
$$
\large
\begin{align}
p &= \sigma(r_j-r_k) \quad (4) \\
&= \frac{1}{1+\exp{[-(r_j-r_k)]}} \\
&= \frac{1}{1+\exp{(-r_j+r_k)}} \\
&= \frac{\exp{(r_j)}}{\exp{(r_j)}+\exp{(r_j-r_j+r_k)}} \\
&= \frac{\exp{(r_j)}}{\exp{(r_j)}+\exp{(r_k)}} \quad (5)
\end{align}
$$

$(5)$式は出力層で$r_j, r_k$が得られた際にソフトマックス関数を計算することに対応する。$(3),(5)$式より、$(1)$式がクロスエントロピー誤差関数であり、かつ報酬を出力するネットワークをソフトマックス関数と同様の式に基づいて学習させると解釈できる。

$3.$ Reinforcement learning

目的関数

InstructGPTにおける強化学習では強化学習によって得られる方策の$\pi_{\phi}^{\mathrm{RL}}$とSupervised Fine-Tuningによって得られた$\pi_{\phi}^{\mathrm{SFT}}$を元に、下記のような目的関数を用いて学習を行う。
$$
\large
\begin{align}
\mathrm{Objective}(\phi) = E_{(x,y) \sim D’} \left[ r_{\theta}(x,y) – \beta \log{\frac{\pi_{\phi}^{\mathrm{RL}}(y|x)}{\pi_{\phi}^{\mathrm{SFT}}(y|x)}} \right] + \gamma E_{x \sim D} \left[ \log{(\pi_{\phi}^{\mathrm{RL}}(x))} \right] \quad (6)
\end{align}
$$

上記の$D’$は新たなプロンプトの入力である$x$と強化学習の結果生成される$y$に対応し、$r_{\theta}(x,y)$はRewardModelの出力、$\displaystyle \beta \log{ \frac{\pi_{\phi}^{\mathrm{RL}}(y|x)}{\pi_{\phi}^{\mathrm{SFT}}(y|x)} }$はPPO論文などのKL penaltyにそれぞれ対応する。また、$D$はpre-trainingの際に用いたコーパスであり、$\displaystyle E_{x \sim D} \left[ \log{(\pi_{\phi}^{\mathrm{RL}}(x))} \right]$は元々のpre-trainの結果から大きく変わった結果が得られないように設定される。

ここで$\gamma$は事前学習+SFTの結果との一貫性に対応する係数であり、InstructGPT論文では$\gamma=0$のときを”PPO”、$\gamma \neq 0$のときを”PPO-ptx”と表す。また、InstructGPT論文ではInstructGPTが”PPO-ptx”に対応するとされる。

a per-token KL penalty

$$
\large
\begin{align}
\log{ \frac{\pi_{\phi}^{\mathrm{RL}}(y|x)}{\pi_{\phi}^{\mathrm{SFT}}(y|x)} } \quad (7)
\end{align}
$$

上記の式はKL penaltyの期待値の内部に対応するが、$y$が系列であるので以下、per-tokenの形式への変形を行う。式変形にあたって、$y=(y_1, \cdots , y_N)$のように表す。このとき、$y_i$より前の系列を$\mathbf{y}_{:i}$とおくと、$(7)$式は下記のように変形できる。
$$
\large
\begin{align}
\log{ \frac{\pi_{\phi}^{\mathrm{RL}}(y|x)}{\pi_{\phi}^{\mathrm{SFT}}(y|x)} } &= \log{ \frac{\displaystyle \prod_{i=1}^{N} \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i})}{\displaystyle \prod_{i=1}^{N} \pi_{\phi}^{\mathrm{SFT}}(y_i|x,\mathbf{y}_{:i})} } \\
&= \sum_{i=1}^{N} \log{ \frac{\displaystyle \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i})}{\displaystyle \pi_{\phi}^{\mathrm{SFT}}(y_i|x,\mathbf{y}_{:i})} } \quad (8)
\end{align}
$$

ここで$(8)$式の期待値を取ることでa per-token KL penaltyを表すことができる。また、上記では$i=1$のとき$\mathbf{y}_{:i}$は存在せず、$i=2$のとき$\mathbf{y}_{:i}=(y_1)$、$i \geq 3$のとき$\mathbf{y}_{:i}=(y_1, \cdots , y_{i-1})$が対応することに注意が必要である。

目的関数の勾配

以下、$(6)$式の勾配の計算について取り扱う。
$$
\large
\begin{align}
\mathrm{Objective}(\phi) = E_{(x,y) \sim D’} \left[ r_{\theta}(x,y) – \beta \log{\frac{\pi_{\phi}^{\mathrm{RL}}(y|x)}{\pi_{\phi}^{\mathrm{SFT}}(y|x)}} \right] + \gamma E_{x \sim D} \left[ \log{(\pi_{\phi}^{\mathrm{RL}}(x))} \right] \quad (6)
\end{align}
$$

まず、$\displaystyle E_{(x,y) \sim D’} \left[ r_{\theta}(x,y) \right]$の$\phi$に関する勾配の計算は$r_{\theta}(x,y)$が定数であることに基づいて、下記のように得られる。
$$
\large
\begin{align}
\nabla_{\phi} E_{(x,y) \sim D’} \left[ r_{\theta}(x,y) \right] = E_{(x,y) \sim D’} \left[ \sum_{i=1}^{N} r_{\theta}(x,y) \nabla_{\phi} \log{\pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i})} \right]
\end{align}
$$

上記は方策勾配法の基本的な勾配の計算と同様である。式の理解にあたっては収益$r_{\theta}(x,y)$の大きさに応じてパラメータ$\phi$の修正量を調整すると解釈すればよい。詳しい導出の流れや式の解釈は下記で取り扱った。

次にa per-token KL penaltyの勾配の計算を行う。式の簡略化にあたって、下記のように表した$(6), (8)$式に基づく$1$トークン分のKL penaltyの勾配を計算する。
$$
\large
\begin{align}
E_{(x,y) \sim D’} \left[ \log{ \frac{\displaystyle \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i})}{\displaystyle \pi_{\phi}^{\mathrm{SFT}}(y_i|x,\mathbf{y}_{:i})} } \right] = \sum_{y_i} \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i}) \log{ \frac{\displaystyle \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i})}{\displaystyle \pi_{\phi}^{\mathrm{SFT}}(y_i|x,\mathbf{y}_{:i})} } \quad (9)
\end{align}
$$

$(9)$式の$\displaystyle \sum_{y_i}$の中の項に関して$\nabla_{\phi}$を用いて勾配は下記のように計算できる。
$$
\large
\begin{align}
& \nabla_{\phi} \left[ \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i}) \log{ \frac{\displaystyle \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i})}{\displaystyle \pi_{\phi}^{\mathrm{SFT}}(y_i|x,\mathbf{y}_{:i})} } \right] \\
&= \nabla_{\phi} \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i}) \log{ \frac{\displaystyle \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i})}{\displaystyle \pi_{\phi}^{\mathrm{SFT}}(y_i|x,\mathbf{y}_{:i})} } + \cancel{\pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i})} \cdot \left( \frac{\displaystyle \cancel{\pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i})}}{\displaystyle \cancel{\pi_{\phi}^{\mathrm{SFT}}(y_i|x,\mathbf{y}_{:i})}} \right)^{-1} \cdot \frac{\displaystyle \nabla_{\phi} \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i})}{\displaystyle \cancel{\pi_{\phi}^{\mathrm{SFT}}(y_i|x,\mathbf{y}_{:i})}} \\
&= \nabla_{\phi} \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i}) \left[ 1 + \log{ \frac{\displaystyle \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i})}{\displaystyle \pi_{\phi}^{\mathrm{SFT}}(y_i|x,\mathbf{y}_{:i})} } \right] \quad (10)
\end{align}
$$

計算にあたってはSFTのパラメータ$\phi$が固定であることから定数であるとみなした。$(8)$式と$(10)$式を元に$(6)$式を見ると、$(10)$式の勾配の逆向きにパラメータ$\phi$をUpdateすることが確認できる。ここで$\displaystyle \frac{\displaystyle \pi_{\phi}^{\mathrm{RL}}(y_{i}|x,\mathbf{y}_{:i})}{\displaystyle \pi_{\phi}^{\mathrm{SFT}}(y_i|x,\mathbf{y}_{:i})}$の大きさに応じて勾配に基づく修正量を調整することが確認できるので、RLとSFTの確率が同様な場合は修正量が小さくなりペナルティが小さいと解釈できる。

InstructGPTまとめ

InstructGPTは①SFT、②RewardModel、③ReinforcementLearning(PPO, KL penalty)に基づいて学習済みのGPT-$3$に対し追加の学習を行う手法である。