aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar Piotr Padlewski <prazek@google.com>2018-09-26 10:09:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 10:13:29 -0700
commit00ae12ad8bf5c348e4c31448e3922cbaab54cc03 (patch)
treeacf191a108f56683a16212ccd78fd6000927f3c2 /tensorflow/core/grappler
parent23a07f2c1444509986eece54e486cdcf0b8e32e4 (diff)
Hoisting RandomUniform out of functions
This patch introduces optimization that hoists RandomUniform out of map functions. By doing it, we make function stateless, which is crucial for parallelization and vectorization. PiperOrigin-RevId: 214623178
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD65
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.cc49
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.h36
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc15
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h10
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc289
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h55
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc84
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc14
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc21
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion.cc30
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion_test.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc2
19 files changed, 651 insertions, 87 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index cf305cebe1..d42a560cb2 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -22,6 +22,7 @@ cc_library(
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -31,6 +32,7 @@ tf_cc_test(
visibility = ["//visibility:public"],
deps = [
":filter_fusion",
+ ":graph_test_utils",
":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:test",
@@ -146,6 +148,62 @@ tf_cc_test(
)
cc_library(
+ name = "graph_test_utils",
+ testonly = 1,
+ srcs = ["graph_test_utils.cc"],
+ hdrs = [
+ "graph_test_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core:testlib",
+ ] + tf_protos_all(),
+)
+
+cc_library(
+ name = "hoist_random_uniform",
+ srcs = ["hoist_random_uniform.cc"],
+ hdrs = [
+ "hoist_random_uniform.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "hoist_random_uniform_test",
+ srcs = ["hoist_random_uniform_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_test_utils",
+ ":graph_utils",
+ ":hoist_random_uniform",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ ] + tf_protos_all(),
+)
+
+cc_library(
name = "latency_all_edges",
srcs = ["latency_all_edges.cc"],
hdrs = [
@@ -256,7 +314,7 @@ cc_library(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
- "//tensorflow/core:ptr_util",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -265,6 +323,7 @@ tf_cc_test(
srcs = ["map_and_filter_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_and_filter_fusion",
"//tensorflow/core:framework",
@@ -294,6 +353,7 @@ cc_library(
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -302,6 +362,7 @@ tf_cc_test(
srcs = ["map_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_fusion",
"//tensorflow/core:framework",
@@ -339,6 +400,7 @@ tf_cc_test(
srcs = ["map_parallelization_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_parallelization",
"//tensorflow/core:framework",
@@ -422,6 +484,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":filter_fusion",
+ ":hoist_random_uniform",
":latency_all_edges",
":map_and_batch_fusion",
":map_and_filter_fusion",
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
index c71aa6e804..1ad495bbad 100644
--- a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
@@ -43,19 +43,14 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node,
fused_node.set_op("FilterDataset");
fused_node.add_input(first_filter_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = first_filter_node.attr().at("predicate");
*attr.mutable_func()->mutable_name() = fused_function.signature().name();
(*fused_node.mutable_attr())["predicate"] = std::move(attr);
- copy_attribute("Targuments", first_filter_node, &fused_node);
+ graph_utils::CopyAttribute("Targuments", first_filter_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, second_filter_node, &fused_node);
+ graph_utils::CopyAttribute(key, second_filter_node, &fused_node);
return fused_node;
}
@@ -120,8 +115,8 @@ Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// functions, or make sure that optimization passes run after filter
// fusion.
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_predicate));
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
nodes_to_delete.insert(first_filter_node->name());
nodes_to_delete.insert(second_filter_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
index 12b1924efd..c8becc5cc0 100644
--- a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -28,14 +28,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "FilterDataset", {string(input_node_name)},
- {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeFilterNode;
TEST(FilterFusionTest, FuseTwoFilterIntoOne) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
new file mode 100644
index 0000000000..b2eec7220e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_tests_utils {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "MapDataset", {string(input_node_name)},
+ {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<DataType>{}}});
+}
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "FilterDataset", {string(input_node_name)},
+ {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<TensorShape>{}}});
+}
+
+} // end namespace graph_tests_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
new file mode 100644
index 0000000000..ca0fde997d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_tests_utils {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name = "XTimesTwo");
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name = "IsZero");
+
+} // end namespace graph_tests_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 2dd9ee822e..48825d0346 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -260,6 +260,21 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
}
function->mutable_signature()->set_name(std::move(name));
}
+
+void CopyAttribute(const string& attribute_name, const NodeDef& from,
+ NodeDef* to_node) {
+ (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
+}
+
+void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
+ const NodeDef& second, NodeDef* to_node) {
+ CopyAttribute(attribute_name, first, to_node);
+ (*to_node->mutable_attr())
+ .at(attribute_name)
+ .mutable_list()
+ ->MergeFrom(second.attr().at(attribute_name).list());
+}
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index b117482db2..189a72d255 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -106,6 +106,16 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function);
+// Copies attribute having name `attribute_name` from node `from` to node
+// `to_node`.
+void CopyAttribute(const string& attribute_name, const NodeDef& from,
+ NodeDef* to_node);
+
+// Concatenates list attribute having name `attribute_name` from `first` and
+// `second` node, setting it to `to_node`.
+void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
+ const NodeDef& second, NodeDef* to_node);
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
new file mode 100644
index 0000000000..ce0b2db039
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
@@ -0,0 +1,289 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node,
+ const FunctionDef& stateless_function,
+ MutableGraphView* graph) {
+ NodeDef stateless_map;
+ graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(),
+ &stateless_map);
+
+ stateless_map.set_op("MapDataset");
+ stateless_map.add_input(zip_node.name());
+ // Add placeholders.
+ for (int i = 1; i < map_node.input_size(); i++)
+ stateless_map.add_input(map_node.input(i));
+
+ auto attr = map_node.attr().at("f");
+ *attr.mutable_func()->mutable_name() = stateless_function.signature().name();
+ *attr.mutable_func()->mutable_attr() = stateless_function.attr();
+ (*stateless_map.mutable_attr())["f"] = std::move(attr);
+
+ graph_utils::CopyAttribute("Targuments", map_node, &stateless_map);
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::CopyAttribute(key, map_node, &stateless_map);
+
+ if (const auto* attr =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
+ (*stateless_map.mutable_attr())["use_inter_op_parallelism"] = *attr;
+
+ return stateless_map;
+}
+
+NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
+ MutableGraphView* graph) {
+ NodeDef random_dataset;
+ random_dataset.set_op("RandomDataset");
+ graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(),
+ &random_dataset);
+
+ const auto* seed = graph_utils::AddScalarConstNode<int64>(
+ random_uniform_node.attr().at("seed").i(), graph);
+ const auto* seed2 = graph_utils::AddScalarConstNode<int64>(
+ random_uniform_node.attr().at("seed2").i(), graph);
+
+ random_dataset.add_input(seed->name());
+ random_dataset.add_input(seed2->name());
+
+ (*random_dataset.mutable_attr())["output_shapes"].mutable_list()->add_shape();
+ (*random_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
+ DT_INT64);
+
+ return random_dataset;
+}
+
+NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
+ NodeDef batch_dataset;
+ batch_dataset.set_op("BatchDatasetV2");
+ graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(),
+ &batch_dataset);
+ const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph);
+ const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph);
+ batch_dataset.add_input(random_dataset.name());
+ batch_dataset.add_input(batch_size->name());
+ batch_dataset.add_input(drop_reminder->name());
+
+ (*batch_dataset.mutable_attr())["output_shapes"]
+ .mutable_list()
+ ->add_shape()
+ ->mutable_dim()
+ ->Add()
+ ->set_size(-1);
+ (*batch_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
+ DT_INT64);
+
+ return batch_dataset;
+}
+
+NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
+ MutableGraphView* graph) {
+ NodeDef zip_node;
+ graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(),
+ &zip_node);
+
+ zip_node.set_op("ZipDataset");
+ zip_node.add_input(first_node.name());
+ zip_node.add_input(second_node.name());
+
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::ConcatAttributeList(key, first_node, second_node, &zip_node);
+
+ (*zip_node.mutable_attr())["N"].set_i(2);
+
+ return zip_node;
+}
+
+// We need to insert our argument before the placeholders, which are the last
+// arguments.
+OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) {
+ int new_argument_idx = signature->input_arg_size() - num_placeholders;
+ signature->add_input_arg();
+ for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) {
+ signature->mutable_input_arg()->SwapElements(i - 1, i);
+ }
+ auto* seed_arg = signature->mutable_input_arg(new_argument_idx);
+ seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx));
+ seed_arg->set_type(DT_INT64);
+
+ return seed_arg;
+}
+
+// Make function that uses `StatelessRandomUniform` instead of `RandomUniform`
+// to make it less statefull. The function can still be stateful, but in when
+// other stateful ops are e.g. `Assert`, then it will be parallelizable.
+const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function,
+ bool is_stateful,
+ int num_placeholders,
+ FunctionDefLibrary* library) {
+ FunctionDef* stateless_function = library->add_function();
+ *stateless_function = map_function;
+ if (is_stateful)
+ stateless_function->mutable_signature()->set_is_stateful(is_stateful);
+ graph_utils::SetUniqueGraphFunctionName("stateless_function", library,
+ stateless_function);
+
+ auto* seed_arg = InsertSeedArgument(stateless_function->mutable_signature(),
+ num_placeholders);
+
+ auto* const random_uniform = stateless_function->mutable_node_def(
+ function_utils::FindFunctionNodeWithOp("RandomUniform",
+ *stateless_function));
+
+ // Replace RandomUniform node with StatelessRandomUniform.
+ random_uniform->set_op("StatelessRandomUniform");
+ random_uniform->add_input(seed_arg->name());
+ (*random_uniform->mutable_attr())["Tseed"].set_type(DT_INT64);
+ random_uniform->mutable_attr()->erase("seed");
+ random_uniform->mutable_attr()->erase("seed2");
+
+ return stateless_function;
+}
+// This function returns true if function is stateful and has single
+// RandomUniform op and no other stateful ops except Assert.
+// `is_stateful_after_hoisting` is set to true if RandomUniform is the only
+// stateful op and hoisting can be performed.
+bool CanHoistRandomUniform(const FunctionDef& map_function,
+ const FunctionLibraryDefinition& library,
+ bool* is_stateful_after_hoisting,
+ const NodeDef** random_uniform_op) {
+ if (!map_function.signature().is_stateful()) return false;
+ *is_stateful_after_hoisting = true;
+
+ bool have_other_stateful_ops = false;
+
+ for (const auto& node : map_function.node_def()) {
+ const OpDef* op_def;
+ TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def));
+ // Skip stateless nodes and assert, as it does not actually have a state.
+ if (!op_def->is_stateful()) continue;
+
+ if (op_def->name() == "Assert") {
+ have_other_stateful_ops = true;
+ continue;
+ }
+
+ // TODO(prazek): For now we only handle RandomUniform, we should handle
+ // RandomUniformInt as well.
+ if (op_def->name() != "RandomUniform") return false;
+
+ // TODO(prazek): For now we can only hoist single RandomUniform.
+ if (*random_uniform_op != nullptr) return false;
+
+ *random_uniform_op = &node;
+ }
+
+ if (!have_other_stateful_ops) *is_stateful_after_hoisting = false;
+
+ // Have we found single RandomUniform?
+ return *random_uniform_op != nullptr;
+}
+
+int NumberOfPlaceholders(const NodeDef& map_node) {
+ // First input of MapDataset is the argument to the function. Rest of the
+ // inputs are placeholders.
+ return map_node.input_size() - 1;
+}
+
+} // namespace
+
+Status HoistRandomUniform::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ item.graph.library());
+
+ auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
+ // TODO(prazek): we could also handle ParallelMapDataset and
+ // MapAndBatchDataset.
+ if (node.op() == "MapDataset") return &node;
+ return nullptr;
+ };
+
+ for (const NodeDef& node : item.graph.node()) {
+ const NodeDef* map_node = get_map_node(node);
+ if (!map_node) continue;
+
+ const auto& fun = map_node->attr().at("f");
+ const FunctionDef* func = function_library.Find(fun.func().name());
+
+ const NodeDef* random_uniform_op = nullptr;
+ bool is_stateful_after_hoisting = true;
+ if (!CanHoistRandomUniform(*func, function_library,
+ &is_stateful_after_hoisting, &random_uniform_op))
+ continue;
+ const auto* random_seed_dataset =
+ graph.AddNode(MakeRandomDataset(*random_uniform_op, &graph));
+
+ const auto* batch_dataset =
+ graph.AddNode(MakeBatchTwo(*random_seed_dataset, &graph));
+
+ const NodeDef& parent_node = *graph_utils::GetInputNode(*map_node, graph);
+
+ const auto* zip_node =
+ graph.AddNode(MakeZipNode(parent_node, *batch_dataset, &graph));
+
+ const auto* stateless_func = MakeLessStatefulFunction(
+ *func, is_stateful_after_hoisting, NumberOfPlaceholders(*map_node),
+ output->mutable_library());
+
+ const auto* stateless_map = graph.AddNode(
+ MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph));
+
+ graph.ReplaceInput(*map_node, *stateless_map);
+
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
+ nodes_to_delete.insert(map_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void HoistRandomUniform::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(HoistRandomUniform, "hoist_random_uniform");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
new file mode 100644
index 0000000000..d1bcf6782d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization hoists instances of `random_uniform` out of a function
+// with the aim of making it stateless. It creates a new function that takes a
+// random seed as an extra argument and uses `stateless_random_uniform` instead
+// of `random_uniform` to make it stateless.
+// It also creates RandomDataset(seed).batch(2), which is zipped with old input
+// to the map. The batching in RandomDataset is because we need 2 seeds for
+// `stateless_random_uniform`.
+// TODO(prazek): for now only `RandomUniform` is handled, but we could handle
+// `RandomUniformInt` similarly.
+class HoistRandomUniform : public CustomGraphOptimizer {
+ public:
+ HoistRandomUniform() = default;
+ ~HoistRandomUniform() override = default;
+
+ string name() const override { return "hoist_random_uniform"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
new file mode 100644
index 0000000000..455459e3f6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(HoistRandomUniform, SimpleHoisting) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<DataType>{}}}),
+ graph_tests_utils::MakeMapNode("map1", "range", "RandomUniform"),
+ NDef("cache", "CacheDataset", {"map1", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::RandomUniform(),
+ });
+
+ HoistRandomUniform optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ const int new_map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output);
+ const int zip_dataset_id =
+ graph_utils::FindGraphNodeWithOp("ZipDataset", output);
+ const int random_dataset_id =
+ graph_utils::FindGraphNodeWithOp("RandomDataset", output);
+ const int batch_random_id =
+ graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output);
+ ASSERT_NE(random_dataset_id, -1);
+ ASSERT_NE(zip_dataset_id, -1);
+ ASSERT_NE(new_map_id, -1);
+ ASSERT_NE(batch_random_id, -1);
+
+ const auto& new_map = output.node(new_map_id);
+ const auto& zip = output.node(zip_dataset_id);
+ const auto& random = output.node(random_dataset_id);
+ const auto& batch = output.node(batch_random_id);
+
+ ASSERT_EQ(new_map.input_size(), 1);
+ EXPECT_EQ(new_map.input(0), zip.name());
+
+ ASSERT_EQ(zip.input_size(), 2);
+ EXPECT_EQ(zip.input(0), "range");
+ EXPECT_EQ(zip.input(1), batch.name());
+
+ ASSERT_EQ(batch.input_size(), 3);
+ EXPECT_EQ(batch.input(0), random.name());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 63945b8b9e..e66766eb23 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -80,11 +80,12 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
// Set `f` and `Targuments` attributes.
for (auto key : {"f", "Targuments"}) {
- (*new_node.mutable_attr())[key] = map_node.attr().at(key);
+ graph_utils::CopyAttribute(key, map_node, &new_node);
}
+
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node.mutable_attr())[key] = batch_node.attr().at(key);
+ graph_utils::CopyAttribute(key, batch_node, &new_node);
}
return new_node;
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
index f1844a141c..c4868eacbb 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -41,19 +42,18 @@ NodeDef MakeFusedNode(const NodeDef& map_node,
fused_node.set_op("MapDataset");
fused_node.add_input(map_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = map_node.attr().at("f");
attr.mutable_func()->set_name(fused_function.signature().name());
(*fused_node.mutable_attr())["f"] = std::move(attr);
- copy_attribute("Targuments", map_node, &fused_node);
+ graph_utils::CopyAttribute("Targuments", map_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, map_node, &fused_node);
+ graph_utils::CopyAttribute(key, map_node, &fused_node);
+
+ if (const auto* attr =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
+ (*fused_node.mutable_attr())["use_inter_op_parallelism"] = *attr;
// Add the predicate output attributes.
(*fused_node.mutable_attr())["output_types"]
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
index f029a093fa..6e6da37d7c 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -27,24 +28,8 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
namespace {
-
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
-
-NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "FilterDataset", {string(input_node_name)},
- {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeFilterNode;
+using graph_tests_utils::MakeMapNode;
TEST(MapAndFilterFusionTest, FuseMapAndFilter) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
index a78ecb09f7..bd943342e8 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -40,24 +41,31 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node,
NodeDef fused_node;
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(),
&fused_node);
-
fused_node.set_op("MapDataset");
fused_node.add_input(parent_map_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = parent_map_node.attr().at("f");
*attr.mutable_func()->mutable_name() = fused_function.signature().name();
(*fused_node.mutable_attr())["f"] = std::move(attr);
- copy_attribute("Targuments", parent_map_node, &fused_node);
-
+ graph_utils::CopyAttribute("Targuments", parent_map_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, map_node, &fused_node);
+ graph_utils::CopyAttribute(key, map_node, &fused_node);
+ auto value_or_false = [](const AttrValue* attr) {
+ if (!attr) return false;
+ return attr->b();
+ };
+
+ const auto* first_parallelism =
+ gtl::FindOrNull(parent_map_node.attr(), "use_inter_op_parallelism");
+ const auto* second_parallelism =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism");
+ // Some graphs cannot execute with use_inter_op_parallelism=False, so we need
+ // to set it to true if one of the ops have it set to true.
+ if (value_or_false(first_parallelism) || value_or_false(second_parallelism)) {
+ (*fused_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
+ }
return fused_node;
}
@@ -123,8 +131,8 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// fusion.
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function));
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
nodes_to_delete.insert(parent_map_node->name());
nodes_to_delete.insert(map_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
index b25dfbd0b8..8889f9dddd 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -28,14 +29,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeMapNode;
TEST(MapFusionTest, FuseTwoMapNodesIntoOne) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
index 305325e434..782c9f48b7 100644
--- a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
@@ -84,9 +84,6 @@ Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item,
auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph));
graph.ReplaceInput(*map_node, *parallel_map);
-
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
nodes_to_delete.insert(map_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
index b2a5d9b6af..9fdfe8af30 100644
--- a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -28,16 +28,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
- StringPiece function_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
-
+using graph_tests_utils::MakeMapNode;
const char stateless_fun_name[] = "XTimesTwo";
const char stateful_fun_name[] = "RandomUniform";
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 7a2f1910da..32ab912619 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -35,10 +35,6 @@ namespace tensorflow {
namespace grappler {
namespace {
-void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) {
- (*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
-}
-
// Returns a FunctionDef containing a MapDefun op that wraps the original
// function.
FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
@@ -61,7 +57,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
for (const string& k : {"f", "output_types", "output_shapes"}) {
// Function, output types and (unbatched) shapes are the same as the
// original map node.
- CopyAttribute(k, map_node, map_defun_node);
+ graph_utils::CopyAttribute(k, map_node, map_defun_node);
}
// Get types of input arguments from original map function
@@ -195,13 +191,16 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node,
}
// Set attrs
- CopyAttribute("Targuments", old_map_node, &map_node);
+ graph_utils::CopyAttribute("Targuments", old_map_node, &map_node);
auto& func_attr = (*map_node.mutable_attr())["f"];
func_attr.mutable_func()->set_name(vectorized_func.signature().name());
for (auto key : {"output_shapes", "output_types"}) {
- CopyAttribute(key, old_batch_node, &map_node);
+ graph_utils::CopyAttribute(key, old_batch_node, &map_node);
}
+
+ (*map_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
+
return map_node;
}
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
index cb0ff670e8..99c4afa634 100644
--- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
@@ -64,7 +64,7 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node.mutable_attr())[key] = repeat_node.attr().at(key);
+ graph_utils::CopyAttribute(key, repeat_node, &new_node);
}
return new_node;
};