aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2015-12-15 14:46:48 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-15 14:46:48 -0800
commit1b0b52430b4038a90c185bb02b687e76133d259c (patch)
tree487ac71c61737e05d64261597cc0a3f5649a2cf5
parentc435b0aa60f67af8b7cce67e1be414f823759393 (diff)
Add a human-readable Tensor.__repr__(), and improve TensorShape.__str__().
The repr replaces the default Python representation for an object with information about the Tensor's name, shape and dtype. TensorShape's new str() is a more compact representation that makes it easier to tell, at a glance, what the shape represents. For example: print repr(tf.placeholder(tf.qint32, shape=(32, None, 2), name="c")) # ==> <tf.Tensor 'c:0' shape=(32, ?, 2) dtype=qint32> Fixes #460. Change: 110268564
-rw-r--r--tensorflow/python/client/session.py3
-rw-r--r--tensorflow/python/framework/ops.py4
-rw-r--r--tensorflow/python/framework/tensor_shape.py10
-rw-r--r--tensorflow/python/framework/tensor_shape_test.py14
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/unpack_op_test.py2
6 files changed, 43 insertions, 3 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 4cc1b9809b..74358349b7 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -360,8 +360,7 @@ class BaseSession(SessionInterface):
raise ValueError(
'Cannot feed value of shape %r for Tensor %r, '
'which has shape %r'
- % (np_val.shape, subfeed_t.name,
- tuple(subfeed_t.get_shape().dims)))
+ % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
feed_dict_string[compat.as_bytes(subfeed_t.name)] = np_val
# Run request and get response.
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 00051b4a7f..d73cf47b71 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -383,6 +383,10 @@ class Tensor(object):
(", dtype=%s" % self._dtype.name) if self._dtype else "",
(", device=%s" % self.device) if self.device else "")
+ def __repr__(self):
+ return "<tf.Tensor '%s' shape=%s dtype=%s>" % (
+ self.name, self.get_shape(), self._dtype.name)
+
def __hash__(self):
# Necessary to support Python's collection membership operators
return id(self)
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index d1fb6ddc06..c692e66577 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -424,6 +424,16 @@ class TensorShape(object):
def __repr__(self):
return "TensorShape(%s)" % self._dims
+ def __str__(self):
+ if self.ndims is None:
+ return "<unknown>"
+ elif self.ndims == 1:
+ length = self._dims[0].value
+ return "(%s,)" % (str(length) if length is not None else "?")
+ else:
+ return "(%s)" % ", ".join(str(d.value) if d.value is not None else "?"
+ for d in self._dims)
+
@property
def dims(self):
"""Returns a list of Dimensions, or None if the shape is unspecified."""
diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py
index ca5da2ba72..9efafb1f68 100644
--- a/tensorflow/python/framework/tensor_shape_test.py
+++ b/tensorflow/python/framework/tensor_shape_test.py
@@ -268,6 +268,20 @@ class ShapeTest(test_util.TensorFlowTestCase):
self.assertEqual(tensor_shape.TensorShape([1, 37, 42]),
tensor_shape.as_shape(proto))
+ def testStr(self):
+ self.assertEqual("<unknown>", str(tensor_shape.unknown_shape()))
+ self.assertEqual("(?,)", str(tensor_shape.unknown_shape(ndims=1)))
+ self.assertEqual("(?, ?)", str(tensor_shape.unknown_shape(ndims=2)))
+ self.assertEqual("(?, ?, ?)", str(tensor_shape.unknown_shape(ndims=3)))
+
+ self.assertEqual("()", str(tensor_shape.scalar()))
+ self.assertEqual("(7,)", str(tensor_shape.vector(7)))
+ self.assertEqual("(3, 8)", str(tensor_shape.matrix(3, 8)))
+ self.assertEqual("(4, 5, 2)", str(tensor_shape.TensorShape([4, 5, 2])))
+
+ self.assertEqual("(32, ?, 1, 9)",
+ str(tensor_shape.TensorShape([32, None, 1, 9])))
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 44182683ca..5bc7f2b37d 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -545,6 +545,19 @@ class PlaceholderTest(tf.test.TestCase):
with self.assertRaisesOpError(r"Shape \[-1,10\] has negative dimensions"):
s.eval()
+ def testTensorStr(self):
+ a = tf.placeholder(tf.float32, name="a")
+ self.assertEqual("<tf.Tensor 'a:0' shape=<unknown> dtype=float32>", repr(a))
+
+ b = tf.placeholder(tf.int32, shape=(32, 40), name="b")
+ self.assertEqual(
+ "<tf.Tensor 'b:0' shape=(32, 40) dtype=int32>",
+ repr(b))
+
+ c = tf.placeholder(tf.qint32, shape=(32, None, 2), name="c")
+ self.assertEqual(
+ "<tf.Tensor 'c:0' shape=(32, ?, 2) dtype=qint32>",
+ repr(c))
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/unpack_op_test.py b/tensorflow/python/kernel_tests/unpack_op_test.py
index 47ed9e617c..621c717617 100644
--- a/tensorflow/python/kernel_tests/unpack_op_test.py
+++ b/tensorflow/python/kernel_tests/unpack_op_test.py
@@ -65,7 +65,7 @@ class UnpackOpTest(tf.test.TestCase):
def testCannotInferNum(self):
x = tf.placeholder(np.float32)
with self.assertRaisesRegexp(
- ValueError, r'Cannot infer num from shape TensorShape\(None\)'):
+ ValueError, r'Cannot infer num from shape <unknown>'):
tf.unpack(x)