diff options
Diffstat (limited to 'tensorflow/core/kernels/shuffle_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/shuffle_dataset_op.cc | 31 |
1 files changed, 14 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/shuffle_dataset_op.cc b/tensorflow/core/kernels/shuffle_dataset_op.cc index dd0ab57e9d..2146ba2aa1 100644 --- a/tensorflow/core/kernels/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/shuffle_dataset_op.cc @@ -105,7 +105,8 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { mutex_lock l(mu_); int64 start_micros = ctx->env()->NowMicros(); int64 num_log_entries = 0; - while (input_impl_ && buffer_.size() < dataset()->buffer_size_) { + while (!end_of_input_sequence_ && + buffer_.size() < dataset()->buffer_size_) { if (ctx->env()->NowMicros() > ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) { num_log_entries++; @@ -113,10 +114,9 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { << buffer_.size() << " of " << dataset()->buffer_size_; } std::vector<Tensor> input_element; - bool end_of_input_sequence; TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element, - &end_of_input_sequence)); - if (!end_of_input_sequence) { + &end_of_input_sequence_)); + if (!end_of_input_sequence_) { buffer_.emplace_back(std::move(input_element)); } else { input_impl_.reset(); @@ -135,7 +135,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { std::swap(buffer_[index], buffer_.back()); buffer_.pop_back(); } else { - DCHECK(input_impl_ == nullptr); + DCHECK(end_of_input_sequence_); *end_of_sequence = true; } return Status::OK(); @@ -148,11 +148,11 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { // Save the tensors in the buffer. TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("buffer_size"), buffer_.size())); - for (size_t i = 0; i < buffer_.size(); i++) { + for (int i = 0; i < buffer_.size(); i++) { TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat("buffer_", i, "_size")), buffer_[i].size())); - for (size_t j = 0; j < buffer_[i].size(); j++) { + for (int j = 0; j < buffer_[i].size(); j++) { TF_RETURN_IF_ERROR(writer->WriteTensor( full_name(strings::StrCat("buffer_", i, "_", j)), buffer_[i][j])); @@ -165,7 +165,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { // Save input iterator if it hasn't been exhausted else write // "end_of_input_sequence". - if (!input_impl_) { + if (end_of_input_sequence_) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("end_of_input_sequence"), "")); } else { @@ -180,15 +180,10 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { buffer_.clear(); // Restore the buffer. - size_t buffer_size; - { - int64 temp; - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("buffer_size"), &temp)); - buffer_size = static_cast<size_t>(temp); - } - buffer_.reserve(buffer_size); - for (size_t i = 0; i < buffer_size; i++) { + int64 buffer_size; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("buffer_size"), &buffer_size)); + for (int i = 0; i < buffer_size; i++) { int64 list_size; TF_RETURN_IF_ERROR(reader->ReadScalar( full_name(strings::StrCat("buffer_", i, "_size")), &list_size)); @@ -210,6 +205,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { input_impl_ = dataset()->input_->MakeIterator(prefix()); TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); } else { + end_of_input_sequence_ = true; input_impl_.reset(); } return Status::OK(); @@ -234,6 +230,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel { mutex mu_; std::vector<std::vector<Tensor>> buffer_ GUARDED_BY(mu_); std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + bool end_of_input_sequence_ GUARDED_BY(mu_) = false; const int64 seed_ GUARDED_BY(mu_); const int64 seed2_ GUARDED_BY(mu_); random::PhiloxRandom parent_generator_ GUARDED_BY(mu_); |