重点サンプリング(Importance Sampling)の数式表記とPythonを用いた計算例

特定の確率分布の期待値を別の確率分布からサンプリングした値に基づいて計算する手法を重点サンプリング(Importance Sampling)といいます。当記事では重点サンプリングの数式表記とPythonを用いた計算例の確認をそれぞれ行いました。
「ゼロから作るDeep Learning④ー強化学習編」の$5.5.2$「重点サンプリング」〜$5.5.3$節「分散を小さくするには」の内容などを参考に当記事の作成を行いました。

・用語/公式解説
https://www.hello-statisticians.com/explain-terms

重点サンプリングの基本事項

サンプリングによる期待値の近似

離散確率分布$\pi(x)$に基づいて得られる確率変数$x$の期待値を$E_{\pi}[x]$とおくと、期待値は下記のように定義される。
$$
\large
\begin{align}
\mathbb{E}_{\pi}[x] = \sum x \pi(x)
\end{align}
$$

このとき、確率分布が具体的に判明しなくても確率分布からのサンプリングが行うことができる場合は、確率分布に基づいて得られたサンプルから下記のように期待値$E_{\pi}[x]$を近似することができる。
$$
\large
\begin{align}
& \mathrm{sampling} : \, x^{(i)} \sim \pi \quad i=1,2,\cdots,n \\
& E_{\pi}[x] \simeq \frac{1}{n} \sum_{i=1}^{n} x^{(i)}
\end{align}
$$

ここで上記のインデックス$i$は確率分布$\pi(x)$から得られた$i$番目のサンプルを表すことを抑えておくと良い。

重点サンプリングの仕組み

通常のサンプリングによる近似では前項の『サンプリングによる期待値の近似』のように近似を行う。一方で重点サンプリングは確率分布$\pi(x)$の期待値を別の確率分布$b(x)$から得られたサンプルを元に計算する手法であり、下記のような変形に基づく。
$$
\large
\begin{align}
\mathbb{E}_{\pi}[x] &= \sum x \pi(x) \\
&= \sum x \frac{b(x)}{b(x)} \pi(x) \\
&= \sum x \frac{\pi(x)}{b(x)} b(x) = \mathbb{E}_{b} \left[ \frac{\pi(x)}{b(x)} x \right]
\end{align}
$$

上記より、$\displaystyle \mathbb{E}_{\pi}[x] = \mathbb{E}_{b} \left[ \frac{\pi(x)}{b(x)} x \right]$が成立するので、確率分布$b(x)$に基づいて得られたサンプルを元に確率分布$\pi(x)$の期待値を計算することができる。

Pythonを用いた計算例

確率分布$\pi$に基づくサンプリング

確率分布$\pi(x)$からのサンプリングに基づく期待値$E_{\pi}[x]$の近似値は下記のように計算することができる。

import numpy as np

x = np.array([1, 2, 3])
pi = np.array([0.1, 0.1, 0.8])

# 期待値
e = np.sum(x * pi)
print("E_pi[x]: {}".format(e))

np.random.seed(100)

# モンテカルロ法
n = 100
samples = []
for i in range(n):
    s = np.random.choice(x, p=pi)
    samples.append(s)
    
print("mean: {:.2f}, var: {:.2f}".format(np.mean(samples), np.var(samples)))

・実行結果

E_pi[x]: 2.7
mean: 2.71, var: 0.39

確率分布$b$に基づく重点サンプリング

確率分布$b(x)$からのサンプリングに基づく$\displaystyle \mathbb{E}_{b} \left[ \frac{\pi(x)}{b(x)} x \right]$の近似値は下記のように計算することができる。

np.random.seed(25)

b = np.array([1/3, 1/3, 1/3])
n = 100
samples = []

for i in range(n):
    idx = np.arange(len(b))
    i = np.random.choice(idx, p=b)
    s = x[i]
    rho = pi[i] / b[i]
    samples.append(rho * s)
    
print("mean: {:.2f}, var: {:.2f}".format(np.mean(samples), np.var(samples)))

・実行結果

mean: 2.56, var: 9.70

上記より、重点サンプリングによる近似値はある程度妥当である一方で分散が大きいことが確認できる。次節ではこの分散が大きくなる原因とその対応について確認を行う。

重点サンプリングと分散

分散が大きくなる原因

重点サンプリングを用いた際に分散が大きくなるのは確率分布$\pi(x)$と$b(x)$の差が大きい場合に$\displaystyle \mathbb{E}_{b} \left[ \frac{\pi(x)}{b(x)} x \right]$の$\displaystyle \frac{\pi(x)}{b(x)}$が$1$から大きく離れた値になり、$\displaystyle \frac{\pi(x)}{b(x)} x$の値が安定しないことに起因する。

前節の例では$\pi: (0.1,0.1,0.8)$に対し、$\displaystyle b: \left( \frac{1}{3},\frac{1}{3},\frac{1}{3} \right)$を用いたが、$\displaystyle \frac{\pi(x)}{b(x)}$はそれぞれ下記のように計算される。

・$x=1,2$の場合
$$
\large
\begin{align}
\frac{\pi(x)}{b(x)} &= \frac{0.1}{1/3} \\
&= 0.3
\end{align}
$$

・$x=3$の場合
$$
\large
\begin{align}
\frac{\pi(x)}{b(x)} &= \frac{0.8}{1/3} \\
&= 2.4
\end{align}
$$

上記より、$x=1,2$が得られた場合$\displaystyle \frac{\pi(x)}{b(x)} x$がそれぞれ$0.3, 0.6$であるのに対し、$x=3$が得られた場合$\displaystyle \frac{\pi(x)}{b(x)} x$が$7.2$の値を取る。このように$\displaystyle \frac{\pi(x)}{b(x)}$が$1$から大きく離れた値になることで、$\displaystyle \frac{\pi(x)}{b(x)} x$の値のばらつきが大きくなる。

この解決にあたっては、$\displaystyle \frac{\pi(x)}{b(x)}$を$1$に近づければ良いので、$\pi(x)$に類似の$b(x)$を用いればよい。

また、$\displaystyle \frac{\pi(x)}{b(x)}$を下記のように何らかの文字で表す場合もあるので、下記のような表記も抑えておくと良い。
$$
\large
\begin{align}
\rho(x) = \frac{\pi(x)}{b(x)}
\end{align}
$$

Pythonを用いた計算例

以下では$b:(0.2,0.2,0.6)$を用いて前節と同様に$\displaystyle \mathbb{E}_{b} \left[ \frac{\pi(x)}{b(x)} x \right]$の近似値と分散の計算を行った。

np.random.seed(25)

b = np.array([0.2, 0.2, 0.6])
n = 100
samples = []

for i in range(n):
    idx = np.arange(len(b))
    i = np.random.choice(idx, p=b)
    s = x[i]
    rho = pi[i] / b[i]
    samples.append(rho * s)
    
print("mean: {:.2f}, var: {:.2f}".format(np.mean(samples), np.var(samples)))

・実行結果

mean: 2.82, var: 2.50

実行結果より分散が前節の重点サンプリングの結果より小さくなったことが確認できる。

「重点サンプリング(Importance Sampling)の数式表記とPythonを用いた計算例」への3件のフィードバック

  1. […] $(2.1)$式は$(1.1)$式の$G(tau)$をアドバンテージ関数で置き換え、さらに重点サンプリング(importance sampling)を行うにあたって式変形を行なったものである。$(2.1)$式のように定義する目的関数をTRPOやPPOの論文ではsurrogate objectiveということも抑えておくと良い。 […]

コメントは受け付けていません。