aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/graph_io.py11
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py3
2 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
index eee1b88353..1709e428fc 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
@@ -192,6 +192,11 @@ def read_keyed_batch_examples(
enqueue_many = read_batch_size > 1
+ if num_epochs is not None:
+ allow_smaller_final_batch = True
+ else:
+ allow_smaller_final_batch = False
+
# Setup batching queue given list of read example tensors.
if randomize_input:
if isinstance(batch_size, ops.Tensor):
@@ -201,11 +206,13 @@ def read_keyed_batch_examples(
queued_examples_with_keys = input_ops.shuffle_batch_join(
example_list, batch_size, capacity=queue_capacity,
min_after_dequeue=min_after_dequeue,
- enqueue_many=enqueue_many, name=scope)
+ enqueue_many=enqueue_many, name=scope,
+ allow_smaller_final_batch=allow_smaller_final_batch)
else:
queued_examples_with_keys = input_ops.batch_join(
example_list, batch_size, capacity=queue_capacity,
- enqueue_many=enqueue_many, name=scope)
+ enqueue_many=enqueue_many, name=scope,
+ allow_smaller_final_batch=allow_smaller_final_batch)
if parse_fn and isinstance(queued_examples_with_keys, dict):
queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME)
return queued_keys, queued_examples_with_keys
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
index 334ef425ef..f11f0a841f 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
@@ -164,7 +164,7 @@ class GraphIOTest(tf.test.TestCase):
file_name_queue_name: "FIFOQueue",
"%s/read/TFRecordReader" % name: "TFRecordReader",
example_queue_name: "RandomShuffleQueue",
- name: "QueueDequeueMany",
+ name: "QueueDequeueUpTo",
file_name_queue_limit_name: "Variable"
}, g)
self.assertEqual(
@@ -249,6 +249,7 @@ class GraphIOTest(tf.test.TestCase):
tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
+ self.assertAllEqual(session.run(inputs), [b"D", b"E"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)