aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-19 13:43:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 13:46:34 -0700
commiteb2fe78e60ea8465443d8c653381a788ec581e49 (patch)
treeb4dfbf11a0e35316aa45f912a9060fe46dc46d21
parent125bf1dbb76c05bf5f88f14e77387ce35f986621 (diff)
[tf.data] MapVectorization optimization: C++ conversion framework to vectorize a MapDefun function. Also implements conversion for two ops: Cast and Unpack.
PiperOrigin-RevId: 213686720
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD76
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.cc196
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.h108
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils_test.cc164
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc82
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h26
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc82
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc341
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.h90
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc600
13 files changed, 1596 insertions, 182 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index e84df10778..7128a50be0 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -49,6 +49,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ ":function_utils",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -67,6 +68,7 @@ tf_cc_test(
srcs = ["fusion_utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":function_utils",
":fusion_utils",
":graph_utils",
"//tensorflow/core:framework",
@@ -78,6 +80,40 @@ tf_cc_test(
)
cc_library(
+ name = "function_utils",
+ srcs = ["function_utils.cc"],
+ hdrs = [
+ "function_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "function_utils_test",
+ srcs = ["function_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ],
+)
+
+cc_library(
name = "graph_utils",
srcs = ["graph_utils.cc"],
hdrs = [
@@ -137,6 +173,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ ":function_utils",
":graph_utils",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
@@ -409,3 +446,42 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
],
)
+
+cc_library(
+ name = "vectorization_utils",
+ srcs = ["vectorization_utils.cc"],
+ hdrs = [
+ "vectorization_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ "@com_google_absl//absl/strings",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/utils:functions",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "vectorization_utils_test",
+ srcs = ["vectorization_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":vectorization_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ] + tf_protos_all(),
+)
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc
new file mode 100644
index 0000000000..e95ea1a4c1
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc
@@ -0,0 +1,196 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+namespace {
+
+template <typename Predicate, typename Collection>
+std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
+ const Collection& collection) {
+ std::vector<int> indices = {};
+ unsigned idx = 0;
+ for (auto&& element : collection) {
+ if (predicate(element)) {
+ indices.push_back(idx);
+ }
+ idx++;
+ }
+ return indices;
+}
+
+} // namespace
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
+ const string& output, int position)
+ : node_name(node_name), node_output(output), position(position) {
+ full_str = strings::StrCat(node_name, ":", node_output, ":", position);
+}
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) {
+ // Parses node_name:node_output:position string into its components.
+ full_str = input;
+ StringPiece capture;
+ StringPiece remaining;
+
+ // Parse "node_name"
+ if (strings::Scanner(input)
+ .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
+ .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+ .GetResult(&remaining, &capture)) {
+ node_name = string(capture.data(), capture.size());
+ }
+
+ // Parse "node_output" if it exists
+ if (strings::Scanner(remaining)
+ .OneLiteral(":")
+ .RestartCapture()
+ .One(strings::Scanner::LETTER)
+ .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
+ .GetResult(&remaining, &capture)) {
+ node_output = string(capture.data(), capture.size());
+ }
+
+ // Parse "position" if it exists
+ if (strings::Scanner(remaining)
+ .OneLiteral(":")
+ .RestartCapture()
+ .Many(strings::Scanner::DIGIT)
+ .GetResult(nullptr, &capture)) {
+ CHECK(strings::safe_strto32(capture, &position));
+ }
+}
+
+// TODO(rachelim): Create a utility class similar to MutableGraphView for
+// FunctionDefs, and use that to manipulate functions. It'll be more
+// performant if we kept mappings of nodes->inputs/outputs, so that we don't
+// have to search over all nodes each time.
+// Note that we're not using GrapplerFunctionItem because it doesn't cover
+// some of our desired uses (eg changing the outputs of a function), and the
+// FunctionDef -> GraphDef conversion isn't really necessary in this case.
+void ReplaceReferences(const string& from, const string& to,
+ FunctionDef* func) {
+ for (NodeDef& n : *func->mutable_node_def()) {
+ std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from,
+ to);
+ }
+
+ for (auto& p : *func->mutable_ret()) {
+ if (p.second == from) {
+ p.second = to;
+ }
+ }
+}
+
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+ StringPiece output_tensor_name,
+ FunctionDef* function, DataType dt) {
+ string name = string(prefix);
+ int id = function->signature().output_arg_size();
+ while (ContainsFunctionOutputWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ auto* output = function->mutable_signature()->mutable_output_arg()->Add();
+ output->set_name(name);
+ output->set_type(dt);
+
+ (*function->mutable_ret())[name] = string(output_tensor_name);
+}
+
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd) {
+ NodeDef* node = fd->add_node_def();
+ if (!name.empty()) {
+ node->set_name(string(name));
+ } else {
+ SetUniqueFunctionNodeName(op, fd, node);
+ }
+ node->set_op(string(op));
+ for (const string& input : inputs) {
+ node->add_input(input);
+ }
+ for (auto attr : attributes) {
+ (*node->mutable_attr())[attr.first] = attr.second;
+ }
+ return node;
+}
+
+bool ContainsFunctionNodeWithName(StringPiece name,
+ const FunctionDef& function) {
+ return FindFunctionNodeWithName(name, function) != -1;
+}
+
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+ return FindFunctionNodeWithOp(op, function) != -1;
+}
+
+bool ContainsFunctionOutputWithName(StringPiece name,
+ const FunctionDef& function) {
+ return FindFunctionOutputWithName(name, function) != -1;
+}
+
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+ function.signature().input_arg());
+ return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+ function.signature().output_arg());
+ return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const NodeDef& node) { return node.name() == name; },
+ function.node_def());
+ return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&op](const NodeDef& node) { return node.op() == op; },
+ function.node_def());
+
+ return indices.empty() ? -1 : indices.front();
+}
+
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+ NodeDef* node) {
+ string name = string(prefix);
+ int id = function->node_def_size();
+ while (ContainsFunctionNodeWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ node->set_name(std::move(name));
+}
+
+} // end namespace function_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.h b/tensorflow/core/grappler/optimizers/data/function_utils.h
new file mode 100644
index 0000000000..d4ce824652
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.h
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+// This namespace contains utility functions for querying and modifying
+// FunctionDefs.
+
+// Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings
+// have the format node_name:node_output:position (if they derive from nodes),
+// or input_name (if they derive from an argument).
+struct FunctionDefTensorDesc {
+ FunctionDefTensorDesc() = default;
+
+ FunctionDefTensorDesc(const string& node_name, const string& output,
+ int position);
+
+ // Parses node_name:node_output:position string into its components.
+ explicit FunctionDefTensorDesc(const string& input);
+
+ // TODO(rachelim): Add provisions to deal with special formats, like how
+ // GrapplerFunctionItem expands node output range if position is not defined
+ string full_str;
+ string node_name;
+ string node_output;
+ int position = -1;
+};
+
+// Replaces all references to `from` tensor in func's nodes' inputs and retvals
+// to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`.
+void ReplaceReferences(const string& from, const string& to, FunctionDef* func);
+
+// Adds a function output to the function def, ensuring that the output key
+// is unique, and maps to output_tensor_name in the ret dict.
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+ StringPiece output_tensor_name,
+ FunctionDef* function, DataType dt);
+
+// Adds a node to a FunctionDef.
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd);
+
+// Checks whether the function contains a node with the given name.
+bool ContainsFunctionNodeWithName(StringPiece name,
+ const FunctionDef& function);
+
+// Checks whether the function contains a node with the given op.
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Checks whether the function contains an output with the given name.
+bool ContainsFunctionOutputWithName(StringPiece name,
+ const FunctionDef& function);
+
+// Returns the index of the function input with the given name or -1 if the
+// function node does not exist.
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function output with the given name or -1 if the
+// function node does not exist.
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given name or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given op or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Sets the function node name using the `prefix` as a prefix while guaranteeing
+// the name is unique across the functions nodes.
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+ NodeDef* node);
+
+} // end namespace function_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
new file mode 100644
index 0000000000..3739e20eb1
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+namespace {
+
+TEST(FunctionDefTensorDesc, Parsing) {
+ FunctionDefTensorDesc f("Cast:y:0");
+ EXPECT_EQ(f.full_str, "Cast:y:0");
+ EXPECT_EQ(f.node_name, "Cast");
+ EXPECT_EQ(f.node_output, "y");
+ EXPECT_EQ(f.position, 0);
+
+ FunctionDefTensorDesc f2("Arg0");
+ EXPECT_EQ(f2.full_str, "Arg0");
+ EXPECT_EQ(f2.node_name, "Arg0");
+ EXPECT_EQ(f2.node_output, "");
+ EXPECT_EQ(f2.position, -1);
+}
+
+TEST(ReplaceReferencesTest, ReplaceReferencesTest) {
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer", {"arg0: int32"}, {"out: int32", "out2: int64"}, {}, {},
+ {{"out", "MapDefun:output:0"}, {"out2", "Cast:y:0"}});
+ NodeDef* derive_node =
+ AddNode("X", "Some_Op", {"MapDefun:output:0"}, {}, &outer);
+ // Check that both the input to "X" and retval of "outer" are replaced.
+ ReplaceReferences("MapDefun:output:0", "arg0", &outer);
+ EXPECT_EQ(outer.ret().at("out"), "arg0");
+ EXPECT_EQ(derive_node->input(0), "arg0");
+}
+
+TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) {
+ FunctionDef function = test::function::XTimesTwo();
+ AddFunctionOutputWithUniqueName("y", "two", &function, DT_INT64);
+ EXPECT_TRUE(ContainsFunctionOutputWithName("y/_1", function));
+ EXPECT_EQ(function.ret().at("y/_1"), "two");
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithName(
+ "weird_name_that_should_not_be_there", function));
+ EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithOp) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
+ function));
+ EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionOutputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_TRUE(ContainsFunctionOutputWithName("y", function));
+ EXPECT_FALSE(ContainsFunctionOutputWithName("Add:z:0", function));
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithOp) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionInputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(FindFunctionInputWithName("x", function), 0);
+ EXPECT_EQ(FindFunctionInputWithName("not_a_name", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionOutputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(FindFunctionOutputWithName("y", function), 0);
+ EXPECT_EQ(FindFunctionOutputWithName("Add:z:0", function), -1);
+}
+
+TEST(FunctionUtilsTest, SetUniqueFunctionNodeName) {
+ FunctionDef function = test::function::XTimesTwo();
+ NodeDef node;
+ SetUniqueFunctionNodeName("abc", &function, &node);
+ for (const NodeDef& function_node : function.node_def()) {
+ EXPECT_NE(node.name(), function_node.name());
+ }
+ auto* new_node = function.add_node_def();
+ *new_node = node;
+
+ NodeDef other;
+ SetUniqueFunctionNodeName("abc", &function, &other);
+ EXPECT_NE(other.name(), new_node->name());
+}
+
+TEST(FunctionUtilsTest, AddNodeToFunctionDef) {
+ FunctionDef func;
+ const char* op_name = "xxx";
+ AddNode(op_name, op_name, {}, {}, &func);
+
+ const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
+ EXPECT_EQ(node1.op(), op_name);
+ EXPECT_EQ(node1.input_size(), 0);
+ EXPECT_EQ(node1.attr_size(), 0);
+
+ const std::vector<string> inputs({"input1", "input2"});
+ AddNode("", op_name, inputs, {}, &func);
+ const NodeDef& node2 =
+ func.node_def(FindFunctionNodeWithName("xxx/_2", func));
+ EXPECT_EQ(node2.op(), op_name);
+ EXPECT_EQ(node2.attr_size(), 0);
+ EXPECT_EQ(node2.input_size(), inputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ EXPECT_EQ(node2.input(i), inputs[i]);
+ }
+
+ AttrValue a1, a2;
+ a1.set_type(DT_INT32);
+ a2.set_type(DT_INT64);
+ const std::vector<std::pair<string, AttrValue>> attrs(
+ {{"attr1", a1}, {"attr2", a2}});
+ AddNode("", op_name, {}, attrs, &func);
+ const NodeDef& node3 =
+ func.node_def(FindFunctionNodeWithName("xxx/_3", func));
+ EXPECT_EQ(node3.op(), op_name);
+ EXPECT_EQ(node3.input_size(), 0);
+ EXPECT_EQ(node3.attr_size(), attrs.size());
+ for (size_t i = 0; i < attrs.size(); ++i) {
+ EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
+ }
+}
+
+} // namespace
+} // namespace function_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
index 01a78c04b0..b3bfee138f 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -407,7 +408,7 @@ void LazyConjunctionNodes(const FunctionDef& first_function,
auto* if_node = fused_function->add_node_def();
// This is guaranteed to succeed.
TF_CHECK_OK(if_builder.Finalize(if_node));
- graph_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
+ function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
}
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
index d5c6466080..e667affeea 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -110,9 +111,9 @@ TEST(FusionUtilsTest, FuseFunctionWithPredicate) {
CheckUniqueNames(*fused_function);
ASSERT_TRUE(
- graph_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
+ function_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
const auto &equal_node = fused_function->node_def(
- graph_utils::FindFunctionNodeWithOp("Equal", *fused_function));
+ function_utils::FindFunctionNodeWithOp("Equal", *fused_function));
EXPECT_EQ(xtimes_two->signature().output_arg(0).name(),
fused_function->signature().output_arg(0).name());
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index d4ab444036..b3f60e34f9 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -108,26 +108,6 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
return graph->AddNode(std::move(node));
}
-NodeDef* AddNode(StringPiece name, StringPiece op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- FunctionDef* fd) {
- NodeDef* node = fd->add_node_def();
- if (!name.empty()) {
- node->set_name(string(name));
- } else {
- SetUniqueFunctionNodeName(op, fd, node);
- }
- node->set_op(string(op));
- for (const string& input : inputs) {
- node->add_input(input);
- }
- for (auto attr : attributes) {
- (*node->mutable_attr())[attr.first] = attr.second;
- }
- return node;
-}
-
template <>
NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
@@ -196,6 +176,11 @@ bool Compare(const GraphDef& g1, const GraphDef& g2) {
return true;
}
+bool ContainsGraphFunctionWithName(StringPiece name,
+ const FunctionDefLibrary& library) {
+ return FindGraphFunctionWithName(name, library) != -1;
+}
+
bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
return FindGraphNodeWithName(name, graph) != -1;
}
@@ -204,18 +189,14 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
return FindGraphNodeWithOp(op, graph) != -1;
}
-bool ContainsGraphFunctionWithName(StringPiece name,
- const FunctionDefLibrary& library) {
- return FindGraphFunctionWithName(name, library) != -1;
-}
-
-bool ContainsFunctionNodeWithName(StringPiece name,
- const FunctionDef& function) {
- return FindFunctionNodeWithName(name, function) != -1;
-}
-
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- return FindFunctionNodeWithOp(op, function) != -1;
+int FindGraphFunctionWithName(StringPiece name,
+ const FunctionDefLibrary& library) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const FunctionDef& function) {
+ return function.signature().name() == name;
+ },
+ library.function());
+ return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
@@ -237,31 +218,6 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
}
-int FindGraphFunctionWithName(StringPiece name,
- const FunctionDefLibrary& library) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&name](const FunctionDef& function) {
- return function.signature().name() == name;
- },
- library.function());
- return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&name](const NodeDef& node) { return node.name() == name; },
- function.node_def());
- return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&op](const NodeDef& node) { return node.op() == op; },
- function.node_def());
-
- return indices.empty() ? -1 : indices.front();
-}
-
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
if (node.input_size() == 0) return nullptr;
GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
@@ -284,17 +240,6 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
node->set_name(std::move(name));
}
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
- NodeDef* node) {
- string name = string(prefix);
- int id = function->node_def_size();
- while (ContainsFunctionNodeWithName(name, *function)) {
- name = strings::StrCat(prefix, "/_", id);
- ++id;
- }
- node->set_name(std::move(name));
-}
-
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function) {
string name = string(prefix);
@@ -305,7 +250,6 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
}
function->mutable_signature()->set_name(std::move(name));
}
-
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 6f431c232d..1652afcd9e 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -37,12 +37,6 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<std::pair<string, AttrValue>>& attributes,
MutableGraphView* graph);
-// Adds a node to a FunctionDef.
-NodeDef* AddNode(StringPiece name, StringPiece op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- FunctionDef* fd);
-
// Adds a Const node with the given value to the graph.
template <typename T>
NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
@@ -76,13 +70,6 @@ bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph);
bool ContainsGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library);
-// Checks whether the function contains a node with the given name.
-bool ContainsFunctionNodeWithName(StringPiece name,
- const FunctionDef& function);
-
-// Checks whether the function contains a node with the given op.
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
// Checks whether the graph contains a node with the given op.
bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph);
@@ -95,14 +82,6 @@ int FindGraphNodeWithName(StringPiece name, const GraphDef& graph);
int FindGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library);
-// Returns the index of the function node with the given name or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
-
-// Returns the index of the function node with the given op or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
// Returns the index of the first node with the given op or -1 if no such node
// exists.
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
@@ -119,11 +98,6 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
-// Sets the function node name using the `prefix` as a prefix while guaranteeing
-// the name is unique across the functions nodes.
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
- NodeDef* node);
-
// Sets the node name using the `prefix` name as a prefix while guaranteeing the
// name is unique across the graph.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index c19ac7b880..6877c207c4 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -112,20 +112,6 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
ContainsGraphFunctionWithName(new_function->signature().name(), library));
}
-TEST(GraphUtilsTest, ContainsFunctionNodeWithName) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_FALSE(ContainsFunctionNodeWithName(
- "weird_name_that_should_not_be_there", function));
- EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
-}
-
-TEST(GraphUtilsTest, ContainsFunctionNodeWithOp) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
- function));
- EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
-}
-
TEST(GraphUtilsTest, ContainsNodeWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
@@ -150,22 +136,6 @@ TEST(GraphUtilsTest, FindGraphNodeWithName) {
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
}
-TEST(GraphUtilsTest, FindFunctionNodeWithName) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_EQ(
- FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
- -1);
- EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
-}
-
-TEST(GraphUtilsTest, FindFunctionNodeWithOp) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_EQ(
- FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
- -1);
- EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
-}
-
TEST(GraphUtilsTest, FindGraphFunctionWithName) {
FunctionDefLibrary library;
EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1);
@@ -225,21 +195,6 @@ TEST(GraphUtilsTest, SetUniqueGraphNodeName) {
EXPECT_NE(node2->name(), node3->name());
}
-TEST(GraphUtilsTest, SetUniqueFunctionNodeName) {
- FunctionDef function = test::function::XTimesTwo();
- NodeDef node;
- SetUniqueFunctionNodeName("abc", &function, &node);
- for (const NodeDef& function_node : function.node_def()) {
- EXPECT_NE(node.name(), function_node.name());
- }
- auto* new_node = function.add_node_def();
- *new_node = node;
-
- NodeDef other;
- SetUniqueFunctionNodeName("abc", &function, &other);
- EXPECT_NE(other.name(), new_node->name());
-}
-
TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
FunctionDefLibrary library;
FunctionDef* new_function = library.add_function();
@@ -251,43 +206,6 @@ TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
other_function->signature().name());
}
-TEST(GraphUtilsTest, AddNodeToFunctionDef) {
- FunctionDef func;
- const char* op_name = "xxx";
- AddNode(op_name, op_name, {}, {}, &func);
-
- const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
- EXPECT_EQ(node1.op(), op_name);
- EXPECT_EQ(node1.input_size(), 0);
- EXPECT_EQ(node1.attr_size(), 0);
-
- const std::vector<string> inputs({"input1", "input2"});
- AddNode("", op_name, inputs, {}, &func);
- const NodeDef& node2 =
- func.node_def(FindFunctionNodeWithName("xxx/_2", func));
- EXPECT_EQ(node2.op(), op_name);
- EXPECT_EQ(node2.attr_size(), 0);
- EXPECT_EQ(node2.input_size(), inputs.size());
- for (size_t i = 0; i < inputs.size(); ++i) {
- EXPECT_EQ(node2.input(i), inputs[i]);
- }
-
- AttrValue a1, a2;
- a1.set_type(DT_INT32);
- a2.set_type(DT_INT64);
- const std::vector<std::pair<string, AttrValue>> attrs(
- {{"attr1", a1}, {"attr2", a2}});
- AddNode("", op_name, {}, attrs, &func);
- const NodeDef& node3 =
- func.node_def(FindFunctionNodeWithName("xxx/_3", func));
- EXPECT_EQ(node3.op(), op_name);
- EXPECT_EQ(node3.input_size(), 0);
- EXPECT_EQ(node3.attr_size(), attrs.size());
- for (size_t i = 0; i < attrs.size(); ++i) {
- EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
- }
-}
-
TEST(GraphUtilsTest, GetInputNode) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index a019b77eb7..07766aa7b3 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -52,8 +53,8 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
// Add MapDefun node
NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add();
map_defun_node->set_op("MapDefun");
- graph_utils::SetUniqueFunctionNodeName(map_defun_node->op(), vectorized_func,
- map_defun_node);
+ function_utils::SetUniqueFunctionNodeName(map_defun_node->op(),
+ vectorized_func, map_defun_node);
// Set attrs and inputs
for (const string& k : {"f", "output_types", "output_shapes"}) {
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
new file mode 100644
index 0000000000..6a59eb0d32
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -0,0 +1,341 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+
+#include "absl/strings/str_join.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/functions.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/scanner.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+using function_utils::FunctionDefTensorDesc;
+
+namespace {
+
+void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
+ const string& output_retval, const DataType t) {
+ // Set to unknown shape
+ TensorShapeProto tensor_shape_proto;
+ PartialTensorShape().AsProto(&tensor_shape_proto);
+
+ function_utils::AddFunctionOutputWithUniqueName(
+ "vectorized_out", output_retval, map_defun_fn, t);
+
+ *(*map_defun_node->mutable_attr())["output_shapes"]
+ .mutable_list()
+ ->add_shape() = tensor_shape_proto;
+ (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+}
+
+void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node, int output_position) {
+ DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
+
+ int num_later_outputs =
+ map_defun_fn->signature().output_arg_size() - output_position - 1;
+
+ // Remove from map_defun_fn's ret dict and output args
+ map_defun_fn->mutable_ret()->erase(
+ map_defun_fn->signature().output_arg(output_position).name());
+ map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
+ output_position, 1);
+
+ // Renumber outputs that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ function_utils::ReplaceReferences(
+ strings::StrCat(map_defun_node->name(),
+ ":output:", output_position + i + 1),
+ strings::StrCat(map_defun_node->name(),
+ ":output:", output_position + i),
+ outer_scope);
+ }
+ map_defun_node->mutable_attr()
+ ->at("output_shapes")
+ .mutable_list()
+ ->mutable_shape()
+ ->DeleteSubrange(output_position, 1);
+ map_defun_node->mutable_attr()
+ ->at("output_types")
+ .mutable_list()
+ ->mutable_type()
+ ->ExtractSubrange(output_position, 1, nullptr);
+}
+
+Status ConvertCastOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node, const NodeDef& cast_node,
+ const FunctionDefTensorDesc& output_desc,
+ std::map<string, string>* conversion_map) {
+ if (output_desc.node_output != "y" || output_desc.position != 0) {
+ // We expect the Cast node to have only one output, with the name "y".
+ return errors::Internal("Cannot convert Cast op output.");
+ }
+
+ // Promote Cast inputs to outputs of MapDefun
+ DCHECK_EQ(cast_node.input_size(), 1);
+ AddMapDefunOutput(map_defun_fn, map_defun_node, cast_node.input(0),
+ cast_node.attr().at("SrcT").type());
+
+ // Add new Cast node
+ NodeDef* new_cast_node = outer_scope->add_node_def();
+ *new_cast_node = cast_node;
+ new_cast_node->clear_name();
+ function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("vectorized/", cast_node.name()), outer_scope,
+ new_cast_node);
+ new_cast_node->set_input(
+ 0, strings::StrCat(map_defun_node->name(), ":output:",
+ map_defun_fn->signature().output_arg_size() - 1));
+
+ // Add the output mapping to conversion map
+ (*conversion_map)[strings::StrCat(output_desc.node_name, ":y:0")] =
+ strings::StrCat(new_cast_node->name(), ":y:0");
+
+ return Status::OK();
+}
+
+Status ConvertUnpackOp(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node, const NodeDef& unpack_node,
+ const FunctionDefTensorDesc& output_desc,
+ std::map<string, string>* conversion_map) {
+ if (output_desc.node_output != "output") {
+ return errors::Internal("Cannot convert Unpack op output.");
+ }
+
+ // Promote Unpack inputs to outputs of MapDefun
+ AddMapDefunOutput(map_defun_fn, map_defun_node, unpack_node.input(0),
+ unpack_node.attr().at("T").type());
+
+ // Add new Unpack node
+ NodeDef* new_unpack_node = outer_scope->add_node_def();
+ *new_unpack_node = unpack_node;
+ new_unpack_node->clear_name();
+ function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("vectorized/", unpack_node.name()), outer_scope,
+ new_unpack_node);
+
+ // Increment "axis" attr by 1:
+ (*new_unpack_node->mutable_attr())["axis"].set_i(
+ unpack_node.attr().at("axis").i() + 1);
+ new_unpack_node->set_input(
+ 0, strings::StrCat(map_defun_node->name(), ":output:",
+ map_defun_fn->signature().output_arg_size() - 1));
+
+ // Add the output mappings to conversion map
+ int num = new_unpack_node->attr().at("num").i();
+ for (int i = 0; i < num; ++i) {
+ (*conversion_map)[strings::StrCat(output_desc.node_name, ":output:", i)] =
+ strings::StrCat(new_unpack_node->name(), ":output:", i);
+ }
+
+ return Status::OK();
+}
+
+int FindOutputToConvert(const FunctionDef& function,
+ const std::set<string>& unconvertible,
+ FunctionDefTensorDesc* f) {
+ for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
+ const string& ret_key = function.signature().output_arg(i).name();
+ *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+
+ if (unconvertible.find(f->node_name) == unconvertible.end()) {
+ return i;
+ }
+ }
+ return -1;
+}
+
+// Helper class that vectorizes the body of a MapDefun node, adding new
+// operations to the graph that collectively compute the same value as what
+// running the MapDefun function on slices of the input would produce.
+// Each instance of the class encapsulates all the data necessary to vectorize a
+// MapDefun op in place.
+class Vectorization {
+ public:
+ Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node)
+ : outer_scope_(outer_scope),
+ map_defun_fn_(map_defun_fn),
+ map_defun_node_(map_defun_node) {}
+
+ // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
+ // the outer_scope_, until there are no convertible outputs remaining.
+ // This method is idempotent.
+ void Vectorize();
+
+ private:
+ // Vectorizes the map defun function's output at output_position
+ Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
+ // Given a descriptor of the original output tensor, gets a string
+ // corresponding to the converted output tensor.
+ Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
+ string* converted);
+ Status AddConversionMappingFromInput(
+ const FunctionDefTensorDesc& output_desc);
+
+ // Adds mappings from node's outputs tensors to converted output tensors,
+ // creating the necessary new node(s). Generally, the steps to convert an op
+ // are:
+ // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
+ // and modify map_defun_node_ attrs accordingly
+ // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+ // These operations collectively compute the same value as what running
+ // the original operation on slices of the input tensors would produce.
+ // For example, a Cast op in MapDefun translates to a Cast op in
+ // outer_scope_, since the vectorized version of Cast is itself.
+ // 3) Set inputs of new node(s) to the corresponding converted inputs (that
+ // are now outputs of map_defun_node_)
+ // 4) For each output of the old node, add the mapping of output strings to
+ // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
+ Status AddConversionMappingFromOp(const NodeDef& node,
+ const FunctionDefTensorDesc& output_desc);
+
+ // Maps a tensor name to the name of the corresponding vectorized tensor. For
+ // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
+ std::map<string, string> conversion_map_;
+ // Unconvertible node names
+ std::set<string> unconvertible_;
+
+ FunctionDef* outer_scope_;
+ FunctionDef* map_defun_fn_;
+ NodeDef* map_defun_node_;
+};
+
+Status Vectorization::AddConversionMappingFromOp(
+ const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
+ for (const string& input_name : node.input()) {
+ if (IsControlInput(input_name)) {
+ return errors::InvalidArgument(
+ "Vectorizing outputs with control inputs is currently not "
+ "supported.");
+ }
+ }
+
+ // TODO(rachelim): Have some mechanism for registering converters and some
+ // uniform, simpler way to represent them.
+
+ // TODO(rachelim): Do step (1) outside of the individual op converters, when
+ // we know how to find out the type of the input.
+ if (node.op() == "Cast") {
+ return ConvertCastOp(outer_scope_, map_defun_fn_, map_defun_node_, node,
+ output_desc, &conversion_map_);
+ } else if (node.op() == "Unpack") {
+ return ConvertUnpackOp(outer_scope_, map_defun_fn_, map_defun_node_, node,
+ output_desc, &conversion_map_);
+ }
+ return errors::Unimplemented("Op converter for \"", node.op(),
+ "\" not implemented yet");
+}
+
+Status Vectorization::AddConversionMappingFromInput(
+ const FunctionDefTensorDesc& output_desc) {
+ int input_index = function_utils::FindFunctionInputWithName(
+ output_desc.node_name, *map_defun_fn_);
+ if (input_index == -1) {
+ return errors::Internal("Cannot convert non-existent input.");
+ }
+
+ conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutputHelper(
+ const FunctionDefTensorDesc& output_desc, string* converted) {
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
+ *converted = *found;
+ return Status::OK();
+ }
+
+ int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
+ *map_defun_fn_);
+ if (index == -1) { // The output comes from an input
+ TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
+ } else {
+ TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
+ map_defun_fn_->node_def(index), output_desc));
+ }
+ *converted = conversion_map_.at(output_desc.full_str);
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutput(int output_position,
+ const FunctionDefTensorDesc& output_desc) {
+ string converted_output_name;
+ TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+
+ // Remove the old output and make everything that referenced it point
+ // to the new string
+ function_utils::ReplaceReferences(
+ strings::StrCat(map_defun_node_->name(), ":output:", output_position),
+ converted_output_name, outer_scope_);
+ RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
+ output_position);
+
+ return Status::OK();
+}
+
+void Vectorization::Vectorize() {
+ while (true) {
+ FunctionDefTensorDesc desc;
+ int output_position =
+ FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
+ if (output_position == -1) break;
+
+ if (!ConvertOutput(output_position, desc).ok()) {
+ unconvertible_.insert(desc.node_name);
+ }
+ }
+
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->signature().output_arg_size() == 0) {
+ outer_scope_->mutable_node_def()->DeleteSubrange(
+ function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
+ *outer_scope_),
+ 1);
+ }
+
+ if (!unconvertible_.empty()) {
+ VLOG(2) << "The following nodes could not be converted: ["
+ << absl::StrJoin(unconvertible_, ", ") << "].";
+ }
+}
+} // namespace
+
+void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node) {
+ Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+}
+
+} // end namespace vectorization_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
new file mode 100644
index 0000000000..bb405faa77
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// Given a function, `map_defun_fn`, that is mapped across some input vector
+// elements via a MapDefun operation, `VectorizeMapDefun` attempts to
+// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the
+// `outer_scope`; that is, replacing `map_defun_fn` operations with new
+// `outer_scope` operations that produce the same vector output(s) as executing
+// the `map_defun_fn` operations on elements of vector input(s) would. If all
+// `map_defun_fn` operations are successfully lifted, `map_defun_node` is
+// eliminated from `outer_scope` altogether. However, if some operations cannot
+// be lifted, and this vectorization only succeeds partially, `map_defun_node`
+// remains to be used for operations that were not lifted.
+//
+// Example:
+// If the input to the `VectorizeMapDefun` function is a MapDefun
+// whose `map_defun_fn` performs the Cast operation, the vectorization will
+// eliminate the MapDefun. This is because the Cast operation supports
+// any tensor shape and can thus be lifted to the `outer_scope`.
+//
+// Before:
+//
+//
+// outer_scope +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | map_defun_fn +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// outer_scope +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node);
+
+} // end namespace vectorization_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
new file mode 100644
index 0000000000..e129fa9237
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -0,0 +1,600 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+namespace {
+
+NodeDef* AddCastNode(const string& name, const std::vector<string>& inputs,
+ DataType src, DataType dst, bool truncate,
+ FunctionDef* fn) {
+ NodeDef* node = function_utils::AddNode(name, "Cast", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("SrcT", src, node);
+ graph_transforms::SetNodeAttr("DstT", dst, node);
+ graph_transforms::SetNodeAttr("Truncate", truncate, node);
+ return node;
+}
+
+NodeDef* AddUnstackNode(const string& name, const std::vector<string>& inputs,
+ DataType t, int axis, int num, FunctionDef* fn) {
+ NodeDef* node = function_utils::AddNode(name, "Unpack", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("T", t, node);
+ graph_transforms::SetNodeAttr("axis", axis, node);
+ graph_transforms::SetNodeAttr("num", num, node);
+ return node;
+}
+
+NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
+ const std::vector<DataType>& t_arguments,
+ const std::vector<DataType>& output_types,
+ const std::vector<TensorShape>& output_shapes,
+ const string& function_name, FunctionDef* fn) {
+ NameAttrList func;
+ func.set_name(function_name);
+ NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+ graph_transforms::SetNodeAttr("output_types", output_types, node);
+ graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
+ graph_transforms::SetNodeAttr("f", func, node);
+ return node;
+}
+
+// TODO(rachelim): Use FunctionDefHelper::Create instead
+FunctionDef CreateFunction(
+ StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
+ const std::vector<std::pair<string, DataType>>& outputs,
+ const std::map<string, string>& rets) {
+ FunctionDef func;
+ auto* signature = func.mutable_signature();
+ signature->set_name(string(name));
+ for (const auto& x : inputs) {
+ auto* arg_def = signature->add_input_arg();
+ arg_def->set_name(x.first);
+ arg_def->set_type(x.second);
+ }
+ for (const auto& x : outputs) {
+ auto* arg_def = signature->add_output_arg();
+ arg_def->set_name(x.first);
+ arg_def->set_type(x.second);
+ }
+ for (const auto& x : rets) {
+ (*func.mutable_ret())[x.first] = x.second;
+ }
+
+ return func;
+}
+
+TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
+
+// Before:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// | +-----------+ Arg0 +---+ Arg1 +----+ |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+//
+// After:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | | | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"ret0", "arg0"}, {"ret1", "arg1"}});
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"ret0", "ret1"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
+ EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+}
+
+// Before:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// | +-----------+ Arg0 +---+ Arg1 +----+ |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | +------+ | +---v--+ | |
+// | | |Const | | | Op0 | | |
+// | | +---v--+ | +---+--+ | |
+// | | | | | | |
+// | | | +---v--+ +---v--+ | |
+// | | +---| XOp1 | | XOp2 | | |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+// where XOp1 and XOp2 are not convertible.
+//
+// After:
+//
+// No change because the ops are not convertible.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+ NodeDef* x_op1 =
+ function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+ CHECK_NOTNULL(x_op1);
+
+ NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
+ CHECK_NOTNULL(x_op2);
+
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
+ {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x", "y"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDef outer_copy(outer);
+ FunctionDef inner_copy(inner);
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ // They should be unchanged
+ EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+ EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +---------------+ Arg0 +-------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +---------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +----------+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +-------------------+
+// | +---+--+ |
+// | | |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +----------+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
+ // Tests that behavior is correct when an output is used more than once.
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}, {"ret1", DT_INT64}},
+ {{"ret0", "Cast:y:0"}, {"ret1", "Cast:y:0"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}, {"mapdefun_0", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64, DT_INT64},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +------------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v---+ num=3 | |
+// | | |Unstack| axis=0 | |
+// | | ++--+--++ | |
+// | | | | | | |
+// | | +----+ | +-------+ | |
+// | | | | | | |
+// | | MapDefun +---v--+ +-v----+ +--v---+ | |
+// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ |
+// | +---+--+ +--+---+ +--+---+ |
+// | | | | |
+// | +---v--+ +--v---+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v---+ num=3 |
+// | |Unstack| axis=1 |
+// | ++--+--++ |
+// | | | | |
+// | +----+ | +-------+ |
+// | | | | |
+// | | | | |
+// | +---v--+ +-v----+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
+ FunctionDef inner = CreateFunction(
+ "inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+ {{"ret0", "MyUnstack:output:0"},
+ {"ret1", "MyUnstack:output:1"},
+ {"ret2", "MyUnstack:output:2"}});
+ NodeDef* unstack_op =
+ AddUnstackNode("MyUnstack", {"arg0"}, DT_INT32, 0, 3, &inner);
+ CHECK_NOTNULL(unstack_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT32},
+ {"mapdefun_0", DT_INT32},
+ {"mapdefun_1", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"},
+ {"mapdefun_0", "MapDefun:output:1"},
+ {"mapdefun_1", "MapDefun:output:2"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+ {{1}, {1}, {1}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& unpack_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ EXPECT_EQ(unpack_node.input(0), "x");
+ EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+ EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+ EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(unpack_node.name(), ":output:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(unpack_node.name(), ":output:1"));
+ EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ strings::StrCat(unpack_node.name(), ":output:2"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +------------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | +---+--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v---+ num=3 | |
+// | | |Unstack| axis=0 | |
+// | | ++--+--++ | |
+// | | | | | | |
+// | | +----+ | +-------+ | |
+// | | | | | | |
+// | | MapDefun +---v--+ +-v----+ +--v---+ | |
+// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ |
+// | +---+--+ +--+---+ +--+---+ |
+// | | | | |
+// | +---v--+ +--v---+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---+--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v---+ num=3 |
+// | |Unstack| axis=1 |
+// | ++--+--++ |
+// | | | | |
+// | +----+ | +-------+ |
+// | | | | |
+// | | | | |
+// | +---v--+ +-v----+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
+ FunctionDef inner = CreateFunction(
+ "inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+ {{"ret0", "MyUnstack:output:0"},
+ {"ret1", "MyUnstack:output:1"},
+ {"ret2", "MyUnstack:output:2"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+ NodeDef* unstack_op =
+ AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
+ CHECK_NOTNULL(unstack_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT32},
+ {"mapdefun_0", DT_INT32},
+ {"mapdefun_1", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"},
+ {"mapdefun_0", "MapDefun:output:1"},
+ {"mapdefun_1", "MapDefun:output:2"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+ {{1}, {1}, {1}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ const NodeDef& unpack_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+ EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+ EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(unpack_node.name(), ":output:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(unpack_node.name(), ":output:1"));
+ EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ strings::StrCat(unpack_node.name(), ":output:2"));
+ EXPECT_EQ(outer.node_def_size(), 2);
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | +---------+ | |
+// | | +---v--+ | | |
+// | | |Print | | | |
+// | | +---+--+ | | |
+// | | : +---v--+ | |
+// | | ::::::> Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// No change because we don't deal with control inputs for now.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+ // The attrs aren't relevant
+ NodeDef* print_op =
+ function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+ CHECK_NOTNULL(print_op);
+ NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
+ false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDef outer_copy(outer);
+ FunctionDef inner_copy(inner);
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ // They should be unchanged
+ EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+}
+
+// TODO(rachelim): More test cases when we get around to implementing them:
+// [] A badly defined converter, e.g. doesn't produce nodes that have the
+// same number of outputs/inputs as the nodes to be converted
+// [] Converter where the 'converted' form has multiple nodes.
+// [] Case with dependent nodes, e.g. ops with const inputs that are
+// broadcasted.
+// [] Python-side tests to actually run the functions to make sure
+// they work.
+
+} // namespace
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow