aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/moving_averages_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/moving_averages_test.py')
-rw-r--r--tensorflow/python/training/moving_averages_test.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 53d524a325..d5de85febd 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -6,9 +6,11 @@ from __future__ import print_function
import tensorflow.python.platform
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import types
from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import moving_averages
@@ -130,6 +132,19 @@ class ExponentialMovingAverageTest(test_util.TensorFlowTestCase):
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
+ def testAverageVariablesDeviceAssignment(self):
+ with ops.device("dev_v0"):
+ v0 = variables.Variable(10.0, name="v0")
+ with ops.device("dev_v1"):
+ v1 = state_ops.variable_op(shape=[1], dtype=types.float32, name="v1")
+ tensor2 = v0 + v1
+ ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
+ with ops.device("default"):
+ ema.apply([v0, v1, tensor2])
+ self.assertEqual("dev_v0", ema.average(v0).device)
+ self.assertEqual("dev_v1", ema.average(v1).device)
+ self.assertEqual("default", ema.average(tensor2).device)
+
if __name__ == "__main__":
googletest.main()