diff options
author | Derek Murray <mrry@google.com> | 2015-12-15 14:46:48 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2015-12-15 14:46:48 -0800 |
commit | 1b0b52430b4038a90c185bb02b687e76133d259c (patch) | |
tree | 487ac71c61737e05d64261597cc0a3f5649a2cf5 | |
parent | c435b0aa60f67af8b7cce67e1be414f823759393 (diff) |
Add a human-readable Tensor.__repr__(), and improve TensorShape.__str__().
The repr replaces the default Python representation for an object with
information about the Tensor's name, shape and dtype. TensorShape's
new str() is a more compact representation that makes it easier to
tell, at a glance, what the shape represents.
For example:
print repr(tf.placeholder(tf.qint32, shape=(32, None, 2), name="c"))
# ==> <tf.Tensor 'c:0' shape=(32, ?, 2) dtype=qint32>
Fixes #460.
Change: 110268564
-rw-r--r-- | tensorflow/python/client/session.py | 3 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 4 | ||||
-rw-r--r-- | tensorflow/python/framework/tensor_shape.py | 10 | ||||
-rw-r--r-- | tensorflow/python/framework/tensor_shape_test.py | 14 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/constant_op_test.py | 13 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/unpack_op_test.py | 2 |
6 files changed, 43 insertions, 3 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 4cc1b9809b..74358349b7 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -360,8 +360,7 @@ class BaseSession(SessionInterface): raise ValueError( 'Cannot feed value of shape %r for Tensor %r, ' 'which has shape %r' - % (np_val.shape, subfeed_t.name, - tuple(subfeed_t.get_shape().dims))) + % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) feed_dict_string[compat.as_bytes(subfeed_t.name)] = np_val # Run request and get response. diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 00051b4a7f..d73cf47b71 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -383,6 +383,10 @@ class Tensor(object): (", dtype=%s" % self._dtype.name) if self._dtype else "", (", device=%s" % self.device) if self.device else "") + def __repr__(self): + return "<tf.Tensor '%s' shape=%s dtype=%s>" % ( + self.name, self.get_shape(), self._dtype.name) + def __hash__(self): # Necessary to support Python's collection membership operators return id(self) diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index d1fb6ddc06..c692e66577 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -424,6 +424,16 @@ class TensorShape(object): def __repr__(self): return "TensorShape(%s)" % self._dims + def __str__(self): + if self.ndims is None: + return "<unknown>" + elif self.ndims == 1: + length = self._dims[0].value + return "(%s,)" % (str(length) if length is not None else "?") + else: + return "(%s)" % ", ".join(str(d.value) if d.value is not None else "?" + for d in self._dims) + @property def dims(self): """Returns a list of Dimensions, or None if the shape is unspecified.""" diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py index ca5da2ba72..9efafb1f68 100644 --- a/tensorflow/python/framework/tensor_shape_test.py +++ b/tensorflow/python/framework/tensor_shape_test.py @@ -268,6 +268,20 @@ class ShapeTest(test_util.TensorFlowTestCase): self.assertEqual(tensor_shape.TensorShape([1, 37, 42]), tensor_shape.as_shape(proto)) + def testStr(self): + self.assertEqual("<unknown>", str(tensor_shape.unknown_shape())) + self.assertEqual("(?,)", str(tensor_shape.unknown_shape(ndims=1))) + self.assertEqual("(?, ?)", str(tensor_shape.unknown_shape(ndims=2))) + self.assertEqual("(?, ?, ?)", str(tensor_shape.unknown_shape(ndims=3))) + + self.assertEqual("()", str(tensor_shape.scalar())) + self.assertEqual("(7,)", str(tensor_shape.vector(7))) + self.assertEqual("(3, 8)", str(tensor_shape.matrix(3, 8))) + self.assertEqual("(4, 5, 2)", str(tensor_shape.TensorShape([4, 5, 2]))) + + self.assertEqual("(32, ?, 1, 9)", + str(tensor_shape.TensorShape([32, None, 1, 9]))) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index 44182683ca..5bc7f2b37d 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -545,6 +545,19 @@ class PlaceholderTest(tf.test.TestCase): with self.assertRaisesOpError(r"Shape \[-1,10\] has negative dimensions"): s.eval() + def testTensorStr(self): + a = tf.placeholder(tf.float32, name="a") + self.assertEqual("<tf.Tensor 'a:0' shape=<unknown> dtype=float32>", repr(a)) + + b = tf.placeholder(tf.int32, shape=(32, 40), name="b") + self.assertEqual( + "<tf.Tensor 'b:0' shape=(32, 40) dtype=int32>", + repr(b)) + + c = tf.placeholder(tf.qint32, shape=(32, None, 2), name="c") + self.assertEqual( + "<tf.Tensor 'c:0' shape=(32, ?, 2) dtype=qint32>", + repr(c)) if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/unpack_op_test.py b/tensorflow/python/kernel_tests/unpack_op_test.py index 47ed9e617c..621c717617 100644 --- a/tensorflow/python/kernel_tests/unpack_op_test.py +++ b/tensorflow/python/kernel_tests/unpack_op_test.py @@ -65,7 +65,7 @@ class UnpackOpTest(tf.test.TestCase): def testCannotInferNum(self): x = tf.placeholder(np.float32) with self.assertRaisesRegexp( - ValueError, r'Cannot infer num from shape TensorShape\(None\)'): + ValueError, r'Cannot infer num from shape <unknown>'): tf.unpack(x) |