diff options
Diffstat (limited to 'tensorflow/core/kernels/repeat_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/repeat_dataset_op.cc | 50 |
1 files changed, 3 insertions, 47 deletions
diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc index 0167b9ea64..9813e99a70 100644 --- a/tensorflow/core/kernels/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/repeat_dataset_op.cc @@ -95,15 +95,6 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { *end_of_sequence = true; return Status::OK(); } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - return Status::OK(); - } - Status RestoreInternal(OpKernelContext* ctx, - IteratorStateReader* reader) override { - return Status::OK(); - } }; class FiniteIterator : public DatasetIterator<Dataset> { @@ -117,10 +108,6 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. - if (!input_impl_) { - *end_of_sequence = true; - return Status::OK(); - } while (i_ < dataset()->count_) { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); @@ -131,6 +118,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { input_impl_ = dataset()->input_->MakeIterator(prefix()); } *end_of_sequence = true; + is_exhausted_ = true; input_impl_.reset(); return Status::OK(); } @@ -139,12 +127,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); - if (!input_impl_) { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("input_impl_empty"), "")); - } else { - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); - } + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); return Status::OK(); } @@ -152,11 +135,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); - if (!reader->Contains(full_name("input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); - } else { - input_impl_.reset(); - } + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); return Status::OK(); } @@ -204,29 +183,6 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { } while (true); } - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - if (input_impl_) - TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); - else - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("uninitialized"), "")); - return Status::OK(); - } - - Status RestoreInternal(OpKernelContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - if (reader->Contains(full_name("uninitialized"))) { - input_impl_.reset(); - } else { - input_impl_ = dataset()->input_->MakeIterator(prefix()); - TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); - } - return Status::OK(); - } - private: mutex mu_; std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); |