aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/data/util/structure_test.py61
1 files changed, 32 insertions, 29 deletions
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py
index 2982763181..630a0c912b 100644
--- a/tensorflow/python/data/util/structure_test.py
+++ b/tensorflow/python/data/util/structure_test.py
@@ -34,52 +34,56 @@ from tensorflow.python.platform import test
class StructureTest(test.TestCase, parameterized.TestCase):
- # pylint disable=protected-access
+ # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
+ # will be executed before the (eager- or graph-mode) test environment has been
+ # set up.
+ # pylint: disable=g-long-lambda,protected-access
@parameterized.parameters(
- (constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32],
- [[]]), (sparse_tensor.SparseTensor(
- indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
- structure.SparseTensorStructure, [dtypes.variant], [[3]]),
- ((constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
- structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), ({
- "a": constant_op.constant(37.0),
- "b": constant_op.constant([1, 2, 3])
- }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
- ({
- "a":
- constant_op.constant(37.0),
+ (lambda: constant_op.constant(37.0), structure.TensorStructure,
+ [dtypes.float32], [[]]),
+ (lambda: sparse_tensor.SparseTensor(
+ indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
+ structure.SparseTensorStructure, [dtypes.variant], [[3]]),
+ (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
+ structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
+ (lambda: {
+ "a": constant_op.constant(37.0),
+ "b": constant_op.constant([1, 2, 3])
+ }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
+ (lambda: {
+ "a": constant_op.constant(37.0),
"b": (sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
}, structure.NestedStructure,
[dtypes.float32, dtypes.variant, dtypes.variant], [[], [3], [3]]))
- def testFlatStructure(self, value, expected_structure, expected_types,
+ def testFlatStructure(self, value_fn, expected_structure, expected_types,
expected_shapes):
+ value = value_fn()
s = structure.Structure.from_value(value)
self.assertIsInstance(s, expected_structure)
self.assertEqual(expected_types, s._flat_types)
self.assertEqual(expected_shapes, s._flat_shapes)
@parameterized.parameters(
- (constant_op.constant(37.0), [
+ (lambda: constant_op.constant(37.0), lambda: [
constant_op.constant(38.0),
array_ops.placeholder(dtypes.float32),
variables.Variable(100.0), 42.0,
np.array(42.0, dtype=np.float32)
- ], [constant_op.constant([1.0, 2.0]),
- constant_op.constant(37)]),
- (sparse_tensor.SparseTensor(
+ ], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]),
+ (lambda: sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
- [
+ lambda: [
sparse_tensor.SparseTensor(
indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
sparse_tensor.SparseTensorValue(
indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
array_ops.sparse_placeholder(dtype=dtypes.int32),
array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None])
- ], [
+ ], lambda: [
constant_op.constant(37, shape=[4, 5]),
sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
@@ -88,13 +92,13 @@ class StructureTest(test.TestCase, parameterized.TestCase):
sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
]),
- ({
+ (lambda: {
"a": constant_op.constant(37.0),
"b": constant_op.constant([1, 2, 3])
- }, [{
+ }, lambda: [{
"a": constant_op.constant(15.0),
"b": constant_op.constant([4, 5, 6])
- }], [{
+ }], lambda: [{
"a": constant_op.constant(15.0),
"b": constant_op.constant([4, 5, 6, 7])
}, {
@@ -108,8 +112,11 @@ class StructureTest(test.TestCase, parameterized.TestCase):
indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
}, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
)
- def testIsCompatibleWithStructure(self, original_value, compatible_values,
- incompatible_values):
+ def testIsCompatibleWithStructure(
+ self, original_value_fn, compatible_values_fn, incompatible_values_fn):
+ original_value = original_value_fn()
+ compatible_values = compatible_values_fn()
+ incompatible_values = incompatible_values_fn()
s = structure.Structure.from_value(original_value)
for compatible_value in compatible_values:
self.assertTrue(
@@ -120,10 +127,6 @@ class StructureTest(test.TestCase, parameterized.TestCase):
s.is_compatible_with(
structure.Structure.from_value(incompatible_value)))
- # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
- # will be executed before the (eager- or graph-mode) test environment has been
- # set up.
- # pylint: disable=g-long-lambda
@parameterized.parameters(
(lambda: constant_op.constant(37.0),),
(lambda: sparse_tensor.SparseTensor(