aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-11 12:57:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-11 13:02:14 -0700
commit1de8cce6643bb75edd1441fa01ffb6a0e5258c5f (patch)
treeed6be23a96aac6b182595889d23f944ffbdc441b
parent75b936e4c467af836623c7c72ff84fb0d458e5e6 (diff)
fix #11372, #11396
PiperOrigin-RevId: 161569039
-rw-r--r--tensorflow/core/kernels/record_yielder.cc22
-rw-r--r--tensorflow/core/kernels/record_yielder.h1
-rw-r--r--tensorflow/python/kernel_tests/record_input_test.py38
3 files changed, 57 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/record_yielder.cc b/tensorflow/core/kernels/record_yielder.cc
index 3af555da1a..1370295144 100644
--- a/tensorflow/core/kernels/record_yielder.cc
+++ b/tensorflow/core/kernels/record_yielder.cc
@@ -46,7 +46,7 @@ RecordYielder::~RecordYielder() {
Status RecordYielder::YieldOne(string* value) {
mutex_lock l(mu_);
- while (!BufEnough()) {
+ while (!BufEnough() && status_.ok()) {
buf_enough_.wait(l);
}
if (status_.ok()) {
@@ -98,17 +98,22 @@ void RecordYielder::MainLoop() {
while (true) {
++epoch_;
num_records_yielded_in_epoch_ = 0;
+ num_records_added_in_epoch_ = 0;
// Finds all files.
std::vector<string> filenames;
Status s = MatchFiles(opts_.file_pattern, &filenames);
- if (ShouldFinish(s)) break;
if (filenames.empty()) {
s = errors::NotFound("Found no files at ", opts_.file_pattern);
- if (ShouldFinish(s)) break;
+ if (ShouldFinish(s)) {
+ buf_enough_.notify_all();
+ break;
+ }
}
+ if (ShouldFinish(s)) break;
+
// Shuffles these files according to the epoch # and random seed.
std::mt19937_64 shuffle_rnd(
Hash64(reinterpret_cast<char*>(&epoch_), sizeof(epoch_), opts_.seed));
@@ -139,7 +144,15 @@ void RecordYielder::MainLoop() {
shards[i].done.WaitForNotification();
s.Update(shards[i].status);
}
- if (ShouldFinish(s)) break;
+
+ if (num_records_added_in_epoch_ < opts_.bufsize) {
+ opts_.bufsize = num_records_added_in_epoch_;
+ }
+
+ if (ShouldFinish(s)) {
+ buf_enough_.notify_all();
+ break;
+ }
// Starts the next epoch once all buffered records are consumed.
{
@@ -173,6 +186,7 @@ bool RecordYielder::Add(std::vector<string>* values) {
buf_[index] = std::move(values->back());
}
values->pop_back();
+ num_records_added_in_epoch_++;
}
if (BufEnough()) {
buf_enough_.notify_all();
diff --git a/tensorflow/core/kernels/record_yielder.h b/tensorflow/core/kernels/record_yielder.h
index 44f7c9511f..d86cb75c15 100644
--- a/tensorflow/core/kernels/record_yielder.h
+++ b/tensorflow/core/kernels/record_yielder.h
@@ -119,6 +119,7 @@ class RecordYielder {
// True iff we are draining an epoch.
bool epoch_end_ = false;
+ int64 num_records_added_in_epoch_ = 0;
int64 num_records_yielded_in_epoch_ = 0;
// Trigger when the main loop has exited.
diff --git a/tensorflow/python/kernel_tests/record_input_test.py b/tensorflow/python/kernel_tests/record_input_test.py
index 8fec2affa5..1ec48ac361 100644
--- a/tensorflow/python/kernel_tests/record_input_test.py
+++ b/tensorflow/python/kernel_tests/record_input_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
import os
+from tensorflow.python.framework.errors_impl import NotFoundError
from tensorflow.python.lib.io import tf_record
from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -96,6 +98,42 @@ class RecordInputOpTest(test.TestCase):
for _ in range(50):
sess.run(yield_op)
+ def testEmptyGlob(self):
+ with self.test_session() as sess:
+ record_input = data_flow_ops.RecordInput(file_pattern="foo")
+ yield_op = record_input.get_yield_op()
+ sess.run(variables.global_variables_initializer())
+ with self.assertRaises(NotFoundError):
+ sess.run(yield_op)
+
+ def testBufferTooSmall(self):
+ files = 10
+ records_per_file = 10
+ batches = 2
+ with self.test_session() as sess:
+ self.generateTestData("basic", files, records_per_file)
+
+ records = data_flow_ops.RecordInput(
+ file_pattern=os.path.join(self.get_temp_dir(), "basic.*"),
+ parallelism=2,
+ buffer_size=2000,
+ batch_size=1,
+ shift_ratio=0.33,
+ seed=10,
+ name="record_input",
+ batches=batches)
+
+ yield_op = records.get_yield_op()
+
+ # cycle over 3 epochs and make sure we never duplicate
+ for _ in range(3):
+ epoch_set = set()
+ for _ in range(int(files * records_per_file / batches)):
+ op_list = sess.run(yield_op)
+ self.assertTrue(len(op_list) is batches)
+ for r in op_list:
+ self.assertTrue(r[0] not in epoch_set)
+ epoch_set.add(r[0])
if __name__ == "__main__":
test.main()