aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/repeat_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/repeat_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/repeat_dataset_op.cc50
1 files changed, 3 insertions, 47 deletions
diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc
index 0167b9ea64..9813e99a70 100644
--- a/tensorflow/core/kernels/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/repeat_dataset_op.cc
@@ -95,15 +95,6 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- return Status::OK();
- }
- Status RestoreInternal(OpKernelContext* ctx,
- IteratorStateReader* reader) override {
- return Status::OK();
- }
};
class FiniteIterator : public DatasetIterator<Dataset> {
@@ -117,10 +108,6 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
- if (!input_impl_) {
- *end_of_sequence = true;
- return Status::OK();
- }
while (i_ < dataset()->count_) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@@ -131,6 +118,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
input_impl_ = dataset()->input_->MakeIterator(prefix());
}
*end_of_sequence = true;
+ is_exhausted_ = true;
input_impl_.reset();
return Status::OK();
}
@@ -139,12 +127,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
- if (!input_impl_) {
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("input_impl_empty"), ""));
- } else {
- TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
- }
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
return Status::OK();
}
@@ -152,11 +135,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
- if (!reader->Contains(full_name("input_impl_empty"))) {
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
- } else {
- input_impl_.reset();
- }
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
return Status::OK();
}
@@ -204,29 +183,6 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
} while (true);
}
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- if (input_impl_)
- TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
- else
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("uninitialized"), ""));
- return Status::OK();
- }
-
- Status RestoreInternal(OpKernelContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- if (reader->Contains(full_name("uninitialized"))) {
- input_impl_.reset();
- } else {
- input_impl_ = dataset()->input_->MakeIterator(prefix());
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
- }
- return Status::OK();
- }
-
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);