diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-19 13:43:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 13:46:34 -0700 |
commit | eb2fe78e60ea8465443d8c653381a788ec581e49 (patch) | |
tree | b4dfbf11a0e35316aa45f912a9060fe46dc46d21 | |
parent | 125bf1dbb76c05bf5f88f14e77387ce35f986621 (diff) |
[tf.data] MapVectorization optimization: C++ conversion framework to vectorize a MapDefun function. Also implements conversion for two ops: Cast and Unpack.
PiperOrigin-RevId: 213686720
13 files changed, 1596 insertions, 182 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index e84df10778..7128a50be0 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -49,6 +49,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":graph_utils", + ":function_utils", "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -67,6 +68,7 @@ tf_cc_test( srcs = ["fusion_utils_test.cc"], visibility = ["//visibility:public"], deps = [ + ":function_utils", ":fusion_utils", ":graph_utils", "//tensorflow/core:framework", @@ -78,6 +80,40 @@ tf_cc_test( ) cc_library( + name = "function_utils", + srcs = ["function_utils.cc"], + hdrs = [ + "function_utils.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "function_utils_test", + srcs = ["function_utils_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":function_utils", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core/kernels:cast_op", + "//tensorflow/tools/graph_transforms:transform_utils", + ], +) + +cc_library( name = "graph_utils", srcs = ["graph_utils.cc"], hdrs = [ @@ -137,6 +173,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":function_utils", ":graph_utils", "//tensorflow/core:lib", "//tensorflow/core/grappler:mutable_graph_view", @@ -409,3 +446,42 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", ], ) + +cc_library( + name = "vectorization_utils", + srcs = ["vectorization_utils.cc"], + hdrs = [ + "vectorization_utils.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":function_utils", + ":graph_utils", + "@com_google_absl//absl/strings", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/utils:functions", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "vectorization_utils_test", + srcs = ["vectorization_utils_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":function_utils", + ":vectorization_utils", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core/kernels:cast_op", + "//tensorflow/tools/graph_transforms:transform_utils", + ] + tf_protos_all(), +) diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc new file mode 100644 index 0000000000..e95ea1a4c1 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc @@ -0,0 +1,196 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" + +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace grappler { +namespace function_utils { +namespace { + +template <typename Predicate, typename Collection> +std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate, + const Collection& collection) { + std::vector<int> indices = {}; + unsigned idx = 0; + for (auto&& element : collection) { + if (predicate(element)) { + indices.push_back(idx); + } + idx++; + } + return indices; +} + +} // namespace + +FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name, + const string& output, int position) + : node_name(node_name), node_output(output), position(position) { + full_str = strings::StrCat(node_name, ":", node_output, ":", position); +} + +FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) { + // Parses node_name:node_output:position string into its components. + full_str = input; + StringPiece capture; + StringPiece remaining; + + // Parse "node_name" + if (strings::Scanner(input) + .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE) + .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .GetResult(&remaining, &capture)) { + node_name = string(capture.data(), capture.size()); + } + + // Parse "node_output" if it exists + if (strings::Scanner(remaining) + .OneLiteral(":") + .RestartCapture() + .One(strings::Scanner::LETTER) + .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE) + .GetResult(&remaining, &capture)) { + node_output = string(capture.data(), capture.size()); + } + + // Parse "position" if it exists + if (strings::Scanner(remaining) + .OneLiteral(":") + .RestartCapture() + .Many(strings::Scanner::DIGIT) + .GetResult(nullptr, &capture)) { + CHECK(strings::safe_strto32(capture, &position)); + } +} + +// TODO(rachelim): Create a utility class similar to MutableGraphView for +// FunctionDefs, and use that to manipulate functions. It'll be more +// performant if we kept mappings of nodes->inputs/outputs, so that we don't +// have to search over all nodes each time. +// Note that we're not using GrapplerFunctionItem because it doesn't cover +// some of our desired uses (eg changing the outputs of a function), and the +// FunctionDef -> GraphDef conversion isn't really necessary in this case. +void ReplaceReferences(const string& from, const string& to, + FunctionDef* func) { + for (NodeDef& n : *func->mutable_node_def()) { + std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from, + to); + } + + for (auto& p : *func->mutable_ret()) { + if (p.second == from) { + p.second = to; + } + } +} + +void AddFunctionOutputWithUniqueName(StringPiece prefix, + StringPiece output_tensor_name, + FunctionDef* function, DataType dt) { + string name = string(prefix); + int id = function->signature().output_arg_size(); + while (ContainsFunctionOutputWithName(name, *function)) { + name = strings::StrCat(prefix, "/_", id); + ++id; + } + auto* output = function->mutable_signature()->mutable_output_arg()->Add(); + output->set_name(name); + output->set_type(dt); + + (*function->mutable_ret())[name] = string(output_tensor_name); +} + +NodeDef* AddNode(StringPiece name, StringPiece op, + const std::vector<string>& inputs, + const std::vector<std::pair<string, AttrValue>>& attributes, + FunctionDef* fd) { + NodeDef* node = fd->add_node_def(); + if (!name.empty()) { + node->set_name(string(name)); + } else { + SetUniqueFunctionNodeName(op, fd, node); + } + node->set_op(string(op)); + for (const string& input : inputs) { + node->add_input(input); + } + for (auto attr : attributes) { + (*node->mutable_attr())[attr.first] = attr.second; + } + return node; +} + +bool ContainsFunctionNodeWithName(StringPiece name, + const FunctionDef& function) { + return FindFunctionNodeWithName(name, function) != -1; +} + +bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { + return FindFunctionNodeWithOp(op, function) != -1; +} + +bool ContainsFunctionOutputWithName(StringPiece name, + const FunctionDef& function) { + return FindFunctionOutputWithName(name, function) != -1; +} + +int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) { + std::vector<int> indices = GetElementIndicesWithPredicate( + [&name](const OpDef_ArgDef& arg) { return arg.name() == name; }, + function.signature().input_arg()); + return indices.empty() ? -1 : indices.front(); +} + +int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) { + std::vector<int> indices = GetElementIndicesWithPredicate( + [&name](const OpDef_ArgDef& arg) { return arg.name() == name; }, + function.signature().output_arg()); + return indices.empty() ? -1 : indices.front(); +} + +int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) { + std::vector<int> indices = GetElementIndicesWithPredicate( + [&name](const NodeDef& node) { return node.name() == name; }, + function.node_def()); + return indices.empty() ? -1 : indices.front(); +} + +int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { + std::vector<int> indices = GetElementIndicesWithPredicate( + [&op](const NodeDef& node) { return node.op() == op; }, + function.node_def()); + + return indices.empty() ? -1 : indices.front(); +} + +void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, + NodeDef* node) { + string name = string(prefix); + int id = function->node_def_size(); + while (ContainsFunctionNodeWithName(name, *function)) { + name = strings::StrCat(prefix, "/_", id); + ++id; + } + node->set_name(std::move(name)); +} + +} // end namespace function_utils +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.h b/tensorflow/core/grappler/optimizers/data/function_utils.h new file mode 100644 index 0000000000..d4ce824652 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/function_utils.h @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace grappler { +namespace function_utils { +// This namespace contains utility functions for querying and modifying +// FunctionDefs. + +// Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings +// have the format node_name:node_output:position (if they derive from nodes), +// or input_name (if they derive from an argument). +struct FunctionDefTensorDesc { + FunctionDefTensorDesc() = default; + + FunctionDefTensorDesc(const string& node_name, const string& output, + int position); + + // Parses node_name:node_output:position string into its components. + explicit FunctionDefTensorDesc(const string& input); + + // TODO(rachelim): Add provisions to deal with special formats, like how + // GrapplerFunctionItem expands node output range if position is not defined + string full_str; + string node_name; + string node_output; + int position = -1; +}; + +// Replaces all references to `from` tensor in func's nodes' inputs and retvals +// to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`. +void ReplaceReferences(const string& from, const string& to, FunctionDef* func); + +// Adds a function output to the function def, ensuring that the output key +// is unique, and maps to output_tensor_name in the ret dict. +void AddFunctionOutputWithUniqueName(StringPiece prefix, + StringPiece output_tensor_name, + FunctionDef* function, DataType dt); + +// Adds a node to a FunctionDef. +NodeDef* AddNode(StringPiece name, StringPiece op, + const std::vector<string>& inputs, + const std::vector<std::pair<string, AttrValue>>& attributes, + FunctionDef* fd); + +// Checks whether the function contains a node with the given name. +bool ContainsFunctionNodeWithName(StringPiece name, + const FunctionDef& function); + +// Checks whether the function contains a node with the given op. +bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function); + +// Checks whether the function contains an output with the given name. +bool ContainsFunctionOutputWithName(StringPiece name, + const FunctionDef& function); + +// Returns the index of the function input with the given name or -1 if the +// function node does not exist. +int FindFunctionInputWithName(StringPiece name, const FunctionDef& function); + +// Returns the index of the function output with the given name or -1 if the +// function node does not exist. +int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function); + +// Returns the index of the function node with the given name or -1 if the +// function node does not exist. +int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function); + +// Returns the index of the function node with the given op or -1 if the +// function node does not exist. +int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function); + +// Sets the function node name using the `prefix` as a prefix while guaranteeing +// the name is unique across the functions nodes. +void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, + NodeDef* node); + +} // end namespace function_utils +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ diff --git a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc new file mode 100644 index 0000000000..3739e20eb1 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc @@ -0,0 +1,164 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" + +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace grappler { +namespace function_utils { +namespace { + +TEST(FunctionDefTensorDesc, Parsing) { + FunctionDefTensorDesc f("Cast:y:0"); + EXPECT_EQ(f.full_str, "Cast:y:0"); + EXPECT_EQ(f.node_name, "Cast"); + EXPECT_EQ(f.node_output, "y"); + EXPECT_EQ(f.position, 0); + + FunctionDefTensorDesc f2("Arg0"); + EXPECT_EQ(f2.full_str, "Arg0"); + EXPECT_EQ(f2.node_name, "Arg0"); + EXPECT_EQ(f2.node_output, ""); + EXPECT_EQ(f2.position, -1); +} + +TEST(ReplaceReferencesTest, ReplaceReferencesTest) { + FunctionDef outer = FunctionDefHelper::Create( + "outer", {"arg0: int32"}, {"out: int32", "out2: int64"}, {}, {}, + {{"out", "MapDefun:output:0"}, {"out2", "Cast:y:0"}}); + NodeDef* derive_node = + AddNode("X", "Some_Op", {"MapDefun:output:0"}, {}, &outer); + // Check that both the input to "X" and retval of "outer" are replaced. + ReplaceReferences("MapDefun:output:0", "arg0", &outer); + EXPECT_EQ(outer.ret().at("out"), "arg0"); + EXPECT_EQ(derive_node->input(0), "arg0"); +} + +TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) { + FunctionDef function = test::function::XTimesTwo(); + AddFunctionOutputWithUniqueName("y", "two", &function, DT_INT64); + EXPECT_TRUE(ContainsFunctionOutputWithName("y/_1", function)); + EXPECT_EQ(function.ret().at("y/_1"), "two"); +} + +TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_FALSE(ContainsFunctionNodeWithName( + "weird_name_that_should_not_be_there", function)); + EXPECT_TRUE(ContainsFunctionNodeWithName("two", function)); +} + +TEST(FunctionUtilsTest, ContainsFunctionNodeWithOp) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there", + function)); + EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function)); +} + +TEST(FunctionUtilsTest, ContainsFunctionOutputWithName) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_TRUE(ContainsFunctionOutputWithName("y", function)); + EXPECT_FALSE(ContainsFunctionOutputWithName("Add:z:0", function)); +} + +TEST(FunctionUtilsTest, FindFunctionNodeWithName) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_EQ( + FindFunctionNodeWithName("weird_name_that_should_not_be_there", function), + -1); + EXPECT_NE(FindFunctionNodeWithName("two", function), -1); +} + +TEST(FunctionUtilsTest, FindFunctionNodeWithOp) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_EQ( + FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function), + -1); + EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1); +} + +TEST(FunctionUtilsTest, FindFunctionInputWithName) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_EQ(FindFunctionInputWithName("x", function), 0); + EXPECT_EQ(FindFunctionInputWithName("not_a_name", function), -1); +} + +TEST(FunctionUtilsTest, FindFunctionOutputWithName) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_EQ(FindFunctionOutputWithName("y", function), 0); + EXPECT_EQ(FindFunctionOutputWithName("Add:z:0", function), -1); +} + +TEST(FunctionUtilsTest, SetUniqueFunctionNodeName) { + FunctionDef function = test::function::XTimesTwo(); + NodeDef node; + SetUniqueFunctionNodeName("abc", &function, &node); + for (const NodeDef& function_node : function.node_def()) { + EXPECT_NE(node.name(), function_node.name()); + } + auto* new_node = function.add_node_def(); + *new_node = node; + + NodeDef other; + SetUniqueFunctionNodeName("abc", &function, &other); + EXPECT_NE(other.name(), new_node->name()); +} + +TEST(FunctionUtilsTest, AddNodeToFunctionDef) { + FunctionDef func; + const char* op_name = "xxx"; + AddNode(op_name, op_name, {}, {}, &func); + + const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func)); + EXPECT_EQ(node1.op(), op_name); + EXPECT_EQ(node1.input_size(), 0); + EXPECT_EQ(node1.attr_size(), 0); + + const std::vector<string> inputs({"input1", "input2"}); + AddNode("", op_name, inputs, {}, &func); + const NodeDef& node2 = + func.node_def(FindFunctionNodeWithName("xxx/_2", func)); + EXPECT_EQ(node2.op(), op_name); + EXPECT_EQ(node2.attr_size(), 0); + EXPECT_EQ(node2.input_size(), inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + EXPECT_EQ(node2.input(i), inputs[i]); + } + + AttrValue a1, a2; + a1.set_type(DT_INT32); + a2.set_type(DT_INT64); + const std::vector<std::pair<string, AttrValue>> attrs( + {{"attr1", a1}, {"attr2", a2}}); + AddNode("", op_name, {}, attrs, &func); + const NodeDef& node3 = + func.node_def(FindFunctionNodeWithName("xxx/_3", func)); + EXPECT_EQ(node3.op(), op_name); + EXPECT_EQ(node3.input_size(), 0); + EXPECT_EQ(node3.attr_size(), attrs.size()); + for (size_t i = 0; i < attrs.size(); ++i) { + EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type()); + } +} + +} // namespace +} // namespace function_utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc index 01a78c04b0..b3bfee138f 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -407,7 +408,7 @@ void LazyConjunctionNodes(const FunctionDef& first_function, auto* if_node = fused_function->add_node_def(); // This is guaranteed to succeed. TF_CHECK_OK(if_builder.Finalize(if_node)); - graph_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node); + function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node); GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0"; } diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc index d5c6466080..e667affeea 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -110,9 +111,9 @@ TEST(FusionUtilsTest, FuseFunctionWithPredicate) { CheckUniqueNames(*fused_function); ASSERT_TRUE( - graph_utils::ContainsFunctionNodeWithOp("Equal", *fused_function)); + function_utils::ContainsFunctionNodeWithOp("Equal", *fused_function)); const auto &equal_node = fused_function->node_def( - graph_utils::FindFunctionNodeWithOp("Equal", *fused_function)); + function_utils::FindFunctionNodeWithOp("Equal", *fused_function)); EXPECT_EQ(xtimes_two->signature().output_arg(0).name(), fused_function->signature().output_arg(0).name()); diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index d4ab444036..b3f60e34f9 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -108,26 +108,6 @@ NodeDef* AddNode(StringPiece name, StringPiece op, return graph->AddNode(std::move(node)); } -NodeDef* AddNode(StringPiece name, StringPiece op, - const std::vector<string>& inputs, - const std::vector<std::pair<string, AttrValue>>& attributes, - FunctionDef* fd) { - NodeDef* node = fd->add_node_def(); - if (!name.empty()) { - node->set_name(string(name)); - } else { - SetUniqueFunctionNodeName(op, fd, node); - } - node->set_op(string(op)); - for (const string& input : inputs) { - node->add_input(input); - } - for (auto attr : attributes) { - (*node->mutable_attr())[attr.first] = attr.second; - } - return node; -} - template <> NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) { return AddScalarConstNodeHelper( @@ -196,6 +176,11 @@ bool Compare(const GraphDef& g1, const GraphDef& g2) { return true; } +bool ContainsGraphFunctionWithName(StringPiece name, + const FunctionDefLibrary& library) { + return FindGraphFunctionWithName(name, library) != -1; +} + bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) { return FindGraphNodeWithName(name, graph) != -1; } @@ -204,18 +189,14 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) { return FindGraphNodeWithOp(op, graph) != -1; } -bool ContainsGraphFunctionWithName(StringPiece name, - const FunctionDefLibrary& library) { - return FindGraphFunctionWithName(name, library) != -1; -} - -bool ContainsFunctionNodeWithName(StringPiece name, - const FunctionDef& function) { - return FindFunctionNodeWithName(name, function) != -1; -} - -bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { - return FindFunctionNodeWithOp(op, function) != -1; +int FindGraphFunctionWithName(StringPiece name, + const FunctionDefLibrary& library) { + std::vector<int> indices = GetElementIndicesWithPredicate( + [&name](const FunctionDef& function) { + return function.signature().name() == name; + }, + library.function()); + return indices.empty() ? -1 : indices.front(); } int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) { @@ -237,31 +218,6 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op, [&op](const NodeDef& node) { return node.op() == op; }, graph.node()); } -int FindGraphFunctionWithName(StringPiece name, - const FunctionDefLibrary& library) { - std::vector<int> indices = GetElementIndicesWithPredicate( - [&name](const FunctionDef& function) { - return function.signature().name() == name; - }, - library.function()); - return indices.empty() ? -1 : indices.front(); -} - -int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( - [&name](const NodeDef& node) { return node.name() == name; }, - function.node_def()); - return indices.empty() ? -1 : indices.front(); -} - -int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( - [&op](const NodeDef& node) { return node.op() == op; }, - function.node_def()); - - return indices.empty() ? -1 : indices.front(); -} - NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) { if (node.input_size() == 0) return nullptr; GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0); @@ -284,17 +240,6 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, node->set_name(std::move(name)); } -void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, - NodeDef* node) { - string name = string(prefix); - int id = function->node_def_size(); - while (ContainsFunctionNodeWithName(name, *function)) { - name = strings::StrCat(prefix, "/_", id); - ++id; - } - node->set_name(std::move(name)); -} - void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library, FunctionDef* function) { string name = string(prefix); @@ -305,7 +250,6 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library, } function->mutable_signature()->set_name(std::move(name)); } - } // end namespace graph_utils } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index 6f431c232d..1652afcd9e 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -37,12 +37,6 @@ NodeDef* AddNode(StringPiece name, StringPiece op, const std::vector<std::pair<string, AttrValue>>& attributes, MutableGraphView* graph); -// Adds a node to a FunctionDef. -NodeDef* AddNode(StringPiece name, StringPiece op, - const std::vector<string>& inputs, - const std::vector<std::pair<string, AttrValue>>& attributes, - FunctionDef* fd); - // Adds a Const node with the given value to the graph. template <typename T> NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) { @@ -76,13 +70,6 @@ bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph); bool ContainsGraphFunctionWithName(StringPiece name, const FunctionDefLibrary& library); -// Checks whether the function contains a node with the given name. -bool ContainsFunctionNodeWithName(StringPiece name, - const FunctionDef& function); - -// Checks whether the function contains a node with the given op. -bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function); - // Checks whether the graph contains a node with the given op. bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph); @@ -95,14 +82,6 @@ int FindGraphNodeWithName(StringPiece name, const GraphDef& graph); int FindGraphFunctionWithName(StringPiece name, const FunctionDefLibrary& library); -// Returns the index of the function node with the given name or -1 if the -// function node does not exist. -int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function); - -// Returns the index of the function node with the given op or -1 if the -// function node does not exist. -int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function); - // Returns the index of the first node with the given op or -1 if no such node // exists. int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph); @@ -119,11 +98,6 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op, // is unique across the graph. void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node); -// Sets the function node name using the `prefix` as a prefix while guaranteeing -// the name is unique across the functions nodes. -void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, - NodeDef* node); - // Sets the node name using the `prefix` name as a prefix while guaranteeing the // name is unique across the graph. void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library, diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index c19ac7b880..6877c207c4 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -112,20 +112,6 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) { ContainsGraphFunctionWithName(new_function->signature().name(), library)); } -TEST(GraphUtilsTest, ContainsFunctionNodeWithName) { - FunctionDef function = test::function::XTimesTwo(); - EXPECT_FALSE(ContainsFunctionNodeWithName( - "weird_name_that_should_not_be_there", function)); - EXPECT_TRUE(ContainsFunctionNodeWithName("two", function)); -} - -TEST(GraphUtilsTest, ContainsFunctionNodeWithOp) { - FunctionDef function = test::function::XTimesTwo(); - EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there", - function)); - EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function)); -} - TEST(GraphUtilsTest, ContainsNodeWithOp) { GraphDef graph_def; MutableGraphView graph(&graph_def); @@ -150,22 +136,6 @@ TEST(GraphUtilsTest, FindGraphNodeWithName) { EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1); } -TEST(GraphUtilsTest, FindFunctionNodeWithName) { - FunctionDef function = test::function::XTimesTwo(); - EXPECT_EQ( - FindFunctionNodeWithName("weird_name_that_should_not_be_there", function), - -1); - EXPECT_NE(FindFunctionNodeWithName("two", function), -1); -} - -TEST(GraphUtilsTest, FindFunctionNodeWithOp) { - FunctionDef function = test::function::XTimesTwo(); - EXPECT_EQ( - FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function), - -1); - EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1); -} - TEST(GraphUtilsTest, FindGraphFunctionWithName) { FunctionDefLibrary library; EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1); @@ -225,21 +195,6 @@ TEST(GraphUtilsTest, SetUniqueGraphNodeName) { EXPECT_NE(node2->name(), node3->name()); } -TEST(GraphUtilsTest, SetUniqueFunctionNodeName) { - FunctionDef function = test::function::XTimesTwo(); - NodeDef node; - SetUniqueFunctionNodeName("abc", &function, &node); - for (const NodeDef& function_node : function.node_def()) { - EXPECT_NE(node.name(), function_node.name()); - } - auto* new_node = function.add_node_def(); - *new_node = node; - - NodeDef other; - SetUniqueFunctionNodeName("abc", &function, &other); - EXPECT_NE(other.name(), new_node->name()); -} - TEST(GraphUtilsTest, SetUniqueGraphFunctionName) { FunctionDefLibrary library; FunctionDef* new_function = library.add_function(); @@ -251,43 +206,6 @@ TEST(GraphUtilsTest, SetUniqueGraphFunctionName) { other_function->signature().name()); } -TEST(GraphUtilsTest, AddNodeToFunctionDef) { - FunctionDef func; - const char* op_name = "xxx"; - AddNode(op_name, op_name, {}, {}, &func); - - const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func)); - EXPECT_EQ(node1.op(), op_name); - EXPECT_EQ(node1.input_size(), 0); - EXPECT_EQ(node1.attr_size(), 0); - - const std::vector<string> inputs({"input1", "input2"}); - AddNode("", op_name, inputs, {}, &func); - const NodeDef& node2 = - func.node_def(FindFunctionNodeWithName("xxx/_2", func)); - EXPECT_EQ(node2.op(), op_name); - EXPECT_EQ(node2.attr_size(), 0); - EXPECT_EQ(node2.input_size(), inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) { - EXPECT_EQ(node2.input(i), inputs[i]); - } - - AttrValue a1, a2; - a1.set_type(DT_INT32); - a2.set_type(DT_INT64); - const std::vector<std::pair<string, AttrValue>> attrs( - {{"attr1", a1}, {"attr2", a2}}); - AddNode("", op_name, {}, attrs, &func); - const NodeDef& node3 = - func.node_def(FindFunctionNodeWithName("xxx/_3", func)); - EXPECT_EQ(node3.op(), op_name); - EXPECT_EQ(node3.input_size(), 0); - EXPECT_EQ(node3.attr_size(), attrs.size()); - for (size_t i = 0; i < attrs.size(); ++i) { - EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type()); - } -} - TEST(GraphUtilsTest, GetInputNode) { GraphDef graph_def; MutableGraphView graph(&graph_def); diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index a019b77eb7..07766aa7b3 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -52,8 +53,8 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node, // Add MapDefun node NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add(); map_defun_node->set_op("MapDefun"); - graph_utils::SetUniqueFunctionNodeName(map_defun_node->op(), vectorized_func, - map_defun_node); + function_utils::SetUniqueFunctionNodeName(map_defun_node->op(), + vectorized_func, map_defun_node); // Set attrs and inputs for (const string& k : {"f", "output_types", "output_shapes"}) { diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc new file mode 100644 index 0000000000..6a59eb0d32 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -0,0 +1,341 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" + +#include "absl/strings/str_join.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/functions.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/scanner.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +using function_utils::FunctionDefTensorDesc; + +namespace { + +void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node, + const string& output_retval, const DataType t) { + // Set to unknown shape + TensorShapeProto tensor_shape_proto; + PartialTensorShape().AsProto(&tensor_shape_proto); + + function_utils::AddFunctionOutputWithUniqueName( + "vectorized_out", output_retval, map_defun_fn, t); + + *(*map_defun_node->mutable_attr())["output_shapes"] + .mutable_list() + ->add_shape() = tensor_shape_proto; + (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t); +} + +void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn, + NodeDef* map_defun_node, int output_position) { + DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size()) + << "Trying to remove output that doesn't exist. Output number: " + << output_position; + + int num_later_outputs = + map_defun_fn->signature().output_arg_size() - output_position - 1; + + // Remove from map_defun_fn's ret dict and output args + map_defun_fn->mutable_ret()->erase( + map_defun_fn->signature().output_arg(output_position).name()); + map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange( + output_position, 1); + + // Renumber outputs that come after + for (int i = 0; i < num_later_outputs; ++i) { + function_utils::ReplaceReferences( + strings::StrCat(map_defun_node->name(), + ":output:", output_position + i + 1), + strings::StrCat(map_defun_node->name(), + ":output:", output_position + i), + outer_scope); + } + map_defun_node->mutable_attr() + ->at("output_shapes") + .mutable_list() + ->mutable_shape() + ->DeleteSubrange(output_position, 1); + map_defun_node->mutable_attr() + ->at("output_types") + .mutable_list() + ->mutable_type() + ->ExtractSubrange(output_position, 1, nullptr); +} + +Status ConvertCastOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn, + NodeDef* map_defun_node, const NodeDef& cast_node, + const FunctionDefTensorDesc& output_desc, + std::map<string, string>* conversion_map) { + if (output_desc.node_output != "y" || output_desc.position != 0) { + // We expect the Cast node to have only one output, with the name "y". + return errors::Internal("Cannot convert Cast op output."); + } + + // Promote Cast inputs to outputs of MapDefun + DCHECK_EQ(cast_node.input_size(), 1); + AddMapDefunOutput(map_defun_fn, map_defun_node, cast_node.input(0), + cast_node.attr().at("SrcT").type()); + + // Add new Cast node + NodeDef* new_cast_node = outer_scope->add_node_def(); + *new_cast_node = cast_node; + new_cast_node->clear_name(); + function_utils::SetUniqueFunctionNodeName( + strings::StrCat("vectorized/", cast_node.name()), outer_scope, + new_cast_node); + new_cast_node->set_input( + 0, strings::StrCat(map_defun_node->name(), ":output:", + map_defun_fn->signature().output_arg_size() - 1)); + + // Add the output mapping to conversion map + (*conversion_map)[strings::StrCat(output_desc.node_name, ":y:0")] = + strings::StrCat(new_cast_node->name(), ":y:0"); + + return Status::OK(); +} + +Status ConvertUnpackOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn, + NodeDef* map_defun_node, const NodeDef& unpack_node, + const FunctionDefTensorDesc& output_desc, + std::map<string, string>* conversion_map) { + if (output_desc.node_output != "output") { + return errors::Internal("Cannot convert Unpack op output."); + } + + // Promote Unpack inputs to outputs of MapDefun + AddMapDefunOutput(map_defun_fn, map_defun_node, unpack_node.input(0), + unpack_node.attr().at("T").type()); + + // Add new Unpack node + NodeDef* new_unpack_node = outer_scope->add_node_def(); + *new_unpack_node = unpack_node; + new_unpack_node->clear_name(); + function_utils::SetUniqueFunctionNodeName( + strings::StrCat("vectorized/", unpack_node.name()), outer_scope, + new_unpack_node); + + // Increment "axis" attr by 1: + (*new_unpack_node->mutable_attr())["axis"].set_i( + unpack_node.attr().at("axis").i() + 1); + new_unpack_node->set_input( + 0, strings::StrCat(map_defun_node->name(), ":output:", + map_defun_fn->signature().output_arg_size() - 1)); + + // Add the output mappings to conversion map + int num = new_unpack_node->attr().at("num").i(); + for (int i = 0; i < num; ++i) { + (*conversion_map)[strings::StrCat(output_desc.node_name, ":output:", i)] = + strings::StrCat(new_unpack_node->name(), ":output:", i); + } + + return Status::OK(); +} + +int FindOutputToConvert(const FunctionDef& function, + const std::set<string>& unconvertible, + FunctionDefTensorDesc* f) { + for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) { + const string& ret_key = function.signature().output_arg(i).name(); + *f = FunctionDefTensorDesc(function.ret().at(ret_key)); + + if (unconvertible.find(f->node_name) == unconvertible.end()) { + return i; + } + } + return -1; +} + +// Helper class that vectorizes the body of a MapDefun node, adding new +// operations to the graph that collectively compute the same value as what +// running the MapDefun function on slices of the input would produce. +// Each instance of the class encapsulates all the data necessary to vectorize a +// MapDefun op in place. +class Vectorization { + public: + Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn, + NodeDef* map_defun_node) + : outer_scope_(outer_scope), + map_defun_fn_(map_defun_fn), + map_defun_node_(map_defun_node) {} + + // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in + // the outer_scope_, until there are no convertible outputs remaining. + // This method is idempotent. + void Vectorize(); + + private: + // Vectorizes the map defun function's output at output_position + Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc); + // Given a descriptor of the original output tensor, gets a string + // corresponding to the converted output tensor. + Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc, + string* converted); + Status AddConversionMappingFromInput( + const FunctionDefTensorDesc& output_desc); + + // Adds mappings from node's outputs tensors to converted output tensors, + // creating the necessary new node(s). Generally, the steps to convert an op + // are: + // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_, + // and modify map_defun_node_ attrs accordingly + // 2) Create new node(s) in outer_scope_ that act on batched input tensors. + // These operations collectively compute the same value as what running + // the original operation on slices of the input tensors would produce. + // For example, a Cast op in MapDefun translates to a Cast op in + // outer_scope_, since the vectorized version of Cast is itself. + // 3) Set inputs of new node(s) to the corresponding converted inputs (that + // are now outputs of map_defun_node_) + // 4) For each output of the old node, add the mapping of output strings to + // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0") + Status AddConversionMappingFromOp(const NodeDef& node, + const FunctionDefTensorDesc& output_desc); + + // Maps a tensor name to the name of the corresponding vectorized tensor. For + // example, "Cast:y:0" -> "Vectorize/Cast:y:0" + std::map<string, string> conversion_map_; + // Unconvertible node names + std::set<string> unconvertible_; + + FunctionDef* outer_scope_; + FunctionDef* map_defun_fn_; + NodeDef* map_defun_node_; +}; + +Status Vectorization::AddConversionMappingFromOp( + const NodeDef& node, const FunctionDefTensorDesc& output_desc) { + for (const string& input_name : node.input()) { + if (IsControlInput(input_name)) { + return errors::InvalidArgument( + "Vectorizing outputs with control inputs is currently not " + "supported."); + } + } + + // TODO(rachelim): Have some mechanism for registering converters and some + // uniform, simpler way to represent them. + + // TODO(rachelim): Do step (1) outside of the individual op converters, when + // we know how to find out the type of the input. + if (node.op() == "Cast") { + return ConvertCastOp(outer_scope_, map_defun_fn_, map_defun_node_, node, + output_desc, &conversion_map_); + } else if (node.op() == "Unpack") { + return ConvertUnpackOp(outer_scope_, map_defun_fn_, map_defun_node_, node, + output_desc, &conversion_map_); + } + return errors::Unimplemented("Op converter for \"", node.op(), + "\" not implemented yet"); +} + +Status Vectorization::AddConversionMappingFromInput( + const FunctionDefTensorDesc& output_desc) { + int input_index = function_utils::FindFunctionInputWithName( + output_desc.node_name, *map_defun_fn_); + if (input_index == -1) { + return errors::Internal("Cannot convert non-existent input."); + } + + conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index); + return Status::OK(); +} + +Status Vectorization::ConvertOutputHelper( + const FunctionDefTensorDesc& output_desc, string* converted) { + // It's possible the output already has a mapping, if it comes from a node + // that has already been converted. + if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) { + *converted = *found; + return Status::OK(); + } + + int index = function_utils::FindFunctionNodeWithName(output_desc.node_name, + *map_defun_fn_); + if (index == -1) { // The output comes from an input + TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc)); + } else { + TF_RETURN_IF_ERROR(AddConversionMappingFromOp( + map_defun_fn_->node_def(index), output_desc)); + } + *converted = conversion_map_.at(output_desc.full_str); + return Status::OK(); +} + +Status Vectorization::ConvertOutput(int output_position, + const FunctionDefTensorDesc& output_desc) { + string converted_output_name; + TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name)); + + // Remove the old output and make everything that referenced it point + // to the new string + function_utils::ReplaceReferences( + strings::StrCat(map_defun_node_->name(), ":output:", output_position), + converted_output_name, outer_scope_); + RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_, + output_position); + + return Status::OK(); +} + +void Vectorization::Vectorize() { + while (true) { + FunctionDefTensorDesc desc; + int output_position = + FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc); + if (output_position == -1) break; + + if (!ConvertOutput(output_position, desc).ok()) { + unconvertible_.insert(desc.node_name); + } + } + + // If we've converted all the outputs of the MapDefun function, we no longer + // need the MapDefun node and can delete it. + if (map_defun_fn_->signature().output_arg_size() == 0) { + outer_scope_->mutable_node_def()->DeleteSubrange( + function_utils::FindFunctionNodeWithName(map_defun_node_->name(), + *outer_scope_), + 1); + } + + if (!unconvertible_.empty()) { + VLOG(2) << "The following nodes could not be converted: [" + << absl::StrJoin(unconvertible_, ", ") << "]."; + } +} +} // namespace + +void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, + NodeDef* map_defun_node) { + Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize(); +} + +} // end namespace vectorization_utils +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h new file mode 100644 index 0000000000..bb405faa77 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_ + +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +// Given a function, `map_defun_fn`, that is mapped across some input vector +// elements via a MapDefun operation, `VectorizeMapDefun` attempts to +// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the +// `outer_scope`; that is, replacing `map_defun_fn` operations with new +// `outer_scope` operations that produce the same vector output(s) as executing +// the `map_defun_fn` operations on elements of vector input(s) would. If all +// `map_defun_fn` operations are successfully lifted, `map_defun_node` is +// eliminated from `outer_scope` altogether. However, if some operations cannot +// be lifted, and this vectorization only succeeds partially, `map_defun_node` +// remains to be used for operations that were not lifted. +// +// Example: +// If the input to the `VectorizeMapDefun` function is a MapDefun +// whose `map_defun_fn` performs the Cast operation, the vectorization will +// eliminate the MapDefun. This is because the Cast operation supports +// any tensor shape and can thus be lifted to the `outer_scope`. +// +// Before: +// +// +// outer_scope +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | map_defun_fn +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +---+--+ | | +// | | | | | +// | | | | | +// | | +---v--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// outer_scope +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, + NodeDef* map_defun_node); + +} // end namespace vectorization_utils +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc new file mode 100644 index 0000000000..e129fa9237 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -0,0 +1,600 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { +namespace { + +NodeDef* AddCastNode(const string& name, const std::vector<string>& inputs, + DataType src, DataType dst, bool truncate, + FunctionDef* fn) { + NodeDef* node = function_utils::AddNode(name, "Cast", inputs, {}, fn); + graph_transforms::SetNodeAttr("SrcT", src, node); + graph_transforms::SetNodeAttr("DstT", dst, node); + graph_transforms::SetNodeAttr("Truncate", truncate, node); + return node; +} + +NodeDef* AddUnstackNode(const string& name, const std::vector<string>& inputs, + DataType t, int axis, int num, FunctionDef* fn) { + NodeDef* node = function_utils::AddNode(name, "Unpack", inputs, {}, fn); + graph_transforms::SetNodeAttr("T", t, node); + graph_transforms::SetNodeAttr("axis", axis, node); + graph_transforms::SetNodeAttr("num", num, node); + return node; +} + +NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs, + const std::vector<DataType>& t_arguments, + const std::vector<DataType>& output_types, + const std::vector<TensorShape>& output_shapes, + const string& function_name, FunctionDef* fn) { + NameAttrList func; + func.set_name(function_name); + NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn); + graph_transforms::SetNodeAttr("Targuments", t_arguments, node); + graph_transforms::SetNodeAttr("output_types", output_types, node); + graph_transforms::SetNodeAttr("output_shapes", output_shapes, node); + graph_transforms::SetNodeAttr("f", func, node); + return node; +} + +// TODO(rachelim): Use FunctionDefHelper::Create instead +FunctionDef CreateFunction( + StringPiece name, const std::vector<std::pair<string, DataType>>& inputs, + const std::vector<std::pair<string, DataType>>& outputs, + const std::map<string, string>& rets) { + FunctionDef func; + auto* signature = func.mutable_signature(); + signature->set_name(string(name)); + for (const auto& x : inputs) { + auto* arg_def = signature->add_input_arg(); + arg_def->set_name(x.first); + arg_def->set_type(x.second); + } + for (const auto& x : outputs) { + auto* arg_def = signature->add_output_arg(); + arg_def->set_name(x.first); + arg_def->set_type(x.second); + } + for (const auto& x : rets) { + (*func.mutable_ret())[x.first] = x.second; + } + + return func; +} + +TEST(FunctionDefInputDescTest, ConstructedCorrectly) {} + +// Before: +// +// +------+ +------+ +// +---------------+ Arg0 +---+ Arg1 +--------+ +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ +---v--+ | +// | +-----------+ Arg0 +---+ Arg1 +----+ | +// | | +---+--+ +---+--+ | | +// | | | | | | +// | | MapDefun +---v--+ +---v--+ | | +// | +-----------+ Ret0 +---+ Ret1 +----+ | +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ +// +// +// After: +// +// +------+ +------+ +// +---------------+ Arg0 +---+ Arg1 +--------+ +// | +---+--+ +---+--+ | +// | | | | +// | | | | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ +// +TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { + FunctionDef inner = + CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}}, + {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, + {{"ret0", "arg0"}, {"ret1", "arg1"}}); + FunctionDef outer = CreateFunction( + "outer_function", {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, + {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}}, + {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}}); + + NodeDef* map_defun = AddMapDefunNode( + "MapDefun", {"ret0", "ret1"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}, + {{}, {}}, inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + VectorizeMapDefun(&outer, &inner, map_defun); + EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); + EXPECT_EQ(outer.ret().at("mapdefun"), "ret0"); + EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1"); +} + +// Before: +// +// +------+ +------+ +// +---------------+ Arg0 +---+ Arg1 +--------+ +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ +---v--+ | +// | +-----------+ Arg0 +---+ Arg1 +----+ | +// | | +---+--+ +---+--+ | | +// | | | | | | +// | | +------+ | +---v--+ | | +// | | |Const | | | Op0 | | | +// | | +---v--+ | +---+--+ | | +// | | | | | | | +// | | | +---v--+ +---v--+ | | +// | | +---| XOp1 | | XOp2 | | | +// | | +---+--+ +---+--+ | | +// | | | | | | +// | | MapDefun +---v--+ +---v--+ | | +// | +-----------+ Ret0 +---+ Ret1 +----+ | +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ +// +// where XOp1 and XOp2 are not convertible. +// +// After: +// +// No change because the ops are not convertible. +// +TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) { + FunctionDef inner = + CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}}, + {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, + {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}}); + NodeDef* x_op1 = + function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner); + CHECK_NOTNULL(x_op1); + + NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner); + CHECK_NOTNULL(x_op2); + + FunctionDef outer = CreateFunction( + "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}}, + {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}}, + {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}}); + + NodeDef* map_defun = AddMapDefunNode( + "MapDefun", {"x", "y"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}, + {{}, {}}, inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + FunctionDef outer_copy(outer); + FunctionDef inner_copy(inner); + VectorizeMapDefun(&outer, &inner, map_defun); + // They should be unchanged + EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer)); + EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner)); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +---+--+ | | +// | | | | | +// | | | | | +// | | +---v--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) { + FunctionDef inner = + CreateFunction("inner_function", {{"arg0", DT_INT32}}, + {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}}); + NodeDef* cast_op = + AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner); + CHECK_NOTNULL(cast_op); + + FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}}, + {{"mapdefun", DT_INT64}}, + {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + VectorizeMapDefun(&outer, &inner, map_defun); + EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); + const NodeDef& cast_node = + outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + EXPECT_EQ(cast_node.input(0), "x"); + EXPECT_EQ(outer.ret().at("mapdefun"), + strings::StrCat(cast_node.name(), ":y:0")); + EXPECT_EQ(outer.node_def_size(), 1); +} + +// Before: +// +// +------+ +// +---------------+ Arg0 +-------------------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +---------------+ | +// | | +---+--+ | | +// | | | | | +// | | | | | +// | | +---v--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | +----------+ | | +// | | | | | | +// | | MapDefun +---v--+ +---v--+ | | +// | +-----------+ Ret0 +---+ Ret1 +----+ | +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +-------------------+ +// | +---+--+ | +// | | | +// | | | +// | +---v--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +----------+ | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ +// +TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) { + // Tests that behavior is correct when an output is used more than once. + FunctionDef inner = + CreateFunction("inner_function", {{"arg0", DT_INT32}}, + {{"ret0", DT_INT64}, {"ret1", DT_INT64}}, + {{"ret0", "Cast:y:0"}, {"ret1", "Cast:y:0"}}); + NodeDef* cast_op = + AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner); + CHECK_NOTNULL(cast_op); + + FunctionDef outer = CreateFunction( + "outer_function", {{"x", DT_INT32}}, + {{"mapdefun", DT_INT64}, {"mapdefun_0", DT_INT64}}, + {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64, DT_INT64}, + {{}, {}}, inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + VectorizeMapDefun(&outer, &inner, map_defun); + EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); + const NodeDef& cast_node = + outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + EXPECT_EQ(cast_node.input(0), "x"); + EXPECT_EQ(outer.ret().at("mapdefun"), + strings::StrCat(cast_node.name(), ":y:0")); + EXPECT_EQ(outer.ret().at("mapdefun_0"), + strings::StrCat(cast_node.name(), ":y:0")); + EXPECT_EQ(outer.node_def_size(), 1); +} + +// Before: +// +// +------+ +// +----------------------+ Arg0 +----------------------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +------------------+ Arg0 +------------------+ | +// | | +---+--+ | | +// | | | | | +// | | | | | +// | | +---v---+ num=3 | | +// | | |Unstack| axis=0 | | +// | | ++--+--++ | | +// | | | | | | | +// | | +----+ | +-------+ | | +// | | | | | | | +// | | MapDefun +---v--+ +-v----+ +--v---+ | | +// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ | +// | +---+--+ +--+---+ +--+---+ | +// | | | | | +// | +---v--+ +--v---+ +--v---+ | +// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+ +// +------+ +------+ +------+ +// +// +// After: +// +// +------+ +// +----------------------+ Arg0 +----------------------+ +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v---+ num=3 | +// | |Unstack| axis=1 | +// | ++--+--++ | +// | | | | | +// | +----+ | +-------+ | +// | | | | | +// | | | | | +// | +---v--+ +-v----+ +--v---+ | +// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+ +// +------+ +------+ +------+ +// +TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) { + FunctionDef inner = CreateFunction( + "inner_function", {{"arg0", DT_INT32}}, + {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}}, + {{"ret0", "MyUnstack:output:0"}, + {"ret1", "MyUnstack:output:1"}, + {"ret2", "MyUnstack:output:2"}}); + NodeDef* unstack_op = + AddUnstackNode("MyUnstack", {"arg0"}, DT_INT32, 0, 3, &inner); + CHECK_NOTNULL(unstack_op); + + FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}}, + {{"mapdefun", DT_INT32}, + {"mapdefun_0", DT_INT32}, + {"mapdefun_1", DT_INT32}}, + {{"mapdefun", "MapDefun:output:0"}, + {"mapdefun_0", "MapDefun:output:1"}, + {"mapdefun_1", "MapDefun:output:2"}}); + + NodeDef* map_defun = AddMapDefunNode( + "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32}, + {{1}, {1}, {1}}, inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + VectorizeMapDefun(&outer, &inner, map_defun); + EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); + const NodeDef& unpack_node = + outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer)); + EXPECT_EQ(unpack_node.input(0), "x"); + EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); + EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); + EXPECT_EQ(unpack_node.attr().at("num").i(), 3); + EXPECT_EQ(outer.ret().at("mapdefun"), + strings::StrCat(unpack_node.name(), ":output:0")); + EXPECT_EQ(outer.ret().at("mapdefun_0"), + strings::StrCat(unpack_node.name(), ":output:1")); + EXPECT_EQ(outer.ret().at("mapdefun_1"), + strings::StrCat(unpack_node.name(), ":output:2")); + EXPECT_EQ(outer.node_def_size(), 1); +} + +// Before: +// +// +------+ +// +----------------------+ Arg0 +----------------------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +------------------+ Arg0 +------------------+ | +// | | +---+--+ | | +// | | | | | +// | | +---+--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | +---v---+ num=3 | | +// | | |Unstack| axis=0 | | +// | | ++--+--++ | | +// | | | | | | | +// | | +----+ | +-------+ | | +// | | | | | | | +// | | MapDefun +---v--+ +-v----+ +--v---+ | | +// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ | +// | +---+--+ +--+---+ +--+---+ | +// | | | | | +// | +---v--+ +--v---+ +--v---+ | +// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+ +// +------+ +------+ +------+ +// +// +// After: +// +// +------+ +// +----------------------+ Arg0 +----------------------+ +// | +---+--+ | +// | | | +// | +---+--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +---v---+ num=3 | +// | |Unstack| axis=1 | +// | ++--+--++ | +// | | | | | +// | +----+ | +-------+ | +// | | | | | +// | | | | | +// | +---v--+ +-v----+ +--v---+ | +// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+ +// +------+ +------+ +------+ +// +TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) { + FunctionDef inner = CreateFunction( + "inner_function", {{"arg0", DT_INT32}}, + {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}}, + {{"ret0", "MyUnstack:output:0"}, + {"ret1", "MyUnstack:output:1"}, + {"ret2", "MyUnstack:output:2"}}); + NodeDef* cast_op = + AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner); + CHECK_NOTNULL(cast_op); + NodeDef* unstack_op = + AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner); + CHECK_NOTNULL(unstack_op); + + FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}}, + {{"mapdefun", DT_INT32}, + {"mapdefun_0", DT_INT32}, + {"mapdefun_1", DT_INT32}}, + {{"mapdefun", "MapDefun:output:0"}, + {"mapdefun_0", "MapDefun:output:1"}, + {"mapdefun_1", "MapDefun:output:2"}}); + + NodeDef* map_defun = AddMapDefunNode( + "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32}, + {{1}, {1}, {1}}, inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + VectorizeMapDefun(&outer, &inner, map_defun); + EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); + const NodeDef& cast_node = + outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + EXPECT_EQ(cast_node.input(0), "x"); + const NodeDef& unpack_node = + outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer)); + EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0")); + EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); + EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); + EXPECT_EQ(unpack_node.attr().at("num").i(), 3); + + EXPECT_EQ(outer.ret().at("mapdefun"), + strings::StrCat(unpack_node.name(), ":output:0")); + EXPECT_EQ(outer.ret().at("mapdefun_0"), + strings::StrCat(unpack_node.name(), ":output:1")); + EXPECT_EQ(outer.ret().at("mapdefun_1"), + strings::StrCat(unpack_node.name(), ":output:2")); + EXPECT_EQ(outer.node_def_size(), 2); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +---+--+ | | +// | | +---------+ | | +// | | +---v--+ | | | +// | | |Print | | | | +// | | +---+--+ | | | +// | | : +---v--+ | | +// | | ::::::> Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// No change because we don't deal with control inputs for now. +// +TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { + FunctionDef inner = + CreateFunction("inner_function", {{"arg0", DT_INT32}}, + {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}}); + // The attrs aren't relevant + NodeDef* print_op = + function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner); + CHECK_NOTNULL(print_op); + NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64, + false, &inner); + CHECK_NOTNULL(cast_op); + + FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}}, + {{"mapdefun", DT_INT64}}, + {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + FunctionDef outer_copy(outer); + FunctionDef inner_copy(inner); + VectorizeMapDefun(&outer, &inner, map_defun); + // They should be unchanged + EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer)); +} + +// TODO(rachelim): More test cases when we get around to implementing them: +// [] A badly defined converter, e.g. doesn't produce nodes that have the +// same number of outputs/inputs as the nodes to be converted +// [] Converter where the 'converted' form has multiple nodes. +// [] Case with dependent nodes, e.g. ops with const inputs that are +// broadcasted. +// [] Python-side tests to actually run the functions to make sure +// they work. + +} // namespace +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow |