chainer.gradient_check.check_backward¶
-
chainer.gradient_check.
check_backward
(func, x_data, y_grad, params=(), eps=0.001, atol=1e-05, rtol=0.0001, no_grads=None, dtype=None)[source]¶ Test backward procedure of a given function.
This function automatically checks the backward-process of a given function to ensure that the computed gradients are approximately correct. For example, assuming you’ve defined a
FunctionNode
classMyFunc
, that takes two arguments and returns one value, you can wrap it in a ordinary function and check its gradient computations as follows:>> def test_my_func(self): >> >> def func(xs): >> y, = MyFunc().apply(xs) >> return y >> >> x1_data = xp.array(...) >> x2_data = xp.array(...) >> gy_data = xp.array(...) >> check_backward(func, (x1_data, x2_data), gy_data)
This method creates
Variable
objects withx_data
and callsfunc
with theVariable
s to get its result asVariable
. Then, it setsy_grad
array tograd
attribute of the result and callsbackward
method to get gradients of the inputs. To check correctness of the gradients, the function callsnumerical_grad()
to calculate numerically the gradients and compares the types of gradients withchainer.testing.assert_allclose()
.To reduce computational time, it uses directional derivative along a random vector. A function \(g: \mathbb{R} \rightarrow \mathbb{R}^n\) is defined as \(g(\delta) = f(x + \delta r)\), where \(\delta \in \mathbb{R}\), \(r \in \mathbb{R}^n\) is a random vector and \(f\) is a function which you want to test. Its gradient is
\[g'(\delta) = f'(x + \delta r) \cdot r.\]Therefore, \(g'(0) = f'(x) \cdot r\). So we can check the correctness of back propagation of \(f\) indirectly by comparing this equation with the gradient of \(g\) numerically calculated and that of \(f\) computed by backprop. If \(r\) is chosen from uniform distribution, we can conclude with high probability that the gradient of \(f\) itself is correct.
If input objects (
x1_data
or/andx2_data
in this example) represent integer variables, their gradients are ignored.You can simplify a test when
MyFunc
gets only one argument:>> check_backward(func, x1_data, gy_data)
If
MyFunc
is a loss function which returns a zero-dimensional array, passNone
togy_data
. In this case, it sets1
tograd
attribute of the result:>> check_backward(my_loss_func, (x1_data, x2_data), None)
If
MyFunc
returns multiple outputs, pass all gradients for outputs as a tuple:>> gy1_data = xp.array(...) >> gy2_data = xp.array(...) >> check_backward(func, x1_data, (gy1_data, gy2_data))
You can also test a
Link
. To check gradients of parameters of the link, set a tuple of the parameters toparams
arguments:>> check_backward(my_link, (x1_data, x2_data), gy_data, >> (my_link.W, my_link.b))
Note that
params
are notndarray
s, butVariables
s.Function objects are acceptable as
func
argument:>> check_backward(lambda x1, x2: f(x1, x2), >> (x1_data, x2_data), gy_data)
Note
func
is called many times to get numerical gradients for all inputs. This function doesn’t work correctly whenfunc
behaves randomly as it gets different gradients.Parameters: - func (callable) – A function which gets
Variable
s and returnsVariable
s.func
must returns a tuple ofVariable
s or oneVariable
. You can use aFunction
,FunctionNode
or aLink
object or any other function satisfying the condition. - x_data (ndarray or tuple of ndarrays) – A set of
ndarray
s to be passed tofunc
. Ifx_data
is onendarray
object, it is treated as(x_data,)
. - y_grad (ndarray or tuple of ndarrays or None) – A set of
ndarray
s representing gradients of return-values offunc
. Ify_grad
is onendarray
object, it is treated as(y_grad,)
. Iffunc
is a loss-function,y_grad
should be set toNone
. - params (Variable or tuple of ~chainder.Variable) – A set of
Variable
s whose gradients are checked. Whenfunc
is aLink
object, set its parameters asparams
. Ifparams
is oneVariable
object, it is treated as(params,)
. - eps (float) – Epsilon value to be passed to
numerical_grad()
. - atol (float) – Absolute tolerance to be passed to
chainer.testing.assert_allclose()
. - rtol (float) – Relative tolerance to be passed to
chainer.testing.assert_allclose()
. - no_grads (list of bool) – Flag to skip variable for gradient assertion.
It should be same length as
x_data
. - dtype (dtype) –
x_data
,y_grad
andparams
are casted to this dtype when calculating numerical gradients. Only float types andNone
are allowed.
See also
- func (callable) – A function which gets