# Source code for chainer.training.trainer

import collections
import os

import six

from chainer import reporter as reporter_module
from chainer.training import extension as extension_module
from chainer.training import trigger as trigger_module

class _ExtensionEntry(object):

def __init__(self, extension, priority, trigger, invoke_before_training):
self.extension = extension
self.trigger = trigger
self.invoke_before_training = invoke_before_training
self.priority = priority

[docs]class Trainer(object):

"""The standard training loop in Chainer.

Trainer is an implementation of a training loop. Users can invoke the
training by calling the :meth:run method.

Each iteration of the training loop proceeds as follows.

and backward computations, and an execution of the update formula.
These are all done by the update object held by the trainer.
- Invocation of trainer extensions in the descending order of their
priorities. A trigger object is attached to each extension, and it
decides at each iteration whether the extension should be executed.
Trigger objects are callable objects that take the trainer object as the
argument and return a boolean value indicating whether the extension
should be called or not.

Extensions are callable objects that take the trainer object as the
argument. There are two ways to define custom extensions: inheriting the
:class:Extension class, and decorating functions by
:func:make_extension. See :class:Extension for more details on custom
extensions.

Users can register extensions to the trainer by calling the :meth:extend
method, where some configurations can be added.

- Trigger object, which is also explained above. In most cases,
:class:IntervalTrigger is used, in which case users can simply specify
a tuple of the interval length and its unit, like
(1000, 'iteration') or (1, 'epoch').
- The order of execution of extensions is determined by their priorities.
Extensions of higher priorities are invoked earlier. There are three
standard values for the priorities:

- PRIORITY_WRITER. This is the priority for extensions that write
some records to the :attr:observation dictionary. It includes cases
that the extension directly adds values to the observation dictionary,
or the extension uses the :func:chainer.report function to report
values to the observation dictionary.
- PRIORITY_EDITOR. This is the priority for extensions that edit the
:attr:observation dictionary based on already reported values.
- PRIORITY_READER. This is the priority for extensions that only read
records from the :attr:observation dictionary. This is also suitable
for extensions that do not use the :attr:observation dictionary at
all.

- Extensions with invoke_before_training flag on are also invoked at
the beginning of the training loop. Extensions that update the training
status (e.g., changing learning rates) should have this flag to be
True to ensure that resume of the training loop correctly recovers
the training status.

The current state of the trainer object and objects handled by the trainer
can be serialized through the standard serialization protocol of Chainer.
It enables us to easily suspend and resume the training loop.

.. note::
The serialization does not recover everything of the training loop. It
only recovers the states which change over the training (e.g.
parameters, optimizer states, the batch iterator state, extension
states, etc.). You must initialize the objects correctly before
deserializing the states.

On the other hand, it means that users can change the settings on
deserialization. For example, the exit condition can be changed on the
deserialization, so users can train the model for some iterations,
suspend it, and then resume it with larger number of total iterations.

During the training, it also creates a :class:~chainer.Reporter object to
store observed values on each update. For each iteration, it creates a
fresh observation dictionary and stores it in the :attr:observation
attribute.

Links of the target model of each optimizer are registered to the reporter
object as observers, where the name of each observer is constructed as the
format <optimizer name><link name>. The link name is given by the
:meth:chainer.Link.namedlink method, which represents the path to each
link in the hierarchy. Other observers can be registered by accessing the
reporter object via the :attr:reporter attribute.

The default trainer is plain, i.e., it does not contain any extensions.

Args:
updater (~chainer.training.Updater): Updater object. It defines how to
update the models.
stop_trigger: Trigger that determines when to stop the training loop.
If it is not callable, it is passed to :class:IntervalTrigger.

Attributes:
updater: The updater object for this trainer.
stop_trigger: Trigger that determines when to stop the training loop.
The training loop stops at the iteration on which this trigger
returns True.
observation: Observation of values made at the last update. See the
:class:Reporter class for details.
out: Output directory.
reporter: Reporter object to report observed values.

"""

def __init__(self, updater, stop_trigger=None, out='result'):
self.updater = updater
self.stop_trigger = trigger_module.get_trigger(stop_trigger)
self.observation = {}
self.out = out

