diff options
author | 2018-08-23 09:07:16 -0700 | |
---|---|---|
committer | 2018-08-23 09:11:16 -0700 | |
commit | 3e0709476c411840de9b7c016c6e0dd63e0eec78 (patch) | |
tree | b564b6f61480e8832c8c2a6bf539a05f0b8b6730 /tensorflow/python/training | |
parent | 9289302ad3d7941ddb9ce2d0dff56b333cbcf208 (diff) |
Allows tf.train.ExponentialMovingAverage to work with eager execution.
In the process removes unnecessary restriction on it being called multiple
times on the same variables; it might be necessary to do so if you want to
call ema.apply with different control dependencies in different calls to
session.run, for example.
PiperOrigin-RevId: 209945355
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r-- | tensorflow/python/training/moving_averages.py | 55 | ||||
-rw-r--r-- | tensorflow/python/training/moving_averages_test.py | 21 |
2 files changed, 49 insertions, 27 deletions
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 4b91d1e963..177a7ddfa5 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -363,10 +363,12 @@ class ExponentialMovingAverage(object): `GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to `tf.global_variables()`. - Returns an op that updates all shadow variables as described above. + Returns an op that updates all shadow variables from the current value of + their associated variables. - Note that `apply()` can be called multiple times with different lists of - variables. + Note that `apply()` can be called multiple times. When eager execution is + enabled each call to apply will update the variables once, so this needs to + be called in a loop. Args: var_list: A list of Variable or Tensor objects. The variables @@ -389,31 +391,30 @@ class ExponentialMovingAverage(object): dtypes.float64]: raise TypeError("The variables must be half, float, or double: %s" % var.name) - if var in self._averages: - raise ValueError("Moving average already computed for: %s" % var.name) - # For variables: to lower communication bandwidth across devices we keep - # the moving averages on the same device as the variables. For other - # tensors, we rely on the existing device allocation mechanism. - with ops.init_scope(): - if isinstance(var, variables.Variable): - avg = slot_creator.create_slot(var, - var.initialized_value(), - self.name, - colocate_with_primary=True) - # NOTE(mrry): We only add `tf.Variable` objects to the - # `MOVING_AVERAGE_VARIABLES` collection. - ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var) - else: - avg = slot_creator.create_zeros_slot( - var, - self.name, - colocate_with_primary=(var.op.type in ["Variable", - "VariableV2", - "VarHandleOp"])) - if self._zero_debias: - zero_debias_true.add(avg) - self._averages[var] = avg + if var not in self._averages: + # For variables: to lower communication bandwidth across devices we keep + # the moving averages on the same device as the variables. For other + # tensors, we rely on the existing device allocation mechanism. + with ops.init_scope(): + if isinstance(var, variables.Variable): + avg = slot_creator.create_slot(var, + var.initialized_value(), + self.name, + colocate_with_primary=True) + # NOTE(mrry): We only add `tf.Variable` objects to the + # `MOVING_AVERAGE_VARIABLES` collection. + ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var) + else: + avg = slot_creator.create_zeros_slot( + var, + self.name, + colocate_with_primary=(var.op.type in ["Variable", + "VariableV2", + "VarHandleOp"])) + if self._zero_debias: + zero_debias_true.add(avg) + self._averages[var] = avg with ops.name_scope(self.name) as scope: decay = ops.convert_to_tensor(self._decay, name="decay") diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index 3e85e6bfa7..fdb8d795c3 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -18,9 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import variable_scope @@ -254,6 +256,25 @@ class ExponentialMovingAverageTest(test.TestCase): self.assertEqual(1, sess.run(v0)) self.assertEqual([17.5], sess.run(v1_avg)) + @test_util.run_in_graph_and_eager_modes + def testBasicEager(self): + v0 = variables.Variable(1.0) + v1 = variables.Variable(2.0) + + ema = moving_averages.ExponentialMovingAverage(0.25) + op = ema.apply([v0, v1]) + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + self.evaluate(op) + + self.evaluate(v0.assign(2.0)) + self.evaluate(v1.assign(4.0)) + + self.evaluate(ema.apply([v0, v1])) + + self.assertAllEqual(self.evaluate(ema.average(v0)), 1.75) + self.assertAllEqual(self.evaluate(ema.average(v1)), 3.5) + def averageVariablesNamesHelper(self, zero_debias): with self.test_session(): v0 = variables.Variable(10.0, name="v0") |