chainer.FunctionHook¶
-
class
chainer.
FunctionHook
[source]¶ Base class of hooks for Functions.
FunctionHook
is a callback object that is registered toFunctionNode
. Registered function hooks are invoked before and after forward and backward operations of each function.Function hooks that derive from
FunctionHook
may override the following methods:By default, these methods do nothing.
Specifically, when the
__call__()
method of some function is invoked,forward_preprocess()
(resp.forward_postprocess()
) of all function hooks registered to this function are called before (resp. after) forward propagation.Likewise, when
backward()
of someVariable
is invoked,backward_preprocess()
(resp.backward_postprocess()
) of all function hooks registered to the function which holds this variable as a gradient are called before (resp. after) backward propagation.added()
anddeleted()
are called when the hook is registered or unregistered, respectively.There are two ways to register
FunctionHook
objects toFunctionNode
objects.The first one is to use
with
statement. Function hooks hooked in this way are registered to all functions withinwith
statement and are unregistered at the end ofwith
statement.Example
The following code is a simple example in which we measure the elapsed time of a part of forward propagation procedure with
TimerHook
, which is a subclass ofFunctionHook
.>>> class Model(chainer.Chain): ... def __init__(self): ... super(Model, self).__init__() ... with self.init_scope(): ... self.l = L.Linear(10, 10) ... def __call__(self, x1): ... return F.exp(self.l(x1)) >>> model1 = Model() >>> model2 = Model() >>> x = chainer.Variable(np.zeros((1, 10), np.float32)) >>> with chainer.function_hooks.TimerHook() as m: ... _ = model1(x) ... y = model2(x) >>> model3 = Model() >>> z = model3(y) >>> print('Total time : {}'.format(m.total_time())) ... # doctest:+ELLIPSIS Total time : ...
In this example, we measure the elapsed times for each forward propagation of all functions in
model1
andmodel2
. Note thatmodel3
is not a target of measurement asTimerHook
is unregistered before forward propagation ofmodel3
.Note
Chainer stores the dictionary of registered function hooks as a thread local object. So, function hooks registered are different depending on threads.
The other one is to register it directly to a
FunctionNode
object by calling itsadd_hook()
method. Function hooks registered in this way can be removed bydelete_hook()
method. Contrary to the former registration method, function hooks are registered only to the function whoseadd_hook()
method is called.If the hook is registered globally using
with
statement,None
is passed as thefunction
argument ofadded()
anddeleted()
.If the hook is registered in a specific function using
add_hook()
, theFunctionNode
instance is passed as thefunction
argument ofadded()
anddeleted()
.- Parameters
name (str) – Name of this function hook.
Methods
-
added
(function)[source]¶ Callback function invoked when the function hook is registered
- Parameters
function (FunctionNode) – Function object to which the function hook is added.
None
if the function hook is registered globally.
-
backward_postprocess
(function, in_data, out_grad)[source]¶ Callback function invoked after backward propagation.
- Parameters
function (FunctionNode) – Function object to which the function hook is registered.
in_data (tuple of N-dimensional array) – Input of forward propagation.
out_grad (tuple of N-dimensional array) – Gradient data of backward propagation.
-
backward_preprocess
(function, in_data, out_grad)[source]¶ Callback function invoked before backward propagation.
- Parameters
function (FunctionNode) – Function object to which the function hook is registered.
in_data (tuple of N-dimensional array) – Input data of forward propagation.
out_grad (tuple of N-dimensional array) – Gradient data of backward propagation.
-
deleted
(function)[source]¶ Callback function invoked when the function hook is unregistered
- Parameters
function (FunctionNode) – Function object from which the function hook is deleted.
None
if the function hook was registered globally.
-
forward_postprocess
(function, in_data)[source]¶ Callback function invoked after forward propagation.
- Parameters
function (FunctionNode) – Function object to which the function hook is registered.
in_data (tuple of N-dimensional array) – Input data of forward propagation.
-
forward_preprocess
(function, in_data)[source]¶ Callback function invoked before forward propagation.
- Parameters
function (FunctionNode) – Function object to which the function hook is registered.
in_data (tuple of N-dimensional array) – Input data of forward propagation.
-
__eq__
()¶ Return self==value.
-
__ne__
()¶ Return self!=value.
-
__lt__
()¶ Return self<value.
-
__le__
()¶ Return self<=value.
-
__gt__
()¶ Return self>value.
-
__ge__
()¶ Return self>=value.
Attributes
-
name
= 'FunctionHook'¶