diff options
author | Jing Jun Yin <jingjunyin@gmail.com> | 2017-06-13 19:58:21 -0400 |
---|---|---|
committer | Martin Wicke <martin.wicke@gmail.com> | 2017-06-13 16:58:21 -0700 |
commit | 70ade1b64f65d0a2275672d27129627ff116a997 (patch) | |
tree | ef3b0a68a6dd24dae6fdddb592e08ae435f52d8b /tensorflow | |
parent | 847484e39485dc727dd72a0970d5bfb5c2d5e538 (diff) |
Fix defect: shuffle_batch gives ZeroDivisionError when computing capacity stat (#10477)
* Fix defect: shuffle_batch gives ZeroDivisionError when computing capacity stat
* Cover < case in error checking
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/training/input.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 1755167938..21183823c2 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -762,6 +762,9 @@ def _shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, tensor_list = _as_tensor_list(tensors) with ops.name_scope(name, "shuffle_batch", list(tensor_list) + [keep_input]) as name: + if capacity <= min_after_dequeue: + raise ValueError("capacity %d must be bigger than min_after_dequeue %d." + % (capacity, min_after_dequeue)) tensor_list = _validate(tensor_list) keep_input = _validate_keep_input(keep_input, enqueue_many) tensor_list, sparse_info = _store_sparse_tensors( |