aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-10-09 14:11:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 14:20:07 -0700
commit4fa59ef694c19dc63d574b2d6a349cd753d9cdbd (patch)
tree014dfa5171c1065be039ecfcf206d304f2ceb323
parent5c6ea51834ee410586233d67d43bdb4f1729261f (diff)
[tf.data] Lift parameterized test parameters into lambdas if they create TF ops.
The existing code triggers parts of the TensorFlow runtime that may not have been fully initialized at the time the parameters are evaluated. Lifting into a lambda and invoking the lambda inside the test method will achieve the proper order. PiperOrigin-RevId: 216419757
-rw-r--r--tensorflow/python/data/util/structure_test.py61
1 files changed, 32 insertions, 29 deletions
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py
index 2982763181..630a0c912b 100644
--- a/tensorflow/python/data/util/structure_test.py
+++ b/tensorflow/python/data/util/structure_test.py
@@ -34,52 +34,56 @@ from tensorflow.python.platform import test
class StructureTest(test.TestCase, parameterized.TestCase):
- # pylint disable=protected-access
+ # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
+ # will be executed before the (eager- or graph-mode) test environment has been
+ # set up.
+ # pylint: disable=g-long-lambda,protected-access
@parameterized.parameters(
- (constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32],
- [[]]), (sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
- structure.SparseTensorStructure, [dtypes.variant], [[3]]),
- ((constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
- structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), ({
- "a": constant_op.constant(37.0),
- "b": constant_op.constant([1, 2, 3])
- }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
- ({
- "a":
- constant_op.constant(37.0),
+ (lambda: constant_op.constant(37.0), structure.TensorStructure,
+ [dtypes.float32], [[]]),
+ (lambda: sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
+ structure.SparseTensorStructure, [dtypes.variant], [[3]]),
+ (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
+ structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
+ (lambda: {
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
+ (lambda: {
+ "a": constant_op.constant(37.0),
"b": (sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
}, structure.NestedStructure,
[dtypes.float32, dtypes.variant, dtypes.variant], [[], [3], [3]]))
- def testFlatStructure(self, value, expected_structure, expected_types,
+ def testFlatStructure(self, value_fn, expected_structure, expected_types,
expected_shapes):
+ value = value_fn()
s = structure.Structure.from_value(value)
self.assertIsInstance(s, expected_structure)
self.assertEqual(expected_types, s._flat_types)
self.assertEqual(expected_shapes, s._flat_shapes)
@parameterized.parameters(
- (constant_op.constant(37.0), [
+ (lambda: constant_op.constant(37.0), lambda: [
constant_op.constant(38.0),
array_ops.placeholder(dtypes.float32),
variables.Variable(100.0), 42.0,
np.array(42.0, dtype=np.float32)
- ], [constant_op.constant([1.0, 2.0]),
- constant_op.constant(37)]),
- (sparse_tensor.SparseTensor(
+ ], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]),
+ (lambda: sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
- [
+ lambda: [
sparse_tensor.SparseTensor(
indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
sparse_tensor.SparseTensorValue(
indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
array_ops.sparse_placeholder(dtype=dtypes.int32),
array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None])
- ], [
+ ], lambda: [
constant_op.constant(37, shape=[4, 5]),
sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
@@ -88,13 +92,13 @@ class StructureTest(test.TestCase, parameterized.TestCase):
sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
]),
- ({
+ (lambda: {
"a": constant_op.constant(37.0),
"b": constant_op.constant([1, 2, 3])
- }, [{
+ }, lambda: [{
"a": constant_op.constant(15.0),
"b": constant_op.constant([4, 5, 6])
- }], [{
+ }], lambda: [{
"a": constant_op.constant(15.0),
"b": constant_op.constant([4, 5, 6, 7])
}, {
@@ -108,8 +112,11 @@ class StructureTest(test.TestCase, parameterized.TestCase):
indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
}, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
)
- def testIsCompatibleWithStructure(self, original_value, compatible_values,
- incompatible_values):
+ def testIsCompatibleWithStructure(
+ self, original_value_fn, compatible_values_fn, incompatible_values_fn):
+ original_value = original_value_fn()
+ compatible_values = compatible_values_fn()
+ incompatible_values = incompatible_values_fn()
s = structure.Structure.from_value(original_value)
for compatible_value in compatible_values:
self.assertTrue(
@@ -120,10 +127,6 @@ class StructureTest(test.TestCase, parameterized.TestCase):
s.is_compatible_with(
structure.Structure.from_value(incompatible_value)))
- # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
- # will be executed before the (eager- or graph-mode) test environment has been
- # set up.
- # pylint: disable=g-long-lambda
@parameterized.parameters(
(lambda: constant_op.constant(37.0),),
(lambda: sparse_tensor.SparseTensor(