diff options
author | 2018-06-03 18:18:12 -0700 | |
---|---|---|
committer | 2018-06-03 18:21:00 -0700 | |
commit | bab05a2191383b3c66e9ea9ee192aef0aa36c218 (patch) | |
tree | add4edad3c5bb5d19dd358ba2fa0be0cec604fb8 | |
parent | 45198062b58245711d7446aa389f3b9aa2c1535f (diff) |
[tf.data] Input pipeline rewrites prototype.
This CL:
- adds `tf.contrib.data.optimize()` transformation that can be used to trigger rewrite-based optimization for the input pipeline.
- adds `tf.data.Dataset._as_serialized_graph()` method that returns the serialized graph representation of the dataset
PiperOrigin-RevId: 199068055
17 files changed, 754 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 523d1f2f71..ba707d8d6e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -281,6 +281,19 @@ py_test( ) py_test( + name = "optimize_dataset_op_test", + size = "small", + srcs = ["optimize_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:platform", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( name = "prefetch_dataset_op_test", size = "small", srcs = ["prefetch_dataset_op_test.py"], diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py new file mode 100644 index 0000000000..30f1847dcd --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class OptimizeDatasetTest(test.TestCase): + + def testDefaultOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize()) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testEmptyOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize([])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimization(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + any([node.op == "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +class OptimizeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_dataset(num_elements, batch_size): + return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( + batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) + + self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index eceecfd174..086661adb7 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -209,6 +209,20 @@ py_library( ) py_library( + name = "optimization", + srcs = ["optimization.py"], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( name = "resampling", srcs = ["resampling.py"], srcs_version = "PY2AND3", @@ -368,6 +382,7 @@ py_library( ":get_single_element", ":grouping", ":interleave_ops", + ":optimization", ":prefetching_ops", ":readers", ":resampling", diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py new file mode 100644 index 0000000000..cad41bce29 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +def optimize(optimizations=None): + """A transformation that applies optimizations. + + Args: + optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying + optimizations to use. If not specified, the default set of optimizations + is applied. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return OptimizeDataset(dataset, optimizations) + + return _apply_fn + + +class OptimizeDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and applies optimizations.""" + + def __init__(self, input_dataset, optimizations): + """See `optimize()` for details.""" + super(OptimizeDataset, self).__init__() + self._input_dataset = input_dataset + if optimizations is None: + optimizations = [] + self._optimizations = ops.convert_to_tensor( + optimizations, dtype=dtypes.string, name="optimizations") + + def _as_variant_tensor(self): + return gen_dataset_ops.optimize_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._optimizations, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt b/tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt new file mode 100644 index 0000000000..55dd6179dd --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt @@ -0,0 +1,20 @@ +op { + graph_op_name: "DatasetToGraph" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <<END +A variant tensor representing the dataset to return the graph representation for. +END + } + out_arg { + name: "graph" + description: <<END +The graph representation of the dataset (as serialized GraphDef). +END + } + summary: "Returns a serialized GraphDef representing `input_dataset`." + description: <<END +Returns a graph representation for `input_dataset`. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_IdentityDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_IdentityDataset.pbtxt new file mode 100644 index 0000000000..ff2854fd2c --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_IdentityDataset.pbtxt @@ -0,0 +1,14 @@ +op { + graph_op_name: "IdentityDataset" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <<END +A variant tensor representing the input dataset. +END + } + summary: "A placeholder for input pipeline graph optimizations." + description: <<END +A placeholder for input pipeline graph optimizations. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_OptimizeDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_OptimizeDataset.pbtxt new file mode 100644 index 0000000000..f26eb6e3c3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_OptimizeDataset.pbtxt @@ -0,0 +1,20 @@ +op { + graph_op_name: "OptimizeDataset" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <<END +A variant tensor representing the input dataset. +END + } + in_arg { + name: "optimizations" + description: <<END +A `tf.string` vector `tf.Tensor` identifying optimizations to use. +END + } + summary: "Creates a dataset by applying optimizations to `input_dataset`." + description: <<END +Creates a dataset by applying optimizations to `input_dataset`. +END +} diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 23dc903caf..d8618f391e 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -459,6 +459,8 @@ class DatasetBase : public core::RefCounted { virtual std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const = 0; + + friend class DatasetToGraphOp; // For access to graph related members. }; // Base-class for datasets that are built by ops. @@ -584,6 +586,23 @@ class DatasetOpKernel : public OpKernel { *output = argument_t->scalar<T>()(); return Status::OK(); } + + template <typename T> + Status ParseVectorArgument(OpKernelContext* ctx, + const StringPiece& argument_name, + std::vector<T>* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsVector(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a vector"); + } + int size = argument_t->vec<T>().size(); + output->reserve(size); + for (int i = 0; i < size; ++i) { + output->push_back(argument_t->vec<T>()(i)); + } + return Status::OK(); + } }; // Encapsulates the work required to plug unary Datasets into the core diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f9e1d37b08..c7c7879714 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6170,7 +6170,7 @@ cc_library( tf_kernel_library( name = "dataset_ops", deps = [ - "//tensorflow/core/kernels/data:dataset_ops", + "//tensorflow/core/kernels/data", ], ) diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index d35aad980d..da330e742e 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -549,21 +549,68 @@ tf_kernel_library( ) tf_kernel_library( + name = "identity_dataset_op", + srcs = ["identity_dataset_op.cc"], + deps = [ + ":dataset", + "//tensorflow/core:framework", + ], +) + +tf_kernel_library( + name = "optimize_dataset_op", + srcs = ["optimize_dataset_op.cc"], + deps = [ + ":dataset", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:grappler_item_builder", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/optimizers:meta_optimizer", + "//tensorflow/core/grappler/optimizers/data", + ], +) + +tf_kernel_library( name = "dataset_ops", + srcs = ["dataset_ops.cc"], + deps = [ + ":dataset", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_kernel_library( + name = "data", deps = [ ":batch_dataset_op", ":cache_dataset_ops", ":concatenate_dataset_op", + ":dataset", + ":dataset_ops", ":dense_to_sparse_batch_dataset_op", ":filter_dataset_op", ":flat_map_dataset_op", ":generator_dataset_op", ":group_by_reducer_dataset_op", ":group_by_window_dataset_op", + ":identity_dataset_op", ":interleave_dataset_op", ":iterator_ops", ":map_and_batch_dataset_op", ":map_dataset_op", + ":optimize_dataset_op", ":padded_batch_dataset_op", ":parallel_interleave_dataset_op", ":parallel_map_dataset_op", diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc new file mode 100644 index 0000000000..01989a3bd9 --- /dev/null +++ b/tensorflow/core/kernels/data/dataset_ops.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/data/dataset.h" + +namespace tensorflow { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. +class DatasetToGraphOp : public OpKernel { + public: + explicit DatasetToGraphOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + DatasetBase* dataset; + OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); + GraphDefBuilder b; + DatasetBase::DatasetGraphDefBuilder db(&b); + Node* input_node = nullptr; + OP_REQUIRES_OK(ctx, db.AddParentDataset(ctx, dataset, &input_node)); + GraphDef graph_def; + OP_REQUIRES_OK(ctx, b.ToGraphDef(&graph_def)); + Tensor* result; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result)); + result->scalar<string>()() = graph_def.SerializeAsString(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU), + DatasetToGraphOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/identity_dataset_op.cc b/tensorflow/core/kernels/data/identity_dataset_op.cc new file mode 100644 index 0000000000..e28f188336 --- /dev/null +++ b/tensorflow/core/kernels/data/identity_dataset_op.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <map> + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/dataset.h" + +namespace tensorflow { +namespace { + +// The purpose of identity dataset is to serve as a placeholder when performing +// optimizations. It is not expected to be surfaced in the Python API. +class IdentityDatasetOp : public UnaryDatasetOpKernel { + public: + explicit IdentityDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + *output = new Dataset(ctx, input); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input) + : GraphDatasetBase(ctx), input_(input) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Identity")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { return "IdentityDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + return errors::Unimplemented(strings::StrCat(prefix(), "::Initialize")); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + return errors::Unimplemented( + strings::StrCat(prefix(), "::GetNextInternal")); + } + }; + + const DatasetBase* const input_; + }; + + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; +}; + +REGISTER_KERNEL_BUILDER(Name("IdentityDataset").Device(DEVICE_CPU), + IdentityDatasetOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc new file mode 100644 index 0000000000..8965858e8d --- /dev/null +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -0,0 +1,210 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <map> + +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/grappler_item_builder.h" +#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" +#include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. +class OptimizeDatasetOp : public UnaryDatasetOpKernel { + public: + explicit OptimizeDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + std::vector<string> optimizations; + OP_REQUIRES_OK( + ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations)); + Dataset* dataset = + new Dataset(ctx, input, optimizations, output_types_, output_shapes_); + core::ScopedUnref unref(dataset); + OP_REQUIRES_OK(ctx, dataset->Optimize(ctx, output)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const std::vector<string>& optimizations, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes) + : GraphDatasetBase(ctx), + input_(input), + optimizations_(optimizations), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Optimize")})); + } + + Status Optimize(OpKernelContext* ctx, DatasetBase** output) { + GraphDefBuilder b; + DatasetGraphDefBuilder db(&b); + Node* input_node = nullptr; + TF_RETURN_IF_ERROR(db.AddParentDataset(ctx, input_, &input_node)); + string output_node = input_node->name(); + GraphDef graph_def; + TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); + TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node)); + + Graph graph(OpRegistry::Global()); + TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); + std::vector<Tensor> outputs; + GraphRunner graph_runner(ctx->env()); + // Once rewrites that add/modify functions are introduced, we will need + // persist the results in a function library runtime. + TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {}, + {output_node}, &outputs)); + TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], output)); + (*output)->Ref(); + return Status::OK(); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { return "OptimizeDatasetOp::Dataset"; } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + return errors::Unimplemented(strings::StrCat(prefix(), "::Initialize")); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + return errors::Unimplemented( + strings::StrCat(prefix(), "::GetNextInternal")); + } + }; + + Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def, + string* output_node) { + // Add a fake sink node to allow rewriting the actual sink node. + NodeDef* node = graph_def->mutable_node()->Add(); + node->set_name("FakeSink"); + node->set_op("IdentityDataset"); + node->add_input(*output_node); + { + grappler::GraphView graph(graph_def); + NodeDef* sink = graph.GetNode(*output_node); + (*node->mutable_attr())["output_shapes"] = + sink->attr().at("output_shapes"); + (*node->mutable_attr())["output_types"] = + sink->attr().at("output_types"); + } + + // Create metagraph. + MetaGraphDef meta_graph_def; + (*meta_graph_def.mutable_graph_def()) = *graph_def; + + // Grappler determines fetch ops from collection 'train_op'. + CollectionDef collection_def; + auto node_list = collection_def.mutable_node_list(); + node_list->add_value("FakeSink"); + (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def; + + // Create Grappler item. + tensorflow::RewriterConfig rewriter_config; + for (const string& optimization : optimizations_) { + rewriter_config.add_optimizers(optimization); + } + // If no optimizations were specified, supply a non-existent optimization + // to prevent Grappler from applying the default set of optimizations as + // some of them do not work out of the box at the moment (e.g. because we + // have no cost model for dataset ops). + if (optimizations_.empty()) { + rewriter_config.add_optimizers("non-existent"); + } + tensorflow::grappler::ItemConfig item_config; + item_config.apply_optimizations = true; + std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item = + tensorflow::grappler::GrapplerItemFromMetaGraphDef( + "graph", meta_graph_def, item_config); + std::unordered_map<string, tensorflow::DeviceProperties> device_map; + tensorflow::grappler::VirtualCluster cluster(device_map); + + // Run optimizer. + TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer( + *grappler_item, rewriter_config, ctx->device(), &cluster, graph_def)); + + // Set `output_node` to the input of the fake sink node. + { + grappler::GraphView graph(graph_def); + grappler::GraphView::InputPort input_port = + graph.GetInputPort("FakeSink", 0); + *output_node = graph.GetRegularFanin(input_port).node->name(); + } + + return Status::OK(); + } + + const DatasetBase* input_; + const std::vector<string> optimizations_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + }; + + const int graph_def_version_; + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; +}; + +REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU), + OptimizeDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 6d7d8630a7..9bc6c9a30d 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -698,4 +698,24 @@ REGISTER_OP("DatasetToTFRecord") .Input("compression_type: string") .SetShapeFn(shape_inference::NoOutputs); +REGISTER_OP("DatasetToGraph") + .Input("input_dataset: variant") + .Output("graph: string") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("IdentityDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("OptimizeDataset") + .Input("input_dataset: variant") + .Input("optimizations: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + } // namespace tensorflow diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index ed0c11e6c1..c8fabc4363 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -73,6 +73,17 @@ tf_py_test( ) tf_py_test( + name = "dataset_ops_test", + size = "small", + srcs = ["dataset_ops_test.py"], + additional_deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +tf_py_test( name = "filter_dataset_op_test", size = "small", srcs = ["filter_dataset_op_test.py"], diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py new file mode 100644 index 0000000000..2c4c11e132 --- /dev/null +++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py @@ -0,0 +1,37 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the input pipeline ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class DatasetOpsTest(test.TestCase): + + def testAsSerializedGraph(self): + dataset = dataset_ops.Dataset.range(10) + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 6f9b12b123..ea5fc2099c 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -57,6 +57,15 @@ class Dataset(object): def __init__(self): pass + def _as_serialized_graph(self): + """Produces serialized graph representation of the dataset. + + Returns: + A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a + serialized graph. + """ + return gen_dataset_ops.dataset_to_graph(self._as_variant_tensor()) + @abc.abstractmethod def _as_variant_tensor(self): """Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset. |