diff options
-rw-r--r-- | tensorflow/python/framework/tensor_util.py | 3 | ||||
-rw-r--r-- | tensorflow/python/framework/tensor_util_test.py | 15 |
2 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 1663136507..922a655a70 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -574,7 +574,8 @@ def _ConstantValue(tensor): elif tensor.op.type == "Rank": input_shape = tensor.op.inputs[0].get_shape() if input_shape.ndims is not None: - return np.array([input_shape.ndims], dtype=np.int32) + return np.ndarray(shape=(), buffer=np.array([input_shape.ndims]), + dtype=np.int32) else: return None elif tensor.op.type == "Range": diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 6ce23da162..b2f535b2e2 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -23,6 +23,7 @@ import tensorflow as tf from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_state_ops @@ -577,7 +578,21 @@ class ConstantValueTest(tf.test.TestCase): def testRank(self): tf_val = tf.rank(tf.constant(0.0, shape=[1, 2, 3])) c_val = tf.contrib.util.constant_value(tf_val) + + self.assertEqual(np.ndarray, type(c_val)) + self.assertEqual((), c_val.shape) + self.assertEqual(3, c_val) + + # Repeat test using array_ops.rank_internal to avoid the optimization that + # happens in the rank function. + tf_val = array_ops.rank_internal(tf.constant(0.0, shape=[1, 2, 3]), + optimize=False) + c_val = tf.contrib.util.constant_value(tf_val) + + self.assertEqual(np.ndarray, type(c_val)) + self.assertEqual((), c_val.shape) self.assertEqual(3, c_val) + self.assertEqual([3], c_val) def testCast(self): np_val = np.random.rand(3, 4, 7).astype(np.float32) |