aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/tensor_util_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-15 10:54:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-15 11:03:31 -0800
commit8eda77255d7b123a1ce3dd0b1d873cde3b2ebc96 (patch)
treeac2e08097848bd4be874c2a05635bc6de78a2ece /tensorflow/python/framework/tensor_util_test.py
parent89edbedc143952f2ce245f251f5d244db08ed1cc (diff)
Change constant_value for rank to return a scalar ndarray instead
of a numpy array. Change: 139220642
Diffstat (limited to 'tensorflow/python/framework/tensor_util_test.py')
-rw-r--r--tensorflow/python/framework/tensor_util_test.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 6ce23da162..b2f535b2e2 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -23,6 +23,7 @@ import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_state_ops
@@ -577,7 +578,21 @@ class ConstantValueTest(tf.test.TestCase):
def testRank(self):
tf_val = tf.rank(tf.constant(0.0, shape=[1, 2, 3]))
c_val = tf.contrib.util.constant_value(tf_val)
+
+ self.assertEqual(np.ndarray, type(c_val))
+ self.assertEqual((), c_val.shape)
+ self.assertEqual(3, c_val)
+
+ # Repeat test using array_ops.rank_internal to avoid the optimization that
+ # happens in the rank function.
+ tf_val = array_ops.rank_internal(tf.constant(0.0, shape=[1, 2, 3]),
+ optimize=False)
+ c_val = tf.contrib.util.constant_value(tf_val)
+
+ self.assertEqual(np.ndarray, type(c_val))
+ self.assertEqual((), c_val.shape)
self.assertEqual(3, c_val)
+ self.assertEqual([3], c_val)
def testCast(self):
np_val = np.random.rand(3, 4, 7).astype(np.float32)