diff options
author | 2016-11-15 10:54:03 -0800 | |
---|---|---|
committer | 2016-11-15 11:03:31 -0800 | |
commit | 8eda77255d7b123a1ce3dd0b1d873cde3b2ebc96 (patch) | |
tree | ac2e08097848bd4be874c2a05635bc6de78a2ece /tensorflow/python/framework/tensor_util_test.py | |
parent | 89edbedc143952f2ce245f251f5d244db08ed1cc (diff) |
Change constant_value for rank to return a scalar ndarray instead
of a numpy array.
Change: 139220642
Diffstat (limited to 'tensorflow/python/framework/tensor_util_test.py')
-rw-r--r-- | tensorflow/python/framework/tensor_util_test.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 6ce23da162..b2f535b2e2 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -23,6 +23,7 @@ import tensorflow as tf from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_state_ops @@ -577,7 +578,21 @@ class ConstantValueTest(tf.test.TestCase): def testRank(self): tf_val = tf.rank(tf.constant(0.0, shape=[1, 2, 3])) c_val = tf.contrib.util.constant_value(tf_val) + + self.assertEqual(np.ndarray, type(c_val)) + self.assertEqual((), c_val.shape) + self.assertEqual(3, c_val) + + # Repeat test using array_ops.rank_internal to avoid the optimization that + # happens in the rank function. + tf_val = array_ops.rank_internal(tf.constant(0.0, shape=[1, 2, 3]), + optimize=False) + c_val = tf.contrib.util.constant_value(tf_val) + + self.assertEqual(np.ndarray, type(c_val)) + self.assertEqual((), c_val.shape) self.assertEqual(3, c_val) + self.assertEqual([3], c_val) def testCast(self): np_val = np.random.rand(3, 4, 7).astype(np.float32) |