aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/values_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 8e44f2fea1..91a43d4999 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -32,6 +32,7 @@ from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
@@ -79,6 +80,30 @@ class DistributedValuesTest(test.TestCase):
with self.assertRaises(AssertionError):
v = values.DistributedValues({"/device:cpu:0": 42})
+ def testIsTensorLike(self):
+ with context.graph_mode(), \
+ ops.Graph().as_default(), \
+ ops.device("/device:CPU:0"):
+ one = constant_op.constant(1)
+ two = constant_op.constant(2)
+ v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two})
+ self.assertEqual(two, v.get("/device:GPU:0"))
+ self.assertEqual(one, v.get())
+ self.assertTrue(v.is_tensor_like)
+ self.assertTrue(tensor_util.is_tensor(v))
+
+ def testIsTensorLikeWithAConstant(self):
+ with context.graph_mode(), \
+ ops.Graph().as_default(), \
+ ops.device("/device:CPU:0"):
+ one = constant_op.constant(1)
+ two = 2.0
+ v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two})
+ self.assertEqual(two, v.get("/device:GPU:0"))
+ self.assertEqual(one, v.get())
+ self.assertFalse(v.is_tensor_like)
+ self.assertFalse(tensor_util.is_tensor(v))
+
class DistributedDelegateTest(test.TestCase):