From 1038927c096ecc81ca48665871d1be390444b121 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Mon, 23 Oct 2017 11:07:10 -0700 Subject: Add SerializeIterator op that serializes an IteratorResource into a variant tensor. Add DeserializeIterator op that builds IteratorResource from a variant tensor. Move BundleReaderWrapper and BundleWriterWrapper from dataset.h to iterator_ops.cc. Add generic key-value store interfaces IteratorStateReader and IteratorStateWriter for reading/writing state of iterators. Get rid of IteratorBundleReader and IteratorBundleWriter. PiperOrigin-RevId: 173140858 --- tensorflow/contrib/data/python/kernel_tests/BUILD | 4 + .../data/python/kernel_tests/iterator_ops_test.py | 29 +- .../python/kernel_tests/range_dataset_op_test.py | 67 ++-- .../python/kernel_tests/reader_dataset_ops_test.py | 25 +- tensorflow/core/BUILD | 1 + tensorflow/core/framework/iterator.proto | 17 + tensorflow/core/kernels/BUILD | 1 + tensorflow/core/kernels/dataset.h | 189 ++++------- tensorflow/core/kernels/iterator_ops.cc | 355 +++++++++++++++------ tensorflow/core/kernels/parse_tensor_op.cc | 1 + tensorflow/core/kernels/range_dataset_op.cc | 11 +- tensorflow/core/kernels/reader_dataset_ops.cc | 17 +- tensorflow/core/kernels/repeat_dataset_op.cc | 13 +- tensorflow/core/ops/compat/ops_history.v1.pbtxt | 24 -- tensorflow/core/ops/dataset_ops.cc | 42 +-- tensorflow/python/kernel_tests/BUILD | 5 + .../python/kernel_tests/iterator_ops_test.py | 29 +- .../python/kernel_tests/range_dataset_op_test.py | 67 ++-- .../python/kernel_tests/reader_dataset_ops_test.py | 26 +- 19 files changed, 537 insertions(+), 386 deletions(-) create mode 100644 tensorflow/core/framework/iterator.proto diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index c34c9dad9b..b3175e3e56 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -185,6 +185,7 @@ py_test( "//tensorflow/python:function", "//tensorflow/python:functional_ops", "//tensorflow/python:gradients", + "//tensorflow/python:io_ops", "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:script_ops", @@ -252,6 +253,8 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:tensor_shape", "//tensorflow/python:variables", @@ -274,6 +277,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:tensor_shape", diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 20f6d6ba34..bda9a2a4a3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import script_ops @@ -538,9 +539,23 @@ class IteratorTest(test.TestCase): def testIncorrectIteratorRestore(self): - def _iterator_checkpoint_prefix(): + def _path(): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + _path(), parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def _build_range_dataset_graph(): start = 1 stop = 10 @@ -548,22 +563,18 @@ class IteratorTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = _iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op def _build_reader_dataset_graph(): filenames = ["test"] # Does not exist but we don't care in this test. - path = _iterator_checkpoint_prefix() iterator = readers.FixedLengthRecordDataset( filenames, 1, 0, 0).make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op # Saving iterator for RangeDataset graph. diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index c8a0072809..c944eb4a49 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -29,6 +29,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -193,6 +195,21 @@ class RangeDatasetTest(test.TestCase): def _iterator_checkpoint_prefix(self): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_prefix(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def testSaveRestore(self): def _build_graph(start, stop): @@ -200,10 +217,8 @@ class RangeDatasetTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -246,14 +261,13 @@ class RangeDatasetTest(test.TestCase): def testRestoreWithoutBuildingDatasetGraph(self): - def _build_graph(start, stop, num_epochs, path): + def _build_graph(start, stop, num_epochs): dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -262,10 +276,8 @@ class RangeDatasetTest(test.TestCase): num_epochs = 5 break_point = 5 break_epoch = 3 - path = self._iterator_checkpoint_prefix() with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs, - path) + init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs) with self.test_session(graph=g) as sess: sess.run(variables.global_variables_initializer()) sess.run(init_op) @@ -282,8 +294,7 @@ class RangeDatasetTest(test.TestCase): output_shapes = tensor_shape.scalar() iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + restore_op = self._restore_op(iterator._iterator_resource) get_next = iterator.get_next() with self.test_session(graph=g) as sess: sess.run(restore_op) @@ -302,10 +313,8 @@ class RangeDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -343,10 +352,8 @@ class RangeDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -379,10 +386,8 @@ class RangeDatasetTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 @@ -424,10 +429,8 @@ class RangeDatasetTest(test.TestCase): start, stop).repeat(num_epochs).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 @@ -471,10 +474,8 @@ class RangeDatasetTest(test.TestCase): start, stop).repeat(num_epochs).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index c9f88f3dfc..2682e8bdfa 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -276,18 +277,31 @@ class FixedLengthRecordReaderTest(test.TestCase): def _iterator_checkpoint_path(self): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_path(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def _build_iterator_graph(self, num_epochs): filenames = self._createFiles() - path = self._iterator_checkpoint_path() dataset = (readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes) .repeat(num_epochs)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op def _restore_iterator(self): @@ -295,8 +309,7 @@ class FixedLengthRecordReaderTest(test.TestCase): output_shapes = tensor_shape.scalar() iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) get_next = iterator.get_next() - restore_op = gen_dataset_ops.restore_iterator( - iterator._iterator_resource, self._iterator_checkpoint_path()) + restore_op = self._restore_op(iterator._iterator_resource) return restore_op, get_next def testSaveRestore(self): diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 6ad93a73f4..c4f880da9d 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -163,6 +163,7 @@ CORE_PROTO_SRCS = [ "framework/function.proto", "framework/graph.proto", "framework/graph_transfer_info.proto", + "framework/iterator.proto", "framework/kernel_def.proto", "framework/log_memory.proto", "framework/node_def.proto", diff --git a/tensorflow/core/framework/iterator.proto b/tensorflow/core/framework/iterator.proto new file mode 100644 index 0000000000..7e5f5ea2e0 --- /dev/null +++ b/tensorflow/core/framework/iterator.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "IteratorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.util"; + +// Protocol buffer representing the metadata for an iterator's state stored +// as a Variant tensor. +message IteratorStateMetadata { + // A user-specified version string. + string version = 1; + + // Keys for tensors in the VariantTensorDataProto. + repeated string keys = 2; +} diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index d931f12f6d..f5bfa60199 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6061,6 +6061,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h index f9ffc4e065..a906113466 100644 --- a/tensorflow/core/kernels/dataset.h +++ b/tensorflow/core/kernels/dataset.h @@ -17,12 +17,14 @@ limitations under the License. #include +#include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -39,54 +41,25 @@ namespace tensorflow { class ResourceMgr; -class BundleReaderWrapper { +// Interface for reading values from a key-value store. +// Used for restoring iterator state. +class IteratorStateReader { public: - BundleReaderWrapper(BundleReader* bundle_reader) - : bundle_reader_(bundle_reader) {} + virtual Status ReadScalar(StringPiece key, int64* val) = 0; + virtual Status ReadScalar(StringPiece key, string* val) = 0; + virtual bool Contains(StringPiece key) = 0; - // Reads a scalar value. - template - Status ReadScalar(StringPiece key, T* val) { - Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); - TF_RETURN_IF_ERROR(Lookup(key, &val_t)); - *val = val_t.scalar()(); - return Status::OK(); - } - - bool Contains(StringPiece key) { return bundle_reader_->Contains(key); } - - private: - Status Lookup(StringPiece key, Tensor* val) { - return bundle_reader_->Lookup(key, val); - } - - BundleReader* bundle_reader_; + virtual ~IteratorStateReader() {} }; -class BundleWriterWrapper { +// Interface for writing values to a key-value store. +// Used for saving iterator state. +class IteratorStateWriter { public: - // Note: We intentionally do not provide a constructor that builds a - // BundleWriter from the checkpoint path because we want the caller to be - // in-charge of calling BundleWriter::Finish(). If we expose the Finish() - // method here it may be called pre-maturely by users of this object. - explicit BundleWriterWrapper(BundleWriter* bundle_writer) - : bundle_writer_(bundle_writer) {} - - // Writes a scalar value. - template - Status WriteScalar(StringPiece key, const T val) { - Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); - val_t.scalar()() = val; - TF_RETURN_IF_ERROR(Add(key, val_t)); - return Status::OK(); - } + virtual Status WriteScalar(StringPiece key, const int64& val) = 0; + virtual Status WriteScalar(StringPiece key, const string& val) = 0; - private: - Status Add(StringPiece key, const Tensor& val) { - return bundle_writer_->Add(key, val); - } - - BundleWriter* bundle_writer_; + virtual ~IteratorStateWriter() {} }; // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. @@ -249,10 +222,6 @@ class IteratorContext { // range of outputs is typically represented by an `DatasetBase`, // defined below. class IteratorBase { - protected: - class IteratorBundleReader; - class IteratorBundleWriter; - public: virtual ~IteratorBase() {} @@ -284,87 +253,53 @@ class IteratorBase { virtual const std::vector& output_shapes() const = 0; // Saves the state of this iterator. - virtual Status Save(OpKernelContext* ctx, const string& path) { - BundleWriter bundle_writer(ctx->env(), path); - TF_RETURN_IF_ERROR(bundle_writer.status()); - IteratorBundleWriter writer(&bundle_writer); - TF_RETURN_IF_ERROR(Save(ctx, &writer)); - return bundle_writer.Finish(); + virtual Status Save(IteratorStateWriter* writer) { + if (is_exhausted_) { + LOG(INFO) << "Iterator exhausted."; + return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted); + } else { + return SaveInternal(writer); + } } - virtual Status Restore(OpKernelContext* ctx, const string& path) { - if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) { - return errors::NotFound( - "Failed to restore Iterator state. No file found at ", - MetaFilename(path)); + // Restores the state of this iterator. + virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) { + if (reader->Contains(kIteratorExhausted)) { + LOG(INFO) << "Iterator exhausted. Nothing to restore."; + is_exhausted_ = true; + return Status::OK(); + } else { + return RestoreInternal(ctx, reader); } - BundleReader bundle_reader(ctx->env(), path); - TF_RETURN_IF_ERROR(bundle_reader.status()); - IteratorBundleReader reader(&bundle_reader); - return Restore(ctx, &reader); } static const char kIteratorExhausted[]; protected: // This is needed so that sub-classes of IteratorBase can call - // `RestoreInternal` on their parent iterators, e.g., in + // `SaveInternal` on their parent iterators, e.g., in // `RepeatDataasetOp::Dataset`. - class IteratorBundleReader : public BundleReaderWrapper { - public: - IteratorBundleReader(BundleReader* bundle_reader) - : BundleReaderWrapper(bundle_reader) {} - - // Restores the state of a parent iterator recursively. - Status RestoreParent(OpKernelContext* ctx, - const std::unique_ptr& parent) { - return parent->RestoreInternal(ctx, this); - } - }; + Status SaveParent(IteratorStateWriter* writer, + const std::unique_ptr& parent) { + return parent->SaveInternal(writer); + } // This is needed so that sub-classes of IteratorBase can call - // `SaveInternal` on their parent iterators, e.g., in + // `RestoreInternal` on their parent iterators, e.g., in // `RepeatDataasetOp::Dataset`. - class IteratorBundleWriter : public BundleWriterWrapper { - public: - IteratorBundleWriter(BundleWriter* bundle_writer) - : BundleWriterWrapper(bundle_writer) {} - // Saves the state of a parent iterator recursively. - Status SaveParent(OpKernelContext* ctx, - const std::unique_ptr& parent) { - return parent->SaveInternal(ctx, this); - } - }; - - virtual Status Save(OpKernelContext* ctx, IteratorBundleWriter* writer) { - if (is_exhausted_) { - LOG(INFO) << "Iterator exhausted."; - return writer->WriteScalar(kIteratorExhausted, - kIteratorExhausted); - } else { - return SaveInternal(ctx, writer); - } + Status RestoreParent(OpKernelContext* ctx, IteratorStateReader* reader, + const std::unique_ptr& parent) { + return parent->RestoreInternal(ctx, reader); } - // Saves the state of this iterator. - virtual Status SaveInternal(OpKernelContext* ctx, - IteratorBundleWriter* writer) { + // Saves the state of this iterator recursively. + virtual Status SaveInternal(IteratorStateWriter* writer) { return errors::Unimplemented("SaveInternal"); } - virtual Status Restore(OpKernelContext* ctx, IteratorBundleReader* reader) { - if (reader->Contains(kIteratorExhausted)) { - LOG(INFO) << "Iterator exhausted. Nothing to restore."; - is_exhausted_ = true; - return Status::OK(); - } else { - return RestoreInternal(ctx, reader); - } - } - - // Restores the state of this iterator. + // Restores the state of this iterator recursively. virtual Status RestoreInternal(OpKernelContext* ctx, - IteratorBundleReader* reader) { + IteratorStateReader* reader) { return errors::Unimplemented("RestoreInternal"); } @@ -404,7 +339,7 @@ class DatasetBase : public core::RefCounted { virtual string DebugString() = 0; // Serializes the dataset and writes it to the `writer`. - virtual Status Save(BundleWriterWrapper* writer) const { + virtual Status Save(IteratorStateWriter* writer) const { return errors::Unimplemented("DatasetBase::Save"); } @@ -435,20 +370,14 @@ class GraphDatasetBase : public DatasetBase { const string op_name() const { return op_name_; } - Status Save(BundleWriterWrapper* writer) const override { - GraphDefBuilder b; - DatasetGraphDefBuilder db(&b); - Node* node = nullptr; - TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node)); - string output_name = node->name(); - GraphDef graph_def; - TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); + Status Save(IteratorStateWriter* writer) const override { string serialized_graph_def; - graph_def.SerializeToString(&serialized_graph_def); + string output_node; + TF_RETURN_IF_ERROR(Serialize(&serialized_graph_def, &output_node)); TF_RETURN_IF_ERROR( - writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); + writer->WriteScalar(kDatasetGraphKey, serialized_graph_def)); TF_RETURN_IF_ERROR( - writer->WriteScalar(kDatasetGraphOutputNodeKey, output_name)); + writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node)); return Status::OK(); } @@ -460,6 +389,18 @@ class GraphDatasetBase : public DatasetBase { static const char kDatasetGraphOutputNodeKey[]; private: + Status Serialize(string* serialized_graph_def, string* output_node) const { + GraphDefBuilder b; + DatasetGraphDefBuilder db(&b); + Node* node = nullptr; + TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node)); + *output_node = node->name(); + GraphDef graph_def; + TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); + graph_def.SerializeToString(serialized_graph_def); + return Status::OK(); + } + const string op_name_; }; @@ -505,18 +446,18 @@ class DatasetIterator : public IteratorBase { return GetNextInternal(ctx, out_tensors, end_of_sequence); } - protected: - Status Save(OpKernelContext* ctx, IteratorBundleWriter* writer) final { + Status Save(IteratorStateWriter* writer) final { TF_RETURN_IF_ERROR(dataset()->Save(writer)); - return IteratorBase::Save(ctx, writer); + return IteratorBase::Save(writer); } + protected: // Internal implementation of GetNext that is wrapped in tracing logic. virtual Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) = 0; - string full_name(const string& name) { + string full_name(const string& name) const { return strings::StrCat(prefix(), ":", name); } diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc index df13edc83a..b7c1fff2a9 100644 --- a/tensorflow/core/kernels/iterator_ops.cc +++ b/tensorflow/core/kernels/iterator_ops.cc @@ -16,9 +16,11 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/framework/iterator.pb.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/resource_op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -35,6 +37,8 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following ops. +const char kIteratorVariantTypeName[] = "tensorflow::Iterator"; + Status VerifyTypesMatch(const DataTypeVector& expected, const DataTypeVector& received) { if (expected.size() != received.size()) { @@ -93,10 +97,10 @@ class IteratorResource : public ResourceBase { } } - Status Save(OpKernelContext* ctx, const string& path) { + Status Save(IteratorStateWriter* writer) { std::shared_ptr captured_iterator(iterator_); if (captured_iterator) { - return captured_iterator->Save(ctx, path); + return captured_iterator->Save(writer); } else { return errors::FailedPrecondition( "Save() failed because the iterator has not been initialized. " @@ -105,49 +109,34 @@ class IteratorResource : public ResourceBase { } } - Status Restore(OpKernelContext* ctx, const string& path) { - if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) { - return errors::NotFound( - "Failed to restore Iterator state. No file found at ", - MetaFilename(path)); - } - - BundleReader bundle_reader(ctx->env(), path); - TF_RETURN_IF_ERROR(bundle_reader.status()); - BundleReaderWrapper reader(&bundle_reader); - if (reader.Contains(GraphDatasetBase::kDatasetGraphKey)) { - string serialized_graph_def; - TF_RETURN_IF_ERROR(reader.ReadScalar(GraphDatasetBase::kDatasetGraphKey, - &serialized_graph_def)); - GraphDef graph_def; - graph_def.ParseFromString(serialized_graph_def); - // TODO(srbs): Is there a way of getting the op registry of the original - // graph. - Graph graph(OpRegistry::Global()); - TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); - string output_node; - TF_RETURN_IF_ERROR(reader.ReadScalar( - GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node)); - std::vector outputs; - GraphRunner graph_runner(ctx->env()); - TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {}, - {output_node}, &outputs)); - DatasetBase* dataset; - TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); - TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator"))); - } else if (reader.Contains(IteratorBase::kIteratorExhausted)) { - TF_RETURN_IF_ERROR(set_iterator(std::unique_ptr( - new ExhaustedIterator(output_dtypes_, output_shapes_)))); + Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) { + string serialized_graph_def; + TF_RETURN_IF_ERROR(reader->ReadScalar(GraphDatasetBase::kDatasetGraphKey, + &serialized_graph_def)); + GraphDef graph_def; + if (!graph_def.ParseFromString(serialized_graph_def)) { + return errors::Internal("Error parsing dataset GraphDef."); } + string output_node; + TF_RETURN_IF_ERROR(reader->ReadScalar( + GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node)); + DatasetBase* dataset = nullptr; + Graph graph(OpRegistry::Global()); + TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); + std::vector outputs; + GraphRunner graph_runner(ctx->env()); + TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {}, + {output_node}, &outputs)); + TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset)); + + TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator"))); std::shared_ptr captured_iterator(iterator_); if (captured_iterator) { - // TODO(srbs): Figure a way to pass bundle_reader here. - return captured_iterator->Restore(ctx, path); + return captured_iterator->Restore(ctx, reader); } else { return errors::FailedPrecondition( - "Failed to restore iterator from ", path, - ". Make sure the checkpoint ", + "Failed to restore iterator. Make sure the checkpoint ", "is not corrupt. If the checkpoint does not contain the GraphDef, ", "you will need to initialize your iterator before restoring."); } @@ -174,43 +163,194 @@ class IteratorResource : public ResourceBase { } private: - // A no-op iterator which always sets end_of_sequence = true. An instance of - // this is returned when attempting to restore an exhausted iterator. This is - // needed because the Dataset GraphDef may not have been saved for exhausted - // iterators so the actual Iterator can not be built. - class ExhaustedIterator : public IteratorBase { - public: - ExhaustedIterator(const DataTypeVector& output_dtypes, - const std::vector& output_shapes) - : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {} - Status GetNext(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) final { - *end_of_sequence = true; - return Status::OK(); - } + std::shared_ptr iterator_; + const DataTypeVector output_dtypes_; + const std::vector output_shapes_; +}; + +// Helper class for reading data from a VariantTensorData object. +class VariantTensorDataReader : public IteratorStateReader { + public: + explicit VariantTensorDataReader(const VariantTensorData* data) + : data_(data) { + PreProcess(); + } + + // Returns OK iff the initialization was successful, i.e., + // pre-processing did not have errors. + Status status() const { return status_; } + + Status ReadScalar(StringPiece key, int64* val) override { + return ReadScalarInternal(key, val); + } + + Status ReadScalar(StringPiece key, string* val) override { + return ReadScalarInternal(key, val); + } - const DataTypeVector& output_dtypes() const override { - return output_dtypes_; + bool Contains(StringPiece key) override { + return map_.find(key.ToString()) != map_.end(); + } + + private: + void PreProcess() { + string metadata; + data_->get_metadata(&metadata); + IteratorStateMetadata proto; + if (!proto.ParseFromString(metadata)) { + status_ = errors::Internal("Error parsing IteratorStateMetadata."); + return; + } + size_t num_entries = proto.keys_size(); + CHECK_EQ(num_entries, data_->tensors_size()); + for (size_t i = 0; i < num_entries; i++) { + map_[proto.keys(i)] = i; } + } - const std::vector& output_shapes() const override { - return output_shapes_; + template + Status ReadScalarInternal(StringPiece key, T* val) { + if (map_.find(key.ToString()) == map_.end()) { + return errors::NotFound(key); } + *val = data_->tensors(map_[key.ToString()]).scalar()(); + return Status::OK(); + } - virtual const std::vector& output_shapes() { - return output_shapes_; + std::map map_; + const VariantTensorData* data_; // Not owned. + Status status_; +}; + +// Helper class for writing data to a VariantTensorData object. +class VariantTensorDataWriter : public IteratorStateWriter { + public: + // Does not take ownership of data. + explicit VariantTensorDataWriter(VariantTensorData* data) : data_(data) {} + + Status WriteScalar(StringPiece key, const int64& val) override { + return WriteScalarInternal(key, val); + } + + Status WriteScalar(StringPiece key, const string& val) override { + return WriteScalarInternal(key, val); + } + + // Writes the metadata to `data_`. + Status Flush() { + string metadata; + if (!metadata_proto_.SerializeToString(&metadata)) { + return errors::Internal("Unable to serialize IteratorStateMetadata."); } + data_->set_metadata(metadata); + return Status::OK(); + } - private: - const DataTypeVector output_dtypes_; - const std::vector output_shapes_; - }; + private: + template + Status WriteScalarInternal(StringPiece key, const T& val) { + // Write key to the metadata proto. This gets written to `data_` + // when `Flush()` is called. We do this lazily to avoid multiple + // serialization calls. + metadata_proto_.add_keys(key.ToString()); + + // Update tensors. + Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); + val_t.scalar()() = val; + *(data_->add_tensors()) = std::move(val_t); + return Status::OK(); + } - std::shared_ptr iterator_; - const DataTypeVector output_dtypes_; - const std::vector output_shapes_; + VariantTensorData* data_; + // TODO(srbs): Set the version string. + IteratorStateMetadata metadata_proto_; +}; + +// Wrapper for encoding/decoding the iterator state stored in a Variant tensor. +// The get() method returns an IteratorStateReader which can be used +// to restore iterator state. +// +// Usage example: +// +// Encoding: +// +// Tensor t(DT_VARIANT, TensorShape({})); +// t->scalar()() = IteratorStateVariant(iterator_resource); +// +// Encode() sets the type_name of the VariantTensorData object to +// IteratorStateVariant::TypeName(). +// +// Decoding: +// +// Variant v = ; +// DecodeUnaryVariant(&v); +// IteratorStateVariant* wrapper = v.get(); +// iterator_resource->Restore(ctx, wrapper->get()) +// +// The type_name of the VariantTensorData object to be decoded must +// match IteratorStateVariant::TypeName(). +class IteratorStateVariant { + public: + IteratorStateVariant() : data_(nullptr) {} + IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) { + if (other.data_) { + Decode(*other.data_); + } + } + // Initializes this object with the current state of the iterator so + // that it can be written on the next call to Encode(). + Status InitializeFromIterator(IteratorResource* iterator_resource) { + data_.reset(new VariantTensorData()); + data_->set_type_name(TypeName()); + VariantTensorDataWriter writer(data_.get()); + TF_RETURN_IF_ERROR(iterator_resource->Save(&writer)); + TF_RETURN_IF_ERROR(writer.Flush()); + return Status::OK(); + } + string TypeName() const { return kIteratorVariantTypeName; } + void Encode(VariantTensorData* data) const { *data = *data_; } + bool Decode(const VariantTensorData& data) { + if (data.type_name() != TypeName()) { + return false; + } + std::unique_ptr tensor_data(new VariantTensorData); + *tensor_data = data; + std::unique_ptr reader( + new VariantTensorDataReader(tensor_data.get())); + status_ = reader->status(); + if (!status_.ok()) { + return false; + } + data_ = std::move(tensor_data); + reader_ = std::move(reader); + return true; + } + IteratorStateReader* get() { return reader_.get(); } + Status status() const { return status_; } + string DebugString() const { + if (data_) { + return strings::StrCat("IteratorStateVariant<", + "data: ", data_->DebugString(), + " status: ", status_.ToString(), ">"); + } else { + return strings::StrCat("IteratorStateVariant"); + } + } + + private: + std::unique_ptr reader_; + Status status_; + std::unique_ptr data_; }; +// Register the reader class in the global variant decode_fn registry +// so that a Variant containing a serialized representation of iterator state +// can be decoded using DecodeUnaryVariant. If we don't do this we will need +// to manually decode the returned Variant using MaybeDecodeAndCopy in +// DeserializeIteratorOp which is not recommended. +REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant, + kIteratorVariantTypeName); + // TODO(mrry): Can we simply use the template kernel here? class IteratorHandleOp : public ResourceOpKernel { public: @@ -294,37 +434,6 @@ class ToSingleElementOp : public OpKernel { } }; -class SaveIteratorOp : public OpKernel { - public: - explicit SaveIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - IteratorResource* iterator_resource; - OP_REQUIRES_OK( - ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(1).shape()), - errors::InvalidArgument("SaveIteratorOp: path must be scalar")); - const string& path = ctx->input(1).scalar()(); - OP_REQUIRES_OK(ctx, iterator_resource->Save(ctx, path)); - } -}; - -class RestoreIteratorOp : public OpKernel { - public: - explicit RestoreIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - IteratorResource* iterator_resource; - OP_REQUIRES_OK( - ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); - OP_REQUIRES( - ctx, TensorShapeUtils::IsScalar(ctx->input(1).shape()), - errors::InvalidArgument("RestoreIteratorOp: path must be scalar")); - const string& path = ctx->input(1).scalar()(); - OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, path)); - } -}; - class OneShotIteratorOp : public AsyncOpKernel { public: explicit OneShotIteratorOp(OpKernelConstruction* ctx) @@ -644,15 +753,55 @@ class IteratorFromStringHandleOp : public OpKernel { std::vector output_shapes_; }; +class SerializeIteratorOp : public OpKernel { + public: + explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& resource_handle_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), + errors::InvalidArgument("resource_handle must be a scalar")); + + // Validate that the handle corresponds to a real resource, and + // that it is an IteratorResource. + IteratorResource* iterator_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); + iterator_resource->Unref(); + Tensor* variant_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t)); + IteratorStateVariant v; + OP_REQUIRES_OK(ctx, v.InitializeFromIterator(iterator_resource)); + variant_t->scalar()() = v; + } +}; + +class DeserializeIteratorOp : public OpKernel { + public: + explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + // Validate that the handle corresponds to a real resource, and + // that it is an IteratorResource. + IteratorResource* iterator_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); + + Variant variant = ctx->input(1).scalar()(); + auto* wrapper = variant.get(); + OP_REQUIRES(ctx, wrapper != nullptr, + errors::InvalidArgument( + "DeserializeIteratorOp: Unable to parse variant tensor.")); + OP_REQUIRES_OK(ctx, wrapper->status()); + OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, wrapper->get())); + } +}; + REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp); REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU), MakeIteratorOp); REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU), ToSingleElementOp); -REGISTER_KERNEL_BUILDER(Name("SaveIterator").Device(DEVICE_CPU), - SaveIteratorOp); -REGISTER_KERNEL_BUILDER(Name("RestoreIterator").Device(DEVICE_CPU), - RestoreIteratorOp); REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU), OneShotIteratorOp); REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU), @@ -661,6 +810,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU), IteratorToStringHandleOp); REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU), IteratorFromStringHandleOp); +REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU), + SerializeIteratorOp); +REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU), + DeserializeIteratorOp); } // namespace diff --git a/tensorflow/core/kernels/parse_tensor_op.cc b/tensorflow/core/kernels/parse_tensor_op.cc index ab91a6ef67..6b599612ad 100644 --- a/tensorflow/core/kernels/parse_tensor_op.cc +++ b/tensorflow/core/kernels/parse_tensor_op.cc @@ -92,6 +92,7 @@ class SerializeTensorOp : public OpKernel { Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint("T"), \ SerializeTensorOp); TF_CALL_ALL_TYPES(REGISTER) +TF_CALL_variant(REGISTER) #undef REGISTER } // namespace tensorflow diff --git a/tensorflow/core/kernels/range_dataset_op.cc b/tensorflow/core/kernels/range_dataset_op.cc index a57c21a590..7adfcc4f8d 100644 --- a/tensorflow/core/kernels/range_dataset_op.cc +++ b/tensorflow/core/kernels/range_dataset_op.cc @@ -112,19 +112,16 @@ class RangeDatasetOp : public DatasetOpKernel { } protected: - Status SaveInternal(OpKernelContext* ctx, - IteratorBundleWriter* writer) override { + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("next"), next_)); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("next"), next_)); return Status::OK(); } Status RestoreInternal(OpKernelContext* ctx, - IteratorBundleReader* reader) override { + IteratorStateReader* reader) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("next"), &next_)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next"), &next_)); return Status::OK(); } diff --git a/tensorflow/core/kernels/reader_dataset_ops.cc b/tensorflow/core/kernels/reader_dataset_ops.cc index b455c28e07..fb88c55f73 100644 --- a/tensorflow/core/kernels/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/reader_dataset_ops.cc @@ -356,31 +356,30 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel { } protected: - Status SaveInternal(OpKernelContext* ctx, - IteratorBundleWriter* writer) override { + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name("current_file_index"), current_file_index_)); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), + current_file_index_)); // `input_buffer_` is empty if // 1. GetNext has not been called even once. // 2. All files have been read and iterator has been exhausted. int64 current_pos = input_buffer_ ? input_buffer_->Tell() : -1; TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("current_pos"), current_pos)); + writer->WriteScalar(full_name("current_pos"), current_pos)); return Status::OK(); } Status RestoreInternal(OpKernelContext* ctx, - IteratorBundleReader* reader) override { + IteratorStateReader* reader) override { mutex_lock l(mu_); int64 current_file_index; - TF_RETURN_IF_ERROR(reader->ReadScalar( - full_name("current_file_index"), ¤t_file_index)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), + ¤t_file_index)); current_file_index_ = size_t(current_file_index); int64 current_pos; TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("current_pos"), ¤t_pos)); + reader->ReadScalar(full_name("current_pos"), ¤t_pos)); // Seek to current_pos. input_buffer_.reset(); diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc index 5d836927d2..9813e99a70 100644 --- a/tensorflow/core/kernels/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/repeat_dataset_op.cc @@ -124,19 +124,18 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { } protected: - Status SaveInternal(OpKernelContext* ctx, - IteratorBundleWriter* writer) override { + Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); - TF_RETURN_IF_ERROR(writer->SaveParent(ctx, input_impl_)); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_)); + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); return Status::OK(); } Status RestoreInternal(OpKernelContext* ctx, - IteratorBundleReader* reader) override { + IteratorStateReader* reader) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); - TF_RETURN_IF_ERROR(reader->RestoreParent(ctx, input_impl_)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_)); + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); return Status::OK(); } diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 6772024263..c5ceb14a09 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -28753,18 +28753,6 @@ op { } is_stateful: true } -op { - name: "RestoreIterator" - input_arg { - name: "iterator" - type: DT_RESOURCE - } - input_arg { - name: "path" - type: DT_STRING - } - is_stateful: true -} op { name: "RestoreSlice" input_arg { @@ -29548,18 +29536,6 @@ op { } is_stateful: true } -op { - name: "SaveIterator" - input_arg { - name: "iterator" - type: DT_RESOURCE - } - input_arg { - name: "path" - type: DT_STRING - } - is_stateful: true -} op { name: "SaveSlices" input_arg { diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 566049179a..8b77e3f9f0 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -598,24 +598,6 @@ This operation may be executed multiple times. Each execution will reset the iterator in `iterator` to the first element of `dataset`. )doc"); -REGISTER_OP("SaveIterator") - .Input("iterator: resource") - .Input("path: string") - .SetShapeFn(shape_inference::NoOutputs) - .Doc(R"doc( -Saves the state of the `iterator` at `path`. - -This state can be restored using "RestoreIterator". -)doc"); - -REGISTER_OP("RestoreIterator") - .Input("iterator: resource") - .Input("path: string") - .SetShapeFn(shape_inference::NoOutputs) - .Doc(R"doc( -Restores the state of the `iterator` from the checkpoint saved at `path` using "SaveIterator". -)doc"); - REGISTER_OP("OneShotIterator") .Output("handle: resource") .Attr("dataset_factory: func") @@ -737,4 +719,28 @@ output_shapes: If specified, defines the shape of each tuple component in an element produced by the resulting iterator. )doc"); +REGISTER_OP("SerializeIterator") + .Input("resource_handle: resource") + .Output("serialized: variant") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Converts the given `resource_handle` representing an iterator to a variant tensor. + +resource_handle: A handle to an iterator resource. +serialized: A variant tensor storing the state of the iterator contained in the + resource. +)doc"); + +REGISTER_OP("DeserializeIterator") + .Input("resource_handle: resource") + .Input("serialized: variant") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Converts the given variant tensor to an iterator and stores it in the given resource. + +resource_handle: A handle to an iterator resource. +serialized: A variant tensor storing the state of the iterator contained in the + resource. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 0e36c3498a..b02bae95fd 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2886,7 +2886,9 @@ tf_py_test( "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:io_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:tensor_shape", "//tensorflow/python:variables", @@ -2907,7 +2909,9 @@ tf_py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", "//tensorflow/python:lib", + "//tensorflow/python:parsing_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", @@ -3022,6 +3026,7 @@ tf_py_test( "//tensorflow/python:function", "//tensorflow/python:functional_ops", "//tensorflow/python:gradients", + "//tensorflow/python:io_ops", "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:script_ops", diff --git a/tensorflow/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/kernel_tests/iterator_ops_test.py index b5ec9f7db0..2128ef4ae1 100644 --- a/tensorflow/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/kernel_tests/iterator_ops_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import script_ops @@ -538,9 +539,23 @@ class IteratorTest(test.TestCase): def testIncorrectIteratorRestore(self): - def _iterator_checkpoint_prefix(): + def _path(): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + _path(), parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def _build_range_dataset_graph(): start = 1 stop = 10 @@ -548,22 +563,18 @@ class IteratorTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = _iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op def _build_reader_dataset_graph(): filenames = ["test"] # Does not exist but we don't care in this test. - path = _iterator_checkpoint_prefix() iterator = readers.FixedLengthRecordDataset( filenames, 1, 0, 0).make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op # Saving iterator for RangeDataset graph. diff --git a/tensorflow/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/kernel_tests/range_dataset_op_test.py index 8291967155..0c530522b8 100644 --- a/tensorflow/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/python/kernel_tests/range_dataset_op_test.py @@ -27,6 +27,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -169,6 +171,21 @@ class RangeDatasetTest(test.TestCase): def _iterator_checkpoint_prefix(self): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_prefix(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def testSaveRestore(self): def _build_graph(start, stop): @@ -176,10 +193,8 @@ class RangeDatasetTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -222,14 +237,13 @@ class RangeDatasetTest(test.TestCase): def testRestoreWithoutBuildingDatasetGraph(self): - def _build_graph(start, stop, num_epochs, path): + def _build_graph(start, stop, num_epochs): dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -238,10 +252,8 @@ class RangeDatasetTest(test.TestCase): num_epochs = 5 break_point = 5 break_epoch = 3 - path = self._iterator_checkpoint_prefix() with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs, - path) + init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs) with self.test_session(graph=g) as sess: sess.run(variables.global_variables_initializer()) sess.run(init_op) @@ -258,8 +270,7 @@ class RangeDatasetTest(test.TestCase): output_shapes = tensor_shape.scalar() iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + restore_op = self._restore_op(iterator._iterator_resource) get_next = iterator.get_next() with self.test_session(graph=g) as sess: sess.run(restore_op) @@ -278,10 +289,8 @@ class RangeDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -319,10 +328,8 @@ class RangeDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -355,10 +362,8 @@ class RangeDatasetTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 @@ -400,10 +405,8 @@ class RangeDatasetTest(test.TestCase): start, stop).repeat(num_epochs).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 @@ -447,10 +450,8 @@ class RangeDatasetTest(test.TestCase): start, stop).repeat(num_epochs).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py index 38420328ef..c8e7333b4b 100644 --- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py @@ -31,6 +31,8 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -273,18 +275,31 @@ class FixedLengthRecordReaderTest(test.TestCase): def _iterator_checkpoint_path(self): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_path(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def _build_iterator_graph(self, num_epochs): filenames = self._createFiles() - path = self._iterator_checkpoint_path() dataset = (readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes) .repeat(num_epochs)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op def _restore_iterator(self): @@ -292,8 +307,7 @@ class FixedLengthRecordReaderTest(test.TestCase): output_shapes = tensor_shape.scalar() iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) get_next = iterator.get_next() - restore_op = gen_dataset_ops.restore_iterator( - iterator._iterator_resource, self._iterator_checkpoint_path()) + restore_op = self._restore_op(iterator._iterator_resource) return restore_op, get_next def testSaveRestore(self): -- cgit v1.2.3