diff options
Diffstat (limited to 'tensorflow/contrib/data/python/ops/batching.py')
-rw-r--r-- | tensorflow/contrib/data/python/ops/batching.py | 89 |
1 files changed, 22 insertions, 67 deletions
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index d4ade7adfd..e6e5f716b6 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes @@ -24,7 +25,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops @@ -103,42 +103,6 @@ def unbatch(): return _apply_fn -def filter_irregular_batches(batch_size): - """Transformation that filters out batches that are not of size batch_size.""" - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - tensor_batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") - - flattened = _RestructuredDataset(dataset, - tuple(nest.flatten(dataset.output_types))) - - def _predicate(*xs): - """Return `True` if this element is a full batch.""" - # Extract the dynamic batch size from the first component of the flattened - # batched element. - first_component = xs[0] - first_component_batch_size = array_ops.shape( - first_component, out_type=dtypes.int64)[0] - - return math_ops.equal(first_component_batch_size, tensor_batch_size) - - filtered = flattened.filter(_predicate) - - maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) - - def _set_first_dimension(shape): - return shape.merge_with( - tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) - - known_shapes = nest.map_structure(_set_first_dimension, - dataset.output_shapes) - return _RestructuredDataset(filtered, dataset.output_types, known_shapes) - - return _apply_fn - - def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). @@ -171,43 +135,34 @@ def batch_and_drop_remainder(batch_size): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - batched = dataset.batch(batch_size) - return filter_irregular_batches(batch_size)(batched) - - return _apply_fn + tensor_batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + batched = dataset.batch(tensor_batch_size) + flattened = _RestructuredDataset(batched, + tuple(nest.flatten(batched.output_types))) -def padded_batch_and_drop_remainder(batch_size, - padded_shapes, - padding_values=None): - """A batching and padding transformation that omits the final small batch. + def _predicate(*xs): + """Return `True` if this element is a full batch.""" + # Extract the dynamic batch size from the first component of the flattened + # batched element. + first_component = xs[0] + first_component_batch_size = array_ops.shape( + first_component, out_type=dtypes.int64)[0] - Like @{tf.data.Dataset.padded_batch}, this transformation combines - consecutive elements of this dataset into batches. However, if the batch - size does not evenly divide the input dataset size, this transformation will - drop the final smaller element. + return math_ops.equal(first_component_batch_size, tensor_batch_size) - See `@{tf.contrib.data.batch_and_drop_remainder}` for more details. + filtered = flattened.filter(_predicate) - Args: - batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements of this dataset to combine in a single batch. - padded_shapes: A nested structure of `tf.TensorShape` or - `tf.int64` vector tensor-like objects. See - @{tf.data.Dataset.padded_batch} for details. - padding_values: (Optional.) A nested structure of scalar-shaped - `tf.Tensor`. See @{tf.data.Dataset.padded_batch} for details. + maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) - Returns: - A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply} - """ + def _set_first_dimension(shape): + return shape.merge_with( + tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - batched = dataset.padded_batch( - batch_size, padded_shapes=padded_shapes, padding_values=padding_values) - return filter_irregular_batches(batch_size)(batched) + known_shapes = nest.map_structure(_set_first_dimension, + batched.output_shapes) + return _RestructuredDataset(filtered, batched.output_types, known_shapes) return _apply_fn |