aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/ops/grouping.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/ops/grouping.py')
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index ca9540bf13..5d9640a768 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -149,9 +149,9 @@ def bucket_by_sequence_length(element_length_func,
@{tf.data.Dataset.padded_batch}. Defaults to padding with 0.
pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
size to maximum length in batch. If `True`, will pad dimensions with
- unknown size to bucket boundary, and caller must ensure that the source
- `Dataset` does not contain any elements with length longer than
- `max(bucket_boundaries)`.
+ unknown size to bucket boundary minus 1 (i.e., the maximum length in each
+ bucket), and caller must ensure that the source `Dataset` does not contain
+ any elements with length longer than `max(bucket_boundaries)`.
Returns:
A `Dataset` transformation function, which can be passed to
@@ -203,7 +203,7 @@ def bucket_by_sequence_length(element_length_func,
none_filler = None
if pad_to_bucket_boundary:
err_msg = ("When pad_to_bucket_boundary=True, elements must have "
- "length <= max(bucket_boundaries).")
+ "length < max(bucket_boundaries).")
check = check_ops.assert_less(
bucket_id,
constant_op.constant(len(bucket_batch_sizes) - 1,
@@ -213,7 +213,7 @@ def bucket_by_sequence_length(element_length_func,
boundaries = constant_op.constant(bucket_boundaries,
dtype=dtypes.int64)
bucket_boundary = boundaries[bucket_id]
- none_filler = bucket_boundary
+ none_filler = bucket_boundary - 1
shapes = make_padded_shapes(
padded_shapes or grouped_dataset.output_shapes,
none_filler=none_filler)