aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-08-23 09:07:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 09:11:16 -0700
commit3e0709476c411840de9b7c016c6e0dd63e0eec78 (patch)
treeb564b6f61480e8832c8c2a6bf539a05f0b8b6730 /tensorflow/python/training
parent9289302ad3d7941ddb9ce2d0dff56b333cbcf208 (diff)
Allows tf.train.ExponentialMovingAverage to work with eager execution.
In the process removes unnecessary restriction on it being called multiple times on the same variables; it might be necessary to do so if you want to call ema.apply with different control dependencies in different calls to session.run, for example. PiperOrigin-RevId: 209945355
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/moving_averages.py55
-rw-r--r--tensorflow/python/training/moving_averages_test.py21
2 files changed, 49 insertions, 27 deletions
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 4b91d1e963..177a7ddfa5 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -363,10 +363,12 @@ class ExponentialMovingAverage(object):
`GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to
`tf.global_variables()`.
- Returns an op that updates all shadow variables as described above.
+ Returns an op that updates all shadow variables from the current value of
+ their associated variables.
- Note that `apply()` can be called multiple times with different lists of
- variables.
+ Note that `apply()` can be called multiple times. When eager execution is
+ enabled each call to apply will update the variables once, so this needs to
+ be called in a loop.
Args:
var_list: A list of Variable or Tensor objects. The variables
@@ -389,31 +391,30 @@ class ExponentialMovingAverage(object):
dtypes.float64]:
raise TypeError("The variables must be half, float, or double: %s" %
var.name)
- if var in self._averages:
- raise ValueError("Moving average already computed for: %s" % var.name)
- # For variables: to lower communication bandwidth across devices we keep
- # the moving averages on the same device as the variables. For other
- # tensors, we rely on the existing device allocation mechanism.
- with ops.init_scope():
- if isinstance(var, variables.Variable):
- avg = slot_creator.create_slot(var,
- var.initialized_value(),
- self.name,
- colocate_with_primary=True)
- # NOTE(mrry): We only add `tf.Variable` objects to the
- # `MOVING_AVERAGE_VARIABLES` collection.
- ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
- else:
- avg = slot_creator.create_zeros_slot(
- var,
- self.name,
- colocate_with_primary=(var.op.type in ["Variable",
- "VariableV2",
- "VarHandleOp"]))
- if self._zero_debias:
- zero_debias_true.add(avg)
- self._averages[var] = avg
+ if var not in self._averages:
+ # For variables: to lower communication bandwidth across devices we keep
+ # the moving averages on the same device as the variables. For other
+ # tensors, we rely on the existing device allocation mechanism.
+ with ops.init_scope():
+ if isinstance(var, variables.Variable):
+ avg = slot_creator.create_slot(var,
+ var.initialized_value(),
+ self.name,
+ colocate_with_primary=True)
+ # NOTE(mrry): We only add `tf.Variable` objects to the
+ # `MOVING_AVERAGE_VARIABLES` collection.
+ ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
+ else:
+ avg = slot_creator.create_zeros_slot(
+ var,
+ self.name,
+ colocate_with_primary=(var.op.type in ["Variable",
+ "VariableV2",
+ "VarHandleOp"]))
+ if self._zero_debias:
+ zero_debias_true.add(avg)
+ self._averages[var] = avg
with ops.name_scope(self.name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay")
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 3e85e6bfa7..fdb8d795c3 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import variable_scope
@@ -254,6 +256,25 @@ class ExponentialMovingAverageTest(test.TestCase):
self.assertEqual(1, sess.run(v0))
self.assertEqual([17.5], sess.run(v1_avg))
+ @test_util.run_in_graph_and_eager_modes
+ def testBasicEager(self):
+ v0 = variables.Variable(1.0)
+ v1 = variables.Variable(2.0)
+
+ ema = moving_averages.ExponentialMovingAverage(0.25)
+ op = ema.apply([v0, v1])
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(op)
+
+ self.evaluate(v0.assign(2.0))
+ self.evaluate(v1.assign(4.0))
+
+ self.evaluate(ema.apply([v0, v1]))
+
+ self.assertAllEqual(self.evaluate(ema.average(v0)), 1.75)
+ self.assertAllEqual(self.evaluate(ema.average(v1)), 3.5)
+
def averageVariablesNamesHelper(self, zero_debias):
with self.test_session():
v0 = variables.Variable(10.0, name="v0")