aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-28 16:10:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 16:17:55 -0700
commit478d370eb116ad2294134d75a886637a7d6da225 (patch)
tree279ef8e8a2c9abeeda583393a986f055b9be314c
parenta98bac521406bedef3ff2b9af9564b21ddda4d82 (diff)
[tf.data] Use Graph instead of GraphDef/FunctionDef for vectorization transforms
PiperOrigin-RevId: 215011835
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD7
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h4
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc28
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc112
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD3
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc29
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc36
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h23
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc451
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.h35
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc205
12 files changed, 574 insertions, 375 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 81c1bddf67..5a3abbb545 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -124,10 +124,10 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
- "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
] + tf_protos_all(),
)
@@ -523,6 +523,7 @@ cc_library(
":function_utils",
":graph_utils",
"@com_google_absl//absl/strings",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -538,6 +539,7 @@ tf_cc_test(
srcs = ["vectorization_utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_utils",
":function_utils",
":vectorization_utils",
"//tensorflow/core:framework",
@@ -547,7 +549,10 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ # For ops we need registered
+ "//tensorflow/core/kernels/data:dataset_ops",
"//tensorflow/core/kernels:cast_op",
+ "//tensorflow/core/kernels:logging_ops",
"//tensorflow/tools/graph_transforms:transform_utils",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 5dd7819100..3af34f6904 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -116,8 +116,8 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
-// Sets the node name using the `prefix` name as a prefix while guaranteeing the
-// name is unique across the graph.
+// Sets the function name using the `prefix` name as a prefix while guaranteeing
+// the name is unique across the function library.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function);
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 32ab912619..9328a7ca99 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -86,21 +86,19 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
// efficient vectorization with VectorizeMapDefun.
FunctionDef* vectorized_func =
CreateMapDefunWrapper(map_node, orig_func, library);
- NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0);
- DCHECK_EQ(map_defun_node->op(), "MapDefun");
-
- // Create a copy of the original function so that we can mutate it, and
- // attach that to the map defun node.
- FunctionDef* map_defun_fn = library->add_function();
- *map_defun_fn = orig_func;
- graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library,
- map_defun_fn);
- (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name(
- map_defun_fn->signature().name());
-
- vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn,
- map_defun_node);
- return vectorized_func;
+ const NodeDef& map_defun_node = vectorized_func->node_def(0);
+ DCHECK_EQ(map_defun_node.op(), "MapDefun");
+
+ // TODO(b/116285210): Unreferenced functions should get cleaned up later
+ FunctionDef* result;
+ Status s = vectorization_utils::VectorizeMapDefun(
+ *vectorized_func, map_defun_node, library, &result);
+
+ if (!s.ok()) {
+ LOG(ERROR) << "VectorizeMapDefun failed: " << s;
+ return vectorized_func;
+ }
+ return result;
}
bool IsOutputShapesFullyDefined(const NodeDef& node) {
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
index ed1bd6bc97..f4faf41549 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
@@ -30,72 +30,51 @@ namespace {
using test::function::GDef;
using test::function::NDef;
-void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims,
- TensorShapeProto* t) {
- for (size_t i = 0; i < dims.size(); ++i) {
- auto* d = t->add_dim();
- d->set_size(dims[i]);
- }
-}
-
-AttrValue MakeShapeListAttr(
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) {
- AttrValue shapes_attr;
- for (size_t i = 0; i < shapes.size(); ++i) {
- MakeTensorShapeProtoHelper(shapes[i],
- shapes_attr.mutable_list()->add_shape());
- }
-
- return shapes_attr;
-}
-
-NodeDef MakeMapNodeHelper(
- StringPiece name, StringPiece input_node_name, StringPiece function_name,
- StringPiece map_op_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
+NodeDef MakeMapNodeHelper(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name, StringPiece map_op_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
return test::function::NDef(
name, map_op_name, {string(input_node_name)},
{{"f", FunctionDefHelper::FunctionRef(string(function_name))},
{"Targuments", {}},
- {"output_shapes", MakeShapeListAttr(output_shapes)},
+ {"output_shapes", output_shapes},
{"output_types", output_types}});
}
-NodeDef MakeMapNode(
- StringPiece name, StringPiece input_node_name, StringPiece function_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset",
output_shapes, output_types);
}
-NodeDef MakeBatchNode(
- StringPiece name, StringPiece input_node_name,
- StringPiece input_batch_size_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
- return NDef(name, "BatchDataset",
- {string(input_node_name), string(input_batch_size_name)},
- {{"output_types", output_types},
- {"output_shapes", MakeShapeListAttr(output_shapes)}});
+NodeDef MakeBatchNode(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDataset",
+ {string(input_node_name), string(input_batch_size_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
}
-NodeDef MakeBatchV2Node(
- StringPiece name, StringPiece input_node_name,
- StringPiece input_batch_size_name, StringPiece input_drop_remainder_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
- return NDef(name, "BatchDatasetV2",
- {string(input_node_name), string(input_batch_size_name),
- string(input_drop_remainder_name)},
- {{"output_types", output_types},
- {"output_shapes", MakeShapeListAttr(output_shapes)}});
+NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ StringPiece input_drop_remainder_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDatasetV2",
+ {string(input_node_name), string(input_batch_size_name),
+ string(input_drop_remainder_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
}
-NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) {
+NodeDef MakeRangeNode(StringPiece name, gtl::ArraySlice<string> inputs) {
return NDef(name, "RangeDataset", inputs,
- {{"output_shapes", MakeShapeListAttr({{}})},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})},
{"output_types", gtl::ArraySlice<DataType>({DT_INT64})}});
}
@@ -184,7 +163,7 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
item.graph = GDef(
{NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("input", "InputDataset", {},
- {{"output_shapes", MakeShapeListAttr({{}})}}),
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}}),
MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
// FunctionLib
@@ -196,6 +175,37 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
}
+TEST(MapVectorizationTest, VectorizeWithFullyDefinedFunction) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ MakeRangeNode("range", {"start", "stop", "step"}),
+ MakeMapNode("map", "range", "Func", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {FunctionDefHelper::Create(
+ "Func", {"x: int64", "y: int64"}, {"res: int64", "res2: int64"}, {},
+ {{{"o"}, "Mul", {"x", "x"}, {{"T", DT_INT64}}}},
+ {{"res", "o:z"}, {"res2", "o:z"}})});
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
+ 1);
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(),
+ 1);
+ const NodeDef& map_node =
+ output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
+ const NodeDef& batch_node =
+ output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output));
+ EXPECT_EQ(map_node.input(0), batch_node.name());
+ EXPECT_EQ(batch_node.input(0), "range");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 1462cb234d..37aa24b947 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -9,13 +9,14 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
VECTORIZER_DEPS = [
":vectorizer_registry",
- "//tensorflow/core/grappler/optimizers/data:function_utils",
+ "//tensorflow/core/grappler/optimizers/data:graph_utils",
] + tf_protos_all()
cc_library(
name = "vectorizer",
hdrs = ["vectorizer.h"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
index c1739737a0..3af6bab409 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
@@ -23,26 +23,21 @@ namespace vectorization_utils {
class CastVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
- if (inputs.size() != 1) {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) override {
+ Status s;
+ if (node.num_inputs() != 1) {
return errors::Internal("Cast op should only have one input.");
}
- // Add new Cast node
- NodeDef* new_cast_node = outer_scope->add_node_def();
- *new_cast_node = node;
- new_cast_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", node.name()), outer_scope,
- new_cast_node);
- new_cast_node->set_input(0, inputs[0]);
-
- // Add the output mapping to conversion map
- (*conversion_map)[strings::StrCat(node.name(), ":y:0")] =
- strings::StrCat(new_cast_node->name(), ":y:0");
+ // Add new Cast node with the same op and attrs as the original node
+ auto new_cast_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
+ // Add input and output mappings
+ input_ports->push_back({new_cast_node, 0});
+ output_ports->push_back({new_cast_node, 0});
return Status::OK();
}
};
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
index 776d3179c5..74ce520ce1 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
@@ -23,31 +23,29 @@ namespace vectorization_utils {
class UnpackVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
- if (inputs.size() != 1) {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) override {
+ Status s;
+ if (node.num_inputs() != 1) {
return errors::Internal("Unpack op should only have one input.");
}
- // Add new Unpack node
- NodeDef* new_unpack_node = outer_scope->add_node_def();
- *new_unpack_node = node;
- new_unpack_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", node.name()), outer_scope,
- new_unpack_node);
+ // Add new Unpack node with the same op and attrs as the original node
+ auto new_unpack_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
// Increment "axis" attr by 1:
- (*new_unpack_node->mutable_attr())["axis"].set_i(
- node.attr().at("axis").i() + 1);
- new_unpack_node->set_input(0, inputs[0]);
+ int new_axis = node.def().attr().at("axis").i() + 1;
+ new_unpack_node->AddAttr("axis", new_axis);
- // Add the output mappings to conversion map
- int num = new_unpack_node->attr().at("num").i();
+ // Add the input mappings
+ input_ports->push_back({new_unpack_node, 0});
+
+ // Add the output mappings
+ int num = node.def().attr().at("num").i();
for (int i = 0; i < num; ++i) {
- (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] =
- strings::StrCat(new_unpack_node->name(), ":output:", i);
+ output_ports->push_back({new_unpack_node, i});
}
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
index d341dbba7d..56eb88c95e 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -17,30 +17,33 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> Port;
+
// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
// for an example.
class Vectorizer {
public:
virtual ~Vectorizer() {}
- // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope`
+ // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
// that produce the same vector output(s) as executing `node`'s op
- // on elements of the vector inputs, and adding mappings to `conversion_map`
- // from old output tensor names to new (vectorized) output tensor names.
- // The new node(s) collectively have the same number of inputs and outputs as
- // the node being converted, and use the tensor names in `inputs` as their
- // inputs.
- virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) = 0;
+ // on elements of the vector inputs. The new Node(s) collectively have the
+ // same number of input and output ports as the node being converted.
+ // Adds mappings for the new nodes' input and output ports to `inputs` and
+ // `outputs` respectively, where the i'th Port in inputs/outputs
+ // corresponds to the i'th input/output port of the node to be converted.
+ virtual Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) = 0;
};
} // namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
index 86e303564b..663ceba027 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -24,9 +24,9 @@ namespace vectorization_utils {
class TestVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* inputs,
+ std::vector<Port>* outputs) override {
return Status::OK();
}
};
@@ -39,10 +39,12 @@ TEST(TestVectorizer, TestTestVectorizer) {
auto vectorizer = VectorizerRegistry::Global()->Get("test_op");
EXPECT_NE(vectorizer, nullptr);
- FunctionDef function;
- NodeDef node;
- std::map<string, string> conversion_map;
- EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok());
+ Graph g(OpRegistry::Global());
+ NodeDef node_def;
+ Status s;
+ Node* node = g.AddNode(node_def, &s);
+ std::vector<Port> inputs, outputs;
+ EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok());
}
} // namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index cb56b65985..cea667f668 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -14,13 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+#include <memory>
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
@@ -36,255 +40,346 @@ namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-using function_utils::FunctionDefTensorDesc;
-
namespace {
-void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
- const string& output_retval, const DataType t) {
- // Set to unknown shape
- TensorShapeProto tensor_shape_proto;
- PartialTensorShape().AsProto(&tensor_shape_proto);
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> TensorDesc;
- function_utils::AddFunctionOutputWithUniqueName(
- "vectorized_out", output_retval, map_defun_fn, t);
+const char* const kRetValOp = "_Retval";
- *(*map_defun_node->mutable_attr())["output_shapes"]
- .mutable_list()
- ->add_shape() = tensor_shape_proto;
- (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
+ Graph* graph) {
+ // NOTE: We need two for loops here because we can't mutate the set of output
+ // edges as we iterate over them.
+ std::vector<const Edge*> edges_to_replace;
+ for (auto edge : old_src.first->out_edges()) {
+ if (edge->src_output() == old_src.second) {
+ edges_to_replace.push_back(edge);
+ }
+ }
+ for (auto edge : edges_to_replace) {
+ graph->AddEdge(new_src.first, new_src.second, edge->dst(),
+ edge->dst_input());
+ graph->RemoveEdge(edge);
+ }
}
-void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node, int output_position) {
- DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
- << "Trying to remove output that doesn't exist. Output number: "
- << output_position;
+Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
+ const TensorDesc& output) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DataType type = output.first->output_type(output.second);
+ int index = map_defun_fn->ret_nodes.size();
- int num_later_outputs =
- map_defun_fn->signature().output_arg_size() - output_position - 1;
+ NodeDef ret_node_def;
+ ret_node_def.set_name("map_out");
+ ret_node_def.set_op(kRetValOp);
+ AddNodeAttr("T", type, &ret_node_def);
+ AddNodeAttr("index", index, &ret_node_def);
- // Remove from map_defun_fn's ret dict and output args
- map_defun_fn->mutable_ret()->erase(
- map_defun_fn->signature().output_arg(output_position).name());
- map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
- output_position, 1);
+ Status s;
+ Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
+ TF_RETURN_IF_ERROR(s);
- // Renumber outputs that come after
- for (int i = 0; i < num_later_outputs; ++i) {
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i + 1),
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i),
- outer_scope);
- }
- map_defun_node->mutable_attr()
- ->at("output_shapes")
- .mutable_list()
- ->mutable_shape()
- ->DeleteSubrange(output_position, 1);
- map_defun_node->mutable_attr()
- ->at("output_types")
- .mutable_list()
- ->mutable_type()
- ->ExtractSubrange(output_position, 1, nullptr);
+ map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
+ map_defun_fn->ret_nodes.push_back(ret_node);
+ map_defun_fn->ret_types.push_back(type);
+
+ return s;
}
-int FindOutputToConvert(const FunctionDef& function,
- const std::set<string>& unconvertible,
- FunctionDefTensorDesc* f) {
- for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
- const string& ret_key = function.signature().output_arg(i).name();
- *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
+ FunctionBody* map_defun_fn, Node* map_defun_node) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
+
+ int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;
- if (unconvertible.find(f->node_name) == unconvertible.end()) {
- return i;
- }
+ // Modify map_defun_fn's signature and remove the output node from its graph
+ map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
+ map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
+ output_position);
+ map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
+ output_position);
+
+ // Renumber the nodes and edges that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ ReplaceEdgeSources({map_defun_node, output_position + i + 1},
+ {map_defun_node, output_position + i}, outer_scope);
+ // Each ret node has an "index" attr that has to be updated
+ map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
+ output_position + i);
}
- return -1;
}
// Helper class that vectorizes the body of a MapDefun node, adding new
// operations to the graph that collectively compute the same value as what
// running the MapDefun function on slices of the input would produce.
-// Each instance of the class encapsulates all the data necessary to vectorize a
-// MapDefun op in place.
+// This class transforms the input FunctionDefs into their corresponding
+// Graph objects and works on the graphs directly, then converts them back
+// to FunctionDefs when GetResult is called.
class Vectorization {
public:
- Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node)
- : outer_scope_(outer_scope),
- map_defun_fn_(map_defun_fn),
- map_defun_node_(map_defun_node) {}
+ explicit Vectorization(FunctionDefLibrary* lib)
+ : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}
- // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
- // the outer_scope_, until there are no convertible outputs remaining.
- // This method is idempotent.
- void Vectorize();
+ // Adds the vectorized function and new map_defun_fn to lib, and points
+ // vectorized_function to the former. Returns an error status if
+ // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
+ // along the way.
+ Status Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDef** result);
private:
- // Vectorizes the map defun function's output at output_position
- Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
- // Given a descriptor of the original output tensor, gets a string
- // corresponding to the converted output tensor.
- Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
- string* converted);
- Status AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc);
+ // Converts FunctionDefs to Graphs.
+ Status Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node);
+
+ // Converts Graphs back to FunctionDefs and adds them to `lib_`.
+ Status GetResult(FunctionDef** vectorized_function);
+
+ // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
+ // `outer_scope_`, until there are no convertible outputs remaining.
+ void VectorizeHelper();
+
+ // Vectorizes map_defun_fn's output at output_position.
+ Status ConvertOutput(int output_position);
// Adds mappings from node's outputs tensors to converted output tensors,
// creating the necessary new node(s). Generally, the steps to convert an op
// are:
- // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
- // and modify map_defun_node_ attrs accordingly
- // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+ // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
// These operations collectively compute the same value as what running
// the original operation on slices of the input tensors would produce.
// For example, a Cast op in MapDefun translates to a Cast op in
- // outer_scope_, since the vectorized version of Cast is itself.
- // 3) Set inputs of new node(s) to the corresponding converted inputs (that
- // are now outputs of map_defun_node_)
- // 4) For each output of the old node, add the mapping of output strings to
- // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
- Status AddConversionMappingFromOp(const NodeDef& node,
- const FunctionDefTensorDesc& output_desc);
-
- // Maps a tensor name to the name of the corresponding vectorized tensor. For
- // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
- std::map<string, string> conversion_map_;
- // Unconvertible node names
- std::set<string> unconvertible_;
-
- FunctionDef* outer_scope_;
- FunctionDef* map_defun_fn_;
- NodeDef* map_defun_node_;
+ // `outer_scope_`, since the vectorized version of Cast is itself.
+ // 2) Promote the inputs of the op inputs to outputs of the
+ // `map_defun_node_` and `map_defun_fn_`.
+ // 3) Add edges between the promoted inputs (that are now outputs of
+ // `map_defun_node`) and the inputs ports of the new node(s).
+ // 4) For each output of the old node, add the mapping of output tensors to
+ // the conversion map.
+ Status AddConversionMapping(Node* op_node);
+
+ // Maps a tensor to the corresponding vectorized tensor. For example,
+ // {"Cast" Node*, 0} -> {"Vectorize/Cast" Node*, 0}
+ std::map<TensorDesc, TensorDesc> conversion_map_;
+
+ // Unconvertible ret nodes
+ std::set<Node*> unconvertible_;
+
+ FunctionDefLibrary* lib_; // Not owned
+ FunctionLibraryDefinition lib_def_;
+ // Note that FunctionBody has a pointer to a Graph object that corresponds
+ // to the function's subgraph, with additional kArgOp and kRetValOp nodes
+ // that denote that function arguments and return values. These nodes have the
+ // attrs "T" for the type, and "index" for the argument / retval index
+ // respectively. FunctionBody also keeps track of arg/ret_nodes and
+ // arg/ret_types, that should be ordered according to argument/output indices.
+ std::unique_ptr<Graph> outer_scope_;
+ std::unique_ptr<FunctionBody> map_defun_fn_;
+ Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
+ Status status_;
};
-Status Vectorization::AddConversionMappingFromOp(
- const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
- for (const string& input_name : node.input()) {
- if (IsControlInput(input_name)) {
+Status Vectorization::AddConversionMapping(Node* op_node) {
+ for (auto edge : op_node->in_edges()) {
+ if (edge->IsControlEdge()) {
return errors::InvalidArgument(
"Vectorizing outputs with control inputs is currently not "
"supported.");
}
}
- // TODO(rachelim): Have some mechanism for registering converters and some
- // uniform, simpler way to represent them.
-
- DataTypeVector types;
- const OpDef* op_def = nullptr;
- TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
- TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
-
- std::vector<string> promoted_inputs;
- promoted_inputs.reserve(node.input_size());
- for (int i = 0; i < node.input_size(); ++i) {
- promoted_inputs.push_back(strings::StrCat(
- map_defun_node_->name(),
- ":output:", map_defun_fn_->signature().output_arg_size() + i));
- }
-
- auto vectorizer = VectorizerRegistry::Global()->Get(node.op());
+ auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
if (vectorizer == nullptr) {
return errors::Unimplemented("No vectorizer registered for op: ",
- node.op());
+ op_node->type_string());
+ }
+ std::vector<Port> input_ports, output_ports;
+ input_ports.reserve(op_node->num_inputs());
+ output_ports.reserve(op_node->num_outputs());
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ &input_ports, &output_ports));
+
+ std::vector<const Edge*> input_edges;
+ TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
+
+ if (op_node->num_outputs() != output_ports.size() ||
+ op_node->num_inputs() != input_ports.size() ||
+ input_edges.size() != input_ports.size()) {
+ return errors::Internal("Vectorizer inputs/outputs don't match.");
}
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_,
- &conversion_map_));
+ // Promote the inputs of the op to MapDefun outputs and connect the edges
+ // accordingly.
+ for (size_t i = 0; i < op_node->num_inputs(); ++i) {
+ auto edge = input_edges[i];
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1,
+ input_ports[i].first, input_ports[i].second);
+ }
- // If we get here, the conversion was successful, so we promote the inputs
- // of the ops to MapDefun outputs.
- for (int i = 0; i < types.size(); ++i) {
- AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+ // Add output mappings.
+ for (size_t i = 0; i < op_node->num_outputs(); ++i) {
+ conversion_map_.insert({{op_node, i}, std::move(output_ports[i])});
}
return Status::OK();
}
-Status Vectorization::AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc) {
- int input_index = function_utils::FindFunctionInputWithName(
- output_desc.node_name, *map_defun_fn_);
- if (input_index == -1) {
- return errors::Internal("Cannot convert non-existent input.");
+Status Vectorization::ConvertOutput(int output_position) {
+ // ret_edge->src() is the actual op that generated the retval, and
+ // ret_edge->dst() is the retval node whose op is "_Retval"
+ const Edge* ret_edge;
+ TF_RETURN_IF_ERROR(
+ map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));
+
+ TensorDesc output({ret_edge->src(), ret_edge->src_output()});
+ TensorDesc converted_output;
+ if (auto found = gtl::FindOrNull(conversion_map_, output)) {
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ converted_output = *found;
+ } else {
+ TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
+ converted_output = conversion_map_.at(output);
}
- conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
+ ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
+ outer_scope_.get());
+ RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
+ map_defun_node_);
+
return Status::OK();
}
-Status Vectorization::ConvertOutputHelper(
- const FunctionDefTensorDesc& output_desc, string* converted) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
- *converted = *found;
- return Status::OK();
+Status Vectorization::Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node,
+ FunctionDef** result) {
+ TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
+ VectorizeHelper();
+ return GetResult(result);
+}
+
+void Vectorization::VectorizeHelper() {
+ while (true) {
+ int output_position = graph_utils::GetFirstElementIndexWithPredicate(
+ [this](Node* n) {
+ return this->unconvertible_.find(n) == this->unconvertible_.end();
+ },
+ map_defun_fn_->ret_nodes);
+
+ // No outputs left to convert
+ if (output_position == -1) break;
+
+ Status s = ConvertOutput(output_position);
+ if (!s.ok()) {
+ Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
+ VLOG(2) << "Could not convert the output at node: "
+ << output_node->DebugString() << "\nError: " << s;
+ unconvertible_.insert(output_node);
+ }
}
- int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
- *map_defun_fn_);
- if (index == -1) { // The output comes from an input
- TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->ret_nodes.empty()) {
+ outer_scope_->RemoveNode(map_defun_node_);
} else {
- TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
- map_defun_fn_->node_def(index), output_desc));
+ // Update MapDefun node attrs accordingly
+ DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
+ map_defun_node_->AddAttr(
+ "output_shapes",
+ std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
+ map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
}
- *converted = conversion_map_.at(output_desc.full_str);
- return Status::OK();
}
+Status Vectorization::Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node) {
+ // Convert outer_scope and map_defun_fn to FunctionBodys so we can
+ // work on Graphs directly.
+ const FunctionDef* map_defun_fn =
+ lib_def_.Find(map_defun_node.attr().at("f").func().name());
+
+ if (map_defun_fn == nullptr) {
+ return errors::NotFound("Could not find function with name ",
+ map_defun_node.attr().at("f").func().name(),
+ " in function library.");
+ }
-Status Vectorization::ConvertOutput(int output_position,
- const FunctionDefTensorDesc& output_desc) {
- string converted_output_name;
- TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+ auto get_func_sig = [this](const string& op, const OpDef** sig) {
+ return this->lib_def_.LookUpOpDef(op, sig);
+ };
+
+ FunctionBody* outer_fn;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_,
+ get_func_sig, &outer_fn));
+ // We don't need outer_fn, just the graph
+ outer_scope_.reset(outer_fn->graph);
+ outer_fn->graph = nullptr;
+ delete outer_fn;
+
+ FunctionBody* tmp;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_,
+ get_func_sig, &tmp));
+ map_defun_fn_.reset(tmp);
+
+ // Find the MapDefun node in outer_scope_
+ int node_id = graph_utils::GetFirstElementIndexWithPredicate(
+ [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
+ outer_scope_->nodes());
+ if (node_id == -1) {
+ return errors::NotFound("Could not find node with name ",
+ map_defun_node.name(), " in outer_scope.");
+ }
+ map_defun_node_ = outer_scope_->FindNodeId(node_id);
+
+ // Add mappings from map_defun_fn_ arg nodes to map_defun_node_ input nodes to
+ // the conversion map
+ for (auto arg_node : map_defun_fn_->arg_nodes) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(
+ arg_node->attrs().Find("index")->i(), &input_node));
- // Remove the old output and make everything that referenced it point
- // to the new string
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node_->name(), ":output:", output_position),
- converted_output_name, outer_scope_);
- RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
- output_position);
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0}});
+ }
return Status::OK();
}
-void Vectorization::Vectorize() {
- while (true) {
- FunctionDefTensorDesc desc;
- int output_position =
- FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
- if (output_position == -1) break;
+Status Vectorization::GetResult(FunctionDef** vectorized_function) {
+ TF_RETURN_IF_ERROR(status_);
- if (!ConvertOutput(output_position, desc).ok()) {
- unconvertible_.insert(desc.node_name);
- }
- }
+ if (!map_defun_fn_->ret_nodes.empty()) {
+ FunctionDef* map_defun_fn = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));
- // If we've converted all the outputs of the MapDefun function, we no longer
- // need the MapDefun node and can delete it.
- if (map_defun_fn_->signature().output_arg_size() == 0) {
- outer_scope_->mutable_node_def()->DeleteSubrange(
- function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
- *outer_scope_),
- 1);
+ AttrValue func_attr;
+ func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
+ map_defun_node_->AddAttr("f", func_attr);
}
- if (!unconvertible_.empty()) {
- VLOG(2) << "The following nodes could not be converted: ["
- << absl::StrJoin(unconvertible_, ", ") << "].";
- }
+ *vectorized_function = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
+ *vectorized_function);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *outer_scope_, (*vectorized_function)->signature().name(),
+ *vectorized_function));
+ return Status::OK();
}
+
} // namespace
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node) {
- Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result) {
+ *result = nullptr;
+ return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
}
} // end namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
index bb405faa77..bd7d390900 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -24,22 +24,28 @@ namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-// Given a function, `map_defun_fn`, that is mapped across some input vector
-// elements via a MapDefun operation, `VectorizeMapDefun` attempts to
-// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the
-// `outer_scope`; that is, replacing `map_defun_fn` operations with new
-// `outer_scope` operations that produce the same vector output(s) as executing
-// the `map_defun_fn` operations on elements of vector input(s) would. If all
-// `map_defun_fn` operations are successfully lifted, `map_defun_node` is
-// eliminated from `outer_scope` altogether. However, if some operations cannot
-// be lifted, and this vectorization only succeeds partially, `map_defun_node`
-// remains to be used for operations that were not lifted.
+// Given a MapDefun node (`map_defun_node`) in a FunctionDef (`outer_scope`)
+// that maps a function in lib across some input vector elements,
+// `VectorizeMapDefun` attempts to create a vectorized version of `outer_scope`
+// by "lifting" operations from the MapDefun function to the new function
+// (`result`); that is, replacing operations in the MapDefun function with
+// operations that produce the same vector output(s) as executing the original
+// operations on elements of vector input(s) would. If all operations in the
+// MapDefun function are successfully lifted, `result` has no MapDefun node
+// altogether. However, if some operations cannot be lifted, and this
+// vectorization only succeeds partially, a MapDefun node remains in `result` to
+// be used for operations that were not lifted, and the modified MapDefun
+// function is added to `lib`. The newly vectorized function `result` is also
+// added to `lib`.
+//
+// Returns Status::OK() if the vectorization is completely or partially
+// successful. Otherwise, returns an error, and sets `result` to nullptr.
//
// Example:
// If the input to the `VectorizeMapDefun` function is a MapDefun
// whose `map_defun_fn` performs the Cast operation, the vectorization will
// eliminate the MapDefun. This is because the Cast operation supports
-// any tensor shape and can thus be lifted to the `outer_scope`.
+// any tensor shape and can thus be lifted to `result`.
//
// Before:
//
@@ -68,7 +74,7 @@ namespace vectorization_utils {
//
// After:
//
-// outer_scope +------+
+// result +------+
// +---------------+ Arg0 +---------+
// | +---+--+ |
// | | |
@@ -80,8 +86,9 @@ namespace vectorization_utils {
// +---------------+ Ret0 +---------+
// +------+
//
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node);
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result);
} // end namespace vectorization_utils
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index e129fa9237..1ff62217dd 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
@@ -60,6 +61,11 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
return node;
}
+string GetRetval(const FunctionDef& function_def, int index) {
+ return function_def.ret().at(
+ function_def.signature().output_arg(index).name());
+}
+
// TODO(rachelim): Use FunctionDefHelper::Create instead
FunctionDef CreateFunction(
StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
@@ -85,7 +91,6 @@ FunctionDef CreateFunction(
return func;
}
-TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
// Before:
//
@@ -133,10 +138,15 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
- EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_EQ(GetRetval(*vectorized, 0), "ret0");
+ EXPECT_EQ(GetRetval(*vectorized, 1), "ret1");
}
// Before:
@@ -149,12 +159,12 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// | +-----------+ Arg0 +---+ Arg1 +----+ |
// | | +---+--+ +---+--+ | |
// | | | | | |
-// | | +------+ | +---v--+ | |
-// | | |Const | | | Op0 | | |
-// | | +---v--+ | +---+--+ | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
// | | | | | | |
// | | | +---v--+ +---v--+ | |
-// | | +---| XOp1 | | XOp2 | | |
+// | | +---| XOp1 | | Cast | | |
// | | +---+--+ +---+--+ | |
// | | | | | |
// | | MapDefun +---v--+ +---v--+ | |
@@ -165,23 +175,50 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// +---------------+ Ret0 +---+ Ret1 +--------+
// +------+ +------+
//
-// where XOp1 and XOp2 are not convertible.
+// where XOp1 is not convertible.
//
// After:
//
-// No change because the ops are not convertible.
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ | |
+// | +-----------+ Arg0 +-+ | |
+// | | +---+--+ | | |
+// | | | | | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
+// | | | | | | |
+// | | | +---v--+ | +---v--+ |
+// | | +---| XOp1 | | | Cast | |
+// | | +---+--+ | +---+--+ |
+// | | | | | |
+// | | MapDefun +---v--+ | | |
+// | +-----------+ Ret0 +-+ | |
+// | +---+--+ | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
//
TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
{{"ret0", DT_INT32}, {"ret1", DT_INT32}},
- {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+ {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}});
+ // TODO(rachelim): If we ever write a converter for MatMul, we have to
+ // change this test.
NodeDef* x_op1 =
- function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+ function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner);
CHECK_NOTNULL(x_op1);
+ graph_transforms::SetNodeAttr("T", DT_INT32, x_op1);
- NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
- CHECK_NOTNULL(x_op2);
+ NodeDef* cast_node =
+ AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner);
+ CHECK_NOTNULL(cast_node);
FunctionDef outer = CreateFunction(
"outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
@@ -193,12 +230,22 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
- // They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
- EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto map_defun_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized));
+ // The Cast node should be converted just fine.
+ EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0");
+
+ // The inner function should only have one retval.
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+ EXPECT_EQ(map_defun_fn->signature().output_arg_size(), 1);
}
// Before:
@@ -257,14 +304,19 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -330,16 +382,21 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -411,21 +468,26 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), "x");
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -486,7 +548,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{"ret1", "MyUnstack:output:1"},
{"ret2", "MyUnstack:output:2"}});
NodeDef* cast_op =
- AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner);
CHECK_NOTNULL(cast_op);
NodeDef* unstack_op =
AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
@@ -505,25 +567,30 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 2);
+ EXPECT_EQ(vectorized->node_def_size(), 2);
}
// Before:
@@ -561,9 +628,11 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}},
{{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
- // The attrs aren't relevant
- NodeDef* print_op =
- function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+ NodeDef* print_op = function_utils::AddNode(
+ "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner);
+ graph_transforms::SetNodeAttr("T", DT_INT32, print_op);
+ graph_transforms::SetNodeAttr("U", gtl::ArraySlice<DataType>({DT_INT32}),
+ print_op);
CHECK_NOTNULL(print_op);
NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
false, &inner);
@@ -578,11 +647,27 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
// They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+ // We check this somewhat manually as the names of nodes may have changed
+ EXPECT_EQ(vectorized->node_def_size(), 1);
+ const NodeDef& map_defun_node = vectorized->node_def(0);
+ EXPECT_EQ(map_defun_node.op(), "MapDefun");
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+
+ const NodeDef& print_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Print", *map_defun_fn));
+ const NodeDef& cast_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *map_defun_fn));
+ string control_input = strings::StrCat("^", print_node.name());
+ EXPECT_TRUE(cast_node.input(0) == control_input ||
+ cast_node.input(1) == control_input);
}
// TODO(rachelim): More test cases when we get around to implementing them: