aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/shuffle_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/shuffle_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/shuffle_dataset_op.cc31
1 files changed, 14 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/shuffle_dataset_op.cc b/tensorflow/core/kernels/shuffle_dataset_op.cc
index dd0ab57e9d..2146ba2aa1 100644
--- a/tensorflow/core/kernels/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/shuffle_dataset_op.cc
@@ -105,7 +105,8 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
int64 start_micros = ctx->env()->NowMicros();
int64 num_log_entries = 0;
- while (input_impl_ && buffer_.size() < dataset()->buffer_size_) {
+ while (!end_of_input_sequence_ &&
+ buffer_.size() < dataset()->buffer_size_) {
if (ctx->env()->NowMicros() >
((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
num_log_entries++;
@@ -113,10 +114,9 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
<< buffer_.size() << " of " << dataset()->buffer_size_;
}
std::vector<Tensor> input_element;
- bool end_of_input_sequence;
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
- &end_of_input_sequence));
- if (!end_of_input_sequence) {
+ &end_of_input_sequence_));
+ if (!end_of_input_sequence_) {
buffer_.emplace_back(std::move(input_element));
} else {
input_impl_.reset();
@@ -135,7 +135,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
std::swap(buffer_[index], buffer_.back());
buffer_.pop_back();
} else {
- DCHECK(input_impl_ == nullptr);
+ DCHECK(end_of_input_sequence_);
*end_of_sequence = true;
}
return Status::OK();
@@ -148,11 +148,11 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
// Save the tensors in the buffer.
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("buffer_size"), buffer_.size()));
- for (size_t i = 0; i < buffer_.size(); i++) {
+ for (int i = 0; i < buffer_.size(); i++) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("buffer_", i, "_size")),
buffer_[i].size()));
- for (size_t j = 0; j < buffer_[i].size(); j++) {
+ for (int j = 0; j < buffer_[i].size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
full_name(strings::StrCat("buffer_", i, "_", j)),
buffer_[i][j]));
@@ -165,7 +165,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
// Save input iterator if it hasn't been exhausted else write
// "end_of_input_sequence".
- if (!input_impl_) {
+ if (end_of_input_sequence_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("end_of_input_sequence"), ""));
} else {
@@ -180,15 +180,10 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
buffer_.clear();
// Restore the buffer.
- size_t buffer_size;
- {
- int64 temp;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("buffer_size"), &temp));
- buffer_size = static_cast<size_t>(temp);
- }
- buffer_.reserve(buffer_size);
- for (size_t i = 0; i < buffer_size; i++) {
+ int64 buffer_size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("buffer_size"), &buffer_size));
+ for (int i = 0; i < buffer_size; i++) {
int64 list_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat("buffer_", i, "_size")), &list_size));
@@ -210,6 +205,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
input_impl_ = dataset()->input_->MakeIterator(prefix());
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
+ end_of_input_sequence_ = true;
input_impl_.reset();
}
return Status::OK();
@@ -234,6 +230,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
mutex mu_;
std::vector<std::vector<Tensor>> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ bool end_of_input_sequence_ GUARDED_BY(mu_) = false;
const int64 seed_ GUARDED_BY(mu_);
const int64 seed2_ GUARDED_BY(mu_);
random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);