aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/ops/batching.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/ops/batching.py')
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py89
1 files changed, 22 insertions, 67 deletions
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index d4ade7adfd..e6e5f716b6 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -24,7 +25,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import math_ops
@@ -103,42 +103,6 @@ def unbatch():
return _apply_fn
-def filter_irregular_batches(batch_size):
- """Transformation that filters out batches that are not of size batch_size."""
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- tensor_batch_size = ops.convert_to_tensor(
- batch_size, dtype=dtypes.int64, name="batch_size")
-
- flattened = _RestructuredDataset(dataset,
- tuple(nest.flatten(dataset.output_types)))
-
- def _predicate(*xs):
- """Return `True` if this element is a full batch."""
- # Extract the dynamic batch size from the first component of the flattened
- # batched element.
- first_component = xs[0]
- first_component_batch_size = array_ops.shape(
- first_component, out_type=dtypes.int64)[0]
-
- return math_ops.equal(first_component_batch_size, tensor_batch_size)
-
- filtered = flattened.filter(_predicate)
-
- maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
-
- def _set_first_dimension(shape):
- return shape.merge_with(
- tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
-
- known_shapes = nest.map_structure(_set_first_dimension,
- dataset.output_shapes)
- return _RestructuredDataset(filtered, dataset.output_types, known_shapes)
-
- return _apply_fn
-
-
def batch_and_drop_remainder(batch_size):
"""A batching transformation that omits the final small batch (if present).
@@ -171,43 +135,34 @@ def batch_and_drop_remainder(batch_size):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- batched = dataset.batch(batch_size)
- return filter_irregular_batches(batch_size)(batched)
-
- return _apply_fn
+ tensor_batch_size = ops.convert_to_tensor(
+ batch_size, dtype=dtypes.int64, name="batch_size")
+ batched = dataset.batch(tensor_batch_size)
+ flattened = _RestructuredDataset(batched,
+ tuple(nest.flatten(batched.output_types)))
-def padded_batch_and_drop_remainder(batch_size,
- padded_shapes,
- padding_values=None):
- """A batching and padding transformation that omits the final small batch.
+ def _predicate(*xs):
+ """Return `True` if this element is a full batch."""
+ # Extract the dynamic batch size from the first component of the flattened
+ # batched element.
+ first_component = xs[0]
+ first_component_batch_size = array_ops.shape(
+ first_component, out_type=dtypes.int64)[0]
- Like @{tf.data.Dataset.padded_batch}, this transformation combines
- consecutive elements of this dataset into batches. However, if the batch
- size does not evenly divide the input dataset size, this transformation will
- drop the final smaller element.
+ return math_ops.equal(first_component_batch_size, tensor_batch_size)
- See `@{tf.contrib.data.batch_and_drop_remainder}` for more details.
+ filtered = flattened.filter(_predicate)
- Args:
- batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
- consecutive elements of this dataset to combine in a single batch.
- padded_shapes: A nested structure of `tf.TensorShape` or
- `tf.int64` vector tensor-like objects. See
- @{tf.data.Dataset.padded_batch} for details.
- padding_values: (Optional.) A nested structure of scalar-shaped
- `tf.Tensor`. See @{tf.data.Dataset.padded_batch} for details.
+ maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
- Returns:
- A `Dataset` transformation function, which can be passed to
- @{tf.data.Dataset.apply}
- """
+ def _set_first_dimension(shape):
+ return shape.merge_with(
+ tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- batched = dataset.padded_batch(
- batch_size, padded_shapes=padded_shapes, padding_values=padding_values)
- return filter_irregular_batches(batch_size)(batched)
+ known_shapes = nest.map_structure(_set_first_dimension,
+ batched.output_shapes)
+ return _RestructuredDataset(filtered, batched.output_types, known_shapes)
return _apply_fn