aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/framework/tensor_util.py3
-rw-r--r--tensorflow/python/framework/tensor_util_test.py15
2 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 1663136507..922a655a70 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -574,7 +574,8 @@ def _ConstantValue(tensor):
elif tensor.op.type == "Rank":
input_shape = tensor.op.inputs[0].get_shape()
if input_shape.ndims is not None:
- return np.array([input_shape.ndims], dtype=np.int32)
+ return np.ndarray(shape=(), buffer=np.array([input_shape.ndims]),
+ dtype=np.int32)
else:
return None
elif tensor.op.type == "Range":
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 6ce23da162..b2f535b2e2 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -23,6 +23,7 @@ import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_state_ops
@@ -577,7 +578,21 @@ class ConstantValueTest(tf.test.TestCase):
def testRank(self):
tf_val = tf.rank(tf.constant(0.0, shape=[1, 2, 3]))
c_val = tf.contrib.util.constant_value(tf_val)
+
+ self.assertEqual(np.ndarray, type(c_val))
+ self.assertEqual((), c_val.shape)
+ self.assertEqual(3, c_val)
+
+ # Repeat test using array_ops.rank_internal to avoid the optimization that
+ # happens in the rank function.
+ tf_val = array_ops.rank_internal(tf.constant(0.0, shape=[1, 2, 3]),
+ optimize=False)
+ c_val = tf.contrib.util.constant_value(tf_val)
+
+ self.assertEqual(np.ndarray, type(c_val))
+ self.assertEqual((), c_val.shape)
self.assertEqual(3, c_val)
+ self.assertEqual([3], c_val)
def testCast(self):
np_val = np.random.rand(3, 4, 7).astype(np.float32)