diff options
author | 2016-08-23 20:26:32 -0800 | |
---|---|---|
committer | 2016-08-23 21:32:33 -0700 | |
commit | 2269152197182ad4b9413d9948e6a248bf109204 (patch) | |
tree | 2a827193ef0208c16586d99b1d5a0562b358dac4 | |
parent | dba3200515e71ef99d93209b43345ffbd7051390 (diff) |
Prevent non-termination when iterating over an unknown shape.
Change: 131133261
-rw-r--r-- | tensorflow/python/framework/tensor_shape.py | 7 | ||||
-rw-r--r-- | tensorflow/python/framework/tensor_shape_test.py | 7 |
2 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 6d3fdb6d80..dcc31f6701 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -487,6 +487,13 @@ class TensorShape(object): # Python 3 wants __bool__, Python 2.7 wants __nonzero__ __nonzero__ = __bool__ + def __iter__(self): + """Returns `self.dims` if the rank is known, otherwise raises ValueError.""" + if self._dims is None: + raise ValueError("Cannot iterate over a shape with unknown rank.") + else: + return iter(self._dims) + def __getitem__(self, key): """Returns the value of a dimension or a shape, depending on the key. diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py index 502be1df7d..01bcd2e2f2 100644 --- a/tensorflow/python/framework/tensor_shape_test.py +++ b/tensorflow/python/framework/tensor_shape_test.py @@ -187,6 +187,9 @@ class ShapeTest(test_util.TensorFlowTestCase): len(s) self.assertFalse(s) self.assertIs(None, s.dims) + with self.assertRaises(ValueError): + for _ in tensor_shape.TensorShape(None): + pass def testFullyDefinedShape(self): s = tensor_shape.TensorShape([tensor_shape.Dimension( @@ -205,6 +208,8 @@ class ShapeTest(test_util.TensorFlowTestCase): self.assertEqual([3, 4, 7], s.as_list()) s.assert_is_compatible_with([3, 4, 7]) s.assert_same_rank([6, 3, 7]) + for d1, d2 in zip(s, [3, 4, 7]): + assert d1.value == d2 def testPartiallyDefinedShape(self): s = tensor_shape.TensorShape([tensor_shape.Dimension( @@ -219,6 +224,8 @@ class ShapeTest(test_util.TensorFlowTestCase): self.assertEqual(tensor_shape.Dimension(None).value, s[1].value) self.assertEqual(tensor_shape.Dimension(7), s[2]) s.assert_same_rank([6, 3, 7]) + for d1, d2 in zip(s, [3, None, 7]): + assert d1.value == d2 def testMergeFullShapes(self): self.assertEqual([3, 4, 7], |