aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Piotr Padlewski <prazek@google.com>2018-09-14 11:28:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-14 11:32:52 -0700
commitc20a7b81d79d30db9e990309ddb419bcb48120cc (patch)
tree9ea682cf79bac18653e7690785e0f5e7117b6b8b
parent89f9080ed0d1a43cb2fa253997b2553c6916f364 (diff)
[tf.data] Introducing an optimization that parallelizes map transformations.
Stateless MapDatasets can be paralellized by switching to ParallelMapDataset. We set `num_parallel_calls` to 2 for now, but in the future a special value will be used that result in the optimal value to be selected dynamically at runtime. This patch also exposed a memory leak which was fixed. PiperOrigin-RevId: 213015223
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD17
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py84
-rw-r--r--tensorflow/core/BUILD2
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc9
-rw-r--r--tensorflow/core/common_runtime/function.cc5
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc8
-rw-r--r--tensorflow/core/framework/function.cc13
-rw-r--r--tensorflow/core/framework/function_testlib.cc34
-rw-r--r--tensorflow/core/framework/function_testlib.h3
-rw-r--r--tensorflow/core/framework/op_kernel.cc11
-rw-r--r--tensorflow/core/framework/op_segment.cc8
-rw-r--r--tensorflow/core/framework/op_segment.h4
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD44
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.cc106
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.h47
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc94
16 files changed, 461 insertions, 28 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
index 7e9ea68047..b3187bf61b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -74,6 +74,23 @@ py_test(
)
py_test(
+ name = "map_parallelization_test",
+ size = "small",
+ srcs = ["map_parallelization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
name = "model_dataset_op_test",
size = "medium",
srcs = ["model_dataset_op_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
new file mode 100644
index 0000000000..dd547db086
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
@@ -0,0 +1,84 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapParallelization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class MapParallelizationTest(test.TestCase, parameterized.TestCase):
+
+ @staticmethod
+ def map_functions():
+ identity = lambda x: x
+ increment = lambda x: x + 1
+
+ def assert_greater(x):
+ assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
+ with ops.control_dependencies([assert_op]):
+ return x
+
+ def random(_):
+ return random_ops.random_uniform([],
+ minval=0,
+ maxval=10,
+ dtype=dtypes.int64,
+ seed=42)
+
+ def assert_with_random(x):
+ x = assert_greater(x)
+ return random(x)
+
+ return (("Identity", identity, True), ("Increment", increment, True),
+ ("AssertGreater", assert_greater, True), ("Random", random, False),
+ ("AssertWithRandom", assert_with_random, False))
+
+ @parameterized.named_parameters(*map_functions.__func__())
+ def testMapParallelization(self, function, should_optimize):
+ next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(next_nodes)).map(function).apply(
+ optimization.optimize(["map_parallelization"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ for x in range(5):
+ result = sess.run(get_next)
+ # No need to run the pipeline if it was not optimized. Also the results
+ # might be hard to check because of random.
+ if not should_optimize:
+ return
+ r = function(x)
+ self.assertAllEqual(r, result)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 1a86bff5cd..55715bb3a6 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1429,9 +1429,11 @@ cc_library(
":test",
":testlib_ops",
"//tensorflow/cc:scope",
+ "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:ops_util",
+ "//tensorflow/core/kernels:random_ops",
],
)
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index b4d8e285bd..af5d5b17e7 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1202,14 +1202,11 @@ Status DirectSession::CreateExecutors(
auto opseg = device->op_segment();
params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
OpKernel** kernel) {
- // We do not share the kernel via the OpSegment if the node is
- // stateless, or a function.
// NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not.
- if (!lib->IsStateful(ndef.op()) ||
- lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
+ if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
return lib->CreateKernel(ndef, kernel);
}
auto create_fn = [lib, &ndef](OpKernel** kernel) {
@@ -1222,10 +1219,8 @@ Status DirectSession::CreateExecutors(
create_fn);
};
params.delete_kernel = [lib](OpKernel* kernel) {
- // If the node is stateful, opseg owns it. Otherwise, delete it.
- if (kernel && !lib->IsStateful(kernel->type_string())) {
+ if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
delete kernel;
- }
};
optimizer.Optimize(lib, options_.env, device, &partition_graph,
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 1c9b69721d..472865ca43 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -414,9 +414,8 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(
device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
&fbody->fdef.signature(), this, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, graph_def_version_, &s);
- *kernel = new CallOp(handle, &construction);
- if (!s.ok()) {
- delete *kernel;
+ if (s.ok()) {
+ *kernel = new CallOp(handle, &construction);
}
return s;
}
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 6c146036ae..f7a2967d00 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -233,14 +233,11 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
params.function_library = lib;
params.create_kernel = [session, lib, opseg](const NodeDef& ndef,
OpKernel** kernel) {
- // We do not share the kernel via the OpSegment if the node is
- // stateless, or a function.
// NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not.
- if (!lib->IsStateful(ndef.op()) ||
- lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
+ if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
return lib->CreateKernel(ndef, kernel);
}
auto create_fn = [lib, &ndef](OpKernel** kernel) {
@@ -252,8 +249,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn);
};
params.delete_kernel = [lib](OpKernel* kernel) {
- // If the node is stateful, opseg owns it. Otherwise, delete it.
- if (kernel && !lib->IsStateful(kernel->type_string())) {
+ if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
delete kernel;
}
};
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index d979353d2f..a17959a448 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1294,6 +1294,18 @@ FunctionDef FunctionDefHelper::Create(
for (const auto& r : ret_def) {
fdef.mutable_ret()->insert({r.first, r.second});
}
+
+ auto* op_def_registry = OpRegistry::Global();
+ // Check if any op is stateful.
+ for (const auto& n : node_def) {
+ const OpDef* op_def = nullptr;
+ auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
+ // Lookup can fail if e.g. we are calling a function that was not yet
+ // defined. If it happens, conservatively assume the op is stateful.
+ if (!status.ok() || op_def->is_stateful()) {
+ fdef.mutable_signature()->set_is_stateful(true);
+ }
+ }
return fdef;
}
@@ -1355,6 +1367,7 @@ FunctionDef FunctionDefHelper::Define(const string& name,
strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
}
}
+ if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
}
// Returns
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index c5a4f661d2..d5c203d276 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -91,6 +91,40 @@ FunctionDef IsZero() {
});
}
+FunctionDef RandomUniform() {
+ const Tensor kZero = test::AsScalar<int64>(0);
+ const Tensor kTen = test::AsScalar<int64>(10);
+
+ return FDH::Define(
+ // Name
+ "RandomUniform",
+ // Args
+ {"x: T"},
+ // Return values
+ {"random_uniform: int64"},
+ // Attr def
+ {"T:{float, double, int32, int64, string}"},
+ {{{"random_uniform/shape"},
+ "Const",
+ {},
+ {{"value", kZero}, {"dtype", DT_INT64}}},
+ {{"random_uniform/min"},
+ "Const",
+ {},
+ {{"value", kZero}, {"dtype", DT_INT64}}},
+ {{"random_uniform/max"},
+ "Const",
+ {},
+ {{"value", kTen}, {"dtype", DT_INT64}}},
+ {{"random_uniform"},
+ "RandomUniformInt",
+ {},
+ {{"T", DT_INT64},
+ {"Tout", DT_INT64},
+ {"seed", 87654321},
+ {"seed2", 42}}}});
+}
+
FunctionDef XTimesTwo() {
const Tensor kTwo = test::AsScalar<int64>(2);
return FDH::Define(
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
index ad61a76f16..a01743423b 100644
--- a/tensorflow/core/framework/function_testlib.h
+++ b/tensorflow/core/framework/function_testlib.h
@@ -84,6 +84,9 @@ FunctionDef NonZero();
// x: T -> bool.
FunctionDef IsZero();
+// x: T -> int64
+FunctionDef RandomUniform();
+
// x:T, y:T -> y:T, x:T
FunctionDef Swap();
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index c694e10193..80f2b12987 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -80,10 +81,8 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
// OpKernel ------------------------------------------------------------------
-// TODO(mrry): Convert to std::make_unique when available.
OpKernel::OpKernel(OpKernelConstruction* context)
- : OpKernel(context,
- std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {}
+ : OpKernel(context, MakeUnique<const NodeDef>(context->def())) {}
OpKernel::OpKernel(OpKernelConstruction* context,
std::unique_ptr<const NodeDef> node_def)
@@ -525,10 +524,8 @@ std::unique_ptr<Tensor> OpKernelContext::forward_input(
return nullptr;
}
}
- // TODO(rmlarsen): Use MakeUnique here. There is already a copy in
- // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of
- // general cleanup of ownership in this code.
- std::unique_ptr<Tensor> output_tensor(new Tensor());
+
+ auto output_tensor = MakeUnique<Tensor>();
CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
return output_tensor;
}
diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc
index dfc5aa7747..75ed4a4eaf 100644
--- a/tensorflow/core/framework/op_segment.cc
+++ b/tensorflow/core/framework/op_segment.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -99,4 +100,11 @@ void OpSegment::RemoveHold(const string& session_handle) {
delete item;
}
+bool OpSegment::ShouldOwnKernel(FunctionLibraryRuntime* lib,
+ const string& node_op) {
+ // OpSegment should not own kernel if the node is stateless, or a function.
+ return lib->IsStateful(node_op) &&
+ lib->GetFunctionLibraryDefinition()->Find(node_op) == nullptr;
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h
index 4433a2554f..37d939ea2b 100644
--- a/tensorflow/core/framework/op_segment.h
+++ b/tensorflow/core/framework/op_segment.h
@@ -60,6 +60,10 @@ class OpSegment {
Status FindOrCreate(const string& session_handle, const string& node_name,
OpKernel** kernel, CreateKernelFn create_fn);
+ // Returns true if OpSegment should own the kernel.
+ static bool ShouldOwnKernel(FunctionLibraryRuntime* lib,
+ const string& node_op);
+
private:
// op name -> OpKernel
typedef std::unordered_map<string, OpKernel*> KernelMap;
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 530c957068..e84df10778 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -19,7 +19,6 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
- "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
@@ -56,8 +55,8 @@ cc_library(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
- "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:functional_ops",
+ "//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core:lib_internal",
] + tf_protos_all(),
@@ -107,7 +106,6 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
- "//tensorflow/core/kernels:cast_op",
],
)
@@ -164,7 +162,6 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
- "//tensorflow/core/kernels:cast_op", # Must be linked for the testlib functions to work.
],
)
@@ -256,7 +253,6 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
- "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
@@ -275,6 +271,43 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/kernels:control_flow_ops",
+ ],
+)
+
+cc_library(
+ name = "map_parallelization",
+ srcs = ["map_parallelization.cc"],
+ hdrs = [
+ "map_parallelization.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/utils:topological_sort",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "map_parallelization_test",
+ srcs = ["map_parallelization_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ ":map_parallelization",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
],
)
@@ -355,6 +388,7 @@ cc_library(
":map_and_batch_fusion",
":map_and_filter_fusion",
":map_fusion",
+ ":map_parallelization",
":map_vectorization",
":noop_elimination",
":shuffle_and_repeat_fusion",
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
new file mode 100644
index 0000000000..305325e434
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
@@ -0,0 +1,106 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h"
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+bool CanParallelize(const FunctionDef& function,
+ const FunctionLibraryDefinition& library) {
+ if (!function.signature().is_stateful()) return true;
+
+ for (const auto& node : function.node_def()) {
+ const OpDef* op_def;
+ TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def));
+ // Assert is marked as stateful, but it does not have any state (except
+ // changing io). Similarly to CUDA, we do not give guarantee that the
+ // assert operation that would fail would be the first one, so that we can
+ // parallelize it.
+ if (op_def->is_stateful() && op_def->name() != "Assert") return false;
+ }
+
+ return true;
+}
+
+NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) {
+ NodeDef parallel_map = map_node;
+ graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(),
+ &parallel_map);
+ parallel_map.set_op("ParallelMapDataset");
+ // TODO(b/114475558): We want to set `num_parallel_calls` to a special value,
+ // so that dynamic tunning will pick the optimal value at runtime. Because
+ // this feature is not yet implemented, we set it to 2, which is the smallest
+ // value that introduces parallelism.
+ auto* num_parallel_calls = graph_utils::AddScalarConstNode(2, graph);
+ parallel_map.add_input(num_parallel_calls->name());
+
+ return parallel_map;
+}
+
+} // namespace
+
+Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ item.graph.library());
+ auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
+ if (node.op() == "MapDataset") return &node;
+ return nullptr;
+ };
+
+ for (const NodeDef& node : item.graph.node()) {
+ const NodeDef* map_node = get_map_node(node);
+ if (!map_node) continue;
+
+ auto* function =
+ function_library.Find(map_node->attr().at("f").func().name());
+ if (!CanParallelize(*function, function_library)) continue;
+
+ auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph));
+ graph.ReplaceInput(*map_node, *parallel_map);
+
+ // TODO(prazek): we could also remove map functions from library if they
+ // are not used anymore.
+ nodes_to_delete.insert(map_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void MapParallelization::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(MapParallelization, "map_parallelization");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.h b/tensorflow/core/grappler/optimizers/data/map_parallelization.h
new file mode 100644
index 0000000000..ac9cf7e12a
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization parallelizes MapDataset when function is stateless.
+class MapParallelization : public CustomGraphOptimizer {
+ public:
+ MapParallelization() = default;
+ ~MapParallelization() override = default;
+
+ string name() const override { return "map_parallelization"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
new file mode 100644
index 0000000000..b2a5d9b6af
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
@@ -0,0 +1,94 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "MapDataset", {string(input_node_name)},
+ {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", {}},
+ {"output_types", {}}});
+}
+
+const char stateless_fun_name[] = "XTimesTwo";
+const char stateful_fun_name[] = "RandomUniform";
+
+TEST(MapParallelizationTest, ParallelizeSimpleMap) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeMapNode("map1", "range", stateless_fun_name)},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+
+ MapParallelization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
+TEST(MapParallelization, ParallelizeAssert) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeMapNode("map1", "range", stateful_fun_name),
+ MakeMapNode("map2", "map1", stateless_fun_name),
+ NDef("cache", "CacheDataset", {"map2", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ test::function::RandomUniform(),
+ });
+
+ MapParallelization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output));
+ EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow