「強化学習」×「Transformer」①:Decision Transformer

Transformerは系列モデリングの学習にあたって様々な用途に用いられており、近年では「強化学習」分野へのTransformerの応用も研究されています。当記事ではTransformerを強化学習に応用した論文の一つであるDecision Transformerについての大枠を取りまとめました。
Decision Transformer論文や「ゼロから作るDeep Learning④ー強化学習編」の内容を参考に作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提の確認

Transformer

Dot Product Attentionに主に基づくTransformerの仕組みについては既知である前提で当記事はまとめました。下記などに解説コンテンツを作成しましたので、合わせて参照ください。

・直感的に理解するTransformerの仕組み(統計の森作成)

強化学習の基本知識

Deep Learningを用いた強化学習を理解するにあたって知っておくとよい内容を下記で取りまとめました。

・仕組みから理解するChatGPT(統計の森作成)

強化学習の基本トピックについて詳しく学ぶ際には日本語では「ゼロから作るDeep Learning④ー強化学習編」、英語では「Reinforcement Learning, second edition: An Introduction」がおすすめです。

Offline RL

Deep Q NetworkやAlphaZeroなど、多くの強化学習のアルゴリズムではエージェントと環境の相互作用によって軌道(trajectory)を生成し、「価値関数の近似」や「方策の最適化」を通して意思決定の最適化を行います。

一方で、予め既に得た軌道(trajectory)を元に教師あり学習のような学習を行うことで意思決定の最適化を行うことができます。このような強化学習の枠組みはオフライン強化学習(Offline RL)といわれます。

Decision TransformerではこのOffline RLの枠組みを用いるので、このような強化学習の研究分野があることは抑えておくと良いです。Offline RLについて詳しくは下記などを参照すると良いと思います。

・Offline RL論文
・A Survey on Transformers in Reinforcement Learning

Decision Transformer

Causal Transformer

Decision Transformer論文 Figure$\, 1$

Decision Transformerの処理概要は上図のように表されます。図より、メインの処理がCausal Transformerによって実現できることが確認できますが、Decision Transformer論文のSection.$3$のMethodより、基本的にはGPTと同様の処理を用いることが確認できます。

GPT(Generative Pre-Training)は自己回帰型(Auto Regressive)の処理に基づきます。処理の詳細はGPTの論文が参照するTransformer Decoderの解説で取り扱いました。

returns-to-go

Decision Transformerでは通常の強化学習で用いるRewardの$r_{t}$に基づいて下記のように定義される報酬和$\hat{R}_{t}$を用います。
$$
\large
\begin{align}
\hat{R}_{t} = \sum_{t’=t}^{T} r_{t’} \quad (1)
\end{align}
$$

上記をDecision Transformer論文では「returns-to-go」のような用語で表されます。一般的な強化学習では軌道(trajectory)の$\tau$を$\tau=(r_0,s_0,a_0,r_1,s_1,a_1, \cdots)$のように表すことが多い一方で、Decision Transformerでは$(1)$式で定義した「returns-to-go」を用いて下記のように軌道$\tau$を定義します。
$$
\large
\begin{align}
\tau=(\hat{R}_0,s_0,a_0,\hat{R}_1,s_1,a_1, \cdots , \hat{R}_{t}, s_{t}, a_{t}, \cdots)
\end{align}
$$

また、上記の$s_{t}$は時点$t$における状態、$a_{t}$は時点$t$における行動を表します。

Decision Transformerの擬似コード

Decision Transformer論文のAlgorithm$\, 1$にDecision Transformerの擬似コードが掲載されています。

# R, s, a, t: returns-to-go, states, actions, or timesteps
# transformer: transformer with causal masking (GPT)
# embed_s , embed_a , embed_R: linear embedding layers
# embed_t: learned episode positional embedding
# pred_a: linear action prediction layer

# main model
def DecisionTransformer(R, s, a, t):
    # compute embeddings for tokens
    pos_embedding = embed_t(t) # per-timestep (note: not per-token) s_embedding = embed_s(s) + pos_embedding
    a_embedding = embed_a(a) + pos_embedding
    R_embedding = embed_R(R) + pos_embedding

    # interleave tokens as (R_1, s_1, a_1, ..., R_K, s_K)
    input_embeds = stack(R_embedding , s_embedding , a_embedding)

    # use transformer to get hidden states
    hidden_states = transformer(input_embeds=input_embeds)

    # select hidden states for action prediction tokens
    a_hidden = unstack(hidden_states).actions

    # predict action
    return pred_a(a_hidden)

# training loop
for (R, s, a, t) in dataloader: # dims: (batch_size, K, dim)
    a_preds = DecisionTransformer(R, s, a, t)
    loss = mean((a_preds - a)**2) # L2 loss for continuous actions optimizer.zero_grad(); loss.backward(); optimizer.step()

# evaluation loop
target_return = 1 # for instance , expert -level return
R, s, a, t, done = [target_return], [env.reset()], [], [1], False
while not done: # autoregressive generation/sampling
    # sample next action
    action = DecisionTransformer(R, s, a, t)[-1] # for cts actions new_s, r, done, _ = env.step(action)
    
    # append new tokens to sequence
    R = R + [R[-1] - r] # decrement returns-to-go with reward s, a, t = s + [new_s], a + [action], t + [len(R)]
    R, s, a, t = R[-K:], ... # only keep context length of K

上記は28行目のa_preds = DecisionTransformer(R, s, a, t)がDecision Transformerを用いたActionの予測に対応するところから理解すると理解しやすいです。Decision Transformerは「状態」、「行動」、「returns-to-go」の入力に基づいて行動の予測を行います。

Decision Transformerの学習

Decision Transformerの学習では学習に用いる「軌道」のサンプルセットが最適である前提で、同様の振る舞いをTransformerが模倣するように学習を行います。

予測結果の$a_t$と「学習に用いる軌道サンプル」を元に、行動が離散値の場合はCross Entropy、連続値の場合は平均二乗誤差(mean-squared error)がlossに用いられます。前項で取り扱った擬似コードでは平均二乗誤差が計算されます。

Evaluations

Decision Transformer論文 Figure$\, 3$

参考

・Transformer論文:Attention is All you need[2017]
・Decision Transformer論文
・Offline RL論文
・A Survey on Transformers in Reinforcement Learning
・Transformer Decoder論文