diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data')
20 files changed, 503 insertions, 207 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 74d936cfbc..530c957068 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -4,36 +4,41 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all") cc_library( - name = "function_rename", - srcs = ["function_rename.cc"], + name = "filter_fusion", + srcs = ["filter_fusion.cc"], hdrs = [ - "function_rename.h", + "filter_fusion.h", ], visibility = ["//visibility:public"], deps = [ ":graph_utils", + ":fusion_utils", + "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core:lib", - "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", ] + tf_protos_all(), ) tf_cc_test( - name = "function_rename_test", - srcs = ["function_rename_test.cc"], + name = "filter_fusion_test", + srcs = ["filter_fusion_test.cc"], visibility = ["//visibility:public"], deps = [ - ":function_rename", + ":filter_fusion", + ":graph_utils", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", - ] + tf_protos_all(), + ], ) cc_library( @@ -46,11 +51,13 @@ cc_library( deps = [ ":graph_utils", "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/kernels:functional_ops", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core:lib_internal", ] + tf_protos_all(), @@ -343,7 +350,7 @@ cc_library( name = "data", visibility = ["//visibility:public"], deps = [ - ":function_rename", + ":filter_fusion", ":latency_all_edges", ":map_and_batch_fusion", ":map_and_filter_fusion", diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc new file mode 100644 index 0000000000..c71aa6e804 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc @@ -0,0 +1,141 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/filter_fusion.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/fusion_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace grappler { +namespace { + +NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node, + const NodeDef& second_filter_node, + const FunctionDef& fused_function, + MutableGraphView* graph) { + NodeDef fused_node; + graph_utils::SetUniqueGraphNodeName("fused_filter", graph->GetGraph(), + &fused_node); + + fused_node.set_op("FilterDataset"); + fused_node.add_input(first_filter_node.input(0)); + + auto copy_attribute = [](const string& attribute_name, const NodeDef& from, + NodeDef* to) { + (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name); + }; + + auto attr = first_filter_node.attr().at("predicate"); + *attr.mutable_func()->mutable_name() = fused_function.signature().name(); + (*fused_node.mutable_attr())["predicate"] = std::move(attr); + + copy_attribute("Targuments", first_filter_node, &fused_node); + + for (auto key : {"output_shapes", "output_types"}) + copy_attribute(key, second_filter_node, &fused_node); + + return fused_node; +} + +} // namespace + +Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + GraphDef sorted_old_graph = item.graph; + TF_RETURN_IF_ERROR(TopologicalSort(&sorted_old_graph)); + *output = sorted_old_graph; + + MutableGraphView graph(output); + std::set<string> nodes_to_delete; + FunctionLibraryDefinition function_library(OpRegistry::Global(), + output->library()); + + auto get_filter_node = [](const NodeDef& node) -> const NodeDef* { + if (node.op() == "FilterDataset") return &node; + return nullptr; + }; + + auto get_fused_predicate = + [&](const NodeDef* first_filter_node, + const NodeDef* second_filter_node) -> FunctionDef* { + const auto& parent_fun = first_filter_node->attr().at("predicate"); + const FunctionDef* first_func = + function_library.Find(parent_fun.func().name()); + const auto& fun = second_filter_node->attr().at("predicate"); + const FunctionDef* second_func = function_library.Find(fun.func().name()); + + if (!fusion_utils::HasSameSignature(first_func->signature(), + second_func->signature())) { + VLOG(1) << "Can't fuse Filters because they have different signature\n"; + return nullptr; + } + + return fusion_utils::FuseFunctions( + *first_func, *second_func, "fused_predicate", + fusion_utils::SameSignature, fusion_utils::SameInput, + fusion_utils::LazyConjunctionOutput, fusion_utils::LazyConjunctionNodes, + output->mutable_library()); + }; + + for (const NodeDef& node : sorted_old_graph.node()) { + const NodeDef* second_filter_node = get_filter_node(node); + if (!second_filter_node) continue; + + const NodeDef* first_filter_node = + get_filter_node(*graph_utils::GetInputNode(*second_filter_node, graph)); + if (!first_filter_node) continue; + + const auto* fused_predicate = + get_fused_predicate(first_filter_node, second_filter_node); + if (!fused_predicate) continue; + const auto* fused_filter_node = graph.AddNode(MakeFusedFilterNode( + *first_filter_node, *second_filter_node, *fused_predicate, &graph)); + + graph.ReplaceInput(*second_filter_node, *fused_filter_node); + + // TODO(prazek): we should run some optimizations on the fused filter + // functions, or make sure that optimization passes run after filter + // fusion. + TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_predicate)); + // TODO(prazek): we could also remove map functions from library if they + // are not used anymore. + nodes_to_delete.insert(first_filter_node->name()); + nodes_to_delete.insert(second_filter_node->name()); + } + + graph.DeleteNodes(nodes_to_delete); + return Status::OK(); +} + +void FilterFusion::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) { + // no-op +} + +REGISTER_GRAPH_OPTIMIZER_AS(FilterFusion, "filter_fusion"); + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/function_rename.h b/tensorflow/core/grappler/optimizers/data/filter_fusion.h index 23ad9470ff..91a0364a46 100644 --- a/tensorflow/core/grappler/optimizers/data/function_rename.h +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.h @@ -13,20 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_ -#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_ +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_ #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" namespace tensorflow { namespace grappler { -class FunctionRename : public CustomGraphOptimizer { +// This optimization fuses filter transformations. +class FilterFusion : public CustomGraphOptimizer { public: - FunctionRename() = default; - ~FunctionRename() override = default; + FilterFusion() = default; + ~FilterFusion() override = default; - string name() const override { return "_test_only_function_rename"; }; + string name() const override { return "filter_fusion"; }; Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { @@ -43,4 +44,4 @@ class FunctionRename : public CustomGraphOptimizer { } // end namespace grappler } // end namespace tensorflow -#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_ +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_ diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc new file mode 100644 index 0000000000..12b1924efd --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/filter_fusion.h" + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) { + return test::function::NDef( + name, "FilterDataset", {string(input_node_name)}, + {{"predicate", FunctionDefHelper::FunctionRef("IsZero")}, + {"Targuments", {}}, + {"output_shapes", {}}, + {"output_types", {}}}); +} + +TEST(FilterFusionTest, FuseTwoFilterIntoOne) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + MakeFilterNode("filter1", "range"), + MakeFilterNode("filter2", "filter1")}, + // FunctionLib + { + test::function::IsZero(), + }); + + FilterFusion optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("FilterDataset", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter1", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter2", output)); +} + +TEST(FilterFusionTest, FuseThreeNodesIntoOne) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + MakeFilterNode("filter1", "range"), MakeFilterNode("filter2", "filter1"), + MakeFilterNode("filter3", "filter2"), + NDef("cache", "CacheDataset", {"filter3", "filename"}, {})}, + // FunctionLib + { + test::function::IsZero(), + }); + + FilterFusion optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("FilterDataset", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter1", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter2", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter3", output)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/function_rename.cc b/tensorflow/core/grappler/optimizers/data/function_rename.cc deleted file mode 100644 index 8cf044d1bd..0000000000 --- a/tensorflow/core/grappler/optimizers/data/function_rename.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/grappler/optimizers/data/function_rename.h" - -#include "tensorflow/core/grappler/clusters/cluster.h" -#include "tensorflow/core/grappler/graph_view.h" -#include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/op_types.h" -#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" -#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" -#include "tensorflow/core/grappler/utils.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace tensorflow { -namespace grappler { - -Status FunctionRename::Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* output) { - *output = item.graph; - GraphView graph(output); - int n = output->mutable_library()->function_size(); - for (int i = 0; i < n; ++i) { - FunctionDef* fn = output->mutable_library()->mutable_function(i); - fn->mutable_signature()->set_name(fn->signature().name() + "world"); - } - - return Status::OK(); -} - -void FunctionRename::Feedback(Cluster* cluster, const GrapplerItem& item, - const GraphDef& optimize_output, double result) { - // no-op -} - -REGISTER_GRAPH_OPTIMIZER_AS(FunctionRename, "_test_only_function_rename"); - -} // end namespace grappler -} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/function_rename_test.cc b/tensorflow/core/grappler/optimizers/data/function_rename_test.cc deleted file mode 100644 index 56b8a960a7..0000000000 --- a/tensorflow/core/grappler/optimizers/data/function_rename_test.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/grappler/optimizers/data/function_rename.h" - -#include "tensorflow/core/framework/function.pb.h" -#include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace grappler { -namespace { - -TEST(FunctionRenameTest, RenameFunction) { - GrapplerItem item; - GraphDef *graph = &item.graph; - FunctionDef *fn = graph->mutable_library()->add_function(); - fn->mutable_signature()->set_name("hello"); - - FunctionRename optimizer; - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - EXPECT_EQ(output.library().function(0).signature().name(), "helloworld"); -} - -} // namespace -} // namespace grappler -} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc index f84f109af6..01a78c04b0 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/fusion_utils.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_def.pb.h" - #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" @@ -52,6 +52,12 @@ string GetOutputNode(const FunctionDef& function, int output_idx) { return function.ret().at(ret_output_name); } +string& GetMutableOutputNode(FunctionDef* function, int output_idx) { + const auto& ret_output_name = + function->signature().output_arg(output_idx).name(); + return function->mutable_ret()->at(ret_output_name); +} + template <typename Iterable> StringCollection GetNames(const Iterable& iterable, int allocate_size) { StringCollection names; @@ -106,7 +112,6 @@ gtl::FlatMap<string, string> GetUniqueNames(const Iterable& first_iterable, // Nodes that will be added to the function can have the same name as the nodes // from parent function. void RenameFunctionNodes(const FunctionDef& first_function, - FunctionDef* fused_function, protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse, protobuf::Map<string, string>* rets_to_fuse) { const gtl::FlatMap<string, string> changed_node_names = @@ -149,6 +154,7 @@ OpDef GetUniqueSignature(const OpDef& first_signature, const gtl::FlatMap<string, string> changed_input_names = GetUniqueNames(first_signature.input_arg(), second_signature.input_arg()); OpDef signature; + signature.set_name(second_signature.name()); for (const auto& input_arg : second_signature.input_arg()) { auto& input = *signature.add_input_arg(); @@ -221,12 +227,13 @@ void FuseFunctionNodes(const StringCollection& first_inputs, } // This function looks for direct edges from input to return and rewrites -// them to the coresponding input of the return of `first_function`. +// them to the corresponding input of the return of `first_function`. void FuseReturns(const StringCollection& first_inputs, const StringCollection& second_inputs, const StringCollection& first_outputs, - const SetInputFn& set_input, FunctionDef* fused_function) { - for (auto& ret : *fused_function->mutable_ret()) { + const SetInputFn& set_input, + protobuf::Map<string, string>* fused_ret) { + for (auto& ret : *fused_ret) { auto return_input = ParseNodeConnection(ret.second); auto input_it = std::find(second_inputs.begin(), second_inputs.end(), return_input); @@ -249,6 +256,33 @@ StringCollection GetFunctionOutputs(const FunctionDef& function) { return outputs; } +FunctionDef* CreateFalsePredicate( + const protobuf::RepeatedPtrField<OpDef_ArgDef>& fake_args, + FunctionDefLibrary* library) { + GraphDef graph; + MutableGraphView graph_view(&graph); + auto* node = graph_utils::AddScalarConstNode(false, &graph_view); + auto* false_predicate = library->add_function(); + graph_utils::SetUniqueGraphFunctionName("false_predicate", library, + false_predicate); + + int num = 0; + for (const auto& fake_arg : fake_args) { + auto* arg = false_predicate->mutable_signature()->add_input_arg(); + arg->set_type(fake_arg.type()); + arg->set_name(strings::StrCat("fake_arg", num)); + num++; + } + + auto* output = false_predicate->mutable_signature()->add_output_arg(); + output->set_name("false_out"); + output->set_type(DT_BOOL); + + (*false_predicate->mutable_ret())["false_out"] = node->name() + ":output:0"; + *false_predicate->mutable_node_def() = std::move(*graph.mutable_node()); + return false_predicate; +} + void CheckIfCanCompose(const OpDef& first_signature, const OpDef& second_signature) { CHECK(CanCompose(first_signature, second_signature)) @@ -259,6 +293,15 @@ void CheckIfCanCompose(const OpDef& first_signature, } // namespace +void MergeNodes(const FunctionDef& first_function, + const FunctionDef& second_function, FunctionDef* fused_function, + FunctionDefLibrary* library) { + // Copy all nodes from first_function. + fused_function->mutable_node_def()->CopyFrom(first_function.node_def()); + // Copy transformed nodes from the second function. + fused_function->mutable_node_def()->MergeFrom(second_function.node_def()); +} + bool CanCompose(const OpDef& first_signature, const OpDef& second_signature) { // TODO(prazek): Functions can have additional inputs being placeholders // for a values used in function. We should be able to also fuse these @@ -285,8 +328,8 @@ void ComposeSignature(const OpDef& first_signature, void ComposeOutput(const protobuf::Map<string, string>& first_ret, const protobuf::Map<string, string>& second_ret, - FunctionDef* fused_function) { - *fused_function->mutable_ret() = second_ret; + protobuf::Map<string, string>* fused_ret) { + *fused_ret = second_ret; } void CombineSignature(const OpDef& first_signature, @@ -302,41 +345,110 @@ void CombineSignature(const OpDef& first_signature, void CombineOutput(const protobuf::Map<string, string>& first_ret, const protobuf::Map<string, string>& second_ret, - FunctionDef* fused_function) { - *fused_function->mutable_ret() = first_ret; - fused_function->mutable_ret()->insert(second_ret.begin(), second_ret.end()); + protobuf::Map<string, string>* fused_ret) { + *fused_ret = first_ret; + fused_ret->insert(second_ret.begin(), second_ret.end()); +} + +string SameInput(const StringCollection& first_inputs, + const StringCollection& second_inputs, + const StringCollection& first_outputs, int arg_num) { + return first_inputs.at(arg_num); +} + +bool HasSameSignature(const OpDef& first_signature, + const OpDef& second_signature) { + return first_signature.input_arg_size() == + second_signature.input_arg_size() && + first_signature.output_arg_size() == + second_signature.output_arg_size(); +} + +void SameSignature(const OpDef& first_signature, const OpDef& second_signature, + OpDef* fused_signature) { + CHECK(HasSameSignature(first_signature, second_signature)) + << "Functions do not have the same signature"; + // Copy signature from first function. + *fused_signature = first_signature; +} + +void LazyConjunctionNodes(const FunctionDef& first_function, + const FunctionDef& second_function, + FunctionDef* fused_function, + FunctionDefLibrary* library) { + fused_function->mutable_node_def()->CopyFrom(first_function.node_def()); + + NodeDefBuilder if_builder("", "If"); + if_builder.Input(GetOutputNode(first_function, 0), 0, DT_BOOL); + DataTypeVector in_arg_types; + std::vector<NodeDefBuilder::NodeOut> inputs; + for (const auto& input_arg : first_function.signature().input_arg()) { + inputs.push_back({input_arg.name(), 0, input_arg.type()}); + in_arg_types.push_back(input_arg.type()); + } + if_builder.Attr("Tin", in_arg_types); + + if_builder.Attr("Tcond", DT_BOOL); + if_builder.Attr("Tout", DataTypeVector{DT_BOOL}); + if_builder.Attr("_lower_using_switch_merge", true); + + NameAttrList then_branch; + then_branch.set_name(second_function.signature().name()); + if_builder.Attr("then_branch", then_branch); + + auto* false_predicate = + CreateFalsePredicate(first_function.signature().input_arg(), library); + + NameAttrList else_branch; + else_branch.set_name(false_predicate->signature().name()); + if_builder.Attr("else_branch", else_branch); + if_builder.Input(inputs); + + auto* if_node = fused_function->add_node_def(); + // This is guaranteed to succeed. + TF_CHECK_OK(if_builder.Finalize(if_node)); + graph_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node); + + GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0"; +} + +void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret, + const protobuf::Map<string, string>& second_ret, + protobuf::Map<string, string>* fused_ret) { + CHECK_EQ(first_ret.size(), 1); + CHECK_EQ(second_ret.size(), 1); + // Temporarily copy returns from first_ret. We are going to change the + // output node after creating it. + *fused_ret = first_ret; } -FunctionDef* FuseFunctions(const FunctionDef& first_function, - const FunctionDef& function, - StringPiece fused_name_prefix, - const SetFunctionSignatureFn& set_signature, - const SetInputFn& set_input, - const SetOutputFn& set_output, - FunctionDefLibrary* library) { - if (first_function.attr_size() != 0 || function.attr_size() != 0) +FunctionDef* FuseFunctions( + const FunctionDef& first_function, const FunctionDef& second_function, + StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature, + const SetInputFn& set_input, const SetOutputFn& set_output, + const SetNodesFn& set_nodes, FunctionDefLibrary* library) { + if (first_function.attr_size() != 0 || second_function.attr_size() != 0) return nullptr; // Functions with attributes are currently not supported // This function will be used as a clone of second function, having unique // names. - FunctionDef setup_function = function; + FunctionDef setup_function = second_function; *setup_function.mutable_signature() = GetUniqueSignature( first_function.signature(), setup_function.signature(), setup_function.mutable_ret(), setup_function.mutable_node_def()); FunctionDef* fused_function = library->add_function(); - // Copy all nodes from first_function. - fused_function->mutable_node_def()->CopyFrom(first_function.node_def()); + set_signature(first_function.signature(), setup_function.signature(), fused_function->mutable_signature()); graph_utils::SetUniqueGraphFunctionName(fused_name_prefix, library, fused_function); - RenameFunctionNodes(first_function, fused_function, - setup_function.mutable_node_def(), + RenameFunctionNodes(first_function, setup_function.mutable_node_def(), setup_function.mutable_ret()); - set_output(first_function.ret(), setup_function.ret(), fused_function); + set_output(first_function.ret(), setup_function.ret(), + fused_function->mutable_ret()); CHECK(fused_function->signature().output_arg_size() == fused_function->ret_size()) @@ -351,10 +463,10 @@ FunctionDef* FuseFunctions(const FunctionDef& first_function, FuseFunctionNodes(first_inputs, second_inputs, first_outputs, set_input, setup_function.mutable_node_def()); FuseReturns(first_inputs, second_inputs, first_outputs, set_input, - fused_function); + fused_function->mutable_ret()); + + set_nodes(first_function, setup_function, fused_function, library); - // Copy transformed nodes from the second function. - fused_function->mutable_node_def()->MergeFrom(setup_function.node_def()); return fused_function; } diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.h b/tensorflow/core/grappler/optimizers/data/fusion_utils.h index 41f13f6cb8..19b7002dcd 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.h +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.h @@ -48,14 +48,20 @@ using SetInputFn = const StringCollection& second_function_inputs, const StringCollection& parent_outputs, int arg_num)>; -// This function is invoked with first function ret. It is used to set up -// returns of fused function. If you need to combine outputs -// of first and second function, then this is a right place to create a new -// nodes. +// This function is invoked with first and second function ret. It is used to +// set up returns of fused function. using SetOutputFn = std::function<void(const protobuf::Map<string, string>& parent_ret, const protobuf::Map<string, string>& second_function_ret, - FunctionDef* fused_function)>; + protobuf::Map<string, string>* fused_ret)>; + +using SetNodesFn = std::function<void( + const FunctionDef& first_function, const FunctionDef& second_function, + FunctionDef* fused_function, FunctionDefLibrary* library)>; + +void MergeNodes(const FunctionDef& first_function, + const FunctionDef& second_function, FunctionDef* fused_function, + FunctionDefLibrary* library); // Returns true if functions can be composed. bool CanCompose(const OpDef& first_signature, const OpDef& second_signature); @@ -71,7 +77,7 @@ string ComposeInput(const StringCollection& first_inputs, // second_function(first_function(args...)). void ComposeOutput(const protobuf::Map<string, string>& first_ret, const protobuf::Map<string, string>& second_ret, - FunctionDef* fused_function); + protobuf::Map<string, string>* fused_ret); // Set input signature to `first_function_signature` and output signature // to `first_function_signature` + `second_function_signature` @@ -83,7 +89,32 @@ void CombineSignature(const OpDef& first_signature, // return *first_function(...), *second_function(...) void CombineOutput(const protobuf::Map<string, string>& first_ret, const protobuf::Map<string, string>& second_ret, - FunctionDef* fused_function); + protobuf::Map<string, string>* fused_ret); + +// Returns true if both signatures have the same number of input and output +// args. +bool HasSameSignature(const OpDef& first_signature, + const OpDef& second_signature); + +// Check if both signatures are same and copy it from `first_signature`. +void SameSignature(const OpDef& first_signature, const OpDef& second_signature, + OpDef* fused_signature); + +// Take the same input as first function. +string SameInput(const StringCollection& first_inputs, + const StringCollection& second_inputs, + const StringCollection& first_outputs, int arg_num); + +// Create a fused function that computes the short-circuit logical AND of the +// result of the first function and the result of the second function. +void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret, + const protobuf::Map<string, string>& second_ret, + protobuf::Map<string, string>* fused_ret); + +void LazyConjunctionNodes(const FunctionDef& first_function, + const FunctionDef& second_function, + FunctionDef* fused_function, + FunctionDefLibrary* library); // Fuse `first_function` with `second_function`, setting `fused_name_prefix` as // a name prefix. The nodes from `first_function` are copied unmodified. All @@ -91,13 +122,11 @@ void CombineOutput(const protobuf::Map<string, string>& first_ret, // that are not conflicting with first function. This means that copied nodes // from second function can end up having different names. For explanation of // set up functions see the documentation of the functions types. -FunctionDef* FuseFunctions(const FunctionDef& first_function, - const FunctionDef& second_function, - StringPiece fused_name_prefix, - const SetFunctionSignatureFn& set_signature, - const SetInputFn& set_input, - const SetOutputFn& set_output, - FunctionDefLibrary* library); +FunctionDef* FuseFunctions( + const FunctionDef& first_function, const FunctionDef& second_function, + StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature, + const SetInputFn& set_input, const SetOutputFn& set_output, + const SetNodesFn& set_nodes, FunctionDefLibrary* library); } // namespace fusion_utils } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc index 7ad5d63bf6..d5c6466080 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc @@ -57,10 +57,10 @@ TEST(FusionUtilsTest, FuseFunctionsByComposition) { auto *function = graph.mutable_library()->add_function(); *function = test::function::XTimesTwo(); - auto *fused_function = - FuseFunctions(*parent_function, *function, "fused_maps", - fusion_utils::ComposeSignature, fusion_utils::ComposeInput, - fusion_utils::ComposeOutput, graph.mutable_library()); + auto *fused_function = FuseFunctions( + *parent_function, *function, "fused_maps", fusion_utils::ComposeSignature, + fusion_utils::ComposeInput, fusion_utils::ComposeOutput, + fusion_utils::MergeNodes, graph.mutable_library()); EXPECT_EQ(fused_function->signature().name(), "fused_maps"); EXPECT_EQ(fused_function->signature().input_arg_size(), 1); @@ -98,7 +98,8 @@ TEST(FusionUtilsTest, FuseFunctionWithPredicate) { auto *fused_function = FuseFunctions(*xtimes_two, *is_zero, "fused_map_and_filter_function", fusion_utils::CombineSignature, fusion_utils::ComposeInput, - fusion_utils::CombineOutput, graph.mutable_library()); + fusion_utils::CombineOutput, fusion_utils::MergeNodes, + graph.mutable_library()); EXPECT_EQ(fused_function->signature().name(), "fused_map_and_filter_function"); @@ -134,10 +135,10 @@ TEST(FusionUtilsTest, FuseSameFunctionWithExtraOutput) { auto *function = graph.mutable_library()->add_function(); *function = test::function::XTimesTwo(); - auto *fused_function = - FuseFunctions(*parent_function, *function, "fused_maps", - fusion_utils::CombineSignature, fusion_utils::ComposeInput, - fusion_utils::CombineOutput, graph.mutable_library()); + auto *fused_function = FuseFunctions( + *parent_function, *function, "fused_maps", fusion_utils::CombineSignature, + fusion_utils::ComposeInput, fusion_utils::CombineOutput, + fusion_utils::MergeNodes, graph.mutable_library()); EXPECT_EQ(fused_function->signature().input_arg_size(), 1); EXPECT_EQ(fused_function->signature().output_arg_size(), 2); @@ -169,7 +170,8 @@ TEST(FusionUtilsTest, ZipFusion) { auto *fused_function = FuseFunctions(*function, *function, "zip_maps", zip_signature, zip_input, - fusion_utils::CombineOutput, graph.mutable_library()); + fusion_utils::CombineOutput, fusion_utils::MergeNodes, + graph.mutable_library()); EXPECT_EQ(fused_function->signature().input_arg_size(), 2); EXPECT_EQ(fused_function->signature().output_arg_size(), 2); diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index 883037173b..5a7fe19265 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -94,11 +94,11 @@ NodeDef* AddNode(StringPiece name, StringPiece op, MutableGraphView* graph) { NodeDef node; if (!name.empty()) { - node.set_name(name.ToString()); + node.set_name(string(name)); } else { SetUniqueGraphNodeName(op, graph->GetGraph(), &node); } - node.set_op(op.ToString()); + node.set_op(string(op)); for (const string& input : inputs) { node.add_input(input); } @@ -114,11 +114,11 @@ NodeDef* AddNode(StringPiece name, StringPiece op, FunctionDef* fd) { NodeDef* node = fd->add_node_def(); if (!name.empty()) { - node->set_name(name.ToString()); + node->set_name(string(name)); } else { SetUniqueFunctionNodeName(op, fd, node); } - node->set_op(op.ToString()); + node->set_op(string(op)); for (const string& input : inputs) { node->add_input(input); } @@ -270,7 +270,7 @@ NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) { void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node) { - string name = prefix.ToString(); + string name = string(prefix); int id = graph->node_size(); while (ContainsGraphNodeWithName(name, *graph)) { if (name.rfind("_generated") != std::string::npos && @@ -286,7 +286,7 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, NodeDef* node) { - string name = prefix.ToString(); + string name = string(prefix); int id = function->node_def_size(); while (ContainsFunctionNodeWithName(name, *function)) { name = strings::StrCat(prefix, "/_", id); @@ -297,7 +297,7 @@ void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library, FunctionDef* function) { - string name = prefix.ToString(); + string name = string(prefix); int id = library->function_size(); while (ContainsGraphFunctionWithName(name, *library)) { name = strings::StrCat(prefix, "/_", id); diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc index 0b25b1ea9d..9e382aeef9 100644 --- a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc +++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc @@ -33,7 +33,7 @@ namespace { constexpr char kInsertOpName[] = "LatencyStatsDataset"; -NodeDef make_latency_node(const NodeDef& node, MutableGraphView* graph) { +NodeDef MakeLatencyNode(const NodeDef& node, MutableGraphView* graph) { NodeDef new_node; new_node.set_op(kInsertOpName); graph_utils::SetUniqueGraphNodeName( @@ -96,7 +96,7 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item, } } - graph.InsertNode(node, make_latency_node(node, &graph)); + graph.InsertNode(node, MakeLatencyNode(node, &graph)); } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc index 3ce238a30a..63945b8b9e 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc @@ -32,9 +32,8 @@ namespace { constexpr char kFusedOpName[] = "MapAndBatchDatasetV2"; -NodeDef make_map_and_batch_node(const NodeDef& map_node, - const NodeDef& batch_node, - MutableGraphView* graph) { +NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node, + MutableGraphView* graph) { NodeDef new_node; new_node.set_op(kFusedOpName); graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(), @@ -104,8 +103,8 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item, // Use a more descriptive variable name now that we know the node type. const NodeDef& batch_node = node; - GraphView::InputPort input_port = graph.GetInputPort(batch_node.name(), 0); - NodeDef* node2 = graph.GetRegularFanin(input_port).node; + NodeDef* node2 = graph_utils::GetInputNode(batch_node, graph); + if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") { continue; } @@ -113,7 +112,7 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item, NodeDef* map_node = node2; auto* new_node = - graph.AddNode(make_map_and_batch_node(*map_node, batch_node, &graph)); + graph.AddNode(MakeMapAndBatchNode(*map_node, batch_node, &graph)); graph.ReplaceInput(batch_node, *new_node); // Mark the `Map` and `Batch` nodes for removal. diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc index 5e76c9f819..f1844a141c 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc @@ -116,22 +116,25 @@ Status MapAndFilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item, const auto& fun = filter_node->attr().at("predicate"); const FunctionDef* filter_func = function_library.Find(fun.func().name()); if (!fusion_utils::CanCompose(map_func->signature(), - filter_func->signature())) + filter_func->signature())) { + VLOG(1) << "Can't fuse map and filter because the output signature of " + "the map function does not match the input signature of the " + "filter function\n"; return nullptr; + } return fusion_utils::FuseFunctions( *map_func, *filter_func, "fused_map_and_filter_function", fusion_utils::CombineSignature, fusion_utils::ComposeInput, - fusion_utils::CombineOutput, output->mutable_library()); + fusion_utils::CombineOutput, fusion_utils::MergeNodes, + output->mutable_library()); }; for (const NodeDef& node : sorted_old_graph.node()) { const NodeDef* filter_node = get_filter_node(node); if (!filter_node) continue; - GraphView::InputPort input_port = - graph.GetInputPort(filter_node->name(), 0); const NodeDef* map_node = - get_map_node(*graph.GetRegularFanin(input_port).node); + get_map_node(*graph_utils::GetInputNode(*filter_node, graph)); if (!map_node) continue; const auto* fused_function = make_fused_function(map_node, filter_node); diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc index 3b6829ade3..f029a093fa 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc @@ -30,7 +30,7 @@ namespace { NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) { return test::function::NDef( - name, "MapDataset", {input_node_name.ToString()}, + name, "MapDataset", {string(input_node_name)}, {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")}, {"Targuments", {}}, {"output_shapes", {}}, @@ -39,7 +39,7 @@ NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) { NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) { return test::function::NDef( - name, "FilterDataset", {input_node_name.ToString()}, + name, "FilterDataset", {string(input_node_name)}, {{"predicate", FunctionDefHelper::FunctionRef("IsZero")}, {"Targuments", {}}, {"output_shapes", {}}, diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc index feb370eb9d..a78ecb09f7 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc @@ -90,21 +90,25 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item, const auto& fun = map_node->attr().at("f"); const FunctionDef* func = function_library.Find(fun.func().name()); - if (!fusion_utils::CanCompose(parent_func->signature(), func->signature())) + if (!fusion_utils::CanCompose(parent_func->signature(), + func->signature())) { + VLOG(1) << "Can't fuse two maps because the output signature of the " + "first map function does not match the input signature of the " + "second function\n"; return nullptr; + } return fusion_utils::FuseFunctions( *parent_func, *func, "fused_map", fusion_utils::ComposeSignature, fusion_utils::ComposeInput, fusion_utils::ComposeOutput, - output->mutable_library()); + fusion_utils::MergeNodes, output->mutable_library()); }; for (const NodeDef& node : sorted_old_graph.node()) { const NodeDef* map_node = get_map_node(node); if (!map_node) continue; - GraphView::InputPort input_port = graph.GetInputPort(map_node->name(), 0); const NodeDef* parent_map_node = - get_map_node(*graph.GetRegularFanin(input_port).node); + get_map_node(*graph_utils::GetInputNode(*map_node, graph)); if (!parent_map_node) continue; const auto* fused_function = get_fused_function(parent_map_node, map_node); diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc index df6c19dc7c..b25dfbd0b8 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc @@ -30,7 +30,7 @@ namespace { NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) { return test::function::NDef( - name, "MapDataset", {input_node_name.ToString()}, + name, "MapDataset", {string(input_node_name)}, {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")}, {"Targuments", {}}, {"output_shapes", {}}, diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index 23f35050f2..a019b77eb7 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -111,10 +112,10 @@ bool HasCapturedInputs(const NodeDef& map_node) { return map_node.attr().at("Targuments").list().type_size() > 0; } -NodeDef make_new_batch_node(const NodeDef& old_batch_node, - const NodeDef& input_node, - const FunctionDef& vectorized_func, - MutableGraphView* graph) { +NodeDef MakeNewBatchNode(const NodeDef& old_batch_node, + const NodeDef& input_node, + const FunctionDef& vectorized_func, + MutableGraphView* graph) { NodeDef batch_node; batch_node.set_op(old_batch_node.op()); graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(), @@ -150,11 +151,11 @@ NodeDef make_new_batch_node(const NodeDef& old_batch_node, return batch_node; } -NodeDef make_new_map_node(const NodeDef& old_map_node, - const NodeDef& old_batch_node, - const NodeDef& new_batch_node, - const FunctionDef& vectorized_func, - MutableGraphView* graph) { +NodeDef MakeNewMapNode(const NodeDef& old_map_node, + const NodeDef& old_batch_node, + const NodeDef& new_batch_node, + const FunctionDef& vectorized_func, + MutableGraphView* graph) { NodeDef map_node; map_node.set_op(old_map_node.op()); graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(), @@ -231,9 +232,9 @@ Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item, CHECK_NOTNULL(vectorized_func); auto* new_batch_node = graph.AddNode( - make_new_batch_node(batch_node, *input_node, *vectorized_func, &graph)); + MakeNewBatchNode(batch_node, *input_node, *vectorized_func, &graph)); - auto* new_map_node = graph.AddNode(make_new_map_node( + auto* new_map_node = graph.AddNode(MakeNewMapNode( *map_node, batch_node, *new_batch_node, *vectorized_func, &graph)); graph.ReplaceInput(batch_node, *new_map_node); diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc index be2475bae8..ed1bd6bc97 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc @@ -55,8 +55,8 @@ NodeDef MakeMapNodeHelper( const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, const gtl::ArraySlice<DataType>& output_types) { return test::function::NDef( - name, map_op_name, {input_node_name.ToString()}, - {{"f", FunctionDefHelper::FunctionRef(function_name.ToString())}, + name, map_op_name, {string(input_node_name)}, + {{"f", FunctionDefHelper::FunctionRef(string(function_name))}, {"Targuments", {}}, {"output_shapes", MakeShapeListAttr(output_shapes)}, {"output_types", output_types}}); @@ -76,7 +76,7 @@ NodeDef MakeBatchNode( const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, const gtl::ArraySlice<DataType>& output_types) { return NDef(name, "BatchDataset", - {input_node_name.ToString(), input_batch_size_name.ToString()}, + {string(input_node_name), string(input_batch_size_name)}, {{"output_types", output_types}, {"output_shapes", MakeShapeListAttr(output_shapes)}}); } @@ -87,8 +87,8 @@ NodeDef MakeBatchV2Node( const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, const gtl::ArraySlice<DataType>& output_types) { return NDef(name, "BatchDatasetV2", - {input_node_name.ToString(), input_batch_size_name.ToString(), - input_drop_remainder_name.ToString()}, + {string(input_node_name), string(input_batch_size_name), + string(input_drop_remainder_name)}, {{"output_types", output_types}, {"output_shapes", MakeShapeListAttr(output_shapes)}}); } diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc index 55d57b3b97..a26f1000a3 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc @@ -69,8 +69,7 @@ Status NoOpElimination::Optimize(Cluster* cluster, const GrapplerItem& item, for (const NodeDef& node : item.graph.node()) { if (!IsNoOp(node, graph)) continue; - GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0); - NodeDef* const parent = graph.GetRegularFanin(input_port).node; + NodeDef* const parent = graph_utils::GetInputNode(node, graph); graph.ReplaceInput(node, *parent); nodes_to_delete.insert(node.name()); diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc index 7c7161c5b2..cb0ff670e8 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc @@ -76,8 +76,8 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster, // Use a more descriptive variable name now that we know the node type. const NodeDef& repeat_node = node; - GraphView::InputPort input_port = graph.GetInputPort(repeat_node.name(), 0); - NodeDef* node2 = graph.GetRegularFanin(input_port).node; + NodeDef* node2 = graph_utils::GetInputNode(repeat_node, graph); + if (node2->op() != "ShuffleDataset") { continue; } |