aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/dataset.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/dataset.h')
-rw-r--r--tensorflow/core/kernels/dataset.h23
1 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h
index aa4f436b39..4a42ac80c3 100644
--- a/tensorflow/core/kernels/dataset.h
+++ b/tensorflow/core/kernels/dataset.h
@@ -306,14 +306,27 @@ class IteratorBase {
// Saves the state of this iterator.
virtual Status Save(IteratorStateWriter* writer) {
- return SaveInternal(writer);
+ if (is_exhausted_) {
+ LOG(INFO) << "Iterator exhausted.";
+ return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted);
+ } else {
+ return SaveInternal(writer);
+ }
}
// Restores the state of this iterator.
virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
- return RestoreInternal(ctx, reader);
+ if (reader->Contains(kIteratorExhausted)) {
+ LOG(INFO) << "Iterator exhausted. Nothing to restore.";
+ is_exhausted_ = true;
+ return Status::OK();
+ } else {
+ return RestoreInternal(ctx, reader);
+ }
}
+ static const char kIteratorExhausted[];
+
protected:
// This is needed so that sub-classes of IteratorBase can call
// `SaveInternal` on their parent iterators, e.g., in
@@ -341,6 +354,8 @@ class IteratorBase {
IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal");
}
+
+ bool is_exhausted_ = false; // Whether the iterator has been exhausted.
};
// Represents a (potentially infinite) range of outputs, where each
@@ -476,6 +491,10 @@ class DatasetIterator : public IteratorBase {
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
port::Tracing::TraceMe activity(params_.prefix);
+ if (is_exhausted_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
return GetNextInternal(ctx, out_tensors, end_of_sequence);
}