aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Jing Jun Yin <jingjunyin@gmail.com>2017-06-13 19:58:21 -0400
committerGravatar Martin Wicke <martin.wicke@gmail.com>2017-06-13 16:58:21 -0700
commit70ade1b64f65d0a2275672d27129627ff116a997 (patch)
treeef3b0a68a6dd24dae6fdddb592e08ae435f52d8b /tensorflow
parent847484e39485dc727dd72a0970d5bfb5c2d5e538 (diff)
Fix defect: shuffle_batch gives ZeroDivisionError when computing capacity stat (#10477)
* Fix defect: shuffle_batch gives ZeroDivisionError when computing capacity stat * Cover < case in error checking
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/training/input.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 1755167938..21183823c2 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -762,6 +762,9 @@ def _shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
tensor_list = _as_tensor_list(tensors)
with ops.name_scope(name, "shuffle_batch",
list(tensor_list) + [keep_input]) as name:
+ if capacity <= min_after_dequeue:
+ raise ValueError("capacity %d must be bigger than min_after_dequeue %d."
+ % (capacity, min_after_dequeue))
tensor_list = _validate(tensor_list)
keep_input = _validate_keep_input(keep_input, enqueue_many)
tensor_list, sparse_info = _store_sparse_tensors(