数式だけの解説ではわかりにくい場合もあると思われるので、統計学の手法や関連する概念をPythonのプログラミングで表現します。当記事では自動微分の理解にあたって、自動微分とそれに基づく勾配法、自動微分のモジュール化などのPythonでの実装を取り扱いました。
・Pythonを用いた統計学のプログラミングまとめ
https://www.hello-statisticians.com/stat_program
作成にあたっては「ゼロから作るDeepLearning③」を主に参考にしましたので、そちらも合わせて参照ください。
Contents
自動微分の基本処理
変数と関数のクラス化
自動微分を考えるにあたっては変数と関数のクラス化を行うと取り扱いやすい。変数と関数は下記のようにクラス化を行うことができる。
import numpy as np
class Variable:
def __init__(self, data):
self.data = data
class Function:
def __call__(self, input):
x = input.data
y = self.forward(x)
output = Variable(y)
return output
def forward(self, x):
raise NotImplementedError()
class Square(Function):
def forward(self, x):
return x**2
x = Variable(np.array(10))
f = Square()
y = f(x)
print(type(y))
print(y.data)
・実行結果
> print(type(y))
<type 'instance'>
> print(y.data)
100
上記ではクラス化を行ったFunction
関数を継承してSquare
クラスを定義し、これを用いて$10$の$2$乗の計算を行った。Square
をFunction
を継承するような構成は、関数の作成の仕方が色々とあるからだと理解することができる。
なお、記述量を減らし可読性を上げるにあたって、NumPy
の読み込みやFunction
クラスとVariable
クラスに関しては内容の変更がない限りは以下に記載を行うプログラムでは省略を行うこととする。
合成関数の順伝播の計算の作成
当項では前項の内容を元に合成関数の作成を行う。前項のSquare
と同様にExp
を定義し、$e^{-x^2}$の$x=1$における値の計算を行う。
class Square(Function):
def forward(self, x):
return x**2
class Exp(Function):
def forward(self, x):
return np.exp(x)
class Neg(Function):
def forward(self, x):
return -x
x = Variable(np.array(1))
f1 = Square()
f2 = Neg()
f3 = Exp()
y = f3(f2(f1(x)))
print("y.data: {:.3f}".format(y.data))
print("exp(-x^2): {:.3f}".format(np.exp(-1.**2)))
・実行結果
> print("y.data: {:.3f}".format(y.data))
y.data: 0.368
> print("exp(-x^2): {:.3f}".format(np.exp(-1.**2)))
exp(-x^2): 0.368
自動微分の計算
自動微分は合成関数の微分の考え方に基づく微分であり、下記のように計算できる。
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
class Function:
def __call__(self, input):
x = input.data
y = self.forward(x)
output = Variable(y)
self.input = input
return output
def forward(self, x):
raise NotImplementedError()
def backward(self, x):
raise NotImplementedError()
class Square(Function):
def forward(self, x):
return x**2
def backward(self, gy):
x = self.input.data
gx = 2*x*gy
return gx
class Exp(Function):
def forward(self, x):
return np.exp(x)
def backward(self, gy):
x = self.input.data
gx = np.exp(x)*gy
return gx
class Neg(Function):
def forward(self, x):
return -x
def backward(self, gy):
x = self.input.data
gx = -gy
return gx
x = Variable(np.array(1))
f1 = Square()
f2 = Neg()
f3 = Exp()
a = f1(x)
b = f2(a)
y = f3(b)
y.grad = np.array(1)
b.grad = f3.backward(y.grad)
a.grad = f2.backward(b.grad)
x.grad = f1.backward(a.grad)
print("x.grad: {:.3f}".format(x.grad))
print("-2x exp(-x^2): {:.3f}".format(-2*1*np.exp(-1**2)))
・実行結果
> print("x.grad: {:.3f}".format(x.grad))
x.grad: -0.736
> print("-2x exp(-x^2): {:.3f}".format(-2*1*np.exp(-1**2)))
-2x exp(-x^2): -0.736
自動微分のモジュール化
バックプロパゲーションの自動化
下記のようにcreator
とset_creator
を用いることでバックプロパゲーションの自動化を行うことができます。
class Variable:
def __init__(self, data):
self.data = data
self.grad = None
self.creator = None
def set_creator(self, func):
self.creator = func
def backward(self):
funcs = [self.creator]
while funcs:
f = funcs.pop()
x, y = f.input, f.output
x.grad = f.backward(y.grad)
if x.creator is not None:
funcs.append(x.creator)
class Function:
def __call__(self, input):
x = input.data
y = self.forward(x)
output = Variable(y)
output.set_creator(self)
self.input = input
self.output = output
return output
def forward(self, x):
raise NotImplementedError()
def backward(self, x):
raise NotImplementedError()
class Square(Function):
def forward(self, x):
return x**2
def backward(self, gy):
x = self.input.data
gx = 2*x*gy
return gx
class Exp(Function):
def forward(self, x):
return np.exp(x)
def backward(self, gy):
x = self.input.data
gx = np.exp(x)*gy
return gx
class Neg(Function):
def forward(self, x):
return -x
def backward(self, gy):
x = self.input.data
gx = -gy
return gx
x = Variable(np.array(1))
f1 = Square()
f2 = Neg()
f3 = Exp()
y = f3(f2(f1(x)))
y.grad = np.array(1)
y.backward()
print("x.grad: {:.3f}".format(x.grad))
print("-2x exp(-x^2): {:.3f}".format(-2*1*np.exp(-1**2)))
・実行結果
> print("x.grad: {:.3f}".format(x.grad))
x.grad: -0.736
> print("-2x exp(-x^2): {:.3f}".format(-2*1*np.exp(-1**2)))
-2x exp(-x^2): -0.736