aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt/python/training/moving_average_optimizer.py
blob: 00cb6bfba92a8f15c493f7476dab35ef7390c0c0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Moving average optimizer."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six

from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import moving_averages
from tensorflow.python.training import optimizer
from tensorflow.python.training import saver


class MovingAverageOptimizer(optimizer.Optimizer):
  """Optimizer that computes a moving average of the variables.

  Empirically it has been found that using the moving average of the trained
  parameters of a deep network is better than using its trained parameters
  directly. This optimizer allows you to compute this moving average and swap
  the variables at save time so that any code outside of the training loop will
  use by default the averaged values instead of the original ones.

  Example of usage:

  ```python

  // Encapsulate your favorite optimizer (here the momentum one)
  // inside the MovingAverageOptimizer.
  opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
  opt = tf.contrib.opt.MovingAverageOptimizer(opt)
  // Then create your model and all its variables.
  model = build_model()
  // Add the training op that optimizes using opt.
  // This needs to be called before swapping_saver().
  opt.minimize(cost, var_list)
  // Then create your saver like this:
  saver = opt.swapping_saver()
  // Pass it to your training loop.
      slim.learning.train(
          model,
          ...
          saver=saver)
  ```

  Note that for evaluation, the normal saver should be used instead of
  swapping_saver().
  """

  def __init__(self, opt, average_decay=0.9999, num_updates=None,
               sequential_update=True):
    """Construct a new MovingAverageOptimizer.

    Args:
      opt: A tf.Optimizer that will be used to compute and apply gradients.
      average_decay: Float.  Decay to use to maintain the moving averages
                     of trained variables.
                     See tf.train.ExponentialMovingAverage for details.
      num_updates: Optional count of number of updates applied to variables.
                   See tf.train.ExponentialMovingAverage for details.
      sequential_update: Bool. If False, will compute the moving average at the
                         same time as the model is updated, potentially doing
                         benign data races.
                         If True, will update the moving average after gradient
                         updates.
    """
    self._optimizer = opt
    self._ema = moving_averages.ExponentialMovingAverage(
        average_decay, num_updates=num_updates)
    self._variable_map = None
    self._sequential_update = sequential_update

  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    train_op = self._optimizer.apply_gradients(
        grads_and_vars, global_step=global_step, name=name)
    var_list = [x[1] for x in grads_and_vars if x[0] is not None]
    self._variable_map = {}
    if self._sequential_update:
      with ops.control_dependencies([train_op]):
        ma_op = self._ema.apply(var_list)
    else:
      ma_op = self._ema.apply(var_list)

    for v in var_list:
      v_avg = self._ema.average(v)
      self._variable_map[v.op.name] = v_avg
      self._variable_map[v_avg.op.name] = v
    return control_flow_ops.group(train_op, ma_op, name="train_with_avg")

  def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
    """Create a saver swapping moving averages and variables.

    You should use this saver during training.  It will save the moving averages
    of the trained parameters under the original parameter names.  For
    evaluations or inference you should use a regular saver and it will
    automatically use the moving averages for the trained variable.

    You must call this function after all variables have been created and after
    you have called Optimizer.minimize().

    Args:
      var_list: List of variables to save, as per `Saver()`.
                If set to None, will save all the variables that have been
                created before this call.
      name: The name of the saver.
      **kwargs: Keyword arguments of `Saver()`.

    Returns:
      A `tf.Saver` object.

    Raises:
      RuntimeError: If apply_gradients or minimize has not been called before.
    """

    if self._variable_map is None:
      raise RuntimeError('Must call apply_gradients or minimize before '
                         'creating the swapping_saver')
    if var_list is None:
      var_list = variables.global_variables()
    if not isinstance(var_list, dict):
      var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
    # Now swap variables and moving averages
    swapped_var_list = {}
    for k, v in six.iteritems(var_list):
      v_swap = self._variable_map.get(v.op.name, None)
      if v_swap:
        swapped_var_list[k] = v_swap
      else:
        swapped_var_list[k] = v
    # Build the swapping saver.
    return saver.Saver(swapped_var_list, name=name, **kwargs)