aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2017-12-19 11:47:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-19 11:57:20 -0800
commit49ca74b2782d12e52839ea47d3dd7061c8710004 (patch)
tree125e8b1ac4192984647c587933284337b93cf7b2
parent6e8ff4e9288c81900e69618f58633da1e261098b (diff)
[tf.data] Saveable iterator for ScanDataset.
PiperOrigin-RevId: 179583460
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py14
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc87
3 files changed, 96 insertions, 6 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index e0d0759567..3d10bb35e6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -397,6 +397,7 @@ py_test(
srcs = ["scan_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index 5338ec56bf..e0494736b7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -21,6 +21,7 @@ import itertools
import numpy as np
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import scan_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@@ -124,5 +125,18 @@ class ScanDatasetTest(test.TestCase):
scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
+class ScanDatasetSerialzationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, num_elements):
+ return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
+ scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
+
+ def testScanCore(self):
+ num_output = 5
+ self.run_core_tests(lambda: self._build_dataset(num_output),
+ lambda: self._build_dataset(2), num_output)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index 84ba051468..413d0c0f57 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -64,20 +64,23 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
std::move(other_arguments),
&captured_func));
- *output =
- new Dataset(input, std::move(initial_state), std::move(captured_func),
- state_types_, output_types_, output_shapes_);
+ *output = new Dataset(ctx, input, func_, std::move(initial_state),
+ std::move(captured_func), state_types_, output_types_,
+ output_shapes_);
}
private:
- class Dataset : public DatasetBase {
+ class Dataset : public GraphDatasetBase {
public:
- Dataset(const DatasetBase* input, std::vector<Tensor> initial_state,
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func, std::vector<Tensor> initial_state,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& state_types,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : input_(input),
+ : GraphDatasetBase(ctx),
+ input_(input),
+ func_(func),
initial_state_(std::move(initial_state)),
captured_func_(std::move(captured_func)),
state_types_(state_types),
@@ -103,6 +106,45 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
string DebugString() override { return "ScanDatasetOp::Dataset"; }
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
+ Node* input_node;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
+ std::vector<Node*> initial_state_nodes;
+ initial_state_nodes.reserve(initial_state_.size());
+ for (const Tensor& t : initial_state_) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ initial_state_nodes.emplace_back(node);
+ }
+ std::vector<Node*> other_arguments;
+ other_arguments.reserve(captured_func_->captured_inputs().size());
+ DataTypeVector other_arguments_types;
+ other_arguments_types.reserve(captured_func_->captured_inputs().size());
+ for (const Tensor& t : captured_func_->captured_inputs()) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ other_arguments.emplace_back(node);
+ other_arguments_types.emplace_back(t.dtype());
+ }
+ AttrValue f;
+ b->BuildAttrValue(func_, &f);
+ AttrValue state_types;
+ b->BuildAttrValue(state_types_, &state_types);
+ AttrValue other_arguments_types_attr;
+ b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {{0, input_node}},
+ {{1, initial_state_nodes}, {2, other_arguments}},
+ {{"f", f},
+ {"Tstate", state_types},
+ {"Targuments", other_arguments_types_attr}},
+ output));
+ return Status::OK();
+ }
+
private:
class Iterator : public DatasetIterator<Dataset> {
public:
@@ -185,6 +227,38 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
return s;
}
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ if (!state_.empty()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("state_size"), state_.size()));
+ for (int idx = 0; idx < state_.size(); idx++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("state[", idx, "]")), state_[idx]));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(OpKernelContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ if (reader->Contains(full_name("state_size"))) {
+ int64 size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("state_size"), &size));
+ state_.resize(size);
+ for (int idx = 0; idx < size; idx++) {
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("state[", idx, "]")), &state_[idx]));
+ }
+ }
+ return Status::OK();
+ }
+
private:
mutex mu_;
const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
@@ -192,6 +266,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
};
const DatasetBase* const input_;
+ const NameAttrList func_;
const std::vector<Tensor> initial_state_;
const std::unique_ptr<CapturedFunction> captured_func_;
const DataTypeVector state_types_;