aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-08-23 20:26:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-23 21:32:33 -0700
commit2269152197182ad4b9413d9948e6a248bf109204 (patch)
tree2a827193ef0208c16586d99b1d5a0562b358dac4
parentdba3200515e71ef99d93209b43345ffbd7051390 (diff)
Prevent non-termination when iterating over an unknown shape.
Change: 131133261
-rw-r--r--tensorflow/python/framework/tensor_shape.py7
-rw-r--r--tensorflow/python/framework/tensor_shape_test.py7
2 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index 6d3fdb6d80..dcc31f6701 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -487,6 +487,13 @@ class TensorShape(object):
# Python 3 wants __bool__, Python 2.7 wants __nonzero__
__nonzero__ = __bool__
+ def __iter__(self):
+ """Returns `self.dims` if the rank is known, otherwise raises ValueError."""
+ if self._dims is None:
+ raise ValueError("Cannot iterate over a shape with unknown rank.")
+ else:
+ return iter(self._dims)
+
def __getitem__(self, key):
"""Returns the value of a dimension or a shape, depending on the key.
diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py
index 502be1df7d..01bcd2e2f2 100644
--- a/tensorflow/python/framework/tensor_shape_test.py
+++ b/tensorflow/python/framework/tensor_shape_test.py
@@ -187,6 +187,9 @@ class ShapeTest(test_util.TensorFlowTestCase):
len(s)
self.assertFalse(s)
self.assertIs(None, s.dims)
+ with self.assertRaises(ValueError):
+ for _ in tensor_shape.TensorShape(None):
+ pass
def testFullyDefinedShape(self):
s = tensor_shape.TensorShape([tensor_shape.Dimension(
@@ -205,6 +208,8 @@ class ShapeTest(test_util.TensorFlowTestCase):
self.assertEqual([3, 4, 7], s.as_list())
s.assert_is_compatible_with([3, 4, 7])
s.assert_same_rank([6, 3, 7])
+ for d1, d2 in zip(s, [3, 4, 7]):
+ assert d1.value == d2
def testPartiallyDefinedShape(self):
s = tensor_shape.TensorShape([tensor_shape.Dimension(
@@ -219,6 +224,8 @@ class ShapeTest(test_util.TensorFlowTestCase):
self.assertEqual(tensor_shape.Dimension(None).value, s[1].value)
self.assertEqual(tensor_shape.Dimension(7), s[2])
s.assert_same_rank([6, 3, 7])
+ for d1, d2 in zip(s, [3, None, 7]):
+ assert d1.value == d2
def testMergeFullShapes(self):
self.assertEqual([3, 4, 7],