diff options
-rw-r--r-- | tensorflow/python/data/util/structure_test.py | 61 |
1 files changed, 32 insertions, 29 deletions
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py index 2982763181..630a0c912b 100644 --- a/tensorflow/python/data/util/structure_test.py +++ b/tensorflow/python/data/util/structure_test.py @@ -34,52 +34,56 @@ from tensorflow.python.platform import test class StructureTest(test.TestCase, parameterized.TestCase): - # pylint disable=protected-access + # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they + # will be executed before the (eager- or graph-mode) test environment has been + # set up. + # pylint: disable=g-long-lambda,protected-access @parameterized.parameters( - (constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32], - [[]]), (sparse_tensor.SparseTensor( - indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), - structure.SparseTensorStructure, [dtypes.variant], [[3]]), - ((constant_op.constant(37.0), constant_op.constant([1, 2, 3])), - structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), ({ - "a": constant_op.constant(37.0), - "b": constant_op.constant([1, 2, 3]) - }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), - ({ - "a": - constant_op.constant(37.0), + (lambda: constant_op.constant(37.0), structure.TensorStructure, + [dtypes.float32], [[]]), + (lambda: sparse_tensor.SparseTensor( + indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), + structure.SparseTensorStructure, [dtypes.variant], [[3]]), + (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), + structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), + (lambda: { + "a": constant_op.constant(37.0), + "b": constant_op.constant([1, 2, 3]) + }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), + (lambda: { + "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, structure.NestedStructure, [dtypes.float32, dtypes.variant, dtypes.variant], [[], [3], [3]])) - def testFlatStructure(self, value, expected_structure, expected_types, + def testFlatStructure(self, value_fn, expected_structure, expected_types, expected_shapes): + value = value_fn() s = structure.Structure.from_value(value) self.assertIsInstance(s, expected_structure) self.assertEqual(expected_types, s._flat_types) self.assertEqual(expected_shapes, s._flat_shapes) @parameterized.parameters( - (constant_op.constant(37.0), [ + (lambda: constant_op.constant(37.0), lambda: [ constant_op.constant(38.0), array_ops.placeholder(dtypes.float32), variables.Variable(100.0), 42.0, np.array(42.0, dtype=np.float32) - ], [constant_op.constant([1.0, 2.0]), - constant_op.constant(37)]), - (sparse_tensor.SparseTensor( + ], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]), + (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), - [ + lambda: [ sparse_tensor.SparseTensor( indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), sparse_tensor.SparseTensorValue( indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), array_ops.sparse_placeholder(dtype=dtypes.int32), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None]) - ], [ + ], lambda: [ constant_op.constant(37, shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[5, 6]), @@ -88,13 +92,13 @@ class StructureTest(test.TestCase, parameterized.TestCase): sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5]) ]), - ({ + (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) - }, [{ + }, lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6]) - }], [{ + }], lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6, 7]) }, { @@ -108,8 +112,11 @@ class StructureTest(test.TestCase, parameterized.TestCase): indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), ) - def testIsCompatibleWithStructure(self, original_value, compatible_values, - incompatible_values): + def testIsCompatibleWithStructure( + self, original_value_fn, compatible_values_fn, incompatible_values_fn): + original_value = original_value_fn() + compatible_values = compatible_values_fn() + incompatible_values = incompatible_values_fn() s = structure.Structure.from_value(original_value) for compatible_value in compatible_values: self.assertTrue( @@ -120,10 +127,6 @@ class StructureTest(test.TestCase, parameterized.TestCase): s.is_compatible_with( structure.Structure.from_value(incompatible_value))) - # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they - # will be executed before the (eager- or graph-mode) test environment has been - # set up. - # pylint: disable=g-long-lambda @parameterized.parameters( (lambda: constant_op.constant(37.0),), (lambda: sparse_tensor.SparseTensor( |