diff options
Diffstat (limited to 'tensorflow/contrib/data/python/ops/batching.py')
-rw-r--r-- | tensorflow/contrib/data/python/ops/batching.py | 27 |
1 files changed, 20 insertions, 7 deletions
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 9f059942a6..367c159dc5 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -272,9 +272,9 @@ def _padded_batch_dense_window(dataset, padded_shape, padding_value=None): padding_value = 0 def batch_init_fn(_): - return array_ops.fill( - array_ops.concat([np.array([0], dtype=np.int32), padded_shape], 0), - constant_op.constant(padding_value, dtype=dataset.output_types)) + batch_shape = array_ops.concat( + [np.array([0], dtype=np.int32), padded_shape], 0) + return gen_array_ops.empty(batch_shape, dtype=dataset.output_types) def batch_reduce_fn(state, value): return array_ops.concat([state, [value]], 0) @@ -647,15 +647,17 @@ def assert_element_shape(expected_shapes): """Assert the shape of this `Dataset`. ```python - shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)] + shapes = [tf.TensorShape([16, 256]), tf.TensorShape([None, 2])] result = dataset.apply(tf.contrib.data.assert_element_shape(shapes)) - print(result.output_shapes) # ==> "((16, 256), <unknown>)" + print(result.output_shapes) # ==> "((16, 256), (<unknown>, 2))" ``` If dataset shapes and expected_shape, are fully defined, assert they match. Otherwise, add assert op that will validate the shapes when tensors are evaluated, and set shapes on tensors, respectively. + Note that unknown dimension in `expected_shapes` will be ignored. + Args: expected_shapes: A nested structure of `tf.TensorShape` objects. @@ -664,20 +666,31 @@ def assert_element_shape(expected_shapes): `tf.data.Dataset.apply` """ + def _merge_output_shapes(original_shapes, expected_shapes): + flat_original_shapes = nest.flatten(original_shapes) + flat_new_shapes = nest.flatten_up_to(original_shapes, expected_shapes) + flat_merged_output_shapes = [ + original_shape.merge_with(new_shape) + for original_shape, new_shape in zip(flat_original_shapes, + flat_new_shapes)] + return nest.pack_sequence_as(original_shapes, flat_merged_output_shapes) + def _check_shape(*elements): flatten_tensors = nest.flatten(elements) flatten_shapes = nest.flatten(expected_shapes) checked_tensors = [ - with_shape(shape, tensor) + with_shape(shape, tensor) if shape else tensor # Ignore unknown shape for shape, tensor in zip(flatten_shapes, flatten_tensors) ] return nest.pack_sequence_as(elements, checked_tensors) def _apply_fn(dataset): + output_shapes = _merge_output_shapes(dataset.output_shapes, + expected_shapes) return _RestructuredDataset( dataset.map(_check_shape), dataset.output_types, - output_shapes=expected_shapes, + output_shapes=output_shapes, output_classes=dataset.output_classes) return _apply_fn |