aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2017-08-30 00:30:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-30 00:34:39 -0700
commitc82d70d383a97045293a59a77e5280c2d3cebaab (patch)
treeae3f420ced68b996d679a0a0e19c6ae23c90e5d1
parent822a9208fcad065c632f5dac7f6efc0686b5892b (diff)
Add support for saving/restoring states of iterators.
All iterators that need state saving will have to implement SaveState and RestoreState methods. Add SaveIterator and RestoreIterator ops. Add implementations for RangeDataset, RepeatDataset and FixedLengthRecordDataset and related tests. More to follow in future CLs. PiperOrigin-RevId: 166959426
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py202
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py132
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/dataset.cc2
-rw-r--r--tensorflow/core/kernels/dataset.h129
-rw-r--r--tensorflow/core/kernels/iterator_ops.cc55
-rw-r--r--tensorflow/core/kernels/range_dataset_op.cc20
-rw-r--r--tensorflow/core/kernels/reader_dataset_ops.cc46
-rw-r--r--tensorflow/core/kernels/repeat_dataset_op.cc18
-rw-r--r--tensorflow/core/ops/dataset_ops.cc18
10 files changed, 615 insertions, 8 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index a8edbbd20c..87bab43ccf 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -17,17 +17,29 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
class RangeDatasetTest(test.TestCase):
+ def tearDown(self):
+ # Remove all checkpoint files.
+ prefix = self._iterator_checkpoint_prefix()
+ pattern = prefix + "*"
+ files = gfile.Glob(pattern)
+ map(gfile.Remove, files)
+
def testStop(self):
stop = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator()
@@ -175,6 +187,196 @@ class RangeDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def _iterator_checkpoint_prefix(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def testSaveRestore(self):
+
+ def _build_graph(start, stop):
+ iterator = dataset_ops.Dataset.range(start,
+ stop).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ path = self._iterator_checkpoint_prefix()
+ save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
+ restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
+ path)
+ return init_op, get_next, save_op, restore_op
+
+ # Saving and restoring in different sessions.
+ start = 2
+ stop = 10
+ break_point = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Saving and restoring in same session.
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testMultipleSaves(self):
+
+ def _build_graph(start, stop):
+ iterator = dataset_ops.Dataset.range(start,
+ stop).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ path = self._iterator_checkpoint_prefix()
+ save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
+ restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
+ path)
+ return init_op, get_next, save_op, restore_op
+
+ start = 2
+ stop = 10
+ break_point1 = 5
+ break_point2 = 7
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point1):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for i in range(break_point1, break_point2):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ break_point2 = 7
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for i in range(break_point2, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testSaveRestoreWithRepeat(self):
+
+ def _build_graph(start, stop, num_epochs):
+ iterator = dataset_ops.Dataset.range(
+ start, stop).repeat(num_epochs).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ path = self._iterator_checkpoint_prefix()
+ save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
+ restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
+ path)
+ return init_op, get_next, save_op, restore_op
+
+ start = 2
+ stop = 10
+ num_epochs = 5
+ break_range = 5
+ break_epoch = 3
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(
+ start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for _ in range(break_epoch - 1):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ for i in range(start, break_range):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for i in range(break_range, stop):
+ self.assertEqual(i, sess.run(get_next))
+ for _ in range(break_epoch, num_epochs):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testSaveRestoreExhaustedIterator(self):
+
+ def _build_graph(start, stop, num_epochs):
+ iterator = dataset_ops.Dataset.range(
+ start, stop).repeat(num_epochs).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ path = self._iterator_checkpoint_prefix()
+ save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
+ restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
+ path)
+ return init_op, get_next, save_op, restore_op
+
+ start = 2
+ stop = 10
+ num_epochs = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(
+ start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for _ in range(num_epochs):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 72a3ff1789..d631fbc76e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@@ -255,7 +256,6 @@ class FixedLengthRecordReaderTest(test.TestCase):
def testFixedLengthRecordDatasetBuffering(self):
test_filenames = self._createFiles()
-
dataset = dataset_ops.FixedLengthRecordDataset(
test_filenames,
self._record_bytes,
@@ -271,6 +271,124 @@ class FixedLengthRecordReaderTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
+ def _build_iterator_graph(self, num_epochs):
+ filenames = self._createFiles()
+ path = os.path.join(self.get_temp_dir(), "iterator")
+ dataset = (dataset_ops.FixedLengthRecordDataset(
+ filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
+ .repeat(num_epochs))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next_op = iterator.get_next()
+ save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
+ restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
+ path)
+ return init_op, get_next_op, save_op, restore_op
+
+ def testSaveRestore(self):
+ num_epochs = 10
+ epoch_break = 5
+ file_break = self._num_files // 2
+ record_break = self._num_records // 2
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch == epoch_break and f == file_break and
+ r == record_break):
+ sess.run(save_op)
+ break
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ else:
+ continue
+ break
+ else:
+ continue
+ break
+ else:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch < epoch_break or
+ (epoch == epoch_break and f < file_break) or
+ (epoch == epoch_break and f == file_break and
+ r < record_break)):
+ continue
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ def testRestoreUnusedIterator(self):
+ num_epochs = 10
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ # Save unused iterator.
+ sess.run(save_op)
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for _ in range(num_epochs * self._num_files * self._num_records):
+ sess.run(get_next_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ def testRestoreExhaustedIterator(self):
+ num_epochs = 10
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for _ in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
class TFRecordDatasetTest(test.TestCase):
@@ -558,8 +676,8 @@ class ReadBatchFeaturesTest(test.TestCase):
def testRead(self):
for batch_size in [1, 2]:
for num_epochs in [1, 10]:
- with ops.Graph().as_default():
- with self.test_session(graph=ops.get_default_graph()) as sess:
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
# Basic test: read from file 0.
self.outputs = self._read_batch_features(
filenames=self.test_filenames[0],
@@ -569,8 +687,8 @@ class ReadBatchFeaturesTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
- with ops.Graph().as_default():
- with self.test_session(graph=ops.get_default_graph()) as sess:
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
# Basic test: read from file 1.
self.outputs = self._read_batch_features(
filenames=self.test_filenames[1],
@@ -580,8 +698,8 @@ class ReadBatchFeaturesTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
- with ops.Graph().as_default():
- with self.test_session(graph=ops.get_default_graph()) as sess:
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
# Basic test: read from both files.
self.outputs = self._read_batch_features(
filenames=self.test_filenames,
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 79a7b8a8b9..0893a01204 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -5568,6 +5568,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core/util/tensor_bundle",
],
)
diff --git a/tensorflow/core/kernels/dataset.cc b/tensorflow/core/kernels/dataset.cc
index f99684b1ca..2bfbdc1cd9 100644
--- a/tensorflow/core/kernels/dataset.cc
+++ b/tensorflow/core/kernels/dataset.cc
@@ -52,4 +52,6 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
MakeDataset(ctx, input, another_input, output);
}
+const char IteratorBase::kIteratorExhausted[] = "ITERATOR_EXHAUSTED";
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h
index 9bfc5c1e96..aa97d34041 100644
--- a/tensorflow/core/kernels/dataset.h
+++ b/tensorflow/core/kernels/dataset.h
@@ -19,7 +19,11 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/util/tensor_bundle/naming.h"
+#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
// Polymorphic datasets should support all primitive TensorFlow
// types. Use this macro to expand `m(T)` once for each primitive type
@@ -86,6 +90,10 @@ class IteratorContext {
// range of outputs is typically represented by an `DatasetBase`,
// defined below.
class IteratorBase {
+ protected:
+ class IteratorBundleReader;
+ class IteratorBundleWriter;
+
public:
virtual ~IteratorBase() {}
@@ -115,6 +123,118 @@ class IteratorBase {
// (and possibly partially defined) shapes of each tuple component
// in the outputs of this iterator.
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+
+ // Saves the state of this iterator.
+ virtual Status SaveState(OpKernelContext* ctx, StringPiece path) {
+ BundleWriter bundle_writer(ctx->env(), path);
+ IteratorBundleWriter writer(&bundle_writer);
+ if (is_exhausted_) {
+ LOG(INFO) << "Iterator exhausted. Nothing to save.";
+ TF_RETURN_IF_ERROR(
+ writer.WriteScalar<string>(kIteratorExhausted, kIteratorExhausted));
+ } else {
+ TF_RETURN_IF_ERROR(SaveStateInternal(ctx, &writer));
+ }
+ TF_RETURN_IF_ERROR(bundle_writer.Finish());
+ return Status::OK();
+ }
+
+ // Restores the state of this iterator.
+ virtual Status RestoreState(OpKernelContext* ctx, StringPiece& path) {
+ if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) {
+ return errors::NotFound(
+ "Failed to restore Iterator state. No file found at ",
+ MetaFilename(path));
+ }
+ BundleReader bundle_reader(ctx->env(), path);
+ if (bundle_reader.Contains(kIteratorExhausted)) {
+ LOG(INFO) << "Iterator exhausted. Nothing to restore.";
+ is_exhausted_ = true;
+ return Status::OK();
+ } else {
+ IteratorBundleReader reader(&bundle_reader);
+ return RestoreStateInternal(ctx, &reader);
+ }
+ }
+
+ protected:
+ class IteratorBundleReader {
+ public:
+ IteratorBundleReader(BundleReader* bundle_reader)
+ : bundle_reader_(bundle_reader) {}
+
+ // Reads a scalar value.
+ template <typename T>
+ Status ReadScalar(T* val, const string& key) {
+ Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
+ TF_RETURN_IF_ERROR(Lookup(StringPiece(key), &val_t));
+ *val = val_t.scalar<T>()();
+ return Status::OK();
+ }
+
+ // Restores the state of a parent iterator recursively.
+ Status RestoreParentState(OpKernelContext* ctx,
+ const std::unique_ptr<IteratorBase>& parent) {
+ return parent->RestoreStateInternal(ctx, this);
+ }
+
+ private:
+ Status Lookup(StringPiece key, Tensor* val) {
+ return bundle_reader_->Lookup(key, val);
+ }
+
+ BundleReader* bundle_reader_;
+ };
+
+ class IteratorBundleWriter {
+ public:
+ IteratorBundleWriter(BundleWriter* bundle_writer)
+ : bundle_writer_(bundle_writer) {}
+
+ // Writes a scalar value.
+ template <typename T>
+ Status WriteScalar(const T val, const string& key) {
+ Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
+ val_t.scalar<T>()() = val;
+ TF_RETURN_IF_ERROR(Add(StringPiece(key), val_t));
+ return Status::OK();
+ }
+
+ // Saves the state of a parent iterator recursively.
+ Status SaveParentState(OpKernelContext* ctx,
+ const std::unique_ptr<IteratorBase>& parent) {
+ return parent->SaveStateInternal(ctx, this);
+ }
+
+ private:
+ Status Add(StringPiece key, const Tensor& val) {
+ return bundle_writer_->Add(key, val);
+ }
+
+ BundleWriter* bundle_writer_;
+ };
+
+ // Saves the state of this iterator.
+ // Note: Contents written to `writer` may not get flushed to disk
+ // until the call to `SaveState` in the leaf iterator is finished.
+ // Must be overridden by sub-classes.
+ virtual Status SaveStateInternal(OpKernelContext* ctx,
+ IteratorBundleWriter* writer) {
+ return errors::Unimplemented("SaveState not implemented.");
+ }
+
+ // Restores the state of this iterator.
+ //
+ // Must be overridden by sub-classes.
+ virtual Status RestoreStateInternal(OpKernelContext* ctx,
+ IteratorBundleReader* reader) {
+ return errors::Unimplemented("RestoreState not implemented");
+ }
+
+ bool is_exhausted_ = false; // Whether the iterator has been exhausted.
+
+ private:
+ static const char kIteratorExhausted[];
};
// Represents a (potentially infinite) range of outputs, where each
@@ -182,6 +302,10 @@ class DatasetIterator : public IteratorBase {
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
port::Tracing::TraceMe activity(params_.prefix);
+ if (is_exhausted_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
return GetNextInternal(ctx, out_tensors, end_of_sequence);
}
@@ -190,6 +314,11 @@ class DatasetIterator : public IteratorBase {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) = 0;
+ protected:
+ string full_name(const string& name) {
+ return strings::StrCat(prefix(), ":", name);
+ }
+
private:
Params params_;
};
diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc
index c0e4f91991..7f0e11872a 100644
--- a/tensorflow/core/kernels/iterator_ops.cc
+++ b/tensorflow/core/kernels/iterator_ops.cc
@@ -89,6 +89,31 @@ class IteratorResource : public ResourceBase {
}
}
+ Status SaveState(OpKernelContext* ctx, StringPiece path) {
+ std::shared_ptr<IteratorBase> captured_iterator(iterator_);
+ if (captured_iterator) {
+ return captured_iterator->SaveState(ctx, path);
+ } else {
+ return errors::FailedPrecondition(
+ "SaveState() failed because the iterator has not been initialized. "
+ "Ensure that you have run the initializer operation for this "
+ "iterator before getting the next element.");
+ }
+ }
+
+ Status RestoreState(OpKernelContext* ctx, StringPiece path) {
+ std::shared_ptr<IteratorBase> captured_iterator(iterator_);
+ if (captured_iterator) {
+ return captured_iterator->RestoreState(ctx, path);
+ } else {
+ return errors::FailedPrecondition(
+ "RestoreState() failed because the iterator has not been "
+ "initialized. "
+ "Ensure that you have run the initializer operation for this "
+ "iterator before getting the next element.");
+ }
+ }
+
// Transfers ownership of iterator to this. This method is thread-safe.
Status set_iterator(std::unique_ptr<IteratorBase> iterator) {
if (iterator) {
@@ -161,6 +186,32 @@ class MakeIteratorOp : public OpKernel {
}
};
+class SaveIteratorOp : public OpKernel {
+ public:
+ explicit SaveIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ IteratorResource* iterator_resource;
+ OP_REQUIRES_OK(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
+ const string& path = ctx->input(1).scalar<string>()();
+ OP_REQUIRES_OK(ctx, iterator_resource->SaveState(ctx, path));
+ }
+};
+
+class RestoreIteratorOp : public OpKernel {
+ public:
+ explicit RestoreIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ IteratorResource* iterator_resource;
+ OP_REQUIRES_OK(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
+ const string& path = ctx->input(1).scalar<string>()();
+ OP_REQUIRES_OK(ctx, iterator_resource->RestoreState(ctx, path));
+ }
+};
+
class OneShotIteratorOp : public AsyncOpKernel {
public:
explicit OneShotIteratorOp(OpKernelConstruction* ctx)
@@ -504,6 +555,10 @@ class IteratorFromStringHandleOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
MakeIteratorOp);
+REGISTER_KERNEL_BUILDER(Name("SaveIterator").Device(DEVICE_CPU),
+ SaveIteratorOp);
+REGISTER_KERNEL_BUILDER(Name("RestoreIterator").Device(DEVICE_CPU),
+ RestoreIteratorOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
OneShotIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/range_dataset_op.cc b/tensorflow/core/kernels/range_dataset_op.cc
index a32a02f57d..9976c55838 100644
--- a/tensorflow/core/kernels/range_dataset_op.cc
+++ b/tensorflow/core/kernels/range_dataset_op.cc
@@ -86,6 +86,7 @@ class RangeDatasetOp : public DatasetOpKernel {
if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) ||
(dataset()->step_ < 0 && next_ <= dataset()->stop_)) {
*end_of_sequence = true;
+ is_exhausted_ = true;
return Status::OK();
}
Tensor value_tensor(cpu_allocator(), DT_INT64, {});
@@ -97,9 +98,26 @@ class RangeDatasetOp : public DatasetOpKernel {
return Status::OK();
}
+ protected:
+ Status SaveStateInternal(OpKernelContext* ctx,
+ IteratorBundleWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar<int64>(next_, full_name("next")));
+ return Status::OK();
+ }
+
+ Status RestoreStateInternal(OpKernelContext* ctx,
+ IteratorBundleReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar<int64>(&next_, full_name("next")));
+ return Status::OK();
+ }
+
private:
mutex mu_;
- int64 next_;
+ int64 next_ GUARDED_BY(mu_);
};
const int64 start_;
diff --git a/tensorflow/core/kernels/reader_dataset_ops.cc b/tensorflow/core/kernels/reader_dataset_ops.cc
index 407f69cde7..73fc09abc8 100644
--- a/tensorflow/core/kernels/reader_dataset_ops.cc
+++ b/tensorflow/core/kernels/reader_dataset_ops.cc
@@ -315,6 +315,7 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
// Iteration ends when there are no more files to process.
if (current_file_index_ == dataset()->filenames_.size()) {
*end_of_sequence = true;
+ is_exhausted_ = true;
return Status::OK();
}
@@ -332,6 +333,51 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
} while (true);
}
+ protected:
+ Status SaveStateInternal(OpKernelContext* ctx,
+ IteratorBundleWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar<int64>(
+ current_file_index_, full_name("current_file_index")));
+
+ // `input_buffer_` is empty if
+ // 1. GetNext has not been called even once.
+ // 2. All files have been read and iterator has been exhausted.
+ int64 current_pos = input_buffer_ ? input_buffer_->Tell() : -1;
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar<int64>(current_pos, full_name("current_pos")));
+ return Status::OK();
+ }
+
+ Status RestoreStateInternal(OpKernelContext* ctx,
+ IteratorBundleReader* reader) override {
+ mutex_lock l(mu_);
+ int64 current_file_index;
+ TF_RETURN_IF_ERROR(reader->ReadScalar<int64>(
+ &current_file_index, full_name("current_file_index")));
+ current_file_index_ = size_t(current_file_index);
+ int64 current_pos;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar<int64>(&current_pos, full_name("current_pos")));
+
+ // Seek to current_pos.
+ input_buffer_.reset();
+ file_.reset();
+ if (current_pos >= 0) { // There was an active input_buffer_.
+ uint64 file_size;
+ TF_RETURN_IF_ERROR(ctx->env()->GetFileSize(
+ dataset()->filenames_[current_file_index_], &file_size));
+ file_pos_limit_ = file_size - dataset()->footer_bytes_;
+ TF_RETURN_IF_ERROR(ctx->env()->NewRandomAccessFile(
+ dataset()->filenames_[current_file_index_], &file_));
+ input_buffer_.reset(
+ new io::InputBuffer(file_.get(), dataset()->buffer_size_));
+ TF_RETURN_IF_ERROR(input_buffer_->Seek(current_pos));
+ }
+
+ return Status::OK();
+ }
+
private:
mutex mu_;
size_t current_file_index_ GUARDED_BY(mu_) = 0;
diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc
index ef17fc8d7d..6ed69ecf2e 100644
--- a/tensorflow/core/kernels/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/repeat_dataset_op.cc
@@ -107,10 +107,28 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
input_impl_ = dataset()->input_->MakeIterator(prefix());
}
*end_of_sequence = true;
+ is_exhausted_ = true;
input_impl_.reset();
return Status::OK();
}
+ protected:
+ Status SaveStateInternal(OpKernelContext* ctx,
+ IteratorBundleWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar<int64>(i_, full_name("i")));
+ TF_RETURN_IF_ERROR(writer->SaveParentState(ctx, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreStateInternal(OpKernelContext* ctx,
+ IteratorBundleReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(reader->ReadScalar<int64>(&i_, full_name("i")));
+ TF_RETURN_IF_ERROR(reader->RestoreParentState(ctx, input_impl_));
+ return Status::OK();
+ }
+
private:
mutex mu_;
int64 i_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 34b9a2119a..f6bd5768d7 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -499,6 +499,24 @@ This operation may be executed multiple times. Each execution will reset the
iterator in `iterator` to the first element of `dataset`.
)doc");
+REGISTER_OP("SaveIterator")
+ .Input("iterator: resource")
+ .Input("path: string")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Saves the state of the `iterator` at `path`.
+
+This state can be restored using "RestoreIterator".
+)doc");
+
+REGISTER_OP("RestoreIterator")
+ .Input("iterator: resource")
+ .Input("path: string")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Restores the state of the `iterator` from the checkpoint saved at `path` using "SaveIterator".
+)doc");
+
REGISTER_OP("OneShotIterator")
.Output("handle: resource")
.Attr("dataset_factory: func")