diff options
Diffstat (limited to 'tensorflow/python/training/moving_averages_test.py')
-rw-r--r-- | tensorflow/python/training/moving_averages_test.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index 53d524a325..d5de85febd 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -6,9 +6,11 @@ from __future__ import print_function import tensorflow.python.platform from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.framework import types from tensorflow.python.ops import constant_op +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.training import moving_averages @@ -130,6 +132,19 @@ class ExponentialMovingAverageTest(test_util.TensorFlowTestCase): self.assertEqual(ema.average_name(v1), ema.average(v1).op.name) self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name) + def testAverageVariablesDeviceAssignment(self): + with ops.device("dev_v0"): + v0 = variables.Variable(10.0, name="v0") + with ops.device("dev_v1"): + v1 = state_ops.variable_op(shape=[1], dtype=types.float32, name="v1") + tensor2 = v0 + v1 + ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg") + with ops.device("default"): + ema.apply([v0, v1, tensor2]) + self.assertEqual("dev_v0", ema.average(v0).device) + self.assertEqual("dev_v1", ema.average(v1).device) + self.assertEqual("default", ema.average(tensor2).device) + if __name__ == "__main__": googletest.main() |