aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data')
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD25
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion.cc141
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion.h (renamed from tensorflow/core/grappler/optimizers/data/function_rename.h)15
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc91
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_rename.cc51
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_rename_test.cc42
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils.cc166
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils.h57
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc22
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc14
-rw-r--r--tensorflow/core/grappler/optimizers/data/latency_all_edges.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion.cc12
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion_test.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc23
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc4
20 files changed, 503 insertions, 207 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 74d936cfbc..530c957068 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -4,36 +4,41 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
cc_library(
- name = "function_rename",
- srcs = ["function_rename.cc"],
+ name = "filter_fusion",
+ srcs = ["filter_fusion.cc"],
hdrs = [
- "function_rename.h",
+ "filter_fusion.h",
],
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ ":fusion_utils",
+ "//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core:lib",
- "//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
] + tf_protos_all(),
)
tf_cc_test(
- name = "function_rename_test",
- srcs = ["function_rename_test.cc"],
+ name = "filter_fusion_test",
+ srcs = ["filter_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
- ":function_rename",
+ ":filter_fusion",
+ ":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
- ] + tf_protos_all(),
+ ],
)
cc_library(
@@ -46,11 +51,13 @@ cc_library(
deps = [
":graph_utils",
"//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/kernels:cast_op",
+ "//tensorflow/core/kernels:functional_ops",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core:lib_internal",
] + tf_protos_all(),
@@ -343,7 +350,7 @@ cc_library(
name = "data",
visibility = ["//visibility:public"],
deps = [
- ":function_rename",
+ ":filter_fusion",
":latency_all_edges",
":map_and_batch_fusion",
":map_and_filter_fusion",
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
new file mode 100644
index 0000000000..c71aa6e804
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
@@ -0,0 +1,141 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/filter_fusion.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node,
+ const NodeDef& second_filter_node,
+ const FunctionDef& fused_function,
+ MutableGraphView* graph) {
+ NodeDef fused_node;
+ graph_utils::SetUniqueGraphNodeName("fused_filter", graph->GetGraph(),
+ &fused_node);
+
+ fused_node.set_op("FilterDataset");
+ fused_node.add_input(first_filter_node.input(0));
+
+ auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
+ NodeDef* to) {
+ (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
+ };
+
+ auto attr = first_filter_node.attr().at("predicate");
+ *attr.mutable_func()->mutable_name() = fused_function.signature().name();
+ (*fused_node.mutable_attr())["predicate"] = std::move(attr);
+
+ copy_attribute("Targuments", first_filter_node, &fused_node);
+
+ for (auto key : {"output_shapes", "output_types"})
+ copy_attribute(key, second_filter_node, &fused_node);
+
+ return fused_node;
+}
+
+} // namespace
+
+Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ GraphDef sorted_old_graph = item.graph;
+ TF_RETURN_IF_ERROR(TopologicalSort(&sorted_old_graph));
+ *output = sorted_old_graph;
+
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ output->library());
+
+ auto get_filter_node = [](const NodeDef& node) -> const NodeDef* {
+ if (node.op() == "FilterDataset") return &node;
+ return nullptr;
+ };
+
+ auto get_fused_predicate =
+ [&](const NodeDef* first_filter_node,
+ const NodeDef* second_filter_node) -> FunctionDef* {
+ const auto& parent_fun = first_filter_node->attr().at("predicate");
+ const FunctionDef* first_func =
+ function_library.Find(parent_fun.func().name());
+ const auto& fun = second_filter_node->attr().at("predicate");
+ const FunctionDef* second_func = function_library.Find(fun.func().name());
+
+ if (!fusion_utils::HasSameSignature(first_func->signature(),
+ second_func->signature())) {
+ VLOG(1) << "Can't fuse Filters because they have different signature\n";
+ return nullptr;
+ }
+
+ return fusion_utils::FuseFunctions(
+ *first_func, *second_func, "fused_predicate",
+ fusion_utils::SameSignature, fusion_utils::SameInput,
+ fusion_utils::LazyConjunctionOutput, fusion_utils::LazyConjunctionNodes,
+ output->mutable_library());
+ };
+
+ for (const NodeDef& node : sorted_old_graph.node()) {
+ const NodeDef* second_filter_node = get_filter_node(node);
+ if (!second_filter_node) continue;
+
+ const NodeDef* first_filter_node =
+ get_filter_node(*graph_utils::GetInputNode(*second_filter_node, graph));
+ if (!first_filter_node) continue;
+
+ const auto* fused_predicate =
+ get_fused_predicate(first_filter_node, second_filter_node);
+ if (!fused_predicate) continue;
+ const auto* fused_filter_node = graph.AddNode(MakeFusedFilterNode(
+ *first_filter_node, *second_filter_node, *fused_predicate, &graph));
+
+ graph.ReplaceInput(*second_filter_node, *fused_filter_node);
+
+ // TODO(prazek): we should run some optimizations on the fused filter
+ // functions, or make sure that optimization passes run after filter
+ // fusion.
+ TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_predicate));
+ // TODO(prazek): we could also remove map functions from library if they
+ // are not used anymore.
+ nodes_to_delete.insert(first_filter_node->name());
+ nodes_to_delete.insert(second_filter_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void FilterFusion::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(FilterFusion, "filter_fusion");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename.h b/tensorflow/core/grappler/optimizers/data/filter_fusion.h
index 23ad9470ff..91a0364a46 100644
--- a/tensorflow/core/grappler/optimizers/data/function_rename.h
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.h
@@ -13,20 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
namespace tensorflow {
namespace grappler {
-class FunctionRename : public CustomGraphOptimizer {
+// This optimization fuses filter transformations.
+class FilterFusion : public CustomGraphOptimizer {
public:
- FunctionRename() = default;
- ~FunctionRename() override = default;
+ FilterFusion() = default;
+ ~FilterFusion() override = default;
- string name() const override { return "_test_only_function_rename"; };
+ string name() const override { return "filter_fusion"; };
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
@@ -43,4 +44,4 @@ class FunctionRename : public CustomGraphOptimizer {
} // end namespace grappler
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
new file mode 100644
index 0000000000..12b1924efd
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
@@ -0,0 +1,91 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/filter_fusion.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
+ return test::function::NDef(
+ name, "FilterDataset", {string(input_node_name)},
+ {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
+ {"Targuments", {}},
+ {"output_shapes", {}},
+ {"output_types", {}}});
+}
+
+TEST(FilterFusionTest, FuseTwoFilterIntoOne) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeFilterNode("filter1", "range"),
+ MakeFilterNode("filter2", "filter1")},
+ // FunctionLib
+ {
+ test::function::IsZero(),
+ });
+
+ FilterFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("FilterDataset", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter2", output));
+}
+
+TEST(FilterFusionTest, FuseThreeNodesIntoOne) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeFilterNode("filter1", "range"), MakeFilterNode("filter2", "filter1"),
+ MakeFilterNode("filter3", "filter2"),
+ NDef("cache", "CacheDataset", {"filter3", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::IsZero(),
+ });
+
+ FilterFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("FilterDataset", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter2", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter3", output));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename.cc b/tensorflow/core/grappler/optimizers/data/function_rename.cc
deleted file mode 100644
index 8cf044d1bd..0000000000
--- a/tensorflow/core/grappler/optimizers/data/function_rename.cc
+++ /dev/null
@@ -1,51 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/function_rename.h"
-
-#include "tensorflow/core/grappler/clusters/cluster.h"
-#include "tensorflow/core/grappler/graph_view.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/op_types.h"
-#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/platform/protobuf.h"
-
-namespace tensorflow {
-namespace grappler {
-
-Status FunctionRename::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* output) {
- *output = item.graph;
- GraphView graph(output);
- int n = output->mutable_library()->function_size();
- for (int i = 0; i < n; ++i) {
- FunctionDef* fn = output->mutable_library()->mutable_function(i);
- fn->mutable_signature()->set_name(fn->signature().name() + "world");
- }
-
- return Status::OK();
-}
-
-void FunctionRename::Feedback(Cluster* cluster, const GrapplerItem& item,
- const GraphDef& optimize_output, double result) {
- // no-op
-}
-
-REGISTER_GRAPH_OPTIMIZER_AS(FunctionRename, "_test_only_function_rename");
-
-} // end namespace grappler
-} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename_test.cc b/tensorflow/core/grappler/optimizers/data/function_rename_test.cc
deleted file mode 100644
index 56b8a960a7..0000000000
--- a/tensorflow/core/grappler/optimizers/data/function_rename_test.cc
+++ /dev/null
@@ -1,42 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/function_rename.h"
-
-#include "tensorflow/core/framework/function.pb.h"
-#include "tensorflow/core/framework/op_def.pb.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-TEST(FunctionRenameTest, RenameFunction) {
- GrapplerItem item;
- GraphDef *graph = &item.graph;
- FunctionDef *fn = graph->mutable_library()->add_function();
- fn->mutable_signature()->set_name("hello");
-
- FunctionRename optimizer;
- GraphDef output;
- TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
- EXPECT_EQ(output.library().function(0).signature().name(), "helloworld");
-}
-
-} // namespace
-} // namespace grappler
-} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
index f84f109af6..01a78c04b0 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_def.pb.h"
-
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
@@ -52,6 +52,12 @@ string GetOutputNode(const FunctionDef& function, int output_idx) {
return function.ret().at(ret_output_name);
}
+string& GetMutableOutputNode(FunctionDef* function, int output_idx) {
+ const auto& ret_output_name =
+ function->signature().output_arg(output_idx).name();
+ return function->mutable_ret()->at(ret_output_name);
+}
+
template <typename Iterable>
StringCollection GetNames(const Iterable& iterable, int allocate_size) {
StringCollection names;
@@ -106,7 +112,6 @@ gtl::FlatMap<string, string> GetUniqueNames(const Iterable& first_iterable,
// Nodes that will be added to the function can have the same name as the nodes
// from parent function.
void RenameFunctionNodes(const FunctionDef& first_function,
- FunctionDef* fused_function,
protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse,
protobuf::Map<string, string>* rets_to_fuse) {
const gtl::FlatMap<string, string> changed_node_names =
@@ -149,6 +154,7 @@ OpDef GetUniqueSignature(const OpDef& first_signature,
const gtl::FlatMap<string, string> changed_input_names =
GetUniqueNames(first_signature.input_arg(), second_signature.input_arg());
OpDef signature;
+ signature.set_name(second_signature.name());
for (const auto& input_arg : second_signature.input_arg()) {
auto& input = *signature.add_input_arg();
@@ -221,12 +227,13 @@ void FuseFunctionNodes(const StringCollection& first_inputs,
}
// This function looks for direct edges from input to return and rewrites
-// them to the coresponding input of the return of `first_function`.
+// them to the corresponding input of the return of `first_function`.
void FuseReturns(const StringCollection& first_inputs,
const StringCollection& second_inputs,
const StringCollection& first_outputs,
- const SetInputFn& set_input, FunctionDef* fused_function) {
- for (auto& ret : *fused_function->mutable_ret()) {
+ const SetInputFn& set_input,
+ protobuf::Map<string, string>* fused_ret) {
+ for (auto& ret : *fused_ret) {
auto return_input = ParseNodeConnection(ret.second);
auto input_it =
std::find(second_inputs.begin(), second_inputs.end(), return_input);
@@ -249,6 +256,33 @@ StringCollection GetFunctionOutputs(const FunctionDef& function) {
return outputs;
}
+FunctionDef* CreateFalsePredicate(
+ const protobuf::RepeatedPtrField<OpDef_ArgDef>& fake_args,
+ FunctionDefLibrary* library) {
+ GraphDef graph;
+ MutableGraphView graph_view(&graph);
+ auto* node = graph_utils::AddScalarConstNode(false, &graph_view);
+ auto* false_predicate = library->add_function();
+ graph_utils::SetUniqueGraphFunctionName("false_predicate", library,
+ false_predicate);
+
+ int num = 0;
+ for (const auto& fake_arg : fake_args) {
+ auto* arg = false_predicate->mutable_signature()->add_input_arg();
+ arg->set_type(fake_arg.type());
+ arg->set_name(strings::StrCat("fake_arg", num));
+ num++;
+ }
+
+ auto* output = false_predicate->mutable_signature()->add_output_arg();
+ output->set_name("false_out");
+ output->set_type(DT_BOOL);
+
+ (*false_predicate->mutable_ret())["false_out"] = node->name() + ":output:0";
+ *false_predicate->mutable_node_def() = std::move(*graph.mutable_node());
+ return false_predicate;
+}
+
void CheckIfCanCompose(const OpDef& first_signature,
const OpDef& second_signature) {
CHECK(CanCompose(first_signature, second_signature))
@@ -259,6 +293,15 @@ void CheckIfCanCompose(const OpDef& first_signature,
} // namespace
+void MergeNodes(const FunctionDef& first_function,
+ const FunctionDef& second_function, FunctionDef* fused_function,
+ FunctionDefLibrary* library) {
+ // Copy all nodes from first_function.
+ fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
+ // Copy transformed nodes from the second function.
+ fused_function->mutable_node_def()->MergeFrom(second_function.node_def());
+}
+
bool CanCompose(const OpDef& first_signature, const OpDef& second_signature) {
// TODO(prazek): Functions can have additional inputs being placeholders
// for a values used in function. We should be able to also fuse these
@@ -285,8 +328,8 @@ void ComposeSignature(const OpDef& first_signature,
void ComposeOutput(const protobuf::Map<string, string>& first_ret,
const protobuf::Map<string, string>& second_ret,
- FunctionDef* fused_function) {
- *fused_function->mutable_ret() = second_ret;
+ protobuf::Map<string, string>* fused_ret) {
+ *fused_ret = second_ret;
}
void CombineSignature(const OpDef& first_signature,
@@ -302,41 +345,110 @@ void CombineSignature(const OpDef& first_signature,
void CombineOutput(const protobuf::Map<string, string>& first_ret,
const protobuf::Map<string, string>& second_ret,
- FunctionDef* fused_function) {
- *fused_function->mutable_ret() = first_ret;
- fused_function->mutable_ret()->insert(second_ret.begin(), second_ret.end());
+ protobuf::Map<string, string>* fused_ret) {
+ *fused_ret = first_ret;
+ fused_ret->insert(second_ret.begin(), second_ret.end());
+}
+
+string SameInput(const StringCollection& first_inputs,
+ const StringCollection& second_inputs,
+ const StringCollection& first_outputs, int arg_num) {
+ return first_inputs.at(arg_num);
+}
+
+bool HasSameSignature(const OpDef& first_signature,
+ const OpDef& second_signature) {
+ return first_signature.input_arg_size() ==
+ second_signature.input_arg_size() &&
+ first_signature.output_arg_size() ==
+ second_signature.output_arg_size();
+}
+
+void SameSignature(const OpDef& first_signature, const OpDef& second_signature,
+ OpDef* fused_signature) {
+ CHECK(HasSameSignature(first_signature, second_signature))
+ << "Functions do not have the same signature";
+ // Copy signature from first function.
+ *fused_signature = first_signature;
+}
+
+void LazyConjunctionNodes(const FunctionDef& first_function,
+ const FunctionDef& second_function,
+ FunctionDef* fused_function,
+ FunctionDefLibrary* library) {
+ fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
+
+ NodeDefBuilder if_builder("", "If");
+ if_builder.Input(GetOutputNode(first_function, 0), 0, DT_BOOL);
+ DataTypeVector in_arg_types;
+ std::vector<NodeDefBuilder::NodeOut> inputs;
+ for (const auto& input_arg : first_function.signature().input_arg()) {
+ inputs.push_back({input_arg.name(), 0, input_arg.type()});
+ in_arg_types.push_back(input_arg.type());
+ }
+ if_builder.Attr("Tin", in_arg_types);
+
+ if_builder.Attr("Tcond", DT_BOOL);
+ if_builder.Attr("Tout", DataTypeVector{DT_BOOL});
+ if_builder.Attr("_lower_using_switch_merge", true);
+
+ NameAttrList then_branch;
+ then_branch.set_name(second_function.signature().name());
+ if_builder.Attr("then_branch", then_branch);
+
+ auto* false_predicate =
+ CreateFalsePredicate(first_function.signature().input_arg(), library);
+
+ NameAttrList else_branch;
+ else_branch.set_name(false_predicate->signature().name());
+ if_builder.Attr("else_branch", else_branch);
+ if_builder.Input(inputs);
+
+ auto* if_node = fused_function->add_node_def();
+ // This is guaranteed to succeed.
+ TF_CHECK_OK(if_builder.Finalize(if_node));
+ graph_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
+
+ GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
+}
+
+void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret,
+ const protobuf::Map<string, string>& second_ret,
+ protobuf::Map<string, string>* fused_ret) {
+ CHECK_EQ(first_ret.size(), 1);
+ CHECK_EQ(second_ret.size(), 1);
+ // Temporarily copy returns from first_ret. We are going to change the
+ // output node after creating it.
+ *fused_ret = first_ret;
}
-FunctionDef* FuseFunctions(const FunctionDef& first_function,
- const FunctionDef& function,
- StringPiece fused_name_prefix,
- const SetFunctionSignatureFn& set_signature,
- const SetInputFn& set_input,
- const SetOutputFn& set_output,
- FunctionDefLibrary* library) {
- if (first_function.attr_size() != 0 || function.attr_size() != 0)
+FunctionDef* FuseFunctions(
+ const FunctionDef& first_function, const FunctionDef& second_function,
+ StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
+ const SetInputFn& set_input, const SetOutputFn& set_output,
+ const SetNodesFn& set_nodes, FunctionDefLibrary* library) {
+ if (first_function.attr_size() != 0 || second_function.attr_size() != 0)
return nullptr; // Functions with attributes are currently not supported
// This function will be used as a clone of second function, having unique
// names.
- FunctionDef setup_function = function;
+ FunctionDef setup_function = second_function;
*setup_function.mutable_signature() = GetUniqueSignature(
first_function.signature(), setup_function.signature(),
setup_function.mutable_ret(), setup_function.mutable_node_def());
FunctionDef* fused_function = library->add_function();
- // Copy all nodes from first_function.
- fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
+
set_signature(first_function.signature(), setup_function.signature(),
fused_function->mutable_signature());
graph_utils::SetUniqueGraphFunctionName(fused_name_prefix, library,
fused_function);
- RenameFunctionNodes(first_function, fused_function,
- setup_function.mutable_node_def(),
+ RenameFunctionNodes(first_function, setup_function.mutable_node_def(),
setup_function.mutable_ret());
- set_output(first_function.ret(), setup_function.ret(), fused_function);
+ set_output(first_function.ret(), setup_function.ret(),
+ fused_function->mutable_ret());
CHECK(fused_function->signature().output_arg_size() ==
fused_function->ret_size())
@@ -351,10 +463,10 @@ FunctionDef* FuseFunctions(const FunctionDef& first_function,
FuseFunctionNodes(first_inputs, second_inputs, first_outputs, set_input,
setup_function.mutable_node_def());
FuseReturns(first_inputs, second_inputs, first_outputs, set_input,
- fused_function);
+ fused_function->mutable_ret());
+
+ set_nodes(first_function, setup_function, fused_function, library);
- // Copy transformed nodes from the second function.
- fused_function->mutable_node_def()->MergeFrom(setup_function.node_def());
return fused_function;
}
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.h b/tensorflow/core/grappler/optimizers/data/fusion_utils.h
index 41f13f6cb8..19b7002dcd 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.h
@@ -48,14 +48,20 @@ using SetInputFn =
const StringCollection& second_function_inputs,
const StringCollection& parent_outputs, int arg_num)>;
-// This function is invoked with first function ret. It is used to set up
-// returns of fused function. If you need to combine outputs
-// of first and second function, then this is a right place to create a new
-// nodes.
+// This function is invoked with first and second function ret. It is used to
+// set up returns of fused function.
using SetOutputFn =
std::function<void(const protobuf::Map<string, string>& parent_ret,
const protobuf::Map<string, string>& second_function_ret,
- FunctionDef* fused_function)>;
+ protobuf::Map<string, string>* fused_ret)>;
+
+using SetNodesFn = std::function<void(
+ const FunctionDef& first_function, const FunctionDef& second_function,
+ FunctionDef* fused_function, FunctionDefLibrary* library)>;
+
+void MergeNodes(const FunctionDef& first_function,
+ const FunctionDef& second_function, FunctionDef* fused_function,
+ FunctionDefLibrary* library);
// Returns true if functions can be composed.
bool CanCompose(const OpDef& first_signature, const OpDef& second_signature);
@@ -71,7 +77,7 @@ string ComposeInput(const StringCollection& first_inputs,
// second_function(first_function(args...)).
void ComposeOutput(const protobuf::Map<string, string>& first_ret,
const protobuf::Map<string, string>& second_ret,
- FunctionDef* fused_function);
+ protobuf::Map<string, string>* fused_ret);
// Set input signature to `first_function_signature` and output signature
// to `first_function_signature` + `second_function_signature`
@@ -83,7 +89,32 @@ void CombineSignature(const OpDef& first_signature,
// return *first_function(...), *second_function(...)
void CombineOutput(const protobuf::Map<string, string>& first_ret,
const protobuf::Map<string, string>& second_ret,
- FunctionDef* fused_function);
+ protobuf::Map<string, string>* fused_ret);
+
+// Returns true if both signatures have the same number of input and output
+// args.
+bool HasSameSignature(const OpDef& first_signature,
+ const OpDef& second_signature);
+
+// Check if both signatures are same and copy it from `first_signature`.
+void SameSignature(const OpDef& first_signature, const OpDef& second_signature,
+ OpDef* fused_signature);
+
+// Take the same input as first function.
+string SameInput(const StringCollection& first_inputs,
+ const StringCollection& second_inputs,
+ const StringCollection& first_outputs, int arg_num);
+
+// Create a fused function that computes the short-circuit logical AND of the
+// result of the first function and the result of the second function.
+void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret,
+ const protobuf::Map<string, string>& second_ret,
+ protobuf::Map<string, string>* fused_ret);
+
+void LazyConjunctionNodes(const FunctionDef& first_function,
+ const FunctionDef& second_function,
+ FunctionDef* fused_function,
+ FunctionDefLibrary* library);
// Fuse `first_function` with `second_function`, setting `fused_name_prefix` as
// a name prefix. The nodes from `first_function` are copied unmodified. All
@@ -91,13 +122,11 @@ void CombineOutput(const protobuf::Map<string, string>& first_ret,
// that are not conflicting with first function. This means that copied nodes
// from second function can end up having different names. For explanation of
// set up functions see the documentation of the functions types.
-FunctionDef* FuseFunctions(const FunctionDef& first_function,
- const FunctionDef& second_function,
- StringPiece fused_name_prefix,
- const SetFunctionSignatureFn& set_signature,
- const SetInputFn& set_input,
- const SetOutputFn& set_output,
- FunctionDefLibrary* library);
+FunctionDef* FuseFunctions(
+ const FunctionDef& first_function, const FunctionDef& second_function,
+ StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
+ const SetInputFn& set_input, const SetOutputFn& set_output,
+ const SetNodesFn& set_nodes, FunctionDefLibrary* library);
} // namespace fusion_utils
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
index 7ad5d63bf6..d5c6466080 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
@@ -57,10 +57,10 @@ TEST(FusionUtilsTest, FuseFunctionsByComposition) {
auto *function = graph.mutable_library()->add_function();
*function = test::function::XTimesTwo();
- auto *fused_function =
- FuseFunctions(*parent_function, *function, "fused_maps",
- fusion_utils::ComposeSignature, fusion_utils::ComposeInput,
- fusion_utils::ComposeOutput, graph.mutable_library());
+ auto *fused_function = FuseFunctions(
+ *parent_function, *function, "fused_maps", fusion_utils::ComposeSignature,
+ fusion_utils::ComposeInput, fusion_utils::ComposeOutput,
+ fusion_utils::MergeNodes, graph.mutable_library());
EXPECT_EQ(fused_function->signature().name(), "fused_maps");
EXPECT_EQ(fused_function->signature().input_arg_size(), 1);
@@ -98,7 +98,8 @@ TEST(FusionUtilsTest, FuseFunctionWithPredicate) {
auto *fused_function =
FuseFunctions(*xtimes_two, *is_zero, "fused_map_and_filter_function",
fusion_utils::CombineSignature, fusion_utils::ComposeInput,
- fusion_utils::CombineOutput, graph.mutable_library());
+ fusion_utils::CombineOutput, fusion_utils::MergeNodes,
+ graph.mutable_library());
EXPECT_EQ(fused_function->signature().name(),
"fused_map_and_filter_function");
@@ -134,10 +135,10 @@ TEST(FusionUtilsTest, FuseSameFunctionWithExtraOutput) {
auto *function = graph.mutable_library()->add_function();
*function = test::function::XTimesTwo();
- auto *fused_function =
- FuseFunctions(*parent_function, *function, "fused_maps",
- fusion_utils::CombineSignature, fusion_utils::ComposeInput,
- fusion_utils::CombineOutput, graph.mutable_library());
+ auto *fused_function = FuseFunctions(
+ *parent_function, *function, "fused_maps", fusion_utils::CombineSignature,
+ fusion_utils::ComposeInput, fusion_utils::CombineOutput,
+ fusion_utils::MergeNodes, graph.mutable_library());
EXPECT_EQ(fused_function->signature().input_arg_size(), 1);
EXPECT_EQ(fused_function->signature().output_arg_size(), 2);
@@ -169,7 +170,8 @@ TEST(FusionUtilsTest, ZipFusion) {
auto *fused_function =
FuseFunctions(*function, *function, "zip_maps", zip_signature, zip_input,
- fusion_utils::CombineOutput, graph.mutable_library());
+ fusion_utils::CombineOutput, fusion_utils::MergeNodes,
+ graph.mutable_library());
EXPECT_EQ(fused_function->signature().input_arg_size(), 2);
EXPECT_EQ(fused_function->signature().output_arg_size(), 2);
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 883037173b..5a7fe19265 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -94,11 +94,11 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
MutableGraphView* graph) {
NodeDef node;
if (!name.empty()) {
- node.set_name(name.ToString());
+ node.set_name(string(name));
} else {
SetUniqueGraphNodeName(op, graph->GetGraph(), &node);
}
- node.set_op(op.ToString());
+ node.set_op(string(op));
for (const string& input : inputs) {
node.add_input(input);
}
@@ -114,11 +114,11 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
FunctionDef* fd) {
NodeDef* node = fd->add_node_def();
if (!name.empty()) {
- node->set_name(name.ToString());
+ node->set_name(string(name));
} else {
SetUniqueFunctionNodeName(op, fd, node);
}
- node->set_op(op.ToString());
+ node->set_op(string(op));
for (const string& input : inputs) {
node->add_input(input);
}
@@ -270,7 +270,7 @@ NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
NodeDef* node) {
- string name = prefix.ToString();
+ string name = string(prefix);
int id = graph->node_size();
while (ContainsGraphNodeWithName(name, *graph)) {
if (name.rfind("_generated") != std::string::npos &&
@@ -286,7 +286,7 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
NodeDef* node) {
- string name = prefix.ToString();
+ string name = string(prefix);
int id = function->node_def_size();
while (ContainsFunctionNodeWithName(name, *function)) {
name = strings::StrCat(prefix, "/_", id);
@@ -297,7 +297,7 @@ void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function) {
- string name = prefix.ToString();
+ string name = string(prefix);
int id = library->function_size();
while (ContainsGraphFunctionWithName(name, *library)) {
name = strings::StrCat(prefix, "/_", id);
diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
index 0b25b1ea9d..9e382aeef9 100644
--- a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
+++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
@@ -33,7 +33,7 @@ namespace {
constexpr char kInsertOpName[] = "LatencyStatsDataset";
-NodeDef make_latency_node(const NodeDef& node, MutableGraphView* graph) {
+NodeDef MakeLatencyNode(const NodeDef& node, MutableGraphView* graph) {
NodeDef new_node;
new_node.set_op(kInsertOpName);
graph_utils::SetUniqueGraphNodeName(
@@ -96,7 +96,7 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item,
}
}
- graph.InsertNode(node, make_latency_node(node, &graph));
+ graph.InsertNode(node, MakeLatencyNode(node, &graph));
}
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 3ce238a30a..63945b8b9e 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -32,9 +32,8 @@ namespace {
constexpr char kFusedOpName[] = "MapAndBatchDatasetV2";
-NodeDef make_map_and_batch_node(const NodeDef& map_node,
- const NodeDef& batch_node,
- MutableGraphView* graph) {
+NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
+ MutableGraphView* graph) {
NodeDef new_node;
new_node.set_op(kFusedOpName);
graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(),
@@ -104,8 +103,8 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// Use a more descriptive variable name now that we know the node type.
const NodeDef& batch_node = node;
- GraphView::InputPort input_port = graph.GetInputPort(batch_node.name(), 0);
- NodeDef* node2 = graph.GetRegularFanin(input_port).node;
+ NodeDef* node2 = graph_utils::GetInputNode(batch_node, graph);
+
if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
continue;
}
@@ -113,7 +112,7 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
NodeDef* map_node = node2;
auto* new_node =
- graph.AddNode(make_map_and_batch_node(*map_node, batch_node, &graph));
+ graph.AddNode(MakeMapAndBatchNode(*map_node, batch_node, &graph));
graph.ReplaceInput(batch_node, *new_node);
// Mark the `Map` and `Batch` nodes for removal.
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
index 5e76c9f819..f1844a141c 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
@@ -116,22 +116,25 @@ Status MapAndFilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
const auto& fun = filter_node->attr().at("predicate");
const FunctionDef* filter_func = function_library.Find(fun.func().name());
if (!fusion_utils::CanCompose(map_func->signature(),
- filter_func->signature()))
+ filter_func->signature())) {
+ VLOG(1) << "Can't fuse map and filter because the output signature of "
+ "the map function does not match the input signature of the "
+ "filter function\n";
return nullptr;
+ }
return fusion_utils::FuseFunctions(
*map_func, *filter_func, "fused_map_and_filter_function",
fusion_utils::CombineSignature, fusion_utils::ComposeInput,
- fusion_utils::CombineOutput, output->mutable_library());
+ fusion_utils::CombineOutput, fusion_utils::MergeNodes,
+ output->mutable_library());
};
for (const NodeDef& node : sorted_old_graph.node()) {
const NodeDef* filter_node = get_filter_node(node);
if (!filter_node) continue;
- GraphView::InputPort input_port =
- graph.GetInputPort(filter_node->name(), 0);
const NodeDef* map_node =
- get_map_node(*graph.GetRegularFanin(input_port).node);
+ get_map_node(*graph_utils::GetInputNode(*filter_node, graph));
if (!map_node) continue;
const auto* fused_function = make_fused_function(map_node, filter_node);
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
index 3b6829ade3..f029a093fa 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
@@ -30,7 +30,7 @@ namespace {
NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
return test::function::NDef(
- name, "MapDataset", {input_node_name.ToString()},
+ name, "MapDataset", {string(input_node_name)},
{{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
{"Targuments", {}},
{"output_shapes", {}},
@@ -39,7 +39,7 @@ NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
return test::function::NDef(
- name, "FilterDataset", {input_node_name.ToString()},
+ name, "FilterDataset", {string(input_node_name)},
{{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
{"Targuments", {}},
{"output_shapes", {}},
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
index feb370eb9d..a78ecb09f7 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
@@ -90,21 +90,25 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
const auto& fun = map_node->attr().at("f");
const FunctionDef* func = function_library.Find(fun.func().name());
- if (!fusion_utils::CanCompose(parent_func->signature(), func->signature()))
+ if (!fusion_utils::CanCompose(parent_func->signature(),
+ func->signature())) {
+ VLOG(1) << "Can't fuse two maps because the output signature of the "
+ "first map function does not match the input signature of the "
+ "second function\n";
return nullptr;
+ }
return fusion_utils::FuseFunctions(
*parent_func, *func, "fused_map", fusion_utils::ComposeSignature,
fusion_utils::ComposeInput, fusion_utils::ComposeOutput,
- output->mutable_library());
+ fusion_utils::MergeNodes, output->mutable_library());
};
for (const NodeDef& node : sorted_old_graph.node()) {
const NodeDef* map_node = get_map_node(node);
if (!map_node) continue;
- GraphView::InputPort input_port = graph.GetInputPort(map_node->name(), 0);
const NodeDef* parent_map_node =
- get_map_node(*graph.GetRegularFanin(input_port).node);
+ get_map_node(*graph_utils::GetInputNode(*map_node, graph));
if (!parent_map_node) continue;
const auto* fused_function = get_fused_function(parent_map_node, map_node);
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
index df6c19dc7c..b25dfbd0b8 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
@@ -30,7 +30,7 @@ namespace {
NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
return test::function::NDef(
- name, "MapDataset", {input_node_name.ToString()},
+ name, "MapDataset", {string(input_node_name)},
{{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
{"Targuments", {}},
{"output_shapes", {}},
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 23f35050f2..a019b77eb7 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
@@ -111,10 +112,10 @@ bool HasCapturedInputs(const NodeDef& map_node) {
return map_node.attr().at("Targuments").list().type_size() > 0;
}
-NodeDef make_new_batch_node(const NodeDef& old_batch_node,
- const NodeDef& input_node,
- const FunctionDef& vectorized_func,
- MutableGraphView* graph) {
+NodeDef MakeNewBatchNode(const NodeDef& old_batch_node,
+ const NodeDef& input_node,
+ const FunctionDef& vectorized_func,
+ MutableGraphView* graph) {
NodeDef batch_node;
batch_node.set_op(old_batch_node.op());
graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(),
@@ -150,11 +151,11 @@ NodeDef make_new_batch_node(const NodeDef& old_batch_node,
return batch_node;
}
-NodeDef make_new_map_node(const NodeDef& old_map_node,
- const NodeDef& old_batch_node,
- const NodeDef& new_batch_node,
- const FunctionDef& vectorized_func,
- MutableGraphView* graph) {
+NodeDef MakeNewMapNode(const NodeDef& old_map_node,
+ const NodeDef& old_batch_node,
+ const NodeDef& new_batch_node,
+ const FunctionDef& vectorized_func,
+ MutableGraphView* graph) {
NodeDef map_node;
map_node.set_op(old_map_node.op());
graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(),
@@ -231,9 +232,9 @@ Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item,
CHECK_NOTNULL(vectorized_func);
auto* new_batch_node = graph.AddNode(
- make_new_batch_node(batch_node, *input_node, *vectorized_func, &graph));
+ MakeNewBatchNode(batch_node, *input_node, *vectorized_func, &graph));
- auto* new_map_node = graph.AddNode(make_new_map_node(
+ auto* new_map_node = graph.AddNode(MakeNewMapNode(
*map_node, batch_node, *new_batch_node, *vectorized_func, &graph));
graph.ReplaceInput(batch_node, *new_map_node);
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
index be2475bae8..ed1bd6bc97 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
@@ -55,8 +55,8 @@ NodeDef MakeMapNodeHelper(
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
const gtl::ArraySlice<DataType>& output_types) {
return test::function::NDef(
- name, map_op_name, {input_node_name.ToString()},
- {{"f", FunctionDefHelper::FunctionRef(function_name.ToString())},
+ name, map_op_name, {string(input_node_name)},
+ {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
{"Targuments", {}},
{"output_shapes", MakeShapeListAttr(output_shapes)},
{"output_types", output_types}});
@@ -76,7 +76,7 @@ NodeDef MakeBatchNode(
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
const gtl::ArraySlice<DataType>& output_types) {
return NDef(name, "BatchDataset",
- {input_node_name.ToString(), input_batch_size_name.ToString()},
+ {string(input_node_name), string(input_batch_size_name)},
{{"output_types", output_types},
{"output_shapes", MakeShapeListAttr(output_shapes)}});
}
@@ -87,8 +87,8 @@ NodeDef MakeBatchV2Node(
const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
const gtl::ArraySlice<DataType>& output_types) {
return NDef(name, "BatchDatasetV2",
- {input_node_name.ToString(), input_batch_size_name.ToString(),
- input_drop_remainder_name.ToString()},
+ {string(input_node_name), string(input_batch_size_name),
+ string(input_drop_remainder_name)},
{{"output_types", output_types},
{"output_shapes", MakeShapeListAttr(output_shapes)}});
}
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
index 55d57b3b97..a26f1000a3 100644
--- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
@@ -69,8 +69,7 @@ Status NoOpElimination::Optimize(Cluster* cluster, const GrapplerItem& item,
for (const NodeDef& node : item.graph.node()) {
if (!IsNoOp(node, graph)) continue;
- GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
- NodeDef* const parent = graph.GetRegularFanin(input_port).node;
+ NodeDef* const parent = graph_utils::GetInputNode(node, graph);
graph.ReplaceInput(node, *parent);
nodes_to_delete.insert(node.name());
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
index 7c7161c5b2..cb0ff670e8 100644
--- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
@@ -76,8 +76,8 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
// Use a more descriptive variable name now that we know the node type.
const NodeDef& repeat_node = node;
- GraphView::InputPort input_port = graph.GetInputPort(repeat_node.name(), 0);
- NodeDef* node2 = graph.GetRegularFanin(input_port).node;
+ NodeDef* node2 = graph_utils::GetInputNode(repeat_node, graph);
+
if (node2->op() != "ShuffleDataset") {
continue;
}