diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-07-25 12:57:45 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-25 14:04:20 -0700 |
commit | ed281973d66d0030e58a77a05821bbb88627f5bd (patch) | |
tree | 8faafc997d6f6c441d3ba78b439bd8cef30c0004 | |
parent | 88ffd73ed90dbb02876925b63c3324c82bdb985d (diff) |
Set allow_smaller_final_batch to true when num_epochs is set so the last items
in the data can be read without missing.
Change: 128396107
-rw-r--r-- | tensorflow/contrib/learn/python/learn/learn_io/graph_io.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py | 3 |
2 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index eee1b88353..1709e428fc 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -192,6 +192,11 @@ def read_keyed_batch_examples( enqueue_many = read_batch_size > 1 + if num_epochs is not None: + allow_smaller_final_batch = True + else: + allow_smaller_final_batch = False + # Setup batching queue given list of read example tensors. if randomize_input: if isinstance(batch_size, ops.Tensor): @@ -201,11 +206,13 @@ def read_keyed_batch_examples( queued_examples_with_keys = input_ops.shuffle_batch_join( example_list, batch_size, capacity=queue_capacity, min_after_dequeue=min_after_dequeue, - enqueue_many=enqueue_many, name=scope) + enqueue_many=enqueue_many, name=scope, + allow_smaller_final_batch=allow_smaller_final_batch) else: queued_examples_with_keys = input_ops.batch_join( example_list, batch_size, capacity=queue_capacity, - enqueue_many=enqueue_many, name=scope) + enqueue_many=enqueue_many, name=scope, + allow_smaller_final_batch=allow_smaller_final_batch) if parse_fn and isinstance(queued_examples_with_keys, dict): queued_keys = queued_examples_with_keys.pop(KEY_FEATURE_NAME) return queued_keys, queued_examples_with_keys diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index 334ef425ef..f11f0a841f 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -164,7 +164,7 @@ class GraphIOTest(tf.test.TestCase): file_name_queue_name: "FIFOQueue", "%s/read/TFRecordReader" % name: "TFRecordReader", example_queue_name: "RandomShuffleQueue", - name: "QueueDequeueMany", + name: "QueueDequeueUpTo", file_name_queue_limit_name: "Variable" }, g) self.assertEqual( @@ -249,6 +249,7 @@ class GraphIOTest(tf.test.TestCase): tf.train.start_queue_runners(session, coord=coord) self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"]) + self.assertAllEqual(session.run(inputs), [b"D", b"E"]) with self.assertRaises(errors.OutOfRangeError): session.run(inputs) |