diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-28 16:10:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 16:17:55 -0700 |
commit | 478d370eb116ad2294134d75a886637a7d6da225 (patch) | |
tree | 279ef8e8a2c9abeeda583393a986f055b9be314c | |
parent | a98bac521406bedef3ff2b9af9564b21ddda4d82 (diff) |
[tf.data] Use Graph instead of GraphDef/FunctionDef for vectorization transforms
PiperOrigin-RevId: 215011835
12 files changed, 574 insertions, 375 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 81c1bddf67..5a3abbb545 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -124,10 +124,10 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/grappler:mutable_graph_view", - "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", ] + tf_protos_all(), ) @@ -523,6 +523,7 @@ cc_library( ":function_utils", ":graph_utils", "@com_google_absl//absl/strings", + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -538,6 +539,7 @@ tf_cc_test( srcs = ["vectorization_utils_test.cc"], visibility = ["//visibility:public"], deps = [ + ":graph_utils", ":function_utils", ":vectorization_utils", "//tensorflow/core:framework", @@ -547,7 +549,10 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + # For ops we need registered + "//tensorflow/core/kernels/data:dataset_ops", "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/kernels:logging_ops", "//tensorflow/tools/graph_transforms:transform_utils", ] + tf_protos_all(), ) diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index 5dd7819100..3af34f6904 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -116,8 +116,8 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op, // is unique across the graph. void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node); -// Sets the node name using the `prefix` name as a prefix while guaranteeing the -// name is unique across the graph. +// Sets the function name using the `prefix` name as a prefix while guaranteeing +// the name is unique across the function library. void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library, FunctionDef* function); diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index 32ab912619..9328a7ca99 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -86,21 +86,19 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node, // efficient vectorization with VectorizeMapDefun. FunctionDef* vectorized_func = CreateMapDefunWrapper(map_node, orig_func, library); - NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0); - DCHECK_EQ(map_defun_node->op(), "MapDefun"); - - // Create a copy of the original function so that we can mutate it, and - // attach that to the map defun node. - FunctionDef* map_defun_fn = library->add_function(); - *map_defun_fn = orig_func; - graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library, - map_defun_fn); - (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name( - map_defun_fn->signature().name()); - - vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn, - map_defun_node); - return vectorized_func; + const NodeDef& map_defun_node = vectorized_func->node_def(0); + DCHECK_EQ(map_defun_node.op(), "MapDefun"); + + // TODO(b/116285210): Unreferenced functions should get cleaned up later + FunctionDef* result; + Status s = vectorization_utils::VectorizeMapDefun( + *vectorized_func, map_defun_node, library, &result); + + if (!s.ok()) { + LOG(ERROR) << "VectorizeMapDefun failed: " << s; + return vectorized_func; + } + return result; } bool IsOutputShapesFullyDefined(const NodeDef& node) { diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc index ed1bd6bc97..f4faf41549 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc @@ -30,72 +30,51 @@ namespace { using test::function::GDef; using test::function::NDef; -void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims, - TensorShapeProto* t) { - for (size_t i = 0; i < dims.size(); ++i) { - auto* d = t->add_dim(); - d->set_size(dims[i]); - } -} - -AttrValue MakeShapeListAttr( - const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) { - AttrValue shapes_attr; - for (size_t i = 0; i < shapes.size(); ++i) { - MakeTensorShapeProtoHelper(shapes[i], - shapes_attr.mutable_list()->add_shape()); - } - - return shapes_attr; -} - -NodeDef MakeMapNodeHelper( - StringPiece name, StringPiece input_node_name, StringPiece function_name, - StringPiece map_op_name, - const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, - const gtl::ArraySlice<DataType>& output_types) { +NodeDef MakeMapNodeHelper(StringPiece name, StringPiece input_node_name, + StringPiece function_name, StringPiece map_op_name, + gtl::ArraySlice<PartialTensorShape> output_shapes, + gtl::ArraySlice<DataType> output_types) { return test::function::NDef( name, map_op_name, {string(input_node_name)}, {{"f", FunctionDefHelper::FunctionRef(string(function_name))}, {"Targuments", {}}, - {"output_shapes", MakeShapeListAttr(output_shapes)}, + {"output_shapes", output_shapes}, {"output_types", output_types}}); } -NodeDef MakeMapNode( - StringPiece name, StringPiece input_node_name, StringPiece function_name, - const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, - const gtl::ArraySlice<DataType>& output_types) { +NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name, + gtl::ArraySlice<PartialTensorShape> output_shapes, + gtl::ArraySlice<DataType> output_types) { return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset", output_shapes, output_types); } -NodeDef MakeBatchNode( - StringPiece name, StringPiece input_node_name, - StringPiece input_batch_size_name, - const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, - const gtl::ArraySlice<DataType>& output_types) { - return NDef(name, "BatchDataset", - {string(input_node_name), string(input_batch_size_name)}, - {{"output_types", output_types}, - {"output_shapes", MakeShapeListAttr(output_shapes)}}); +NodeDef MakeBatchNode(StringPiece name, StringPiece input_node_name, + StringPiece input_batch_size_name, + gtl::ArraySlice<PartialTensorShape> output_shapes, + gtl::ArraySlice<DataType> output_types) { + return NDef( + name, "BatchDataset", + {string(input_node_name), string(input_batch_size_name)}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); } -NodeDef MakeBatchV2Node( - StringPiece name, StringPiece input_node_name, - StringPiece input_batch_size_name, StringPiece input_drop_remainder_name, - const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, - const gtl::ArraySlice<DataType>& output_types) { - return NDef(name, "BatchDatasetV2", - {string(input_node_name), string(input_batch_size_name), - string(input_drop_remainder_name)}, - {{"output_types", output_types}, - {"output_shapes", MakeShapeListAttr(output_shapes)}}); +NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name, + StringPiece input_batch_size_name, + StringPiece input_drop_remainder_name, + gtl::ArraySlice<PartialTensorShape> output_shapes, + gtl::ArraySlice<DataType> output_types) { + return NDef( + name, "BatchDatasetV2", + {string(input_node_name), string(input_batch_size_name), + string(input_drop_remainder_name)}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); } -NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) { +NodeDef MakeRangeNode(StringPiece name, gtl::ArraySlice<string> inputs) { return NDef(name, "RangeDataset", inputs, - {{"output_shapes", MakeShapeListAttr({{}})}, + {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}, {"output_types", gtl::ArraySlice<DataType>({DT_INT64})}}); } @@ -184,7 +163,7 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) { item.graph = GDef( {NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), NDef("input", "InputDataset", {}, - {{"output_shapes", MakeShapeListAttr({{}})}}), + {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}}), MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}), MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})}, // FunctionLib @@ -196,6 +175,37 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) { TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); } +TEST(MapVectorizationTest, VectorizeWithFullyDefinedFunction) { + GrapplerItem item; + item.graph = GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + MakeRangeNode("range", {"start", "stop", "step"}), + MakeMapNode("map", "range", "Func", {{}}, {DT_INT32}), + MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})}, + // FunctionLib + {FunctionDefHelper::Create( + "Func", {"x: int64", "y: int64"}, {"res: int64", "res2: int64"}, {}, + {{{"o"}, "Mul", {"x", "x"}, {{"T", DT_INT64}}}}, + {{"res", "o:z"}, {"res2", "o:z"}})}); + MapVectorization optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(), + 1); + EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(), + 1); + const NodeDef& map_node = + output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output)); + const NodeDef& batch_node = + output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output)); + EXPECT_EQ(map_node.input(0), batch_node.name()); + EXPECT_EQ(batch_node.input(0), "range"); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD index 1462cb234d..37aa24b947 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD +++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD @@ -9,13 +9,14 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all") VECTORIZER_DEPS = [ ":vectorizer_registry", - "//tensorflow/core/grappler/optimizers/data:function_utils", + "//tensorflow/core/grappler/optimizers/data:graph_utils", ] + tf_protos_all() cc_library( name = "vectorizer", hdrs = ["vectorizer.h"], deps = [ + "//tensorflow/core:core_cpu", "//tensorflow/core:lib", ] + tf_protos_all(), ) diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc index c1739737a0..3af6bab409 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" namespace tensorflow { @@ -23,26 +23,21 @@ namespace vectorization_utils { class CastVectorizer : public Vectorizer { public: - Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, - FunctionDef* outer_scope, - std::map<string, string>* conversion_map) override { - if (inputs.size() != 1) { + Status Vectorize(const Node& node, Graph* outer_scope, + std::vector<Port>* input_ports, + std::vector<Port>* output_ports) override { + Status s; + if (node.num_inputs() != 1) { return errors::Internal("Cast op should only have one input."); } - // Add new Cast node - NodeDef* new_cast_node = outer_scope->add_node_def(); - *new_cast_node = node; - new_cast_node->clear_name(); - function_utils::SetUniqueFunctionNodeName( - strings::StrCat("vectorized/", node.name()), outer_scope, - new_cast_node); - new_cast_node->set_input(0, inputs[0]); - - // Add the output mapping to conversion map - (*conversion_map)[strings::StrCat(node.name(), ":y:0")] = - strings::StrCat(new_cast_node->name(), ":y:0"); + // Add new Cast node with the same op and attrs as the original node + auto new_cast_node = outer_scope->AddNode(node.def(), &s); + TF_RETURN_IF_ERROR(s); + // Add input and output mappings + input_ports->push_back({new_cast_node, 0}); + output_ports->push_back({new_cast_node, 0}); return Status::OK(); } }; diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc index 776d3179c5..74ce520ce1 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" namespace tensorflow { @@ -23,31 +23,29 @@ namespace vectorization_utils { class UnpackVectorizer : public Vectorizer { public: - Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, - FunctionDef* outer_scope, - std::map<string, string>* conversion_map) override { - if (inputs.size() != 1) { + Status Vectorize(const Node& node, Graph* outer_scope, + std::vector<Port>* input_ports, + std::vector<Port>* output_ports) override { + Status s; + if (node.num_inputs() != 1) { return errors::Internal("Unpack op should only have one input."); } - // Add new Unpack node - NodeDef* new_unpack_node = outer_scope->add_node_def(); - *new_unpack_node = node; - new_unpack_node->clear_name(); - function_utils::SetUniqueFunctionNodeName( - strings::StrCat("vectorized/", node.name()), outer_scope, - new_unpack_node); + // Add new Unpack node with the same op and attrs as the original node + auto new_unpack_node = outer_scope->AddNode(node.def(), &s); + TF_RETURN_IF_ERROR(s); // Increment "axis" attr by 1: - (*new_unpack_node->mutable_attr())["axis"].set_i( - node.attr().at("axis").i() + 1); - new_unpack_node->set_input(0, inputs[0]); + int new_axis = node.def().attr().at("axis").i() + 1; + new_unpack_node->AddAttr("axis", new_axis); - // Add the output mappings to conversion map - int num = new_unpack_node->attr().at("num").i(); + // Add the input mappings + input_ports->push_back({new_unpack_node, 0}); + + // Add the output mappings + int num = node.def().attr().at("num").i(); for (int i = 0; i < num; ++i) { - (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] = - strings::StrCat(new_unpack_node->name(), ":output:", i); + output_ports->push_back({new_unpack_node, i}); } return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h index d341dbba7d..56eb88c95e 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h @@ -17,30 +17,33 @@ limitations under the License. #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ #include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { namespace grappler { namespace vectorization_utils { +// Describes a tensor with its operation Node and output position +typedef std::pair<Node*, int> Port; + // Interface for vectorization of TensorFlow operations. See `CastVectorizer` // for an example. class Vectorizer { public: virtual ~Vectorizer() {} - // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope` + // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope` // that produce the same vector output(s) as executing `node`'s op - // on elements of the vector inputs, and adding mappings to `conversion_map` - // from old output tensor names to new (vectorized) output tensor names. - // The new node(s) collectively have the same number of inputs and outputs as - // the node being converted, and use the tensor names in `inputs` as their - // inputs. - virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, - FunctionDef* outer_scope, - std::map<string, string>* conversion_map) = 0; + // on elements of the vector inputs. The new Node(s) collectively have the + // same number of input and output ports as the node being converted. + // Adds mappings for the new nodes' input and output ports to `inputs` and + // `outputs` respectively, where the i'th Port in inputs/outputs + // corresponds to the i'th input/output port of the node to be converted. + virtual Status Vectorize(const Node& node, Graph* outer_scope, + std::vector<Port>* input_ports, + std::vector<Port>* output_ports) = 0; }; } // namespace vectorization_utils diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc index 86e303564b..663ceba027 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc @@ -24,9 +24,9 @@ namespace vectorization_utils { class TestVectorizer : public Vectorizer { public: - Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, - FunctionDef* outer_scope, - std::map<string, string>* conversion_map) override { + Status Vectorize(const Node& node, Graph* outer_scope, + std::vector<Port>* inputs, + std::vector<Port>* outputs) override { return Status::OK(); } }; @@ -39,10 +39,12 @@ TEST(TestVectorizer, TestTestVectorizer) { auto vectorizer = VectorizerRegistry::Global()->Get("test_op"); EXPECT_NE(vectorizer, nullptr); - FunctionDef function; - NodeDef node; - std::map<string, string> conversion_map; - EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok()); + Graph g(OpRegistry::Global()); + NodeDef node_def; + Status s; + Node* node = g.AddNode(node_def, &s); + std::vector<Port> inputs, outputs; + EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok()); } } // namespace vectorization_utils diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index cb56b65985..cea667f668 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -14,13 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" +#include <memory> #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" #include "absl/strings/str_join.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" @@ -36,255 +40,346 @@ namespace tensorflow { namespace grappler { namespace vectorization_utils { -using function_utils::FunctionDefTensorDesc; - namespace { -void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node, - const string& output_retval, const DataType t) { - // Set to unknown shape - TensorShapeProto tensor_shape_proto; - PartialTensorShape().AsProto(&tensor_shape_proto); +// Describes a tensor with its operation Node and output position +typedef std::pair<Node*, int> TensorDesc; - function_utils::AddFunctionOutputWithUniqueName( - "vectorized_out", output_retval, map_defun_fn, t); +const char* const kRetValOp = "_Retval"; - *(*map_defun_node->mutable_attr())["output_shapes"] - .mutable_list() - ->add_shape() = tensor_shape_proto; - (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t); +void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src, + Graph* graph) { + // NOTE: We need two for loops here because we can't mutate the set of output + // edges as we iterate over them. + std::vector<const Edge*> edges_to_replace; + for (auto edge : old_src.first->out_edges()) { + if (edge->src_output() == old_src.second) { + edges_to_replace.push_back(edge); + } + } + for (auto edge : edges_to_replace) { + graph->AddEdge(new_src.first, new_src.second, edge->dst(), + edge->dst_input()); + graph->RemoveEdge(edge); + } } -void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node, int output_position) { - DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size()) - << "Trying to remove output that doesn't exist. Output number: " - << output_position; +Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node, + const TensorDesc& output) { + // Note that we don't update MapDefun attrs as we go, only when we are done + DataType type = output.first->output_type(output.second); + int index = map_defun_fn->ret_nodes.size(); - int num_later_outputs = - map_defun_fn->signature().output_arg_size() - output_position - 1; + NodeDef ret_node_def; + ret_node_def.set_name("map_out"); + ret_node_def.set_op(kRetValOp); + AddNodeAttr("T", type, &ret_node_def); + AddNodeAttr("index", index, &ret_node_def); - // Remove from map_defun_fn's ret dict and output args - map_defun_fn->mutable_ret()->erase( - map_defun_fn->signature().output_arg(output_position).name()); - map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange( - output_position, 1); + Status s; + Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s); + TF_RETURN_IF_ERROR(s); - // Renumber outputs that come after - for (int i = 0; i < num_later_outputs; ++i) { - function_utils::ReplaceReferences( - strings::StrCat(map_defun_node->name(), - ":output:", output_position + i + 1), - strings::StrCat(map_defun_node->name(), - ":output:", output_position + i), - outer_scope); - } - map_defun_node->mutable_attr() - ->at("output_shapes") - .mutable_list() - ->mutable_shape() - ->DeleteSubrange(output_position, 1); - map_defun_node->mutable_attr() - ->at("output_types") - .mutable_list() - ->mutable_type() - ->ExtractSubrange(output_position, 1, nullptr); + map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0); + map_defun_fn->ret_nodes.push_back(ret_node); + map_defun_fn->ret_types.push_back(type); + + return s; } -int FindOutputToConvert(const FunctionDef& function, - const std::set<string>& unconvertible, - FunctionDefTensorDesc* f) { - for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) { - const string& ret_key = function.signature().output_arg(i).name(); - *f = FunctionDefTensorDesc(function.ret().at(ret_key)); +void RemoveMapDefunOutput(int output_position, Graph* outer_scope, + FunctionBody* map_defun_fn, Node* map_defun_node) { + // Note that we don't update MapDefun attrs as we go, only when we are done + DCHECK_LT(output_position, map_defun_fn->ret_nodes.size()) + << "Trying to remove output that doesn't exist. Output number: " + << output_position; + + int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1; - if (unconvertible.find(f->node_name) == unconvertible.end()) { - return i; - } + // Modify map_defun_fn's signature and remove the output node from its graph + map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]); + map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() + + output_position); + map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() + + output_position); + + // Renumber the nodes and edges that come after + for (int i = 0; i < num_later_outputs; ++i) { + ReplaceEdgeSources({map_defun_node, output_position + i + 1}, + {map_defun_node, output_position + i}, outer_scope); + // Each ret node has an "index" attr that has to be updated + map_defun_fn->ret_nodes[output_position + i]->AddAttr("index", + output_position + i); } - return -1; } // Helper class that vectorizes the body of a MapDefun node, adding new // operations to the graph that collectively compute the same value as what // running the MapDefun function on slices of the input would produce. -// Each instance of the class encapsulates all the data necessary to vectorize a -// MapDefun op in place. +// This class transforms the input FunctionDefs into their corresponding +// Graph objects and works on the graphs directly, then converts them back +// to FunctionDefs when GetResult is called. class Vectorization { public: - Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node) - : outer_scope_(outer_scope), - map_defun_fn_(map_defun_fn), - map_defun_node_(map_defun_node) {} + explicit Vectorization(FunctionDefLibrary* lib) + : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {} - // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in - // the outer_scope_, until there are no convertible outputs remaining. - // This method is idempotent. - void Vectorize(); + // Adds the vectorized function and new map_defun_fn to lib, and points + // vectorized_function to the former. Returns an error status if + // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere + // along the way. + Status Vectorize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, FunctionDef** result); private: - // Vectorizes the map defun function's output at output_position - Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc); - // Given a descriptor of the original output tensor, gets a string - // corresponding to the converted output tensor. - Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc, - string* converted); - Status AddConversionMappingFromInput( - const FunctionDefTensorDesc& output_desc); + // Converts FunctionDefs to Graphs. + Status Initialize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node); + + // Converts Graphs back to FunctionDefs and adds them to `lib_`. + Status GetResult(FunctionDef** vectorized_function); + + // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in + // `outer_scope_`, until there are no convertible outputs remaining. + void VectorizeHelper(); + + // Vectorizes map_defun_fn's output at output_position. + Status ConvertOutput(int output_position); // Adds mappings from node's outputs tensors to converted output tensors, // creating the necessary new node(s). Generally, the steps to convert an op // are: - // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_, - // and modify map_defun_node_ attrs accordingly - // 2) Create new node(s) in outer_scope_ that act on batched input tensors. + // 1) Create new node(s) in `outer_scope_` that act on batched input tensors. // These operations collectively compute the same value as what running // the original operation on slices of the input tensors would produce. // For example, a Cast op in MapDefun translates to a Cast op in - // outer_scope_, since the vectorized version of Cast is itself. - // 3) Set inputs of new node(s) to the corresponding converted inputs (that - // are now outputs of map_defun_node_) - // 4) For each output of the old node, add the mapping of output strings to - // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0") - Status AddConversionMappingFromOp(const NodeDef& node, - const FunctionDefTensorDesc& output_desc); - - // Maps a tensor name to the name of the corresponding vectorized tensor. For - // example, "Cast:y:0" -> "Vectorize/Cast:y:0" - std::map<string, string> conversion_map_; - // Unconvertible node names - std::set<string> unconvertible_; - - FunctionDef* outer_scope_; - FunctionDef* map_defun_fn_; - NodeDef* map_defun_node_; + // `outer_scope_`, since the vectorized version of Cast is itself. + // 2) Promote the inputs of the op inputs to outputs of the + // `map_defun_node_` and `map_defun_fn_`. + // 3) Add edges between the promoted inputs (that are now outputs of + // `map_defun_node`) and the inputs ports of the new node(s). + // 4) For each output of the old node, add the mapping of output tensors to + // the conversion map. + Status AddConversionMapping(Node* op_node); + + // Maps a tensor to the corresponding vectorized tensor. For example, + // {"Cast" Node*, 0} -> {"Vectorize/Cast" Node*, 0} + std::map<TensorDesc, TensorDesc> conversion_map_; + + // Unconvertible ret nodes + std::set<Node*> unconvertible_; + + FunctionDefLibrary* lib_; // Not owned + FunctionLibraryDefinition lib_def_; + // Note that FunctionBody has a pointer to a Graph object that corresponds + // to the function's subgraph, with additional kArgOp and kRetValOp nodes + // that denote that function arguments and return values. These nodes have the + // attrs "T" for the type, and "index" for the argument / retval index + // respectively. FunctionBody also keeps track of arg/ret_nodes and + // arg/ret_types, that should be ordered according to argument/output indices. + std::unique_ptr<Graph> outer_scope_; + std::unique_ptr<FunctionBody> map_defun_fn_; + Node* map_defun_node_ = nullptr; // Owned by `outer_scope` + Status status_; }; -Status Vectorization::AddConversionMappingFromOp( - const NodeDef& node, const FunctionDefTensorDesc& output_desc) { - for (const string& input_name : node.input()) { - if (IsControlInput(input_name)) { +Status Vectorization::AddConversionMapping(Node* op_node) { + for (auto edge : op_node->in_edges()) { + if (edge->IsControlEdge()) { return errors::InvalidArgument( "Vectorizing outputs with control inputs is currently not " "supported."); } } - // TODO(rachelim): Have some mechanism for registering converters and some - // uniform, simpler way to represent them. - - DataTypeVector types; - const OpDef* op_def = nullptr; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def)); - TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types)); - - std::vector<string> promoted_inputs; - promoted_inputs.reserve(node.input_size()); - for (int i = 0; i < node.input_size(); ++i) { - promoted_inputs.push_back(strings::StrCat( - map_defun_node_->name(), - ":output:", map_defun_fn_->signature().output_arg_size() + i)); - } - - auto vectorizer = VectorizerRegistry::Global()->Get(node.op()); + auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string()); if (vectorizer == nullptr) { return errors::Unimplemented("No vectorizer registered for op: ", - node.op()); + op_node->type_string()); + } + std::vector<Port> input_ports, output_ports; + input_ports.reserve(op_node->num_inputs()); + output_ports.reserve(op_node->num_outputs()); + TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(), + &input_ports, &output_ports)); + + std::vector<const Edge*> input_edges; + TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges)); + + if (op_node->num_outputs() != output_ports.size() || + op_node->num_inputs() != input_ports.size() || + input_edges.size() != input_ports.size()) { + return errors::Internal("Vectorizer inputs/outputs don't match."); } - TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_, - &conversion_map_)); + // Promote the inputs of the op to MapDefun outputs and connect the edges + // accordingly. + for (size_t i = 0; i < op_node->num_inputs(); ++i) { + auto edge = input_edges[i]; + TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_, + {edge->src(), edge->src_output()})); + outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1, + input_ports[i].first, input_ports[i].second); + } - // If we get here, the conversion was successful, so we promote the inputs - // of the ops to MapDefun outputs. - for (int i = 0; i < types.size(); ++i) { - AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]); + // Add output mappings. + for (size_t i = 0; i < op_node->num_outputs(); ++i) { + conversion_map_.insert({{op_node, i}, std::move(output_ports[i])}); } return Status::OK(); } -Status Vectorization::AddConversionMappingFromInput( - const FunctionDefTensorDesc& output_desc) { - int input_index = function_utils::FindFunctionInputWithName( - output_desc.node_name, *map_defun_fn_); - if (input_index == -1) { - return errors::Internal("Cannot convert non-existent input."); +Status Vectorization::ConvertOutput(int output_position) { + // ret_edge->src() is the actual op that generated the retval, and + // ret_edge->dst() is the retval node whose op is "_Retval" + const Edge* ret_edge; + TF_RETURN_IF_ERROR( + map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge)); + + TensorDesc output({ret_edge->src(), ret_edge->src_output()}); + TensorDesc converted_output; + if (auto found = gtl::FindOrNull(conversion_map_, output)) { + // It's possible the output already has a mapping, if it comes from a node + // that has already been converted. + converted_output = *found; + } else { + TF_RETURN_IF_ERROR(AddConversionMapping(output.first)); + converted_output = conversion_map_.at(output); } - conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index); + ReplaceEdgeSources({map_defun_node_, output_position}, converted_output, + outer_scope_.get()); + RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(), + map_defun_node_); + return Status::OK(); } -Status Vectorization::ConvertOutputHelper( - const FunctionDefTensorDesc& output_desc, string* converted) { - // It's possible the output already has a mapping, if it comes from a node - // that has already been converted. - if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) { - *converted = *found; - return Status::OK(); +Status Vectorization::Vectorize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, + FunctionDef** result) { + TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node)); + VectorizeHelper(); + return GetResult(result); +} + +void Vectorization::VectorizeHelper() { + while (true) { + int output_position = graph_utils::GetFirstElementIndexWithPredicate( + [this](Node* n) { + return this->unconvertible_.find(n) == this->unconvertible_.end(); + }, + map_defun_fn_->ret_nodes); + + // No outputs left to convert + if (output_position == -1) break; + + Status s = ConvertOutput(output_position); + if (!s.ok()) { + Node* output_node = map_defun_fn_->ret_nodes.at(output_position); + VLOG(2) << "Could not convert the output at node: " + << output_node->DebugString() << "\nError: " << s; + unconvertible_.insert(output_node); + } } - int index = function_utils::FindFunctionNodeWithName(output_desc.node_name, - *map_defun_fn_); - if (index == -1) { // The output comes from an input - TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc)); + // If we've converted all the outputs of the MapDefun function, we no longer + // need the MapDefun node and can delete it. + if (map_defun_fn_->ret_nodes.empty()) { + outer_scope_->RemoveNode(map_defun_node_); } else { - TF_RETURN_IF_ERROR(AddConversionMappingFromOp( - map_defun_fn_->node_def(index), output_desc)); + // Update MapDefun node attrs accordingly + DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size()); + map_defun_node_->AddAttr( + "output_shapes", + std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size())); + map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types); } - *converted = conversion_map_.at(output_desc.full_str); - return Status::OK(); } +Status Vectorization::Initialize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node) { + // Convert outer_scope and map_defun_fn to FunctionBodys so we can + // work on Graphs directly. + const FunctionDef* map_defun_fn = + lib_def_.Find(map_defun_node.attr().at("f").func().name()); + + if (map_defun_fn == nullptr) { + return errors::NotFound("Could not find function with name ", + map_defun_node.attr().at("f").func().name(), + " in function library."); + } -Status Vectorization::ConvertOutput(int output_position, - const FunctionDefTensorDesc& output_desc) { - string converted_output_name; - TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name)); + auto get_func_sig = [this](const string& op, const OpDef** sig) { + return this->lib_def_.LookUpOpDef(op, sig); + }; + + FunctionBody* outer_fn; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_, + get_func_sig, &outer_fn)); + // We don't need outer_fn, just the graph + outer_scope_.reset(outer_fn->graph); + outer_fn->graph = nullptr; + delete outer_fn; + + FunctionBody* tmp; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_, + get_func_sig, &tmp)); + map_defun_fn_.reset(tmp); + + // Find the MapDefun node in outer_scope_ + int node_id = graph_utils::GetFirstElementIndexWithPredicate( + [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); }, + outer_scope_->nodes()); + if (node_id == -1) { + return errors::NotFound("Could not find node with name ", + map_defun_node.name(), " in outer_scope."); + } + map_defun_node_ = outer_scope_->FindNodeId(node_id); + + // Add mappings from map_defun_fn_ arg nodes to map_defun_node_ input nodes to + // the conversion map + for (auto arg_node : map_defun_fn_->arg_nodes) { + Node* input_node; + TF_RETURN_IF_ERROR(map_defun_node_->input_node( + arg_node->attrs().Find("index")->i(), &input_node)); - // Remove the old output and make everything that referenced it point - // to the new string - function_utils::ReplaceReferences( - strings::StrCat(map_defun_node_->name(), ":output:", output_position), - converted_output_name, outer_scope_); - RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_, - output_position); + conversion_map_.insert({{arg_node, 0}, {input_node, 0}}); + } return Status::OK(); } -void Vectorization::Vectorize() { - while (true) { - FunctionDefTensorDesc desc; - int output_position = - FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc); - if (output_position == -1) break; +Status Vectorization::GetResult(FunctionDef** vectorized_function) { + TF_RETURN_IF_ERROR(status_); - if (!ConvertOutput(output_position, desc).ok()) { - unconvertible_.insert(desc.node_name); - } - } + if (!map_defun_fn_->ret_nodes.empty()) { + FunctionDef* map_defun_fn = lib_->add_function(); + graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn); + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn)); - // If we've converted all the outputs of the MapDefun function, we no longer - // need the MapDefun node and can delete it. - if (map_defun_fn_->signature().output_arg_size() == 0) { - outer_scope_->mutable_node_def()->DeleteSubrange( - function_utils::FindFunctionNodeWithName(map_defun_node_->name(), - *outer_scope_), - 1); + AttrValue func_attr; + func_attr.mutable_func()->set_name(map_defun_fn->signature().name()); + map_defun_node_->AddAttr("f", func_attr); } - if (!unconvertible_.empty()) { - VLOG(2) << "The following nodes could not be converted: [" - << absl::StrJoin(unconvertible_, ", ") << "]."; - } + *vectorized_function = lib_->add_function(); + graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_, + *vectorized_function); + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *outer_scope_, (*vectorized_function)->signature().name(), + *vectorized_function)); + return Status::OK(); } + } // namespace -void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node) { - Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize(); +Status VectorizeMapDefun(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, FunctionDefLibrary* lib, + FunctionDef** result) { + *result = nullptr; + return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result); } } // end namespace vectorization_utils diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h index bb405faa77..bd7d390900 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h @@ -24,22 +24,28 @@ namespace tensorflow { namespace grappler { namespace vectorization_utils { -// Given a function, `map_defun_fn`, that is mapped across some input vector -// elements via a MapDefun operation, `VectorizeMapDefun` attempts to -// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the -// `outer_scope`; that is, replacing `map_defun_fn` operations with new -// `outer_scope` operations that produce the same vector output(s) as executing -// the `map_defun_fn` operations on elements of vector input(s) would. If all -// `map_defun_fn` operations are successfully lifted, `map_defun_node` is -// eliminated from `outer_scope` altogether. However, if some operations cannot -// be lifted, and this vectorization only succeeds partially, `map_defun_node` -// remains to be used for operations that were not lifted. +// Given a MapDefun node (`map_defun_node`) in a FunctionDef (`outer_scope`) +// that maps a function in lib across some input vector elements, +// `VectorizeMapDefun` attempts to create a vectorized version of `outer_scope` +// by "lifting" operations from the MapDefun function to the new function +// (`result`); that is, replacing operations in the MapDefun function with +// operations that produce the same vector output(s) as executing the original +// operations on elements of vector input(s) would. If all operations in the +// MapDefun function are successfully lifted, `result` has no MapDefun node +// altogether. However, if some operations cannot be lifted, and this +// vectorization only succeeds partially, a MapDefun node remains in `result` to +// be used for operations that were not lifted, and the modified MapDefun +// function is added to `lib`. The newly vectorized function `result` is also +// added to `lib`. +// +// Returns Status::OK() if the vectorization is completely or partially +// successful. Otherwise, returns an error, and sets `result` to nullptr. // // Example: // If the input to the `VectorizeMapDefun` function is a MapDefun // whose `map_defun_fn` performs the Cast operation, the vectorization will // eliminate the MapDefun. This is because the Cast operation supports -// any tensor shape and can thus be lifted to the `outer_scope`. +// any tensor shape and can thus be lifted to `result`. // // Before: // @@ -68,7 +74,7 @@ namespace vectorization_utils { // // After: // -// outer_scope +------+ +// result +------+ // +---------------+ Arg0 +---------+ // | +---+--+ | // | | | @@ -80,8 +86,9 @@ namespace vectorization_utils { // +---------------+ Ret0 +---------+ // +------+ // -void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node); +Status VectorizeMapDefun(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, FunctionDefLibrary* lib, + FunctionDef** result); } // end namespace vectorization_utils } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc index e129fa9237..1ff62217dd 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" @@ -60,6 +61,11 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs, return node; } +string GetRetval(const FunctionDef& function_def, int index) { + return function_def.ret().at( + function_def.signature().output_arg(index).name()); +} + // TODO(rachelim): Use FunctionDefHelper::Create instead FunctionDef CreateFunction( StringPiece name, const std::vector<std::pair<string, DataType>>& inputs, @@ -85,7 +91,6 @@ FunctionDef CreateFunction( return func; } -TEST(FunctionDefInputDescTest, ConstructedCorrectly) {} // Before: // @@ -133,10 +138,15 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { {{}, {}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - EXPECT_EQ(outer.ret().at("mapdefun"), "ret0"); - EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1"); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + EXPECT_EQ(GetRetval(*vectorized, 0), "ret0"); + EXPECT_EQ(GetRetval(*vectorized, 1), "ret1"); } // Before: @@ -149,12 +159,12 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { // | +-----------+ Arg0 +---+ Arg1 +----+ | // | | +---+--+ +---+--+ | | // | | | | | | -// | | +------+ | +---v--+ | | -// | | |Const | | | Op0 | | | -// | | +---v--+ | +---+--+ | | +// | | +------+ | | | | +// | | |Const | | | | | +// | | +---v--+ | | | | // | | | | | | | // | | | +---v--+ +---v--+ | | -// | | +---| XOp1 | | XOp2 | | | +// | | +---| XOp1 | | Cast | | | // | | +---+--+ +---+--+ | | // | | | | | | // | | MapDefun +---v--+ +---v--+ | | @@ -165,23 +175,50 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { // +---------------+ Ret0 +---+ Ret1 +--------+ // +------+ +------+ // -// where XOp1 and XOp2 are not convertible. +// where XOp1 is not convertible. // // After: // -// No change because the ops are not convertible. +// +// +------+ +------+ +// +---------------+ Arg0 +---+ Arg1 +--------+ +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ | | +// | +-----------+ Arg0 +-+ | | +// | | +---+--+ | | | +// | | | | | | +// | | +------+ | | | | +// | | |Const | | | | | +// | | +---v--+ | | | | +// | | | | | | | +// | | | +---v--+ | +---v--+ | +// | | +---| XOp1 | | | Cast | | +// | | +---+--+ | +---+--+ | +// | | | | | | +// | | MapDefun +---v--+ | | | +// | +-----------+ Ret0 +-+ | | +// | +---+--+ | | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ // TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) { FunctionDef inner = CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}}, {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, - {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}}); + {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}}); + // TODO(rachelim): If we ever write a converter for MatMul, we have to + // change this test. NodeDef* x_op1 = - function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner); + function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner); CHECK_NOTNULL(x_op1); + graph_transforms::SetNodeAttr("T", DT_INT32, x_op1); - NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner); - CHECK_NOTNULL(x_op2); + NodeDef* cast_node = + AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner); + CHECK_NOTNULL(cast_node); FunctionDef outer = CreateFunction( "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}}, @@ -193,12 +230,22 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) { {{}, {}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - FunctionDef outer_copy(outer); - FunctionDef inner_copy(inner); - VectorizeMapDefun(&outer, &inner, map_defun); - // They should be unchanged - EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer)); - EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + + auto map_defun_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized)); + // The Cast node should be converted just fine. + EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0"); + + // The inner function should only have one retval. + FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib); + const FunctionDef* map_defun_fn = + lib_def.Find(map_defun_node.attr().at("f").func().name()); + EXPECT_EQ(map_defun_fn->signature().output_arg_size(), 1); } // Before: @@ -257,14 +304,19 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) { inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& cast_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); EXPECT_EQ(cast_node.input(0), "x"); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(cast_node.name(), ":y:0")); - EXPECT_EQ(outer.node_def_size(), 1); + EXPECT_EQ(vectorized->node_def_size(), 1); } // Before: @@ -330,16 +382,21 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) { {{}, {}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& cast_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); EXPECT_EQ(cast_node.input(0), "x"); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(cast_node.name(), ":y:0")); - EXPECT_EQ(outer.ret().at("mapdefun_0"), + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(cast_node.name(), ":y:0")); - EXPECT_EQ(outer.node_def_size(), 1); + EXPECT_EQ(vectorized->node_def_size(), 1); } // Before: @@ -411,21 +468,26 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) { {{1}, {1}, {1}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& unpack_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& unpack_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Unpack", *vectorized)); EXPECT_EQ(unpack_node.input(0), "x"); EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); EXPECT_EQ(unpack_node.attr().at("num").i(), 3); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(unpack_node.name(), ":output:0")); - EXPECT_EQ(outer.ret().at("mapdefun_0"), + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(unpack_node.name(), ":output:1")); - EXPECT_EQ(outer.ret().at("mapdefun_1"), + EXPECT_EQ(GetRetval(*vectorized, 2), strings::StrCat(unpack_node.name(), ":output:2")); - EXPECT_EQ(outer.node_def_size(), 1); + EXPECT_EQ(vectorized->node_def_size(), 1); } // Before: @@ -486,7 +548,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) { {"ret1", "MyUnstack:output:1"}, {"ret2", "MyUnstack:output:2"}}); NodeDef* cast_op = - AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner); + AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner); CHECK_NOTNULL(cast_op); NodeDef* unstack_op = AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner); @@ -505,25 +567,30 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) { {{1}, {1}, {1}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& cast_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); EXPECT_EQ(cast_node.input(0), "x"); - const NodeDef& unpack_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer)); + const NodeDef& unpack_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Unpack", *vectorized)); EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0")); EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); EXPECT_EQ(unpack_node.attr().at("num").i(), 3); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(unpack_node.name(), ":output:0")); - EXPECT_EQ(outer.ret().at("mapdefun_0"), + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(unpack_node.name(), ":output:1")); - EXPECT_EQ(outer.ret().at("mapdefun_1"), + EXPECT_EQ(GetRetval(*vectorized, 2), strings::StrCat(unpack_node.name(), ":output:2")); - EXPECT_EQ(outer.node_def_size(), 2); + EXPECT_EQ(vectorized->node_def_size(), 2); } // Before: @@ -561,9 +628,11 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { FunctionDef inner = CreateFunction("inner_function", {{"arg0", DT_INT32}}, {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}}); - // The attrs aren't relevant - NodeDef* print_op = - function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner); + NodeDef* print_op = function_utils::AddNode( + "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner); + graph_transforms::SetNodeAttr("T", DT_INT32, print_op); + graph_transforms::SetNodeAttr("U", gtl::ArraySlice<DataType>({DT_INT32}), + print_op); CHECK_NOTNULL(print_op); NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64, false, &inner); @@ -578,11 +647,27 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - FunctionDef outer_copy(outer); - FunctionDef inner_copy(inner); - VectorizeMapDefun(&outer, &inner, map_defun); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); // They should be unchanged - EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer)); + // We check this somewhat manually as the names of nodes may have changed + EXPECT_EQ(vectorized->node_def_size(), 1); + const NodeDef& map_defun_node = vectorized->node_def(0); + EXPECT_EQ(map_defun_node.op(), "MapDefun"); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib); + const FunctionDef* map_defun_fn = + lib_def.Find(map_defun_node.attr().at("f").func().name()); + + const NodeDef& print_node = map_defun_fn->node_def( + function_utils::FindFunctionNodeWithOp("Print", *map_defun_fn)); + const NodeDef& cast_node = map_defun_fn->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *map_defun_fn)); + string control_input = strings::StrCat("^", print_node.name()); + EXPECT_TRUE(cast_node.input(0) == control_input || + cast_node.input(1) == control_input); } // TODO(rachelim): More test cases when we get around to implementing them: |