aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2017-10-23 11:07:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-23 11:11:17 -0700
commit1038927c096ecc81ca48665871d1be390444b121 (patch)
tree40c7ff20843bc62f248153b5b85e8116e16c3f4c
parent57f3e529d935e6b08a6c0a3a418ad367d9314fde (diff)
Add SerializeIterator op that serializes an IteratorResource into a variant tensor.
Add DeserializeIterator op that builds IteratorResource from a variant tensor. Move BundleReaderWrapper and BundleWriterWrapper from dataset.h to iterator_ops.cc. Add generic key-value store interfaces IteratorStateReader and IteratorStateWriter for reading/writing state of iterators. Get rid of IteratorBundleReader and IteratorBundleWriter. PiperOrigin-RevId: 173140858
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py29
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py67
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py25
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/framework/iterator.proto17
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/dataset.h189
-rw-r--r--tensorflow/core/kernels/iterator_ops.cc355
-rw-r--r--tensorflow/core/kernels/parse_tensor_op.cc1
-rw-r--r--tensorflow/core/kernels/range_dataset_op.cc11
-rw-r--r--tensorflow/core/kernels/reader_dataset_ops.cc17
-rw-r--r--tensorflow/core/kernels/repeat_dataset_op.cc13
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt24
-rw-r--r--tensorflow/core/ops/dataset_ops.cc42
-rw-r--r--tensorflow/python/kernel_tests/BUILD5
-rw-r--r--tensorflow/python/kernel_tests/iterator_ops_test.py29
-rw-r--r--tensorflow/python/kernel_tests/range_dataset_op_test.py67
-rw-r--r--tensorflow/python/kernel_tests/reader_dataset_ops_test.py26
19 files changed, 537 insertions, 386 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index c34c9dad9b..b3175e3e56 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -185,6 +185,7 @@ py_test(
"//tensorflow/python:function",
"//tensorflow/python:functional_ops",
"//tensorflow/python:gradients",
+ "//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:script_ops",
@@ -252,6 +253,8 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:variables",
@@ -274,6 +277,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
"//tensorflow/python:lib",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:tensor_shape",
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 20f6d6ba34..bda9a2a4a3 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
@@ -538,9 +539,23 @@ class IteratorTest(test.TestCase):
def testIncorrectIteratorRestore(self):
- def _iterator_checkpoint_prefix():
+ def _path():
return os.path.join(self.get_temp_dir(), "iterator")
+ def _save_op(iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ _path(), parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(_path()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
def _build_range_dataset_graph():
start = 1
stop = 10
@@ -548,22 +563,18 @@ class IteratorTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = _iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = _save_op(iterator._iterator_resource)
+ restore_op = _restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
def _build_reader_dataset_graph():
filenames = ["test"] # Does not exist but we don't care in this test.
- path = _iterator_checkpoint_prefix()
iterator = readers.FixedLengthRecordDataset(
filenames, 1, 0, 0).make_initializable_iterator()
init_op = iterator.initializer
get_next_op = iterator.get_next()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = _save_op(iterator._iterator_resource)
+ restore_op = _restore_op(iterator._iterator_resource)
return init_op, get_next_op, save_op, restore_op
# Saving iterator for RangeDataset graph.
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index c8a0072809..c944eb4a49 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -29,6 +29,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@@ -193,6 +195,21 @@ class RangeDatasetTest(test.TestCase):
def _iterator_checkpoint_prefix(self):
return os.path.join(self.get_temp_dir(), "iterator")
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_prefix(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
def testSaveRestore(self):
def _build_graph(start, stop):
@@ -200,10 +217,8 @@ class RangeDatasetTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@@ -246,14 +261,13 @@ class RangeDatasetTest(test.TestCase):
def testRestoreWithoutBuildingDatasetGraph(self):
- def _build_graph(start, stop, num_epochs, path):
+ def _build_graph(start, stop, num_epochs):
dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@@ -262,10 +276,8 @@ class RangeDatasetTest(test.TestCase):
num_epochs = 5
break_point = 5
break_epoch = 3
- path = self._iterator_checkpoint_prefix()
with ops.Graph().as_default() as g:
- init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs,
- path)
+ init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
with self.test_session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
@@ -282,8 +294,7 @@ class RangeDatasetTest(test.TestCase):
output_shapes = tensor_shape.scalar()
iterator = iterator_ops.Iterator.from_structure(output_types,
output_shapes)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ restore_op = self._restore_op(iterator._iterator_resource)
get_next = iterator.get_next()
with self.test_session(graph=g) as sess:
sess.run(restore_op)
@@ -302,10 +313,8 @@ class RangeDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@@ -343,10 +352,8 @@ class RangeDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@@ -379,10 +386,8 @@ class RangeDatasetTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
@@ -424,10 +429,8 @@ class RangeDatasetTest(test.TestCase):
start, stop).repeat(num_epochs).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
@@ -471,10 +474,8 @@ class RangeDatasetTest(test.TestCase):
start, stop).repeat(num_epochs).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index c9f88f3dfc..2682e8bdfa 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@@ -276,18 +277,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
def _iterator_checkpoint_path(self):
return os.path.join(self.get_temp_dir(), "iterator")
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_path(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
def _build_iterator_graph(self, num_epochs):
filenames = self._createFiles()
- path = self._iterator_checkpoint_path()
dataset = (readers.FixedLengthRecordDataset(
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
.repeat(num_epochs))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next_op = iterator.get_next()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next_op, save_op, restore_op
def _restore_iterator(self):
@@ -295,8 +309,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
output_shapes = tensor_shape.scalar()
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
get_next = iterator.get_next()
- restore_op = gen_dataset_ops.restore_iterator(
- iterator._iterator_resource, self._iterator_checkpoint_path())
+ restore_op = self._restore_op(iterator._iterator_resource)
return restore_op, get_next
def testSaveRestore(self):
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6ad93a73f4..c4f880da9d 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -163,6 +163,7 @@ CORE_PROTO_SRCS = [
"framework/function.proto",
"framework/graph.proto",
"framework/graph_transfer_info.proto",
+ "framework/iterator.proto",
"framework/kernel_def.proto",
"framework/log_memory.proto",
"framework/node_def.proto",
diff --git a/tensorflow/core/framework/iterator.proto b/tensorflow/core/framework/iterator.proto
new file mode 100644
index 0000000000..7e5f5ea2e0
--- /dev/null
+++ b/tensorflow/core/framework/iterator.proto
@@ -0,0 +1,17 @@
+syntax = "proto3";
+
+package tensorflow;
+option cc_enable_arenas = true;
+option java_outer_classname = "IteratorProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.util";
+
+// Protocol buffer representing the metadata for an iterator's state stored
+// as a Variant tensor.
+message IteratorStateMetadata {
+ // A user-specified version string.
+ string version = 1;
+
+ // Keys for tensors in the VariantTensorDataProto.
+ repeated string keys = 2;
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index d931f12f6d..f5bfa60199 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -6061,6 +6061,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
],
)
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h
index f9ffc4e065..a906113466 100644
--- a/tensorflow/core/kernels/dataset.h
+++ b/tensorflow/core/kernels/dataset.h
@@ -17,12 +17,14 @@ limitations under the License.
#include <memory>
+#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -39,54 +41,25 @@ namespace tensorflow {
class ResourceMgr;
-class BundleReaderWrapper {
+// Interface for reading values from a key-value store.
+// Used for restoring iterator state.
+class IteratorStateReader {
public:
- BundleReaderWrapper(BundleReader* bundle_reader)
- : bundle_reader_(bundle_reader) {}
+ virtual Status ReadScalar(StringPiece key, int64* val) = 0;
+ virtual Status ReadScalar(StringPiece key, string* val) = 0;
+ virtual bool Contains(StringPiece key) = 0;
- // Reads a scalar value.
- template <typename T>
- Status ReadScalar(StringPiece key, T* val) {
- Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
- TF_RETURN_IF_ERROR(Lookup(key, &val_t));
- *val = val_t.scalar<T>()();
- return Status::OK();
- }
-
- bool Contains(StringPiece key) { return bundle_reader_->Contains(key); }
-
- private:
- Status Lookup(StringPiece key, Tensor* val) {
- return bundle_reader_->Lookup(key, val);
- }
-
- BundleReader* bundle_reader_;
+ virtual ~IteratorStateReader() {}
};
-class BundleWriterWrapper {
+// Interface for writing values to a key-value store.
+// Used for saving iterator state.
+class IteratorStateWriter {
public:
- // Note: We intentionally do not provide a constructor that builds a
- // BundleWriter from the checkpoint path because we want the caller to be
- // in-charge of calling BundleWriter::Finish(). If we expose the Finish()
- // method here it may be called pre-maturely by users of this object.
- explicit BundleWriterWrapper(BundleWriter* bundle_writer)
- : bundle_writer_(bundle_writer) {}
-
- // Writes a scalar value.
- template <typename T>
- Status WriteScalar(StringPiece key, const T val) {
- Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
- val_t.scalar<T>()() = val;
- TF_RETURN_IF_ERROR(Add(key, val_t));
- return Status::OK();
- }
+ virtual Status WriteScalar(StringPiece key, const int64& val) = 0;
+ virtual Status WriteScalar(StringPiece key, const string& val) = 0;
- private:
- Status Add(StringPiece key, const Tensor& val) {
- return bundle_writer_->Add(key, val);
- }
-
- BundleWriter* bundle_writer_;
+ virtual ~IteratorStateWriter() {}
};
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
@@ -249,10 +222,6 @@ class IteratorContext {
// range of outputs is typically represented by an `DatasetBase`,
// defined below.
class IteratorBase {
- protected:
- class IteratorBundleReader;
- class IteratorBundleWriter;
-
public:
virtual ~IteratorBase() {}
@@ -284,87 +253,53 @@ class IteratorBase {
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
// Saves the state of this iterator.
- virtual Status Save(OpKernelContext* ctx, const string& path) {
- BundleWriter bundle_writer(ctx->env(), path);
- TF_RETURN_IF_ERROR(bundle_writer.status());
- IteratorBundleWriter writer(&bundle_writer);
- TF_RETURN_IF_ERROR(Save(ctx, &writer));
- return bundle_writer.Finish();
+ virtual Status Save(IteratorStateWriter* writer) {
+ if (is_exhausted_) {
+ LOG(INFO) << "Iterator exhausted.";
+ return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted);
+ } else {
+ return SaveInternal(writer);
+ }
}
- virtual Status Restore(OpKernelContext* ctx, const string& path) {
- if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) {
- return errors::NotFound(
- "Failed to restore Iterator state. No file found at ",
- MetaFilename(path));
+ // Restores the state of this iterator.
+ virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
+ if (reader->Contains(kIteratorExhausted)) {
+ LOG(INFO) << "Iterator exhausted. Nothing to restore.";
+ is_exhausted_ = true;
+ return Status::OK();
+ } else {
+ return RestoreInternal(ctx, reader);
}
- BundleReader bundle_reader(ctx->env(), path);
- TF_RETURN_IF_ERROR(bundle_reader.status());
- IteratorBundleReader reader(&bundle_reader);
- return Restore(ctx, &reader);
}
static const char kIteratorExhausted[];
protected:
// This is needed so that sub-classes of IteratorBase can call
- // `RestoreInternal` on their parent iterators, e.g., in
+ // `SaveInternal` on their parent iterators, e.g., in
// `RepeatDataasetOp::Dataset`.
- class IteratorBundleReader : public BundleReaderWrapper {
- public:
- IteratorBundleReader(BundleReader* bundle_reader)
- : BundleReaderWrapper(bundle_reader) {}
-
- // Restores the state of a parent iterator recursively.
- Status RestoreParent(OpKernelContext* ctx,
- const std::unique_ptr<IteratorBase>& parent) {
- return parent->RestoreInternal(ctx, this);
- }
- };
+ Status SaveParent(IteratorStateWriter* writer,
+ const std::unique_ptr<IteratorBase>& parent) {
+ return parent->SaveInternal(writer);
+ }
// This is needed so that sub-classes of IteratorBase can call
- // `SaveInternal` on their parent iterators, e.g., in
+ // `RestoreInternal` on their parent iterators, e.g., in
// `RepeatDataasetOp::Dataset`.
- class IteratorBundleWriter : public BundleWriterWrapper {
- public:
- IteratorBundleWriter(BundleWriter* bundle_writer)
- : BundleWriterWrapper(bundle_writer) {}
- // Saves the state of a parent iterator recursively.
- Status SaveParent(OpKernelContext* ctx,
- const std::unique_ptr<IteratorBase>& parent) {
- return parent->SaveInternal(ctx, this);
- }
- };
-
- virtual Status Save(OpKernelContext* ctx, IteratorBundleWriter* writer) {
- if (is_exhausted_) {
- LOG(INFO) << "Iterator exhausted.";
- return writer->WriteScalar<string>(kIteratorExhausted,
- kIteratorExhausted);
- } else {
- return SaveInternal(ctx, writer);
- }
+ Status RestoreParent(OpKernelContext* ctx, IteratorStateReader* reader,
+ const std::unique_ptr<IteratorBase>& parent) {
+ return parent->RestoreInternal(ctx, reader);
}
- // Saves the state of this iterator.
- virtual Status SaveInternal(OpKernelContext* ctx,
- IteratorBundleWriter* writer) {
+ // Saves the state of this iterator recursively.
+ virtual Status SaveInternal(IteratorStateWriter* writer) {
return errors::Unimplemented("SaveInternal");
}
- virtual Status Restore(OpKernelContext* ctx, IteratorBundleReader* reader) {
- if (reader->Contains(kIteratorExhausted)) {
- LOG(INFO) << "Iterator exhausted. Nothing to restore.";
- is_exhausted_ = true;
- return Status::OK();
- } else {
- return RestoreInternal(ctx, reader);
- }
- }
-
- // Restores the state of this iterator.
+ // Restores the state of this iterator recursively.
virtual Status RestoreInternal(OpKernelContext* ctx,
- IteratorBundleReader* reader) {
+ IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal");
}
@@ -404,7 +339,7 @@ class DatasetBase : public core::RefCounted {
virtual string DebugString() = 0;
// Serializes the dataset and writes it to the `writer`.
- virtual Status Save(BundleWriterWrapper* writer) const {
+ virtual Status Save(IteratorStateWriter* writer) const {
return errors::Unimplemented("DatasetBase::Save");
}
@@ -435,20 +370,14 @@ class GraphDatasetBase : public DatasetBase {
const string op_name() const { return op_name_; }
- Status Save(BundleWriterWrapper* writer) const override {
- GraphDefBuilder b;
- DatasetGraphDefBuilder db(&b);
- Node* node = nullptr;
- TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node));
- string output_name = node->name();
- GraphDef graph_def;
- TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
+ Status Save(IteratorStateWriter* writer) const override {
string serialized_graph_def;
- graph_def.SerializeToString(&serialized_graph_def);
+ string output_node;
+ TF_RETURN_IF_ERROR(Serialize(&serialized_graph_def, &output_node));
TF_RETURN_IF_ERROR(
- writer->WriteScalar<string>(kDatasetGraphKey, serialized_graph_def));
+ writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
TF_RETURN_IF_ERROR(
- writer->WriteScalar<string>(kDatasetGraphOutputNodeKey, output_name));
+ writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
return Status::OK();
}
@@ -460,6 +389,18 @@ class GraphDatasetBase : public DatasetBase {
static const char kDatasetGraphOutputNodeKey[];
private:
+ Status Serialize(string* serialized_graph_def, string* output_node) const {
+ GraphDefBuilder b;
+ DatasetGraphDefBuilder db(&b);
+ Node* node = nullptr;
+ TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node));
+ *output_node = node->name();
+ GraphDef graph_def;
+ TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
+ graph_def.SerializeToString(serialized_graph_def);
+ return Status::OK();
+ }
+
const string op_name_;
};
@@ -505,18 +446,18 @@ class DatasetIterator : public IteratorBase {
return GetNextInternal(ctx, out_tensors, end_of_sequence);
}
- protected:
- Status Save(OpKernelContext* ctx, IteratorBundleWriter* writer) final {
+ Status Save(IteratorStateWriter* writer) final {
TF_RETURN_IF_ERROR(dataset()->Save(writer));
- return IteratorBase::Save(ctx, writer);
+ return IteratorBase::Save(writer);
}
+ protected:
// Internal implementation of GetNext that is wrapped in tracing logic.
virtual Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) = 0;
- string full_name(const string& name) {
+ string full_name(const string& name) const {
return strings::StrCat(prefix(), ":", name);
}
diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc
index df13edc83a..b7c1fff2a9 100644
--- a/tensorflow/core/kernels/iterator_ops.cc
+++ b/tensorflow/core/kernels/iterator_ops.cc
@@ -16,9 +16,11 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
+#include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
@@ -35,6 +37,8 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following ops.
+const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
+
Status VerifyTypesMatch(const DataTypeVector& expected,
const DataTypeVector& received) {
if (expected.size() != received.size()) {
@@ -93,10 +97,10 @@ class IteratorResource : public ResourceBase {
}
}
- Status Save(OpKernelContext* ctx, const string& path) {
+ Status Save(IteratorStateWriter* writer) {
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
if (captured_iterator) {
- return captured_iterator->Save(ctx, path);
+ return captured_iterator->Save(writer);
} else {
return errors::FailedPrecondition(
"Save() failed because the iterator has not been initialized. "
@@ -105,49 +109,34 @@ class IteratorResource : public ResourceBase {
}
}
- Status Restore(OpKernelContext* ctx, const string& path) {
- if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) {
- return errors::NotFound(
- "Failed to restore Iterator state. No file found at ",
- MetaFilename(path));
- }
-
- BundleReader bundle_reader(ctx->env(), path);
- TF_RETURN_IF_ERROR(bundle_reader.status());
- BundleReaderWrapper reader(&bundle_reader);
- if (reader.Contains(GraphDatasetBase::kDatasetGraphKey)) {
- string serialized_graph_def;
- TF_RETURN_IF_ERROR(reader.ReadScalar(GraphDatasetBase::kDatasetGraphKey,
- &serialized_graph_def));
- GraphDef graph_def;
- graph_def.ParseFromString(serialized_graph_def);
- // TODO(srbs): Is there a way of getting the op registry of the original
- // graph.
- Graph graph(OpRegistry::Global());
- TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
- string output_node;
- TF_RETURN_IF_ERROR(reader.ReadScalar(
- GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
- std::vector<Tensor> outputs;
- GraphRunner graph_runner(ctx->env());
- TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
- {output_node}, &outputs));
- DatasetBase* dataset;
- TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
- TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
- } else if (reader.Contains(IteratorBase::kIteratorExhausted)) {
- TF_RETURN_IF_ERROR(set_iterator(std::unique_ptr<IteratorBase>(
- new ExhaustedIterator(output_dtypes_, output_shapes_))));
+ Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
+ string serialized_graph_def;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(GraphDatasetBase::kDatasetGraphKey,
+ &serialized_graph_def));
+ GraphDef graph_def;
+ if (!graph_def.ParseFromString(serialized_graph_def)) {
+ return errors::Internal("Error parsing dataset GraphDef.");
}
+ string output_node;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
+ DatasetBase* dataset = nullptr;
+ Graph graph(OpRegistry::Global());
+ TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
+ std::vector<Tensor> outputs;
+ GraphRunner graph_runner(ctx->env());
+ TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
+ {output_node}, &outputs));
+ TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
+
+ TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
if (captured_iterator) {
- // TODO(srbs): Figure a way to pass bundle_reader here.
- return captured_iterator->Restore(ctx, path);
+ return captured_iterator->Restore(ctx, reader);
} else {
return errors::FailedPrecondition(
- "Failed to restore iterator from ", path,
- ". Make sure the checkpoint ",
+ "Failed to restore iterator. Make sure the checkpoint ",
"is not corrupt. If the checkpoint does not contain the GraphDef, ",
"you will need to initialize your iterator before restoring.");
}
@@ -174,43 +163,194 @@ class IteratorResource : public ResourceBase {
}
private:
- // A no-op iterator which always sets end_of_sequence = true. An instance of
- // this is returned when attempting to restore an exhausted iterator. This is
- // needed because the Dataset GraphDef may not have been saved for exhausted
- // iterators so the actual Iterator can not be built.
- class ExhaustedIterator : public IteratorBase {
- public:
- ExhaustedIterator(const DataTypeVector& output_dtypes,
- const std::vector<PartialTensorShape>& output_shapes)
- : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {}
- Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) final {
- *end_of_sequence = true;
- return Status::OK();
- }
+ std::shared_ptr<IteratorBase> iterator_;
+ const DataTypeVector output_dtypes_;
+ const std::vector<PartialTensorShape> output_shapes_;
+};
+
+// Helper class for reading data from a VariantTensorData object.
+class VariantTensorDataReader : public IteratorStateReader {
+ public:
+ explicit VariantTensorDataReader(const VariantTensorData* data)
+ : data_(data) {
+ PreProcess();
+ }
+
+ // Returns OK iff the initialization was successful, i.e.,
+ // pre-processing did not have errors.
+ Status status() const { return status_; }
+
+ Status ReadScalar(StringPiece key, int64* val) override {
+ return ReadScalarInternal(key, val);
+ }
+
+ Status ReadScalar(StringPiece key, string* val) override {
+ return ReadScalarInternal(key, val);
+ }
- const DataTypeVector& output_dtypes() const override {
- return output_dtypes_;
+ bool Contains(StringPiece key) override {
+ return map_.find(key.ToString()) != map_.end();
+ }
+
+ private:
+ void PreProcess() {
+ string metadata;
+ data_->get_metadata(&metadata);
+ IteratorStateMetadata proto;
+ if (!proto.ParseFromString(metadata)) {
+ status_ = errors::Internal("Error parsing IteratorStateMetadata.");
+ return;
+ }
+ size_t num_entries = proto.keys_size();
+ CHECK_EQ(num_entries, data_->tensors_size());
+ for (size_t i = 0; i < num_entries; i++) {
+ map_[proto.keys(i)] = i;
}
+ }
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return output_shapes_;
+ template <typename T>
+ Status ReadScalarInternal(StringPiece key, T* val) {
+ if (map_.find(key.ToString()) == map_.end()) {
+ return errors::NotFound(key);
}
+ *val = data_->tensors(map_[key.ToString()]).scalar<T>()();
+ return Status::OK();
+ }
- virtual const std::vector<PartialTensorShape>& output_shapes() {
- return output_shapes_;
+ std::map<string, size_t> map_;
+ const VariantTensorData* data_; // Not owned.
+ Status status_;
+};
+
+// Helper class for writing data to a VariantTensorData object.
+class VariantTensorDataWriter : public IteratorStateWriter {
+ public:
+ // Does not take ownership of data.
+ explicit VariantTensorDataWriter(VariantTensorData* data) : data_(data) {}
+
+ Status WriteScalar(StringPiece key, const int64& val) override {
+ return WriteScalarInternal(key, val);
+ }
+
+ Status WriteScalar(StringPiece key, const string& val) override {
+ return WriteScalarInternal(key, val);
+ }
+
+ // Writes the metadata to `data_`.
+ Status Flush() {
+ string metadata;
+ if (!metadata_proto_.SerializeToString(&metadata)) {
+ return errors::Internal("Unable to serialize IteratorStateMetadata.");
}
+ data_->set_metadata(metadata);
+ return Status::OK();
+ }
- private:
- const DataTypeVector output_dtypes_;
- const std::vector<PartialTensorShape> output_shapes_;
- };
+ private:
+ template <typename T>
+ Status WriteScalarInternal(StringPiece key, const T& val) {
+ // Write key to the metadata proto. This gets written to `data_`
+ // when `Flush()` is called. We do this lazily to avoid multiple
+ // serialization calls.
+ metadata_proto_.add_keys(key.ToString());
+
+ // Update tensors.
+ Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
+ val_t.scalar<T>()() = val;
+ *(data_->add_tensors()) = std::move(val_t);
+ return Status::OK();
+ }
- std::shared_ptr<IteratorBase> iterator_;
- const DataTypeVector output_dtypes_;
- const std::vector<PartialTensorShape> output_shapes_;
+ VariantTensorData* data_;
+ // TODO(srbs): Set the version string.
+ IteratorStateMetadata metadata_proto_;
+};
+
+// Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
+// The get() method returns an IteratorStateReader which can be used
+// to restore iterator state.
+//
+// Usage example:
+//
+// Encoding:
+//
+// Tensor t(DT_VARIANT, TensorShape({}));
+// t->scalar<Variant>()() = IteratorStateVariant(iterator_resource);
+//
+// Encode() sets the type_name of the VariantTensorData object to
+// IteratorStateVariant::TypeName().
+//
+// Decoding:
+//
+// Variant v = <VariantTensorDataProto object>;
+// DecodeUnaryVariant(&v);
+// IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
+// iterator_resource->Restore(ctx, wrapper->get())
+//
+// The type_name of the VariantTensorData object to be decoded must
+// match IteratorStateVariant::TypeName().
+class IteratorStateVariant {
+ public:
+ IteratorStateVariant() : data_(nullptr) {}
+ IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
+ if (other.data_) {
+ Decode(*other.data_);
+ }
+ }
+ // Initializes this object with the current state of the iterator so
+ // that it can be written on the next call to Encode().
+ Status InitializeFromIterator(IteratorResource* iterator_resource) {
+ data_.reset(new VariantTensorData());
+ data_->set_type_name(TypeName());
+ VariantTensorDataWriter writer(data_.get());
+ TF_RETURN_IF_ERROR(iterator_resource->Save(&writer));
+ TF_RETURN_IF_ERROR(writer.Flush());
+ return Status::OK();
+ }
+ string TypeName() const { return kIteratorVariantTypeName; }
+ void Encode(VariantTensorData* data) const { *data = *data_; }
+ bool Decode(const VariantTensorData& data) {
+ if (data.type_name() != TypeName()) {
+ return false;
+ }
+ std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData);
+ *tensor_data = data;
+ std::unique_ptr<VariantTensorDataReader> reader(
+ new VariantTensorDataReader(tensor_data.get()));
+ status_ = reader->status();
+ if (!status_.ok()) {
+ return false;
+ }
+ data_ = std::move(tensor_data);
+ reader_ = std::move(reader);
+ return true;
+ }
+ IteratorStateReader* get() { return reader_.get(); }
+ Status status() const { return status_; }
+ string DebugString() const {
+ if (data_) {
+ return strings::StrCat("IteratorStateVariant<",
+ "data: ", data_->DebugString(),
+ " status: ", status_.ToString(), ">");
+ } else {
+ return strings::StrCat("IteratorStateVariant<empty>");
+ }
+ }
+
+ private:
+ std::unique_ptr<IteratorStateReader> reader_;
+ Status status_;
+ std::unique_ptr<VariantTensorData> data_;
};
+// Register the reader class in the global variant decode_fn registry
+// so that a Variant containing a serialized representation of iterator state
+// can be decoded using DecodeUnaryVariant. If we don't do this we will need
+// to manually decode the returned Variant using MaybeDecodeAndCopy in
+// DeserializeIteratorOp which is not recommended.
+REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
+ kIteratorVariantTypeName);
+
// TODO(mrry): Can we simply use the template kernel here?
class IteratorHandleOp : public ResourceOpKernel<IteratorResource> {
public:
@@ -294,37 +434,6 @@ class ToSingleElementOp : public OpKernel {
}
};
-class SaveIteratorOp : public OpKernel {
- public:
- explicit SaveIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- IteratorResource* iterator_resource;
- OP_REQUIRES_OK(
- ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(1).shape()),
- errors::InvalidArgument("SaveIteratorOp: path must be scalar"));
- const string& path = ctx->input(1).scalar<string>()();
- OP_REQUIRES_OK(ctx, iterator_resource->Save(ctx, path));
- }
-};
-
-class RestoreIteratorOp : public OpKernel {
- public:
- explicit RestoreIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- IteratorResource* iterator_resource;
- OP_REQUIRES_OK(
- ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
- OP_REQUIRES(
- ctx, TensorShapeUtils::IsScalar(ctx->input(1).shape()),
- errors::InvalidArgument("RestoreIteratorOp: path must be scalar"));
- const string& path = ctx->input(1).scalar<string>()();
- OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, path));
- }
-};
-
class OneShotIteratorOp : public AsyncOpKernel {
public:
explicit OneShotIteratorOp(OpKernelConstruction* ctx)
@@ -644,15 +753,55 @@ class IteratorFromStringHandleOp : public OpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
+class SerializeIteratorOp : public OpKernel {
+ public:
+ explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& resource_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
+ errors::InvalidArgument("resource_handle must be a scalar"));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an IteratorResource.
+ IteratorResource* iterator_resource;
+ OP_REQUIRES_OK(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
+ iterator_resource->Unref();
+ Tensor* variant_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t));
+ IteratorStateVariant v;
+ OP_REQUIRES_OK(ctx, v.InitializeFromIterator(iterator_resource));
+ variant_t->scalar<Variant>()() = v;
+ }
+};
+
+class DeserializeIteratorOp : public OpKernel {
+ public:
+ explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an IteratorResource.
+ IteratorResource* iterator_resource;
+ OP_REQUIRES_OK(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
+
+ Variant variant = ctx->input(1).scalar<Variant>()();
+ auto* wrapper = variant.get<IteratorStateVariant>();
+ OP_REQUIRES(ctx, wrapper != nullptr,
+ errors::InvalidArgument(
+ "DeserializeIteratorOp: Unable to parse variant tensor."));
+ OP_REQUIRES_OK(ctx, wrapper->status());
+ OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, wrapper->get()));
+ }
+};
+
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
MakeIteratorOp);
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
ToSingleElementOp);
-REGISTER_KERNEL_BUILDER(Name("SaveIterator").Device(DEVICE_CPU),
- SaveIteratorOp);
-REGISTER_KERNEL_BUILDER(Name("RestoreIterator").Device(DEVICE_CPU),
- RestoreIteratorOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
OneShotIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
@@ -661,6 +810,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
IteratorToStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
IteratorFromStringHandleOp);
+REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
+ SerializeIteratorOp);
+REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
+ DeserializeIteratorOp);
} // namespace
diff --git a/tensorflow/core/kernels/parse_tensor_op.cc b/tensorflow/core/kernels/parse_tensor_op.cc
index ab91a6ef67..6b599612ad 100644
--- a/tensorflow/core/kernels/parse_tensor_op.cc
+++ b/tensorflow/core/kernels/parse_tensor_op.cc
@@ -92,6 +92,7 @@ class SerializeTensorOp : public OpKernel {
Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
SerializeTensorOp<T>);
TF_CALL_ALL_TYPES(REGISTER)
+TF_CALL_variant(REGISTER)
#undef REGISTER
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/range_dataset_op.cc b/tensorflow/core/kernels/range_dataset_op.cc
index a57c21a590..7adfcc4f8d 100644
--- a/tensorflow/core/kernels/range_dataset_op.cc
+++ b/tensorflow/core/kernels/range_dataset_op.cc
@@ -112,19 +112,16 @@ class RangeDatasetOp : public DatasetOpKernel {
}
protected:
- Status SaveInternal(OpKernelContext* ctx,
- IteratorBundleWriter* writer) override {
+ Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(
- writer->WriteScalar<int64>(full_name("next"), next_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("next"), next_));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
- IteratorBundleReader* reader) override {
+ IteratorStateReader* reader) override {
mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(
- reader->ReadScalar<int64>(full_name("next"), &next_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next"), &next_));
return Status::OK();
}
diff --git a/tensorflow/core/kernels/reader_dataset_ops.cc b/tensorflow/core/kernels/reader_dataset_ops.cc
index b455c28e07..fb88c55f73 100644
--- a/tensorflow/core/kernels/reader_dataset_ops.cc
+++ b/tensorflow/core/kernels/reader_dataset_ops.cc
@@ -356,31 +356,30 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
}
protected:
- Status SaveInternal(OpKernelContext* ctx,
- IteratorBundleWriter* writer) override {
+ Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(writer->WriteScalar<int64>(
- full_name("current_file_index"), current_file_index_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
+ current_file_index_));
// `input_buffer_` is empty if
// 1. GetNext has not been called even once.
// 2. All files have been read and iterator has been exhausted.
int64 current_pos = input_buffer_ ? input_buffer_->Tell() : -1;
TF_RETURN_IF_ERROR(
- writer->WriteScalar<int64>(full_name("current_pos"), current_pos));
+ writer->WriteScalar(full_name("current_pos"), current_pos));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
- IteratorBundleReader* reader) override {
+ IteratorStateReader* reader) override {
mutex_lock l(mu_);
int64 current_file_index;
- TF_RETURN_IF_ERROR(reader->ReadScalar<int64>(
- full_name("current_file_index"), &current_file_index));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
+ &current_file_index));
current_file_index_ = size_t(current_file_index);
int64 current_pos;
TF_RETURN_IF_ERROR(
- reader->ReadScalar<int64>(full_name("current_pos"), &current_pos));
+ reader->ReadScalar(full_name("current_pos"), &current_pos));
// Seek to current_pos.
input_buffer_.reset();
diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc
index 5d836927d2..9813e99a70 100644
--- a/tensorflow/core/kernels/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/repeat_dataset_op.cc
@@ -124,19 +124,18 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
}
protected:
- Status SaveInternal(OpKernelContext* ctx,
- IteratorBundleWriter* writer) override {
+ Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(writer->WriteScalar<int64>(full_name("i"), i_));
- TF_RETURN_IF_ERROR(writer->SaveParent(ctx, input_impl_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(OpKernelContext* ctx,
- IteratorBundleReader* reader) override {
+ IteratorStateReader* reader) override {
mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(reader->ReadScalar<int64>(full_name("i"), &i_));
- TF_RETURN_IF_ERROR(reader->RestoreParent(ctx, input_impl_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
return Status::OK();
}
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 6772024263..c5ceb14a09 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -28754,18 +28754,6 @@ op {
is_stateful: true
}
op {
- name: "RestoreIterator"
- input_arg {
- name: "iterator"
- type: DT_RESOURCE
- }
- input_arg {
- name: "path"
- type: DT_STRING
- }
- is_stateful: true
-}
-op {
name: "RestoreSlice"
input_arg {
name: "file_pattern"
@@ -29549,18 +29537,6 @@ op {
is_stateful: true
}
op {
- name: "SaveIterator"
- input_arg {
- name: "iterator"
- type: DT_RESOURCE
- }
- input_arg {
- name: "path"
- type: DT_STRING
- }
- is_stateful: true
-}
-op {
name: "SaveSlices"
input_arg {
name: "filename"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 566049179a..8b77e3f9f0 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -598,24 +598,6 @@ This operation may be executed multiple times. Each execution will reset the
iterator in `iterator` to the first element of `dataset`.
)doc");
-REGISTER_OP("SaveIterator")
- .Input("iterator: resource")
- .Input("path: string")
- .SetShapeFn(shape_inference::NoOutputs)
- .Doc(R"doc(
-Saves the state of the `iterator` at `path`.
-
-This state can be restored using "RestoreIterator".
-)doc");
-
-REGISTER_OP("RestoreIterator")
- .Input("iterator: resource")
- .Input("path: string")
- .SetShapeFn(shape_inference::NoOutputs)
- .Doc(R"doc(
-Restores the state of the `iterator` from the checkpoint saved at `path` using "SaveIterator".
-)doc");
-
REGISTER_OP("OneShotIterator")
.Output("handle: resource")
.Attr("dataset_factory: func")
@@ -737,4 +719,28 @@ output_shapes: If specified, defines the shape of each tuple component in an
element produced by the resulting iterator.
)doc");
+REGISTER_OP("SerializeIterator")
+ .Input("resource_handle: resource")
+ .Output("serialized: variant")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Converts the given `resource_handle` representing an iterator to a variant tensor.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+ resource.
+)doc");
+
+REGISTER_OP("DeserializeIterator")
+ .Input("resource_handle: resource")
+ .Input("serialized: variant")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Converts the given variant tensor to an iterator and stores it in the given resource.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+ resource.
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 0e36c3498a..b02bae95fd 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2886,7 +2886,9 @@ tf_py_test(
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:io_ops",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:variables",
@@ -2907,7 +2909,9 @@ tf_py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
"//tensorflow/python:lib",
+ "//tensorflow/python:parsing_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:iterator_ops",
@@ -3022,6 +3026,7 @@ tf_py_test(
"//tensorflow/python:function",
"//tensorflow/python:functional_ops",
"//tensorflow/python:gradients",
+ "//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:script_ops",
diff --git a/tensorflow/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/kernel_tests/iterator_ops_test.py
index b5ec9f7db0..2128ef4ae1 100644
--- a/tensorflow/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/kernel_tests/iterator_ops_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
@@ -538,9 +539,23 @@ class IteratorTest(test.TestCase):
def testIncorrectIteratorRestore(self):
- def _iterator_checkpoint_prefix():
+ def _path():
return os.path.join(self.get_temp_dir(), "iterator")
+ def _save_op(iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ _path(), parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(_path()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
def _build_range_dataset_graph():
start = 1
stop = 10
@@ -548,22 +563,18 @@ class IteratorTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = _iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = _save_op(iterator._iterator_resource)
+ restore_op = _restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
def _build_reader_dataset_graph():
filenames = ["test"] # Does not exist but we don't care in this test.
- path = _iterator_checkpoint_prefix()
iterator = readers.FixedLengthRecordDataset(
filenames, 1, 0, 0).make_initializable_iterator()
init_op = iterator.initializer
get_next_op = iterator.get_next()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = _save_op(iterator._iterator_resource)
+ restore_op = _restore_op(iterator._iterator_resource)
return init_op, get_next_op, save_op, restore_op
# Saving iterator for RangeDataset graph.
diff --git a/tensorflow/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/kernel_tests/range_dataset_op_test.py
index 8291967155..0c530522b8 100644
--- a/tensorflow/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/kernel_tests/range_dataset_op_test.py
@@ -27,6 +27,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@@ -169,6 +171,21 @@ class RangeDatasetTest(test.TestCase):
def _iterator_checkpoint_prefix(self):
return os.path.join(self.get_temp_dir(), "iterator")
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_prefix(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
def testSaveRestore(self):
def _build_graph(start, stop):
@@ -176,10 +193,8 @@ class RangeDatasetTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@@ -222,14 +237,13 @@ class RangeDatasetTest(test.TestCase):
def testRestoreWithoutBuildingDatasetGraph(self):
- def _build_graph(start, stop, num_epochs, path):
+ def _build_graph(start, stop, num_epochs):
dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@@ -238,10 +252,8 @@ class RangeDatasetTest(test.TestCase):
num_epochs = 5
break_point = 5
break_epoch = 3
- path = self._iterator_checkpoint_prefix()
with ops.Graph().as_default() as g:
- init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs,
- path)
+ init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
with self.test_session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
@@ -258,8 +270,7 @@ class RangeDatasetTest(test.TestCase):
output_shapes = tensor_shape.scalar()
iterator = iterator_ops.Iterator.from_structure(output_types,
output_shapes)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ restore_op = self._restore_op(iterator._iterator_resource)
get_next = iterator.get_next()
with self.test_session(graph=g) as sess:
sess.run(restore_op)
@@ -278,10 +289,8 @@ class RangeDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@@ -319,10 +328,8 @@ class RangeDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@@ -355,10 +362,8 @@ class RangeDatasetTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
@@ -400,10 +405,8 @@ class RangeDatasetTest(test.TestCase):
start, stop).repeat(num_epochs).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
@@ -447,10 +450,8 @@ class RangeDatasetTest(test.TestCase):
start, stop).repeat(num_epochs).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
index 38420328ef..c8e7333b4b 100644
--- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
@@ -31,6 +31,8 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@@ -273,18 +275,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
def _iterator_checkpoint_path(self):
return os.path.join(self.get_temp_dir(), "iterator")
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_path(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
def _build_iterator_graph(self, num_epochs):
filenames = self._createFiles()
- path = self._iterator_checkpoint_path()
dataset = (readers.FixedLengthRecordDataset(
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
.repeat(num_epochs))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next_op = iterator.get_next()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next_op, save_op, restore_op
def _restore_iterator(self):
@@ -292,8 +307,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
output_shapes = tensor_shape.scalar()
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
get_next = iterator.get_next()
- restore_op = gen_dataset_ops.restore_iterator(
- iterator._iterator_resource, self._iterator_checkpoint_path())
+ restore_op = self._restore_op(iterator._iterator_resource)
return restore_op, get_next
def testSaveRestore(self):