Function hooks¶
Chainer provides a function-hook mechanism that enriches
the behavior of forward and backward propagation of Function
.
Base class¶
-
class
chainer.function.
FunctionHook
[source]¶ Base class of hooks for Functions.
FunctionHook
is an callback object that is registered toFunction
. Registered function hooks are invoked before and after forward and backward operations of each function.Function hooks that derive
FunctionHook
are required to implement four methods:forward_preprocess()
,forward_postprocess()
,backward_preprocess()
, andbackward_postprocess()
. By default, these methods do nothing.Specifically, when
__call__()
method of some function is invoked,forward_preprocess()
(resp.forward_postprocess()
) of all function hooks registered to this function are called before (resp. after) forward propagation.Likewise, when
backward()
of someVariable
is invoked,backward_preprocess()
(resp.backward_postprocess()
) of all function hooks registered to the function which holds this variable as a gradient are called before (resp. after) backward propagation.There are two ways to register
FunctionHook
objects toFunction
objects.First one is to use
with
statement. Function hooks hooked in this way are registered to all functions withinwith
statement and are unregistered at the end ofwith
statement.Example
The following code is a simple example in which we measure the elapsed time of a part of forward propagation procedure with
TimerHook
, which is a subclass ofFunctionHook
.>>> from chainer import function_hooks >>> class Model(chainer.Chain): ... def __call__(self, x1): ... return F.exp(self.l(x1)) >>> model1 = Model(l=L.Linear(10, 10)) >>> model2 = Model(l=L.Linear(10, 10)) >>> x = chainer.Variable(np.zeros((1, 10), 'f')) >>> with chainer.function_hooks.TimerHook() as m: ... _ = model1(x) ... y = model2(x) ... print("Total time : " + str(m.total_time())) ... model3 = Model(l=L.Linear(10, 10)) ... z = model3(y) Total time : ...
In this example, we measure the elapsed times for each forward propagation of all functions in
model1
andmodel2
(specifically,LinearFunction
andExp
ofmodel1
andmodel2
). Note thatmodel3
is not a target of measurement asTimerHook
is unregistered before forward propagation ofmodel3
.Note
Chainer stores the dictionary of registered function hooks as a thread local object. So, function hooks registered are different depending on threads.
The other one is to register directly to
Function
object withadd_hook()
method. Function hooks registered in this way can be removed bydelete_hook()
method. Contrary to former registration method, function hooks are registered only to the function whichadd_hook()
is called.Parameters: name (str) – Name of this function hook. -
backward_postprocess
(function, in_data, out_grad)[source]¶ Callback function invoked after backward propagation.
Parameters: - function (Function) – Function object to which the function hook is registered.
- in_data (tuple of numpy.ndarray or tuple of cupy.ndarray) – Input of forward propagation.
- out_grad (tuple of numpy.ndarray or tuple of cupy.ndarray) – Gradient data of backward propagation.
-
backward_preprocess
(function, in_data, out_grad)[source]¶ Callback function invoked before backward propagation.
Parameters: - function (Function) – Function object to which the function hook is registered.
- in_data (tuple of numpy.ndarray or tuple of cupy.ndarray) – Input data of forward propagation.
- out_grad (tuple of numpy.ndarray or tuple of cupy.ndarray) – Gradient data of backward propagation.
-
Concrete function hooks¶
-
class
chainer.function_hooks.
PrintHook
(sep='', end='n', file=<open file '<stdout>', mode 'w'>, flush=True)[source]¶ Function hook that prints debug information.
This function hook outputs the debug information of input arguments of
forward
andbackward
methods involved in the hooked functions at preprocessing time (that is, just before each method is called).Unlike simple “debug print” technique, where users insert print functions at every function to be inspected, we can show the information of all functions involved with single
with
statement.Further, this hook enables us to show the information of
backward
methods without inserting print functions into Chainer’s library code.Variables: - sep – Separator of print function.
- end – Character to be added at the end of print function.
- file – Output file_like object that that redirect to.
- flush – If
True
, this hook forcibly flushes the text stream at the end of preprocessing.
Example
The basic usage is to use it with
with
statement.>>> from chainer import function_hooks >>> l = L.Linear(10, 10) >>> x = chainer.Variable(np.zeros((1, 10), 'f')) >>> with chainer.function_hooks.PrintHook(): ... y = l(x) ... z = F.sum(y) ... z.backward()
In this example,
PrintHook
shows the debug information of forward propagation ofLinearFunction
(which is implicitly called byl
) andSum
(called byF.sum
) and backward propagation ofz
andy
.