diff options
author | Piotr Padlewski <prazek@google.com> | 2018-09-14 11:28:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-14 11:32:52 -0700 |
commit | c20a7b81d79d30db9e990309ddb419bcb48120cc (patch) | |
tree | 9ea682cf79bac18653e7690785e0f5e7117b6b8b | |
parent | 89f9080ed0d1a43cb2fa253997b2553c6916f364 (diff) |
[tf.data] Introducing an optimization that parallelizes map transformations.
Stateless MapDatasets can be paralellized by switching to ParallelMapDataset. We set `num_parallel_calls` to 2 for now, but in the future a special value will be used that result in the optimal value to be selected dynamically at runtime.
This patch also exposed a memory leak which was fixed.
PiperOrigin-RevId: 213015223
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/optimization/BUILD | 17 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py | 84 | ||||
-rw-r--r-- | tensorflow/core/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/graph_mgr.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/framework/function.cc | 13 | ||||
-rw-r--r-- | tensorflow/core/framework/function_testlib.cc | 34 | ||||
-rw-r--r-- | tensorflow/core/framework/function_testlib.h | 3 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/framework/op_segment.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/framework/op_segment.h | 4 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/BUILD | 44 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/map_parallelization.cc | 106 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/map_parallelization.h | 47 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc | 94 |
16 files changed, 461 insertions, 28 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD index 7e9ea68047..b3187bf61b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD @@ -74,6 +74,23 @@ py_test( ) py_test( + name = "map_parallelization_test", + size = "small", + srcs = ["map_parallelization_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( name = "model_dataset_op_test", size = "medium", srcs = ["model_dataset_op_test.py"], diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py new file mode 100644 index 0000000000..dd547db086 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py @@ -0,0 +1,84 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapParallelization optimization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class MapParallelizationTest(test.TestCase, parameterized.TestCase): + + @staticmethod + def map_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + + def assert_greater(x): + assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x]) + with ops.control_dependencies([assert_op]): + return x + + def random(_): + return random_ops.random_uniform([], + minval=0, + maxval=10, + dtype=dtypes.int64, + seed=42) + + def assert_with_random(x): + x = assert_greater(x) + return random(x) + + return (("Identity", identity, True), ("Increment", increment, True), + ("AssertGreater", assert_greater, True), ("Random", random, False), + ("AssertWithRandom", assert_with_random, False)) + + @parameterized.named_parameters(*map_functions.__func__()) + def testMapParallelization(self, function, should_optimize): + next_nodes = ["ParallelMap"] if should_optimize else ["Map"] + dataset = dataset_ops.Dataset.range(5).apply( + optimization.assert_next(next_nodes)).map(function).apply( + optimization.optimize(["map_parallelization"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + for x in range(5): + result = sess.run(get_next) + # No need to run the pipeline if it was not optimized. Also the results + # might be hard to check because of random. + if not should_optimize: + return + r = function(x) + self.assertAllEqual(r, result) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 1a86bff5cd..55715bb3a6 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1429,9 +1429,11 @@ cc_library( ":test", ":testlib_ops", "//tensorflow/cc:scope", + "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:ops_testutil", "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/kernels:random_ops", ], ) diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index b4d8e285bd..af5d5b17e7 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1202,14 +1202,11 @@ Status DirectSession::CreateExecutors( auto opseg = device->op_segment(); params.create_kernel = [this, lib, opseg](const NodeDef& ndef, OpKernel** kernel) { - // We do not share the kernel via the OpSegment if the node is - // stateless, or a function. // NOTE(mrry): We must not share function kernels (implemented // using `CallOp`) between subgraphs, because `CallOp::handle_` // is tied to a particular subgraph. Even if the function itself // is stateful, the `CallOp` that invokes it is not. - if (!lib->IsStateful(ndef.op()) || - lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) { + if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { return lib->CreateKernel(ndef, kernel); } auto create_fn = [lib, &ndef](OpKernel** kernel) { @@ -1222,10 +1219,8 @@ Status DirectSession::CreateExecutors( create_fn); }; params.delete_kernel = [lib](OpKernel* kernel) { - // If the node is stateful, opseg owns it. Otherwise, delete it. - if (kernel && !lib->IsStateful(kernel->type_string())) { + if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) delete kernel; - } }; optimizer.Optimize(lib, options_.env, device, &partition_graph, diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 1c9b69721d..472865ca43 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -414,9 +414,8 @@ Status FunctionLibraryRuntimeImpl::CreateKernel( device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef, &fbody->fdef.signature(), this, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, graph_def_version_, &s); - *kernel = new CallOp(handle, &construction); - if (!s.ok()) { - delete *kernel; + if (s.ok()) { + *kernel = new CallOp(handle, &construction); } return s; } diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 6c146036ae..f7a2967d00 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -233,14 +233,11 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, params.function_library = lib; params.create_kernel = [session, lib, opseg](const NodeDef& ndef, OpKernel** kernel) { - // We do not share the kernel via the OpSegment if the node is - // stateless, or a function. // NOTE(mrry): We must not share function kernels (implemented // using `CallOp`) between subgraphs, because `CallOp::handle_` // is tied to a particular subgraph. Even if the function itself // is stateful, the `CallOp` that invokes it is not. - if (!lib->IsStateful(ndef.op()) || - lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) { + if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { return lib->CreateKernel(ndef, kernel); } auto create_fn = [lib, &ndef](OpKernel** kernel) { @@ -252,8 +249,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn); }; params.delete_kernel = [lib](OpKernel* kernel) { - // If the node is stateful, opseg owns it. Otherwise, delete it. - if (kernel && !lib->IsStateful(kernel->type_string())) { + if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) { delete kernel; } }; diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index d979353d2f..a17959a448 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1294,6 +1294,18 @@ FunctionDef FunctionDefHelper::Create( for (const auto& r : ret_def) { fdef.mutable_ret()->insert({r.first, r.second}); } + + auto* op_def_registry = OpRegistry::Global(); + // Check if any op is stateful. + for (const auto& n : node_def) { + const OpDef* op_def = nullptr; + auto status = op_def_registry->LookUpOpDef(n.op, &op_def); + // Lookup can fail if e.g. we are calling a function that was not yet + // defined. If it happens, conservatively assume the op is stateful. + if (!status.ok() || op_def->is_stateful()) { + fdef.mutable_signature()->set_is_stateful(true); + } + } return fdef; } @@ -1355,6 +1367,7 @@ FunctionDef FunctionDefHelper::Define(const string& name, strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first); } } + if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true); } // Returns diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index c5a4f661d2..d5c203d276 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -91,6 +91,40 @@ FunctionDef IsZero() { }); } +FunctionDef RandomUniform() { + const Tensor kZero = test::AsScalar<int64>(0); + const Tensor kTen = test::AsScalar<int64>(10); + + return FDH::Define( + // Name + "RandomUniform", + // Args + {"x: T"}, + // Return values + {"random_uniform: int64"}, + // Attr def + {"T:{float, double, int32, int64, string}"}, + {{{"random_uniform/shape"}, + "Const", + {}, + {{"value", kZero}, {"dtype", DT_INT64}}}, + {{"random_uniform/min"}, + "Const", + {}, + {{"value", kZero}, {"dtype", DT_INT64}}}, + {{"random_uniform/max"}, + "Const", + {}, + {{"value", kTen}, {"dtype", DT_INT64}}}, + {{"random_uniform"}, + "RandomUniformInt", + {}, + {{"T", DT_INT64}, + {"Tout", DT_INT64}, + {"seed", 87654321}, + {"seed2", 42}}}}); +} + FunctionDef XTimesTwo() { const Tensor kTwo = test::AsScalar<int64>(2); return FDH::Define( diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index ad61a76f16..a01743423b 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -84,6 +84,9 @@ FunctionDef NonZero(); // x: T -> bool. FunctionDef IsZero(); +// x: T -> int64 +FunctionDef RandomUniform(); + // x:T, y:T -> y:T, x:T FunctionDef Swap(); diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index c694e10193..80f2b12987 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -80,10 +81,8 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs, // OpKernel ------------------------------------------------------------------ -// TODO(mrry): Convert to std::make_unique when available. OpKernel::OpKernel(OpKernelConstruction* context) - : OpKernel(context, - std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {} + : OpKernel(context, MakeUnique<const NodeDef>(context->def())) {} OpKernel::OpKernel(OpKernelConstruction* context, std::unique_ptr<const NodeDef> node_def) @@ -525,10 +524,8 @@ std::unique_ptr<Tensor> OpKernelContext::forward_input( return nullptr; } } - // TODO(rmlarsen): Use MakeUnique here. There is already a copy in - // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of - // general cleanup of ownership in this code. - std::unique_ptr<Tensor> output_tensor(new Tensor()); + + auto output_tensor = MakeUnique<Tensor>(); CHECK(output_tensor->CopyFrom(*input.tensor, output_shape)); return output_tensor; } diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc index dfc5aa7747..75ed4a4eaf 100644 --- a/tensorflow/core/framework/op_segment.cc +++ b/tensorflow/core/framework/op_segment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/op_segment.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -99,4 +100,11 @@ void OpSegment::RemoveHold(const string& session_handle) { delete item; } +bool OpSegment::ShouldOwnKernel(FunctionLibraryRuntime* lib, + const string& node_op) { + // OpSegment should not own kernel if the node is stateless, or a function. + return lib->IsStateful(node_op) && + lib->GetFunctionLibraryDefinition()->Find(node_op) == nullptr; +} + } // end namespace tensorflow diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h index 4433a2554f..37d939ea2b 100644 --- a/tensorflow/core/framework/op_segment.h +++ b/tensorflow/core/framework/op_segment.h @@ -60,6 +60,10 @@ class OpSegment { Status FindOrCreate(const string& session_handle, const string& node_name, OpKernel** kernel, CreateKernelFn create_fn); + // Returns true if OpSegment should own the kernel. + static bool ShouldOwnKernel(FunctionLibraryRuntime* lib, + const string& node_op); + private: // op name -> OpKernel typedef std::unordered_map<string, OpKernel*> KernelMap; diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 530c957068..e84df10778 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -19,7 +19,6 @@ cc_library( "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/kernels:cast_op", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", @@ -56,8 +55,8 @@ cc_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", - "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:functional_ops", + "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core:lib_internal", ] + tf_protos_all(), @@ -107,7 +106,6 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core/kernels:cast_op", ], ) @@ -164,7 +162,6 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/kernels:cast_op", # Must be linked for the testlib functions to work. ], ) @@ -256,7 +253,6 @@ cc_library( "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/kernels:cast_op", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", @@ -275,6 +271,43 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/kernels:control_flow_ops", + ], +) + +cc_library( + name = "map_parallelization", + srcs = ["map_parallelization.cc"], + hdrs = [ + "map_parallelization.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_utils", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/utils:topological_sort", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "map_parallelization_test", + srcs = ["map_parallelization_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":graph_utils", + ":map_parallelization", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", ], ) @@ -355,6 +388,7 @@ cc_library( ":map_and_batch_fusion", ":map_and_filter_fusion", ":map_fusion", + ":map_parallelization", ":map_vectorization", ":noop_elimination", ":shuffle_and_repeat_fusion", diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc new file mode 100644 index 0000000000..305325e434 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc @@ -0,0 +1,106 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h" + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" + +namespace tensorflow { +namespace grappler { +namespace { + +bool CanParallelize(const FunctionDef& function, + const FunctionLibraryDefinition& library) { + if (!function.signature().is_stateful()) return true; + + for (const auto& node : function.node_def()) { + const OpDef* op_def; + TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def)); + // Assert is marked as stateful, but it does not have any state (except + // changing io). Similarly to CUDA, we do not give guarantee that the + // assert operation that would fail would be the first one, so that we can + // parallelize it. + if (op_def->is_stateful() && op_def->name() != "Assert") return false; + } + + return true; +} + +NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) { + NodeDef parallel_map = map_node; + graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(), + ¶llel_map); + parallel_map.set_op("ParallelMapDataset"); + // TODO(b/114475558): We want to set `num_parallel_calls` to a special value, + // so that dynamic tunning will pick the optimal value at runtime. Because + // this feature is not yet implemented, we set it to 2, which is the smallest + // value that introduces parallelism. + auto* num_parallel_calls = graph_utils::AddScalarConstNode(2, graph); + parallel_map.add_input(num_parallel_calls->name()); + + return parallel_map; +} + +} // namespace + +Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + *output = item.graph; + MutableGraphView graph(output); + std::set<string> nodes_to_delete; + FunctionLibraryDefinition function_library(OpRegistry::Global(), + item.graph.library()); + auto get_map_node = [](const NodeDef& node) -> const NodeDef* { + if (node.op() == "MapDataset") return &node; + return nullptr; + }; + + for (const NodeDef& node : item.graph.node()) { + const NodeDef* map_node = get_map_node(node); + if (!map_node) continue; + + auto* function = + function_library.Find(map_node->attr().at("f").func().name()); + if (!CanParallelize(*function, function_library)) continue; + + auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph)); + graph.ReplaceInput(*map_node, *parallel_map); + + // TODO(prazek): we could also remove map functions from library if they + // are not used anymore. + nodes_to_delete.insert(map_node->name()); + } + + graph.DeleteNodes(nodes_to_delete); + return Status::OK(); +} + +void MapParallelization::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, + double result) { + // no-op +} + +REGISTER_GRAPH_OPTIMIZER_AS(MapParallelization, "map_parallelization"); + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.h b/tensorflow/core/grappler/optimizers/data/map_parallelization.h new file mode 100644 index 0000000000..ac9cf7e12a --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// This optimization parallelizes MapDataset when function is stateless. +class MapParallelization : public CustomGraphOptimizer { + public: + MapParallelization() = default; + ~MapParallelization() override = default; + + string name() const override { return "map_parallelization"; }; + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_ diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc new file mode 100644 index 0000000000..b2a5d9b6af --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h" + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name) { + return test::function::NDef( + name, "MapDataset", {string(input_node_name)}, + {{"f", FunctionDefHelper::FunctionRef(string(function_name))}, + {"Targuments", {}}, + {"output_shapes", {}}, + {"output_types", {}}}); +} + +const char stateless_fun_name[] = "XTimesTwo"; +const char stateful_fun_name[] = "RandomUniform"; + +TEST(MapParallelizationTest, ParallelizeSimpleMap) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + MakeMapNode("map1", "range", stateless_fun_name)}, + // FunctionLib + { + test::function::XTimesTwo(), + }); + + MapParallelization optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output)); +} + +TEST(MapParallelization, ParallelizeAssert) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + MakeMapNode("map1", "range", stateful_fun_name), + MakeMapNode("map2", "map1", stateless_fun_name), + NDef("cache", "CacheDataset", {"map2", "filename"}, {})}, + // FunctionLib + { + test::function::XTimesTwo(), + test::function::RandomUniform(), + }); + + MapParallelization optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output)); + EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow |