方策勾配法(Policy Gradient Method)の目的関数の定義と勾配の式の導出

方策勾配法(Policy Gradient Method)は強化学習の際に定義される方策をニューラルネットワークで定義し、勾配を用いることで方策の最適化を行う手法です。当記事では方策勾配法における目的関数の定義と勾配の式の導出について取り扱いました。
「ゼロから作るDeep Learning④ー強化学習編」の第$9$章の「方策勾配法」や付録Dの「方策勾配法の証明」の内容を参考に当記事の作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

前提知識

問題設定

強化学習にあたって得られるエピソードにおける一連の「状態、行動、報酬」からなる系列をtrajectory(軌道)という。ここでtrajectoryを$\tau$とおくと、$\tau$は下記のように表すことができる。
$$
\large
\begin{align}
\tau = (S_0, A_0, R_0, S_1, A_1, R_1, \cdots , S_{T+1})
\end{align}
$$

ここで$\tau$の収益を$G(\tau)$とおくと、$G(\tau)$は下記のように表せる。
$$
\large
\begin{align}
G(\tau) = R_0 + \gamma R_1 + \gamma^{2} R_2 + \cdots + \gamma^{T} R_{T}
\end{align}
$$

以下、$G(\tau)$を最大にするように状態$S_t$から行動$A_t$を選択する方策$\pi_{\theta}(A_t|S_t)$の最適化について取り扱う。

方策勾配法の目的関数の定義

$G(\tau)$を最大にするような方策$\pi_{\theta}(A_t|S_t)$を得るにあたっては、下記の目的関数を$\theta$について最大化すればよい。
$$
\large
\begin{align}
J(\theta) = \mathbb{E}_{\tau \sim \pi_{\theta}}[G(\tau)]
\end{align}
$$

上記の式は『$\tau$が方策$\pi_{\theta}$に基づいて得られるときの収益$G(\tau)$の期待値』と解釈すればよい。要するになるべく多い収益$G(\tau)$が得られると期待できるような方策を最適化によって取得すると理解すればよい。

Log-Derivative Trick

対数関数の微分の公式に基づいて下記のような数式が成立する。
$$
\large
\begin{align}
\nabla_{\theta} \log{P(\tau|\theta)} = \frac{\nabla_{\theta} P(\tau|\theta)}{P(\tau|\theta)}
\end{align}
$$

上記の変形はLog-Derivative Trick(log勾配のトリック)といわれ、よく知られている。

勾配の式の導出

$\displaystyle \nabla_{\theta} J(\theta) = \mathbb{E}_{\tau \sim \pi_{\theta}} \left[ G(\tau) \nabla_{\theta} \log{P(\tau|\theta)} \right]$の導出

$$
\large
\begin{align}
\nabla_{\theta} J(\theta) &= \nabla_{\theta} \mathbb{E}_{\tau \sim \pi_{\theta}}[G(\tau)] \quad (2.1) \\
&= \nabla_{\theta} \sum_{\tau} P(\tau|\theta) G(\tau) \\
&= \sum_{\tau} \nabla_{\theta} \left( P(\tau|\theta) G(\tau) \right) \\
&= \sum_{\tau} \left( G(\tau) \nabla_{\theta} P(\tau|\theta) + P(\tau|\theta) \nabla_{\theta} G(\tau) \right) \\
&= \sum_{\tau} G(\tau) \nabla_{\theta} P(\tau|\theta) \quad (2.2)
\end{align}
$$

上記の式展開にあたっては$\nabla_{\theta} G(\tau)=\mathbf{0}$を用いた。$(2.1)$式はさらに下記のように変形することができる。
$$
\large
\begin{align}
\nabla_{\theta} J(\theta) &= \sum_{\tau} G(\tau) \nabla_{\theta} P(\tau|\theta) \quad (2.2) \\
&= \sum_{\tau} G(\tau) P(\tau|\theta) \frac{\nabla_{\theta} P(\tau|\theta)}{P(\tau|\theta)} \\
&= \sum_{\tau} G(\tau) P(\tau|\theta) \nabla_{\theta} \log{P(\tau|\theta)} \\
&= \mathbb{E}_{\tau \sim \pi_{\theta}} \left[ G(\tau) \nabla_{\theta} \log{P(\tau|\theta)} \right] \quad (2.3)
\end{align}
$$

$\displaystyle \nabla_{\theta} J(\theta) = \mathbb{E}_{\tau \sim \pi_{\theta}} \left[ \sum_{t=0}^{T} G(\tau) \nabla_{\theta} \log{\pi_{\theta}(A_t|S_t)} \right]$の導出

$P(\tau|\theta)$は下記のように表すことができる。
$$
\large
\begin{align}
P(\tau|\theta) &= p(S_0) \pi_{\theta}(A_0|S_0) p(S_1|S_0,A_0) \cdots \pi_{\theta}(A_T|S_T) p(S_{T+1}|S_{T},A_{T}) \\
&= p(S_0) \prod_{t=0}^{T} \pi_{\theta}(A_t|S_t) p(S_{t+1}|S_{t},A_{t})
\end{align}
$$

$(2.4)$式の両辺の対数を取ることで下記が得られる。
$$
\large
\begin{align}
\log{P(\tau|\theta)} &= \log{ \left[ p(S_0) \prod_{t=0}^{T} \pi_{\theta}(A_t|S_t) p(S_{t+1}|S_{t},A_{t}) \right] } \quad (2.4)’ \\
&= \log{p(S_0)} + \sum_{t=0}^{T} \log{p(S_{t+1}|S_t,A_t)} + \sum_{t=0}^{T} \log{\pi_{\theta}(A_t|S_t)} \quad (2.5)
\end{align}
$$

$(2.5)$式を元に$\theta$に関する勾配$\nabla_{\theta} \log{P(\tau|\theta)}$は下記のように得られる。
$$
\large
\begin{align}
\nabla_{\theta} \log{P(\tau|\theta)} &= \nabla_{\theta} \left[ \log{p(S_0)} + \sum_{t=0}^{T} \log{p(S_{t+1}|S_t,A_t)} + \sum_{t=0}^{T} \log{\pi_{\theta}(A_t|S_t)} \right] \quad (2.5)’ \\
&= \nabla_{\theta} \sum_{t=0}^{T} \log{\pi_{\theta}(A_t|S_t)} \\
&= \sum_{t=0}^{T} \nabla_{\theta} \log{\pi_{\theta}(A_t|S_t)} \quad (2.6)
\end{align}
$$

上記の変形では$\displaystyle \nabla_{\theta} \log{p(S_0)} = \mathbf{0}$、$\displaystyle \nabla_{\theta} \sum_{t=0}^{T} \log{p(S_{t+1}|S_t,A_t)} = \mathbf{0}$であることを用いた。$(2.6)$式を$(2.3)$式に代入することで下記が得られる。
$$
\large
\begin{align}
\nabla_{\theta} J(\theta) &= \mathbb{E}_{\tau \sim \pi_{\theta}} \left[ G(\tau) \nabla_{\theta} \log{P(\tau|\theta)} \right] \quad (2.3) \\
&= \mathbb{E}_{\tau \sim \pi_{\theta}} \left[ \sum_{t=0}^{T} G(\tau) \nabla_{\theta} \log{\pi_{\theta}(A_t|S_t)} \right] \quad (2.7)
\end{align}
$$

「方策勾配法(Policy Gradient Method)の目的関数の定義と勾配の式の導出」への4件のフィードバック

コメントは受け付けていません。