diff options
author | Derek Murray <mrry@google.com> | 2018-10-09 14:11:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 14:20:07 -0700 |
commit | 4fa59ef694c19dc63d574b2d6a349cd753d9cdbd (patch) | |
tree | 014dfa5171c1065be039ecfcf206d304f2ceb323 | |
parent | 5c6ea51834ee410586233d67d43bdb4f1729261f (diff) |
[tf.data] Lift parameterized test parameters into lambdas if they create TF ops.
The existing code triggers parts of the TensorFlow runtime that may not have been fully
initialized at the time the parameters are evaluated. Lifting into a lambda and invoking
the lambda inside the test method will achieve the proper order.
PiperOrigin-RevId: 216419757
-rw-r--r-- | tensorflow/python/data/util/structure_test.py | 61 |
1 files changed, 32 insertions, 29 deletions
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py index 2982763181..630a0c912b 100644 --- a/tensorflow/python/data/util/structure_test.py +++ b/tensorflow/python/data/util/structure_test.py @@ -34,52 +34,56 @@ from tensorflow.python.platform import test class StructureTest(test.TestCase, parameterized.TestCase): - # pylint disable=protected-access + # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they + # will be executed before the (eager- or graph-mode) test environment has been + # set up. + # pylint: disable=g-long-lambda,protected-access @parameterized.parameters( - (constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32], - [[]]), (sparse_tensor.SparseTensor( - indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), - structure.SparseTensorStructure, [dtypes.variant], [[3]]), - ((constant_op.constant(37.0), constant_op.constant([1, 2, 3])), - structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), ({ - "a": constant_op.constant(37.0), - "b": constant_op.constant([1, 2, 3]) - }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), - ({ - "a": - constant_op.constant(37.0), + (lambda: constant_op.constant(37.0), structure.TensorStructure, + [dtypes.float32], [[]]), + (lambda: sparse_tensor.SparseTensor( + indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), + structure.SparseTensorStructure, [dtypes.variant], [[3]]), + (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), + structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), + (lambda: { + "a": constant_op.constant(37.0), + "b": constant_op.constant([1, 2, 3]) + }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), + (lambda: { + "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, structure.NestedStructure, [dtypes.float32, dtypes.variant, dtypes.variant], [[], [3], [3]])) - def testFlatStructure(self, value, expected_structure, expected_types, + def testFlatStructure(self, value_fn, expected_structure, expected_types, expected_shapes): + value = value_fn() s = structure.Structure.from_value(value) self.assertIsInstance(s, expected_structure) self.assertEqual(expected_types, s._flat_types) self.assertEqual(expected_shapes, s._flat_shapes) @parameterized.parameters( - (constant_op.constant(37.0), [ + (lambda: constant_op.constant(37.0), lambda: [ constant_op.constant(38.0), array_ops.placeholder(dtypes.float32), variables.Variable(100.0), 42.0, np.array(42.0, dtype=np.float32) - ], [constant_op.constant([1.0, 2.0]), - constant_op.constant(37)]), - (sparse_tensor.SparseTensor( + ], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]), + (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), - [ + lambda: [ sparse_tensor.SparseTensor( indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), sparse_tensor.SparseTensorValue( indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), array_ops.sparse_placeholder(dtype=dtypes.int32), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None]) - ], [ + ], lambda: [ constant_op.constant(37, shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[5, 6]), @@ -88,13 +92,13 @@ class StructureTest(test.TestCase, parameterized.TestCase): sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5]) ]), - ({ + (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) - }, [{ + }, lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6]) - }], [{ + }], lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6, 7]) }, { @@ -108,8 +112,11 @@ class StructureTest(test.TestCase, parameterized.TestCase): indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), ) - def testIsCompatibleWithStructure(self, original_value, compatible_values, - incompatible_values): + def testIsCompatibleWithStructure( + self, original_value_fn, compatible_values_fn, incompatible_values_fn): + original_value = original_value_fn() + compatible_values = compatible_values_fn() + incompatible_values = incompatible_values_fn() s = structure.Structure.from_value(original_value) for compatible_value in compatible_values: self.assertTrue( @@ -120,10 +127,6 @@ class StructureTest(test.TestCase, parameterized.TestCase): s.is_compatible_with( structure.Structure.from_value(incompatible_value))) - # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they - # will be executed before the (eager- or graph-mode) test environment has been - # set up. - # pylint: disable=g-long-lambda @parameterized.parameters( (lambda: constant_op.constant(37.0),), (lambda: sparse_tensor.SparseTensor( |