diff options
author | 2017-12-19 11:47:21 -0800 | |
---|---|---|
committer | 2017-12-19 11:57:20 -0800 | |
commit | 49ca74b2782d12e52839ea47d3dd7061c8710004 (patch) | |
tree | 125e8b1ac4192984647c587933284337b93cf7b2 | |
parent | 6e8ff4e9288c81900e69618f58633da1e261098b (diff) |
[tf.data] Saveable iterator for ScanDataset.
PiperOrigin-RevId: 179583460
3 files changed, 96 insertions, 6 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index e0d0759567..3d10bb35e6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -397,6 +397,7 @@ py_test( srcs = ["scan_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index 5338ec56bf..e0494736b7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -21,6 +21,7 @@ import itertools import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -124,5 +125,18 @@ class ScanDatasetTest(test.TestCase): scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) +class ScanDatasetSerialzationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_elements): + return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) + + def testScanCore(self): + num_output = 5 + self.run_core_tests(lambda: self._build_dataset(num_output), + lambda: self._build_dataset(2), num_output) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index 84ba051468..413d0c0f57 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -64,20 +64,23 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { std::move(other_arguments), &captured_func)); - *output = - new Dataset(input, std::move(initial_state), std::move(captured_func), - state_types_, output_types_, output_shapes_); + *output = new Dataset(ctx, input, func_, std::move(initial_state), + std::move(captured_func), state_types_, output_types_, + output_shapes_); } private: - class Dataset : public DatasetBase { + class Dataset : public GraphDatasetBase { public: - Dataset(const DatasetBase* input, std::vector<Tensor> initial_state, + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, std::vector<Tensor> initial_state, std::unique_ptr<CapturedFunction> captured_func, const DataTypeVector& state_types, const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes) - : input_(input), + : GraphDatasetBase(ctx), + input_(input), + func_(func), initial_state_(std::move(initial_state)), captured_func_(std::move(captured_func)), state_types_(state_types), @@ -103,6 +106,45 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { string DebugString() override { return "ScanDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); + Node* input_node; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node)); + std::vector<Node*> initial_state_nodes; + initial_state_nodes.reserve(initial_state_.size()); + for (const Tensor& t : initial_state_) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + initial_state_nodes.emplace_back(node); + } + std::vector<Node*> other_arguments; + other_arguments.reserve(captured_func_->captured_inputs().size()); + DataTypeVector other_arguments_types; + other_arguments_types.reserve(captured_func_->captured_inputs().size()); + for (const Tensor& t : captured_func_->captured_inputs()) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + other_arguments.emplace_back(node); + other_arguments_types.emplace_back(t.dtype()); + } + AttrValue f; + b->BuildAttrValue(func_, &f); + AttrValue state_types; + b->BuildAttrValue(state_types_, &state_types); + AttrValue other_arguments_types_attr; + b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {{0, input_node}}, + {{1, initial_state_nodes}, {2, other_arguments}}, + {{"f", f}, + {"Tstate", state_types}, + {"Targuments", other_arguments_types_attr}}, + output)); + return Status::OK(); + } + private: class Iterator : public DatasetIterator<Dataset> { public: @@ -185,6 +227,38 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { return s; } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + if (!state_.empty()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("state_size"), state_.size())); + for (int idx = 0; idx < state_.size(); idx++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("state[", idx, "]")), state_[idx])); + } + } + return Status::OK(); + } + + Status RestoreInternal(OpKernelContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + if (reader->Contains(full_name("state_size"))) { + int64 size; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("state_size"), &size)); + state_.resize(size); + for (int idx = 0; idx < size; idx++) { + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("state[", idx, "]")), &state_[idx])); + } + } + return Status::OK(); + } + private: mutex mu_; const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); @@ -192,6 +266,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* const input_; + const NameAttrList func_; const std::vector<Tensor> initial_state_; const std::unique_ptr<CapturedFunction> captured_func_; const DataTypeVector state_types_; |