Training Tools¶
Chainer provides a standard implementation of the training loops under the chainer.training
module. It is built on top of many other core features of Chainer, including Variable and Function, Link/Chain/ChainList, Optimizer, Dataset, and Reporter/Summary. Compared to the training loop abstraction of other machine learning tool kits, Chainer’s training framework aims at maximal flexibility, while keeps the simplicity for the typical usages. Most components are pluggable, and users can overwrite the definition.
The core of the training loop abstraction is Trainer
, which implements the training loop itself. The training loop consists of two parts: one is Updater
, which actually updates the parameters to train, and the other is Extension
for arbitrary functionalities other than the parameter update.
Updater and some extensions use chainer.dataset
and Iterator
to scan the datasets and load mini-batches. The trainer also uses Reporter
to collect the observed values, and some extensions use DictSummary
to accumulate them and computes the statistics.
You can find many examples for the usage of this training utilities from the official examples. You can also search the extension implementations from Extensions.
Trainer¶
chainer.training.Trainer |
The standard training loop in Chainer. |
Updaters¶
chainer.training.Updater |
Interface of updater objects for trainers. |
chainer.training.updaters.StandardUpdater |
Standard implementation of Updater. |
chainer.training.updaters.ParallelUpdater |
Implementation of a parallel GPU Updater. |
chainer.training.updaters.MultiprocessParallelUpdater |
Implementation of a multiprocess parallel GPU Updater. |
We have two kinds of updaters for multi-gpus training. The pros/cons for the updaters are as follows:
ParallelUpdater:
- (+) Can use the same iterator for any number of GPUs
- (-) No parallelism at CPU side
- (-) GPUs used later may be blocked due to the limit of kernel-launch queue size
MultiprocessParallelUpdater:
- (+) Parallelism at CPU side
- (+) No degrade due to kernel launch queue size
- (-) Need per-process data iterator
- (-) Reporter cannot collect data except for one of the devices
Extensions¶
An extension is a callable object that can perform arbitrary actions during the training loop.
Extensions can be registered to Trainer
by using Trainer.extend()
method, and they are invoked when the Trigger condition is satisfied.
In addition to the built-in extensions listed below, you can define your own extension by implementing Extension
or using the make_extension()
decorator.
See Trainer Extensions for details.
Common¶
chainer.training.Extension |
Base class of trainer extensions. |
chainer.training.make_extension |
Decorator to make given functions into trainer extensions. |
Evaluation and Metrics Collection¶
These extensions provide features to collect additional metrics.
The typical use case is to use Evaluator
to perform evaluation with a validation dataset to compute validation loss/accuracy.
chainer.training.extensions.Evaluator |
Trainer extension to evaluate models on a validation set. |
chainer.training.extensions.MicroAverage |
Calculates micro-average ratio. |
chainer.training.extensions.FailOnNonNumber |
Trainer extension to raise RuntimeError if parameters contain NaN or Inf. |
chainer.training.extensions.ParameterStatistics |
Trainer extension to report parameter statistics. |
chainer.training.extensions.observe_lr |
Returns a trainer extension to record the learning rate. |
chainer.training.extensions.observe_value |
Returns a trainer extension to continuously record a value. |
Optimizer Behavior Control¶
These extensions provide features to adjust optimizer behavior. The typical use case is to change the learning rate of the optimizer over time.
chainer.training.extensions.ExponentialShift |
Trainer extension to exponentially shift an optimizer attribute. |
chainer.training.extensions.InverseShift |
Trainer extension to shift an optimizer attribute. |
chainer.training.extensions.LinearShift |
Trainer extension to change an optimizer attribute linearly. |
chainer.training.extensions.MultistepShift |
Trainer extension to shift an optimizer attribute in several steps. |
chainer.training.extensions.PolynomialShift |
Trainer extension to polynomially shift an optimizer attribute. |
chainer.training.extensions.WarmupShift |
Trainer extension to gradually initialize an optimizer attribute. |
chainer.training.extensions.StepShift |
Trainer extension to shift an optimizer attribute in “steps”. |
Reporting¶
These extensions provide features to perform reporting of metrics and various statistics to the console or files.
chainer.training.extensions.PrintReport |
Trainer extension to print the accumulated results. |
chainer.training.extensions.ProgressBar |
Trainer extension to print a progress bar and recent training status. |
chainer.training.extensions.LogReport |
Trainer extension to output the accumulated results to a log file. |
chainer.training.extensions.PlotReport |
Trainer extension to output plots. |
chainer.training.extensions.VariableStatisticsPlot |
Trainer extension to plot statistics for Variable s. |
chainer.training.extensions.dump_graph |
Returns a trainer extension to dump a computational graph. |
Snapshot¶
These extensions provide features to take snapshots of models.
chainer.training.extensions.snapshot |
Returns a trainer extension to take snapshots of the trainer. |
chainer.training.extensions.snapshot_object |
Returns a trainer extension to take snapshots of a given object. |
Triggers¶
A trigger is a callable object to decide when to process some specific event within the training loop. It takes a Trainer object as the argument, and returns True if some event should be fired.
It is mainly used to determine when to call an extension. It is also used to determine when to quit the training loop.
chainer.training.get_trigger |
Gets a trigger object. |
chainer.training.triggers.BestValueTrigger |
Trigger invoked when specific value becomes best. |
chainer.training.triggers.EarlyStoppingTrigger |
Trigger for Early Stopping |
chainer.training.triggers.IntervalTrigger |
Trigger based on a fixed interval. |
chainer.training.triggers.ManualScheduleTrigger |
Trigger invoked at specified point(s) of iterations or epochs. |
chainer.training.triggers.MaxValueTrigger |
Trigger invoked when specific value becomes maximum. |
chainer.training.triggers.MinValueTrigger |
Trigger invoked when specific value becomes minimum. |
chainer.training.triggers.TimeTrigger |
Trigger based on a fixed time interval. |