diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 16:23:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 16:23:50 -0700 |
commit | ab4ae7e0cf029896a8f679a998f23763a8c2103d (patch) | |
tree | 12c1de52844c1f7b3d1c13912fdcdfdf91874a0f /tensorflow/contrib/data | |
parent | 02163c55ae4e62495951e24c31e0a6ef96ab4e92 (diff) | |
parent | a5559a9d28bab6abfd65a9fad116ef9c6e13f8c2 (diff) |
Merge pull request #21702 from facaiy:ENH/assert_partial_shape
PiperOrigin-RevId: 210626817
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py | 127 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/batching.py | 21 |
2 files changed, 140 insertions, 8 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 42adfd17f0..9d8e955245 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -720,6 +720,42 @@ class RestructuredDatasetTest(test.TestCase): def test_assert_element_shape(self): + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(5).map(create_dataset) + expected_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 4))) + self.assertEqual(expected_shapes, dataset.output_shapes) + + result = dataset.apply(batching.assert_element_shape(expected_shapes)) + self.assertEqual(expected_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(3).map(create_dataset) + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + with self.assertRaises(ValueError): + dataset.apply(batching.assert_element_shape(wrong_shapes)) + + def test_assert_element_shape_on_unknown_shape_dataset(self): + def create_unknown_shape_dataset(x): return script_ops.py_func( lambda _: ( # pylint: disable=g-long-lambda @@ -748,7 +784,60 @@ class RestructuredDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def test_assert_wrong_element_shape(self): + def test_assert_wrong_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + iterator = ( + dataset.apply(batching.assert_element_shape(wrong_shapes)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + def test_assert_partial_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(5).map(create_dataset) + partial_expected_shape = (tensor_shape.TensorShape(None), # Unknown shape + tensor_shape.TensorShape((None, 4))) # Partial shape + result = dataset.apply( + batching.assert_element_shape(partial_expected_shape)) + # Partial shapes are merged with actual shapes: + actual_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 4))) + self.assertEqual(actual_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_partial_element_shape(self): def create_dataset(_): return (array_ops.ones(2, dtype=dtypes.float32), @@ -756,11 +845,41 @@ class RestructuredDatasetTest(test.TestCase): dataset = dataset_ops.Dataset.range(3).map(create_dataset) wrong_shapes = (tensor_shape.TensorShape(2), - tensor_shape.TensorShape((3, 10))) + tensor_shape.TensorShape((None, 10))) with self.assertRaises(ValueError): dataset.apply(batching.assert_element_shape(wrong_shapes)) - def test_assert_wrong_element_shape_on_unknown_shape_dataset(self): + def test_assert_partial_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + expected_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((None, 4))) + result = dataset.apply(batching.assert_element_shape(expected_shapes)) + self.assertEqual(expected_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self): def create_unknown_shape_dataset(x): return script_ops.py_func( @@ -776,7 +895,7 @@ class RestructuredDatasetTest(test.TestCase): self.assertEqual(unknown_shapes, dataset.output_shapes) wrong_shapes = (tensor_shape.TensorShape(2), - tensor_shape.TensorShape((3, 10))) + tensor_shape.TensorShape((None, 10))) iterator = ( dataset.apply(batching.assert_element_shape(wrong_shapes)) .make_initializable_iterator()) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 9f059942a6..9c2001c34f 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -647,15 +647,17 @@ def assert_element_shape(expected_shapes): """Assert the shape of this `Dataset`. ```python - shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)] + shapes = [tf.TensorShape([16, 256]), tf.TensorShape([None, 2])] result = dataset.apply(tf.contrib.data.assert_element_shape(shapes)) - print(result.output_shapes) # ==> "((16, 256), <unknown>)" + print(result.output_shapes) # ==> "((16, 256), (<unknown>, 2))" ``` If dataset shapes and expected_shape, are fully defined, assert they match. Otherwise, add assert op that will validate the shapes when tensors are evaluated, and set shapes on tensors, respectively. + Note that unknown dimension in `expected_shapes` will be ignored. + Args: expected_shapes: A nested structure of `tf.TensorShape` objects. @@ -664,20 +666,31 @@ def assert_element_shape(expected_shapes): `tf.data.Dataset.apply` """ + def _merge_output_shapes(original_shapes, expected_shapes): + flat_original_shapes = nest.flatten(original_shapes) + flat_new_shapes = nest.flatten_up_to(original_shapes, expected_shapes) + flat_merged_output_shapes = [ + original_shape.merge_with(new_shape) + for original_shape, new_shape in zip(flat_original_shapes, + flat_new_shapes)] + return nest.pack_sequence_as(original_shapes, flat_merged_output_shapes) + def _check_shape(*elements): flatten_tensors = nest.flatten(elements) flatten_shapes = nest.flatten(expected_shapes) checked_tensors = [ - with_shape(shape, tensor) + with_shape(shape, tensor) if shape else tensor # Ignore unknown shape for shape, tensor in zip(flatten_shapes, flatten_tensors) ] return nest.pack_sequence_as(elements, checked_tensors) def _apply_fn(dataset): + output_shapes = _merge_output_shapes(dataset.output_shapes, + expected_shapes) return _RestructuredDataset( dataset.map(_check_shape), dataset.output_types, - output_shapes=expected_shapes, + output_shapes=output_shapes, output_classes=dataset.output_classes) return _apply_fn |