reporter = reporter_module.Reporter()
for name, optimizer in six.iteritems(updater.get_all_optimizers()):
self.reporter = reporter

self._done = False
self._extensions = collections.OrderedDict()

updater.connect_trainer(self)

[docs]    def extend(self, extension, name=None, trigger=None, priority=None,
invoke_before_training=None):
"""Registers an extension to the trainer.

:class:Extension is a callable object which is called after each
update unless the corresponding trigger object decides to skip the
iteration. The order of execution is determined by priorities:
extensions with higher priorities are called earlier in each iteration.
Extensions with the same priority are invoked in the order of
registrations.

If two or more extensions with the same name are registered, suffixes
are added to the names of the second to last extensions. The suffix is
_N where N is the ordinal of the extensions.

See :class:Extension for the interface of extensions.

Args:
extension: Extension to register.
name (str): Name of the extension. If it is omitted, the
default_name attribute of the extension is used instead.
Note that the name would be suffixed by an ordinal in case of
duplicated names as explained above.
trigger (tuple or Trigger): Trigger object that determines when to
invoke the extension. If it is None, extension.trigger
is used instead. If the trigger is not callable, it is passed
to :class:IntervalTrigger to build an interval trigger.
priority (int): Invocation priority of the extension. Extensions
are invoked in the descending order of priorities in each
iteration. If this is None, extension.priority is used
invoke_before_training (bool): If True, the extension is also
invoked just before entering the training loop. If this
None, extension.invoke_before_training is used instead.
This option is mainly used for extensions that alter the
training configuration (e.g., learning rates); in such a case,
resuming from snapshots require the call of extension to
recover the configuration before any updates.

"""
if name is None:
name = getattr(extension, 'name', None)
if name is None:
name = getattr(extension, 'default_name', None)
if name is None:
raise TypeError('name is not given for the extension')
if name == 'training':
raise ValueError(
'the name "training" is prohibited as an extension name')

if trigger is None:
trigger = getattr(extension, 'trigger', None)
trigger = trigger_module.get_trigger(trigger)

if priority is None:
priority = getattr(

if invoke_before_training is None:
invoke_before_training = getattr(
extension, 'invoke_before_training', False)

modified_name = name
ordinal = 0
while modified_name in self._extensions:
ordinal += 1
modified_name = '%s_%d' % (name, ordinal)

extension.name = modified_name
self._extensions[modified_name] = _ExtensionEntry(
extension, priority, trigger, invoke_before_training)

[docs]    def get_extension(self, name):
"""Returns the extension of a given name.

Args:
name (str): Name of the extension.

Returns:
Extension.

"""
extensions = self._extensions
if name in extensions:
return extensions[name].extension
else:

[docs]    def run(self):
"""Executes the training loop.

This method is the core of Trainer. It executes the whole loop of
training the models.

Note that this method cannot run multiple times for one trainer object.

"""
if self._done:
raise RuntimeError('cannot run training loop multiple times')

try:
os.makedirs(self.out)
except OSError:
pass

# sort extensions by priorities
extension_order = sorted(
self._extensions.keys(),
key=lambda name: self._extensions[name].priority, reverse=True)
extensions = [(name, self._extensions[name])
for name in extension_order]

# invoke extensions before the loop
for _, entry in extensions:
if entry.invoke_before_training:
entry.extension(self)

update = self.updater.update
reporter = self.reporter
stop_trigger = self.stop_trigger

# main training loop
try:
while not stop_trigger(self):
self.observation = {}
with reporter.scope(self.observation):
update()
for name, entry in extensions:
if entry.trigger(self):
entry.extension(self)
finally:
for _, entry in extensions:
finalize = entry.extension.finalize
if finalize:
finalize()
self.updater.finalize()

self._done = True

def serialize(self, serializer):
self.updater.serialize(serializer['updater'])
if hasattr(self.stop_trigger, 'serialize'):
self.stop_trigger.serialize(serializer['stop_trigger'])

s = serializer['extensions']
t = serializer['extension_triggers']
for name, entry in six.iteritems(self._extensions):
if hasattr(entry.extension, 'serialize'):
entry.extension.serialize(s[name])
if hasattr(entry.trigger, 'serialize'):
entry.trigger.serialize(t[name])