diff options
Diffstat (limited to 'tensorflow/core/kernels/dataset.h')
-rw-r--r-- | tensorflow/core/kernels/dataset.h | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h index aa4f436b39..4a42ac80c3 100644 --- a/tensorflow/core/kernels/dataset.h +++ b/tensorflow/core/kernels/dataset.h @@ -306,14 +306,27 @@ class IteratorBase { // Saves the state of this iterator. virtual Status Save(IteratorStateWriter* writer) { - return SaveInternal(writer); + if (is_exhausted_) { + LOG(INFO) << "Iterator exhausted."; + return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted); + } else { + return SaveInternal(writer); + } } // Restores the state of this iterator. virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) { - return RestoreInternal(ctx, reader); + if (reader->Contains(kIteratorExhausted)) { + LOG(INFO) << "Iterator exhausted. Nothing to restore."; + is_exhausted_ = true; + return Status::OK(); + } else { + return RestoreInternal(ctx, reader); + } } + static const char kIteratorExhausted[]; + protected: // This is needed so that sub-classes of IteratorBase can call // `SaveInternal` on their parent iterators, e.g., in @@ -341,6 +354,8 @@ class IteratorBase { IteratorStateReader* reader) { return errors::Unimplemented("RestoreInternal"); } + + bool is_exhausted_ = false; // Whether the iterator has been exhausted. }; // Represents a (potentially infinite) range of outputs, where each @@ -476,6 +491,10 @@ class DatasetIterator : public IteratorBase { Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) final { port::Tracing::TraceMe activity(params_.prefix); + if (is_exhausted_) { + *end_of_sequence = true; + return Status::OK(); + } return GetNextInternal(ctx, out_tensors, end_of_sequence); } |