aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-06-03 18:18:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-03 18:21:00 -0700
commitbab05a2191383b3c66e9ea9ee192aef0aa36c218 (patch)
treeadd4edad3c5bb5d19dd358ba2fa0be0cec604fb8
parent45198062b58245711d7446aa389f3b9aa2c1535f (diff)
[tf.data] Input pipeline rewrites prototype.
This CL: - adds `tf.contrib.data.optimize()` transformation that can be used to trigger rewrite-based optimization for the input pipeline. - adds `tf.data.Dataset._as_serialized_graph()` method that returns the serialized graph representation of the dataset PiperOrigin-RevId: 199068055
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD13
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py89
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD15
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py80
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt20
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IdentityDataset.pbtxt14
-rw-r--r--tensorflow/core/api_def/base_api/api_def_OptimizeDataset.pbtxt20
-rw-r--r--tensorflow/core/framework/dataset.h19
-rw-r--r--tensorflow/core/kernels/BUILD2
-rw-r--r--tensorflow/core/kernels/data/BUILD47
-rw-r--r--tensorflow/core/kernels/data/dataset_ops.cc47
-rw-r--r--tensorflow/core/kernels/data/identity_dataset_op.cc102
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc210
-rw-r--r--tensorflow/core/ops/dataset_ops.cc20
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD11
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_ops_test.py37
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py9
17 files changed, 754 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 523d1f2f71..ba707d8d6e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -281,6 +281,19 @@ py_test(
)
py_test(
+ name = "optimize_dataset_op_test",
+ size = "small",
+ srcs = ["optimize_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:platform",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
name = "prefetch_dataset_op_test",
size = "small",
srcs = ["prefetch_dataset_op_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
new file mode 100644
index 0000000000..30f1847dcd
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -0,0 +1,89 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class OptimizeDatasetTest(test.TestCase):
+
+ def testDefaultOptimizations(self):
+ dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch(
+ 10).apply(optimization.optimize())
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ graph = graph_pb2.GraphDef().FromString(
+ sess.run(dataset._as_serialized_graph()))
+ self.assertTrue(
+ all([node.op != "MapAndBatchDatasetV2" for node in graph.node]))
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testEmptyOptimizations(self):
+ dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch(
+ 10).apply(optimization.optimize([]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ graph = graph_pb2.GraphDef().FromString(
+ sess.run(dataset._as_serialized_graph()))
+ self.assertTrue(
+ all([node.op != "MapAndBatchDatasetV2" for node in graph.node]))
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimization(self):
+ dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch(
+ 10).apply(optimization.optimize(["map_and_batch_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ graph = graph_pb2.GraphDef().FromString(
+ sess.run(dataset._as_serialized_graph()))
+ self.assertTrue(
+ any([node.op == "MapAndBatchDatasetV2" for node in graph.node]))
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+class OptimizeDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testCore(self):
+
+ def build_dataset(num_elements, batch_size):
+ return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch(
+ batch_size).apply(optimization.optimize(["map_and_batch_fusion"]))
+
+ self.run_core_tests(lambda: build_dataset(200, 10), None, 20)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index eceecfd174..086661adb7 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -209,6 +209,20 @@ py_library(
)
py_library(
+ name = "optimization",
+ srcs = ["optimization.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":contrib_op_loader",
+ ":gen_dataset_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
name = "resampling",
srcs = ["resampling.py"],
srcs_version = "PY2AND3",
@@ -368,6 +382,7 @@ py_library(
":get_single_element",
":grouping",
":interleave_ops",
+ ":optimization",
":prefetching_ops",
":readers",
":resampling",
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
new file mode 100644
index 0000000000..cad41bce29
--- /dev/null
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for optimizing `tf.data` pipelines."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+
+
+def optimize(optimizations=None):
+ """A transformation that applies optimizations.
+
+ Args:
+ optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying
+ optimizations to use. If not specified, the default set of optimizations
+ is applied.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return OptimizeDataset(dataset, optimizations)
+
+ return _apply_fn
+
+
+class OptimizeDataset(dataset_ops.Dataset):
+ """A `Dataset` that acts as an identity, and applies optimizations."""
+
+ def __init__(self, input_dataset, optimizations):
+ """See `optimize()` for details."""
+ super(OptimizeDataset, self).__init__()
+ self._input_dataset = input_dataset
+ if optimizations is None:
+ optimizations = []
+ self._optimizations = ops.convert_to_tensor(
+ optimizations, dtype=dtypes.string, name="optimizations")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.optimize_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._optimizations,
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt b/tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt
new file mode 100644
index 0000000000..55dd6179dd
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt
@@ -0,0 +1,20 @@
+op {
+ graph_op_name: "DatasetToGraph"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the dataset to return the graph representation for.
+END
+ }
+ out_arg {
+ name: "graph"
+ description: <<END
+The graph representation of the dataset (as serialized GraphDef).
+END
+ }
+ summary: "Returns a serialized GraphDef representing `input_dataset`."
+ description: <<END
+Returns a graph representation for `input_dataset`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_IdentityDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_IdentityDataset.pbtxt
new file mode 100644
index 0000000000..ff2854fd2c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IdentityDataset.pbtxt
@@ -0,0 +1,14 @@
+op {
+ graph_op_name: "IdentityDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ summary: "A placeholder for input pipeline graph optimizations."
+ description: <<END
+A placeholder for input pipeline graph optimizations.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_OptimizeDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_OptimizeDataset.pbtxt
new file mode 100644
index 0000000000..f26eb6e3c3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_OptimizeDataset.pbtxt
@@ -0,0 +1,20 @@
+op {
+ graph_op_name: "OptimizeDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ in_arg {
+ name: "optimizations"
+ description: <<END
+A `tf.string` vector `tf.Tensor` identifying optimizations to use.
+END
+ }
+ summary: "Creates a dataset by applying optimizations to `input_dataset`."
+ description: <<END
+Creates a dataset by applying optimizations to `input_dataset`.
+END
+}
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 23dc903caf..d8618f391e 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -459,6 +459,8 @@ class DatasetBase : public core::RefCounted {
virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const = 0;
+
+ friend class DatasetToGraphOp; // For access to graph related members.
};
// Base-class for datasets that are built by ops.
@@ -584,6 +586,23 @@ class DatasetOpKernel : public OpKernel {
*output = argument_t->scalar<T>()();
return Status::OK();
}
+
+ template <typename T>
+ Status ParseVectorArgument(OpKernelContext* ctx,
+ const StringPiece& argument_name,
+ std::vector<T>* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsVector(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a vector");
+ }
+ int size = argument_t->vec<T>().size();
+ output->reserve(size);
+ for (int i = 0; i < size; ++i) {
+ output->push_back(argument_t->vec<T>()(i));
+ }
+ return Status::OK();
+ }
};
// Encapsulates the work required to plug unary Datasets into the core
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index f9e1d37b08..c7c7879714 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -6170,7 +6170,7 @@ cc_library(
tf_kernel_library(
name = "dataset_ops",
deps = [
- "//tensorflow/core/kernels/data:dataset_ops",
+ "//tensorflow/core/kernels/data",
],
)
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index d35aad980d..da330e742e 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -549,21 +549,68 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "identity_dataset_op",
+ srcs = ["identity_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:framework",
+ ],
+)
+
+tf_kernel_library(
+ name = "optimize_dataset_op",
+ srcs = ["optimize_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:grappler_item_builder",
+ "//tensorflow/core/grappler/clusters:virtual_cluster",
+ "//tensorflow/core/grappler/optimizers:meta_optimizer",
+ "//tensorflow/core/grappler/optimizers/data",
+ ],
+)
+
+tf_kernel_library(
name = "dataset_ops",
+ srcs = ["dataset_ops.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_kernel_library(
+ name = "data",
deps = [
":batch_dataset_op",
":cache_dataset_ops",
":concatenate_dataset_op",
+ ":dataset",
+ ":dataset_ops",
":dense_to_sparse_batch_dataset_op",
":filter_dataset_op",
":flat_map_dataset_op",
":generator_dataset_op",
":group_by_reducer_dataset_op",
":group_by_window_dataset_op",
+ ":identity_dataset_op",
":interleave_dataset_op",
":iterator_ops",
":map_and_batch_dataset_op",
":map_dataset_op",
+ ":optimize_dataset_op",
":padded_batch_dataset_op",
":parallel_interleave_dataset_op",
":parallel_map_dataset_op",
diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc
new file mode 100644
index 0000000000..01989a3bd9
--- /dev/null
+++ b/tensorflow/core/kernels/data/dataset_ops.cc
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+
+namespace tensorflow {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+class DatasetToGraphOp : public OpKernel {
+ public:
+ explicit DatasetToGraphOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ DatasetBase* dataset;
+ OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
+ GraphDefBuilder b;
+ DatasetBase::DatasetGraphDefBuilder db(&b);
+ Node* input_node = nullptr;
+ OP_REQUIRES_OK(ctx, db.AddParentDataset(ctx, dataset, &input_node));
+ GraphDef graph_def;
+ OP_REQUIRES_OK(ctx, b.ToGraphDef(&graph_def));
+ Tensor* result;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
+ result->scalar<string>()() = graph_def.SerializeAsString();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
+ DatasetToGraphOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/identity_dataset_op.cc b/tensorflow/core/kernels/data/identity_dataset_op.cc
new file mode 100644
index 0000000000..e28f188336
--- /dev/null
+++ b/tensorflow/core/kernels/data/identity_dataset_op.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <map>
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+
+namespace tensorflow {
+namespace {
+
+// The purpose of identity dataset is to serve as a placeholder when performing
+// optimizations. It is not expected to be surfaced in the Python API.
+class IdentityDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit IdentityDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input)
+ : GraphDatasetBase(ctx), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Identity")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override { return "IdentityDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return errors::Unimplemented(strings::StrCat(prefix(), "::Initialize"));
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ return errors::Unimplemented(
+ strings::StrCat(prefix(), "::GetNextInternal"));
+ }
+ };
+
+ const DatasetBase* const input_;
+ };
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("IdentityDataset").Device(DEVICE_CPU),
+ IdentityDatasetOp);
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
new file mode 100644
index 0000000000..8965858e8d
--- /dev/null
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -0,0 +1,210 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <map>
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/graph_runner.h"
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/grappler_item_builder.h"
+#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+
+namespace tensorflow {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+class OptimizeDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit OptimizeDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ std::vector<string> optimizations;
+ OP_REQUIRES_OK(
+ ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
+ Dataset* dataset =
+ new Dataset(ctx, input, optimizations, output_types_, output_shapes_);
+ core::ScopedUnref unref(dataset);
+ OP_REQUIRES_OK(ctx, dataset->Optimize(ctx, output));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const std::vector<string>& optimizations,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : GraphDatasetBase(ctx),
+ input_(input),
+ optimizations_(optimizations),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Optimize")}));
+ }
+
+ Status Optimize(OpKernelContext* ctx, DatasetBase** output) {
+ GraphDefBuilder b;
+ DatasetGraphDefBuilder db(&b);
+ Node* input_node = nullptr;
+ TF_RETURN_IF_ERROR(db.AddParentDataset(ctx, input_, &input_node));
+ string output_node = input_node->name();
+ GraphDef graph_def;
+ TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
+ TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
+
+ Graph graph(OpRegistry::Global());
+ TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
+ std::vector<Tensor> outputs;
+ GraphRunner graph_runner(ctx->env());
+ // Once rewrites that add/modify functions are introduced, we will need
+ // persist the results in a function library runtime.
+ TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
+ {output_node}, &outputs));
+ TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], output));
+ (*output)->Ref();
+ return Status::OK();
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return errors::Unimplemented(strings::StrCat(prefix(), "::Initialize"));
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ return errors::Unimplemented(
+ strings::StrCat(prefix(), "::GetNextInternal"));
+ }
+ };
+
+ Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def,
+ string* output_node) {
+ // Add a fake sink node to allow rewriting the actual sink node.
+ NodeDef* node = graph_def->mutable_node()->Add();
+ node->set_name("FakeSink");
+ node->set_op("IdentityDataset");
+ node->add_input(*output_node);
+ {
+ grappler::GraphView graph(graph_def);
+ NodeDef* sink = graph.GetNode(*output_node);
+ (*node->mutable_attr())["output_shapes"] =
+ sink->attr().at("output_shapes");
+ (*node->mutable_attr())["output_types"] =
+ sink->attr().at("output_types");
+ }
+
+ // Create metagraph.
+ MetaGraphDef meta_graph_def;
+ (*meta_graph_def.mutable_graph_def()) = *graph_def;
+
+ // Grappler determines fetch ops from collection 'train_op'.
+ CollectionDef collection_def;
+ auto node_list = collection_def.mutable_node_list();
+ node_list->add_value("FakeSink");
+ (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
+
+ // Create Grappler item.
+ tensorflow::RewriterConfig rewriter_config;
+ for (const string& optimization : optimizations_) {
+ rewriter_config.add_optimizers(optimization);
+ }
+ // If no optimizations were specified, supply a non-existent optimization
+ // to prevent Grappler from applying the default set of optimizations as
+ // some of them do not work out of the box at the moment (e.g. because we
+ // have no cost model for dataset ops).
+ if (optimizations_.empty()) {
+ rewriter_config.add_optimizers("non-existent");
+ }
+ tensorflow::grappler::ItemConfig item_config;
+ item_config.apply_optimizations = true;
+ std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
+ tensorflow::grappler::GrapplerItemFromMetaGraphDef(
+ "graph", meta_graph_def, item_config);
+ std::unordered_map<string, tensorflow::DeviceProperties> device_map;
+ tensorflow::grappler::VirtualCluster cluster(device_map);
+
+ // Run optimizer.
+ TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
+ *grappler_item, rewriter_config, ctx->device(), &cluster, graph_def));
+
+ // Set `output_node` to the input of the fake sink node.
+ {
+ grappler::GraphView graph(graph_def);
+ grappler::GraphView::InputPort input_port =
+ graph.GetInputPort("FakeSink", 0);
+ *output_node = graph.GetRegularFanin(input_port).node->name();
+ }
+
+ return Status::OK();
+ }
+
+ const DatasetBase* input_;
+ const std::vector<string> optimizations_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ const int graph_def_version_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
+ OptimizeDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 6d7d8630a7..9bc6c9a30d 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -698,4 +698,24 @@ REGISTER_OP("DatasetToTFRecord")
.Input("compression_type: string")
.SetShapeFn(shape_inference::NoOutputs);
+REGISTER_OP("DatasetToGraph")
+ .Input("input_dataset: variant")
+ .Output("graph: string")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("IdentityDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("OptimizeDataset")
+ .Input("input_dataset: variant")
+ .Input("optimizations: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
} // namespace tensorflow
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index ed0c11e6c1..c8fabc4363 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -73,6 +73,17 @@ tf_py_test(
)
tf_py_test(
+ name = "dataset_ops_test",
+ size = "small",
+ srcs = ["dataset_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+tf_py_test(
name = "filter_dataset_op_test",
size = "small",
srcs = ["filter_dataset_op_test.py"],
diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
new file mode 100644
index 0000000000..2c4c11e132
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
@@ -0,0 +1,37 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the input pipeline ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class DatasetOpsTest(test.TestCase):
+
+ def testAsSerializedGraph(self):
+ dataset = dataset_ops.Dataset.range(10)
+ with self.test_session() as sess:
+ graph = graph_pb2.GraphDef().FromString(
+ sess.run(dataset._as_serialized_graph()))
+ self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 6f9b12b123..ea5fc2099c 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -57,6 +57,15 @@ class Dataset(object):
def __init__(self):
pass
+ def _as_serialized_graph(self):
+ """Produces serialized graph representation of the dataset.
+
+ Returns:
+ A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
+ serialized graph.
+ """
+ return gen_dataset_ops.dataset_to_graph(self._as_variant_tensor())
+
@abc.abstractmethod
def _as_variant_tensor(self):
"""Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.