diff options
Diffstat (limited to 'tensorflow/core/grappler')
63 files changed, 2911 insertions, 883 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index f716cd72c9..28fd7565cc 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -74,6 +74,10 @@ class GraphProperties { // shape information. void ClearInputProperties(const string& node_name); void ClearOutputProperties(const string& node_name); + // Returns true if we have *any* properties. + bool has_properties() const { + return input_properties_.size() > 0 || output_properties_.size() > 0; + } private: // Relaxes shapes <shapes_and_types>, determined from an EnqueueV2 node, into diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 362092a6cf..db10f586bc 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -1340,6 +1340,8 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) { Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {}); Output g = ops::Shape(s.WithOpName("g"), c); Output h = ops::Fill(s.WithOpName("h"), g, zero); + Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1}); + Output j = ops::Sum(s.WithOpName("j"), a, zero_idx); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -1382,6 +1384,10 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) { ASSERT_EQ(2, shape_f.dim_size()); EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size()); EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size()); + + const auto shape_j = properties.GetOutputProperties("j").at(0).shape(); + ASSERT_EQ(1, shape_j.dim_size()); + EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size()); } TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) { diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt index c94ee2f227..0ec95dd684 100644 --- a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt +++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt @@ -88,6 +88,13 @@ library { } } } + attr { + key: "output_shapes" + value { + list { + } + } + } } ret { key: "while" diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc index 2619a9a8f3..de0a63fc4e 100644 --- a/tensorflow/core/grappler/graph_view.cc +++ b/tensorflow/core/grappler/graph_view.cc @@ -20,23 +20,25 @@ limitations under the License. namespace tensorflow { namespace grappler { -int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { - for (int output_arg_id = 0; output_arg_id < op.output_arg_size(); - ++output_arg_id) { +namespace { +int OpPortIdToArgId(const NodeDef& node, + const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, + int port_id) { + for (int arg_id = 0; arg_id < args.size(); ++arg_id) { if (port_id < 0) { return -1; } else if (port_id == 0) { - return output_arg_id; + return arg_id; } - // Default is 1 port per output arg. + // Default is 1 port per arg. int n = 1; - const auto& output_arg = op.output_arg(output_arg_id); - if (!output_arg.number_attr().empty()) { - n = node.attr().at(output_arg.number_attr()).i(); - } else if (!output_arg.type_list_attr().empty()) { - n = node.attr().at(output_arg.type_list_attr()).list().type_size(); + const auto& arg = args.Get(arg_id); + if (!arg.number_attr().empty()) { + n = node.attr().at(arg.number_attr()).i(); + } else if (!arg.type_list_attr().empty()) { + n = node.attr().at(arg.type_list_attr()).list().type_size(); } if (n < 0) { @@ -44,13 +46,22 @@ int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { DCHECK_GE(n, 0); return -1; } else if (port_id < n) { - return output_arg_id; + return arg_id; } port_id -= n; } return -1; } +} // end namespace + +int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { + return OpPortIdToArgId(node, op.output_arg(), port_id); +} + +int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { + return OpPortIdToArgId(node, op.input_arg(), port_id); +} GraphView::GraphView(GraphDef* graph) : graph_(graph) { for (int i = 0; i < graph_->node_size(); i++) { @@ -72,7 +83,7 @@ void GraphView::AddUniqueNodeOrDie(NodeDef* node) { void GraphView::AddFanouts(NodeDef* node) { for (int i = 0; i < node->input_size(); ++i) { OutputPort fanin; - string fanin_name = ParseNodeName(node->input(i), &fanin.port_id); + const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id); fanin.node = nodes_[fanin_name]; InputPort input; diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index ec946ca3b5..09c36a1368 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -26,7 +26,7 @@ limitations under the License. namespace tensorflow { namespace grappler { -// Map a node/op's output port_id to arg_id. +// Map a node/op's input/output port_id to arg_id. // // The port_id refers to the n-th tensor of the node, while the arg_id refers to // the n-th arg of the op. These two can be different if an op's arg is a list @@ -34,6 +34,7 @@ namespace grappler { // // We return -1 for any invalid port_id (i.e., no corresponding arg_id). int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); +int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); // A utility class to simplify the traversal of a GraphDef. class GraphView { diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index 3d7d2faf7c..f90e2c8cfc 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -26,7 +26,7 @@ namespace { class GraphViewTest : public ::testing::Test {}; -TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) { +TEST_F(GraphViewTest, OpPortIdToArgIdShapeN) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); ops::ShapeN b(s.WithOpName("b"), {a, a, a}); @@ -45,9 +45,16 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) { EXPECT_TRUE( OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok()); - EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0)); - EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1)); + // Const has 0 inputs, 1 output. + EXPECT_EQ(-1, OpInputPortIdToArgId(a_node_def, *a_op_def, 0)); + EXPECT_EQ(0, OpOutputPortIdToArgId(a_node_def, *a_op_def, 0)); + EXPECT_EQ(-1, OpOutputPortIdToArgId(a_node_def, *a_op_def, 1)); + // ShapeN has N=3 inputs and outputs. + EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0)); + EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 1)); + EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 2)); + EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 3)); EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0)); EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1)); EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2)); @@ -55,7 +62,7 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) { EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4)); } -TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) { +TEST_F(GraphViewTest, OpPortIdToArgIdSparseSplit) { for (int num_splits : {1, 2}) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10}); @@ -70,6 +77,13 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) { EXPECT_TRUE( OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok()); + // We have 4 inputs. + EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0)); + EXPECT_EQ(1, OpInputPortIdToArgId(b_node_def, *b_op_def, 1)); + EXPECT_EQ(2, OpInputPortIdToArgId(b_node_def, *b_op_def, 2)); + EXPECT_EQ(3, OpInputPortIdToArgId(b_node_def, *b_op_def, 3)); + EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 4)); + for (int port_id = 0; port_id <= num_splits * 3; ++port_id) { int arg_id = -1; if (port_id < num_splits * 3) { diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index bbc0fedd22..2c490f3966 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -38,6 +38,7 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) { restore_op = other.restore_op; save_restore_loc_tensor = other.save_restore_loc_tensor; queue_runners = other.queue_runners; + allowed_optimizations = other.allowed_optimizations; graph.Swap(graph_def); } diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h index 939e5fa046..a0748abfe6 100644 --- a/tensorflow/core/grappler/grappler_item.h +++ b/tensorflow/core/grappler/grappler_item.h @@ -77,6 +77,15 @@ struct GrapplerItem { // Return a set of node names that must be preserved. This includes feed and // fetch nodes, keep_ops, init_ops. std::unordered_set<string> NodesToPreserve() const; + + // Restrict types of optimizations that are allowed for this GrapplerItem. + struct AllowedOptimizations { + // Is it allowed to add nodes to the graph that do not have registered + // gradient function. + bool non_differentiable_rewrites = true; + }; + + AllowedOptimizations allowed_optimizations; }; // Return the transitive fanin of a set of terminal nodes. diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 029515ad3c..369046666d 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -192,9 +192,13 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( const string feed_name = NodeName(feed_node); new_item->feed.emplace_back(feed_name, Tensor()); } + for (const auto& fetch_node : cfg.fetch_nodes) { + new_item->fetch.emplace_back(NodeName(fetch_node)); + } - // Attempt to detect the fetch node(s). - if (meta_graph.collection_def().count("train_op") > 0) { + // Attempt to detect the fetch node(s) if they were not set explicitly. + if (new_item->fetch.empty() && + meta_graph.collection_def().count("train_op") > 0) { const CollectionDef& nodes = meta_graph.collection_def().at("train_op"); if (nodes.has_node_list()) { for (const auto& node : nodes.node_list().value()) { diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h index aafd2fdcda..1698587f8c 100644 --- a/tensorflow/core/grappler/grappler_item_builder.h +++ b/tensorflow/core/grappler/grappler_item_builder.h @@ -49,6 +49,8 @@ struct ItemConfig { bool prune_graph = false; // Override feed nodes list. std::set<string> feed_nodes; + // Override fetch nodes list. + std::set<string> fetch_nodes; }; // Factory method for creating a GrapplerItem from a MetaGraphDef. diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc index 4b90bf3038..d00981f174 100644 --- a/tensorflow/core/grappler/grappler_item_builder_test.cc +++ b/tensorflow/core/grappler/grappler_item_builder_test.cc @@ -313,6 +313,29 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) { EXPECT_EQ(item2->feed[0].second.NumElements(), 1); } +TEST_F(GrapplerItemBuilderTest, ExplicitFeedAndFetch) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), 0); + auto y = ops::Const(s.WithOpName("y"), 1); + auto z = ops::Add(s.WithOpName("z"), x, y); + + MetaGraphDef meta_graph; + TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def())); + + ItemConfig config; + config.feed_nodes.insert("x"); + config.fetch_nodes.insert("z"); + + std::unique_ptr<GrapplerItem> item = + GrapplerItemFromMetaGraphDef("0", meta_graph, config); + ASSERT_TRUE(item != nullptr); + + EXPECT_EQ(item->feed.size(), 1); + EXPECT_EQ(item->fetch.size(), 1); + EXPECT_EQ(item->feed[0].first, "x"); + EXPECT_EQ(item->fetch[0], "z"); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 3521669b63..cbf5c8e038 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include <unordered_set> - +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" @@ -102,6 +101,22 @@ bool IsConjugateTranspose(const NodeDef& node) { return node.op() == "ConjugateTranspose"; } +bool IsControlFlow(const NodeDef& node) { + // TODO(williamchan): Add a microbenchmark to compare FlatSet vs. iterative + // string comparison. + static const gtl::FlatSet<string>* const kControFlowOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ + "ControlTrigger", + "Enter", + "Exit", + "LoopCond", + "Merge", + "NextIteration", + "Switch", + })); + return kControFlowOps->count(node.op()) > 0; +} + bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; } bool IsConv2DBackpropFilter(const NodeDef& node) { @@ -140,26 +155,26 @@ bool IsDiv(const NodeDef& node) { return node.op() == "Div"; } // e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing, // e.g. inv. bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) { - static const std::unordered_set<string>* monotonic_non_decreasing_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1", "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint", "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh", })); - static const std::unordered_set<string>* monotonic_non_increasing_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ "Inv", "Reciprocal", "Erfc", "Rsqrt", "Neg", })); - if (monotonic_non_decreasing_ops->count(node.op()) > 0) { + if (kMonotonicNonDecreasingOps->count(node.op()) > 0) { if (is_non_decreasing) { *is_non_decreasing = true; } return true; - } else if (monotonic_non_increasing_ops->count(node.op()) > 0) { + } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) { if (is_non_decreasing) { *is_non_decreasing = false; } @@ -425,8 +440,44 @@ bool IsSwitch(const NodeDef& node) { return op == "Switch" || op == "RefSwitch"; } +bool IsSymbolicGradient(const NodeDef& node) { + return node.op() == "SymbolicGradient"; +} + bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; } +bool IsTensorArray(const NodeDef& node) { + static const gtl::FlatSet<string>* const kTensorArrayOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ + "TensorArray", + "TensorArrayV2", + "TensorArrayV3", + "TensorArrayGrad", + "TensorArrayGradV2", + "TensorArrayGradV3", + "TensorArrayGradWithShape", + "TensorArrayWrite", + "TensorArrayWriteV2", + "TensorArrayWriteV3", + "TensorArrayRead", + "TensorArrayReadV2", + "TensorArrayReadV3", + "TensorArrayConcat", + "TensorArrayConcatV2", + "TensorArrayConcatV3", + "TensorArraySplit", + "TensorArraySplitV2", + "TensorArraySplitV3", + "TensorArraySize", + "TensorArraySizeV2", + "TensorArraySizeV3", + "TensorArrayClose", + "TensorArrayCloseV2", + "TensorArrayCloseV3", + })); + return kTensorArrayOps->count(node.op()) > 0; +} + bool IsTile(const NodeDef& node) { return node.op() == "Tile"; } bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; } @@ -538,30 +589,29 @@ OPDEF_PROPERTY_HELPER(Aggregate, aggregate) OPDEF_PROPERTY_HELPER(Commutative, commutative) bool IsInvolution(const NodeDef& node) { - static const std::unordered_set<string>* involution_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ - "Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"})); - return involution_ops->count(node.op()) > 0; + static const gtl::FlatSet<string>* const kInvolutionOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert", + "Neg", "LogicalNot"})); + return kInvolutionOps->count(node.op()) > 0; } bool IsValueAndOrderAndShapePreserving(const NodeDef& node) { if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { return true; } - static const std::unordered_set<string>* - value_and_order_and_shape_preserving_ops = - CHECK_NOTNULL((new const std::unordered_set<string>{ - "CheckNumerics", - "DebugGradientIdentity", - "DeepCopy" - "Enter", - "Exit", - "PreventGradient", - "Print", - "Snapshot", - "StopGradient", - })); - return value_and_order_and_shape_preserving_ops->count(node.op()) > 0 || + static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps = + CHECK_NOTNULL((new const gtl::FlatSet<string>{ + "CheckNumerics", + "DebugGradientIdentity", + "DeepCopy" + "Enter", + "Exit", + "PreventGradient", + "Print", + "Snapshot", + "StopGradient", + })); + return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 || IsIdentity(node); } @@ -569,31 +619,31 @@ bool IsValueAndOrderPreserving(const NodeDef& node) { if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { return true; } - static const std::unordered_set<string>* value_and_order_preserving_ops = - CHECK_NOTNULL((new const std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps = + CHECK_NOTNULL((new const gtl::FlatSet<string>{ "ExpandDims", "Reshape", "Squeeze", })); - return value_and_order_preserving_ops->count(node.op()) > 0 || + return kValueAndOrderPreservingOps->count(node.op()) > 0 || IsValueAndOrderAndShapePreserving(node); } bool IsValuePreserving(const NodeDef& node) { - static const std::unordered_set<string>* value_preserving_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kValuePreservingOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ "InvertPermutation", "Reverse", "Roll", "Transpose", })); return IsValueAndOrderPreserving(node) || - value_preserving_ops->count(node.op()) > 0; + kValuePreservingOps->count(node.op()) > 0; } bool IsUnaryElementWise(const NodeDef& node) { - static const std::unordered_set<string>* element_wise_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kElementWiseOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ "Abs", "Acos", "Acosh", @@ -642,7 +692,7 @@ bool IsUnaryElementWise(const NodeDef& node) { "Tan" "Tanh", })); - return element_wise_ops->count(node.op()) > 0 || + return kElementWiseOps->count(node.op()) > 0 || IsValueAndOrderAndShapePreserving(node); } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 25ab6b65ac..d4e0159e81 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -46,6 +46,7 @@ bool IsConjugateTranspose(const NodeDef& node); bool IsConcat(const NodeDef& node); bool IsConcatOffset(const NodeDef& node); bool IsConstant(const NodeDef& node); +bool IsControlFlow(const NodeDef& node); bool IsConv2D(const NodeDef& node); bool IsConv2DBackpropFilter(const NodeDef& node); bool IsConv2DBackpropInput(const NodeDef& node); @@ -149,7 +150,9 @@ bool IsStridedSliceGrad(const NodeDef& node); bool IsSub(const NodeDef& node); bool IsSum(const NodeDef& node); bool IsSwitch(const NodeDef& node); +bool IsSymbolicGradient(const NodeDef& node); bool IsTanhGrad(const NodeDef& node); +bool IsTensorArray(const NodeDef& node); bool IsTile(const NodeDef& node); bool IsTranspose(const NodeDef& node); bool IsTruncateDiv(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 960d1addb3..c708f84948 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -525,6 +525,7 @@ cc_library( "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/utils:colocation", @@ -541,6 +542,7 @@ tf_cuda_cc_test( ":custom_graph_optimizer_registry", ":meta_optimizer", "//tensorflow/cc:cc_ops", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", "//tensorflow/core:test", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 75ed12635e..7d5014ee0a 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -276,7 +276,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) { for (int i = 0; i < output->input_size(); ++i) { auto input = output->input(i); - string name = ParseNodeName(input, &position); + StringPiece name = ParseNodeNameAsStringPiece(input, &position); if (name == node.name() && /*control input*/ position < 0) { return true; } @@ -1568,7 +1568,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { for (NodeDef* output : outputs) { if (IsControlInput(output->input(0))) continue; int port; - const string node_name = ParseNodeName(output->input(0), &port); + const StringPiece node_name = + ParseNodeNameAsStringPiece(output->input(0), &port); if (node_name == node.name()) { tails->insert(ChainLink(output, port)); } else { @@ -1618,7 +1619,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { } else { for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) { int port; - const string node_name = ParseNodeName(new_tail->input(0), &port); + const StringPiece node_name = + ParseNodeNameAsStringPiece(new_tail->input(0), &port); if (node_name != tail->name()) { return Status::OK(); } @@ -2929,8 +2931,8 @@ uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const { for (const auto& input : node.input()) { int pos; - string node_name = ParseNodeName(input, &pos); - h = Hash64CombineUnordered(Hash64(node_name), h); + const StringPiece node_name = ParseNodeNameAsStringPiece(input, &pos); + h = Hash64CombineUnordered(Hash64(node_name.data(), node_name.size()), h); h = Hash64CombineUnordered(std::hash<int>()(pos), h); } for (const auto& attr : node.attr()) { @@ -3247,6 +3249,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, optimized_graph_ = &optimized_item.graph; node_map_.reset(new NodeMap(optimized_graph_)); + // Disable restricted graph rewrites. + options_.unary_ops_composition &= + item.allowed_optimizations.non_differentiable_rewrites; + if (options_.dedup_computations) { DedupComputations(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index ca5d3a6dfd..3d0d95bba7 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -616,28 +616,37 @@ Status ConstantFolding::MaterializeReductionIndices( // We can't do anything if we don't know the rank of the input. return Status::OK(); } - const int rank = input_prop.shape().dim_size(); - if (rank == 0) { + const int input_rank = input_prop.shape().dim_size(); + if (input_rank < 1) { // Unexpected graph, don't try to change it. return Status::OK(); } + const OpInfo::TensorProperties& reduction_indices_prop = input_props[1]; + DataType dtype = reduction_indices_prop.dtype(); + if (dtype != DT_INT32 && dtype != DT_INT64) { + return Status::OK(); + } + PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape()); + const int num_reduction_indices = reduction_indices_shape.num_elements(); + const std::vector<OpInfo::TensorProperties>& output_props = properties.GetOutputProperties(node->name()); if (output_props.size() != 1) { return Status::OK(); } - const bool keep_dims = - node->attr().count("keep_dims") && node->attr().at("keep_dims").b(); const OpInfo::TensorProperties& output_prop = output_props[0]; - PartialTensorShape output_shape(output_prop.shape()); - if (output_shape.num_elements() != 1) { - bool full_reduction = false; + const int output_rank = + output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size(); + + bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank; + if (!full_reduction) { + // A full reduction will generate a tensor of one of the shapes + // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of + // elements in the output of the reduction, we may deduce it from reshape + // nodes following it. for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) { - if (!IsReshape(*fanout) && !keep_dims) { - // Depending on how it's setup, a full reduction will generate a tensor - // of shape [], [1], [1, 1], [1, 1, ...]. If keep_dims isn't true, we - // rely on the existence of a reshape node following the reduction to - // ensure that the fanout is fed a scalar of the right shape. + full_reduction = false; + if (!IsReshape(*fanout)) { return Status::OK(); } const std::vector<OpInfo::TensorProperties>& reshape_props = @@ -658,20 +667,15 @@ Status ConstantFolding::MaterializeReductionIndices( } } - const OpInfo::TensorProperties& reduction_prop = input_props[1]; - DataType dtype = reduction_prop.dtype(); - if (dtype != DT_INT32 && dtype != DT_INT64) { - return Status::OK(); - } - // We know it's a full reduction. We can generate the set of indices to - // reduce. + // We know it's a full reduction. We can generate the full set of indices to + // reduce as a constant node. string const_name = OptimizedNodeName(*node, "-reduction_indices"); if (node_map_->GetNode(const_name)) { return Status::OK(); } NodeDef* reduction_indices = graph_->add_node(); - Tensor value(dtype, TensorShape({rank})); - for (int i = 0; i < rank; ++i) { + Tensor value(dtype, TensorShape({input_rank})); + for (int i = 0; i < input_rank; ++i) { if (dtype == DT_INT32) { value.vec<int32>()(i) = i; } else { @@ -680,6 +684,7 @@ Status ConstantFolding::MaterializeReductionIndices( } TF_RETURN_IF_ERROR( CreateNodeDef(const_name, TensorValue(&value), reduction_indices)); + reduction_indices->set_device(node->device()); string ctrl_dep = AddControlDependency(node->input(1), graph_, node_map_.get()); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index b09360a2c2..fab01edfed 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -2591,58 +2591,100 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) { } TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output input = - ops::Placeholder(s.WithOpName("input"), DT_FLOAT, - ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); - Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); - Output sum = ops::Sum(s.WithOpName("sum"), input, indices); - Output size = ops::Const(s.WithOpName("size"), 1, {1}); - Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size); + for (bool use_reshape : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); + // If use_reshape is false, we need to now the number of indices to apply + // the rewrite. + Output indices = ops::Placeholder( + s.WithOpName("indices"), DT_INT32, + ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2}))); + Output sum = ops::Sum(s.WithOpName("sum"), input, indices); + if (use_reshape) { + Output size = ops::Const(s.WithOpName("size"), 1, {1}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size); + } - GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch.push_back("reshape"); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back(use_reshape ? "reshape" : "sum"); - auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4})); - Tensor indices_t(DT_INT32, TensorShape({2})); - indices_t.flat<int>()(0) = 0; - indices_t.flat<int>()(1) = 1; - auto tensors_expected = EvaluateNodes( - item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}}); - EXPECT_EQ(1, tensors_expected.size()); + auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4})); + Tensor indices_t(DT_INT32, TensorShape({2})); + indices_t.flat<int>()(0) = 0; + indices_t.flat<int>()(1) = 1; + auto tensors_expected = EvaluateNodes( + item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}}); + EXPECT_EQ(1, tensors_expected.size()); - ConstantFolding optimizer(nullptr /* cpu_device */); - GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + // Use aggressive mode to force the shape inference to propagate placeholder + // shapes. + ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, + nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); - // Run a second time to make sure the optimization is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + // Run a second time to make sure the optimization is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); - int found = 0; - for (const auto& node : output.node()) { - if (node.name() == "ConstantFolding/sum-reduction_indices") { - ++found; - EXPECT_EQ("Const", node.op()); - EXPECT_EQ("^indices", node.input(0)); - EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape()) - .num_elements()); - } else if (node.name() == "sum") { - ++found; - EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1)); - } else if (node.name() == "indices") { - ++found; + int found = 0; + for (const auto& node : output.node()) { + if (node.name() == "ConstantFolding/sum-reduction_indices") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^indices", node.input(0)); + EXPECT_EQ(2, + TensorShape(node.attr().at("value").tensor().tensor_shape()) + .num_elements()); + } else if (node.name() == "sum") { + ++found; + EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1)); + } else if (node.name() == "indices") { + ++found; + } } + EXPECT_EQ(3, found); + + auto tensors = EvaluateNodes(output, item.fetch, + {{"input", input_t}, {"indices", indices_t}}); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5); } - EXPECT_EQ(3, found); +} - auto tensors = EvaluateNodes(output, item.fetch, - {{"input", input_t}, {"indices", indices_t}}); - EXPECT_EQ(1, tensors.size()); - test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5); +TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) { + for (bool input_rank_known : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape( + PartialTensorShape({-1, -1}))) + : ops::Placeholder(s.WithOpName("input"), DT_FLOAT)); + Output indices = + ops::Placeholder(s.WithOpName("indices"), DT_INT32, + ops::Placeholder::Shape( + PartialTensorShape({input_rank_known ? 1 : 2}))); + Output sum = ops::Sum(s.WithOpName("sum"), input, indices); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("sum"); + + // Use aggressive mode to force the shape inference to propagate placeholder + // shapes. + ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, + nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + CompareGraphs(item.graph, output); + } } TEST_F(ConstantFoldingTest, LargeConstant) { diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index cf305cebe1..ee7c14e3ab 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -22,6 +22,7 @@ cc_library( "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -31,6 +32,7 @@ tf_cc_test( visibility = ["//visibility:public"], deps = [ ":filter_fusion", + ":graph_test_utils", ":graph_utils", "//tensorflow/core:framework", "//tensorflow/core:test", @@ -87,11 +89,12 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":graph_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/grappler:mutable_graph_view", - "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -121,11 +124,12 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/grappler:mutable_graph_view", - "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -135,6 +139,7 @@ tf_cc_test( visibility = ["//visibility:public"], deps = [ ":graph_utils", + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -146,6 +151,62 @@ tf_cc_test( ) cc_library( + name = "graph_test_utils", + testonly = 1, + srcs = ["graph_test_utils.cc"], + hdrs = [ + "graph_test_utils.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core:testlib", + ] + tf_protos_all(), +) + +cc_library( + name = "hoist_random_uniform", + srcs = ["hoist_random_uniform.cc"], + hdrs = [ + "hoist_random_uniform.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":function_utils", + ":graph_utils", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core:lib_internal", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "hoist_random_uniform_test", + srcs = ["hoist_random_uniform_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":graph_test_utils", + ":graph_utils", + ":hoist_random_uniform", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + ] + tf_protos_all(), +) + +cc_library( name = "latency_all_edges", srcs = ["latency_all_edges.cc"], hdrs = [ @@ -256,7 +317,7 @@ cc_library( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core:ptr_util", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -265,6 +326,7 @@ tf_cc_test( srcs = ["map_and_filter_fusion_test.cc"], visibility = ["//visibility:public"], deps = [ + ":graph_test_utils", ":graph_utils", ":map_and_filter_fusion", "//tensorflow/core:framework", @@ -294,6 +356,7 @@ cc_library( "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -302,6 +365,7 @@ tf_cc_test( srcs = ["map_fusion_test.cc"], visibility = ["//visibility:public"], deps = [ + ":graph_test_utils", ":graph_utils", ":map_fusion", "//tensorflow/core:framework", @@ -339,6 +403,7 @@ tf_cc_test( srcs = ["map_parallelization_test.cc"], visibility = ["//visibility:public"], deps = [ + ":graph_test_utils", ":graph_utils", ":map_parallelization", "//tensorflow/core:framework", @@ -422,6 +487,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":filter_fusion", + ":hoist_random_uniform", ":latency_all_edges", ":map_and_batch_fusion", ":map_and_filter_fusion", @@ -458,7 +524,9 @@ cc_library( deps = [ ":function_utils", ":graph_utils", + "//tensorflow/cc:ops", "@com_google_absl//absl/strings", + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -474,6 +542,7 @@ tf_cc_test( srcs = ["vectorization_utils_test.cc"], visibility = ["//visibility:public"], deps = [ + ":graph_utils", ":function_utils", ":vectorization_utils", "//tensorflow/core:framework", @@ -483,7 +552,10 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + # For ops we need registered + "//tensorflow/core/kernels/data:dataset_ops", "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/kernels:logging_ops", "//tensorflow/tools/graph_transforms:transform_utils", ] + tf_protos_all(), ) diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc index c71aa6e804..1ad495bbad 100644 --- a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc @@ -43,19 +43,14 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node, fused_node.set_op("FilterDataset"); fused_node.add_input(first_filter_node.input(0)); - auto copy_attribute = [](const string& attribute_name, const NodeDef& from, - NodeDef* to) { - (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name); - }; - auto attr = first_filter_node.attr().at("predicate"); *attr.mutable_func()->mutable_name() = fused_function.signature().name(); (*fused_node.mutable_attr())["predicate"] = std::move(attr); - copy_attribute("Targuments", first_filter_node, &fused_node); + graph_utils::CopyAttribute("Targuments", first_filter_node, &fused_node); for (auto key : {"output_shapes", "output_types"}) - copy_attribute(key, second_filter_node, &fused_node); + graph_utils::CopyAttribute(key, second_filter_node, &fused_node); return fused_node; } @@ -120,8 +115,8 @@ Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item, // functions, or make sure that optimization passes run after filter // fusion. TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_predicate)); - // TODO(prazek): we could also remove map functions from library if they - // are not used anymore. + // TODO(b/116285210): we could also remove map functions from library if + // they are not used anymore. nodes_to_delete.insert(first_filter_node->name()); nodes_to_delete.insert(second_filter_node->name()); } diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc index 12b1924efd..c8becc5cc0 100644 --- a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" - #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -28,14 +28,7 @@ namespace tensorflow { namespace grappler { namespace { -NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) { - return test::function::NDef( - name, "FilterDataset", {string(input_node_name)}, - {{"predicate", FunctionDefHelper::FunctionRef("IsZero")}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} +using graph_tests_utils::MakeFilterNode; TEST(FilterFusionTest, FuseTwoFilterIntoOne) { using test::function::NDef; diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc index e95ea1a4c1..311df15bc2 100644 --- a/tensorflow/core/grappler/optimizers/data/function_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc @@ -14,31 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace grappler { namespace function_utils { -namespace { - -template <typename Predicate, typename Collection> -std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate, - const Collection& collection) { - std::vector<int> indices = {}; - unsigned idx = 0; - for (auto&& element : collection) { - if (predicate(element)) { - indices.push_back(idx); - } - idx++; - } - return indices; -} - -} // namespace FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name, const string& output, int position) @@ -152,32 +137,27 @@ bool ContainsFunctionOutputWithName(StringPiece name, } int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return graph_utils::GetFirstElementIndexWithPredicate( [&name](const OpDef_ArgDef& arg) { return arg.name() == name; }, function.signature().input_arg()); - return indices.empty() ? -1 : indices.front(); } int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return graph_utils::GetFirstElementIndexWithPredicate( [&name](const OpDef_ArgDef& arg) { return arg.name() == name; }, function.signature().output_arg()); - return indices.empty() ? -1 : indices.front(); } int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return graph_utils::GetFirstElementIndexWithPredicate( [&name](const NodeDef& node) { return node.name() == name; }, function.node_def()); - return indices.empty() ? -1 : indices.front(); } int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return graph_utils::GetFirstElementIndexWithPredicate( [&op](const NodeDef& node) { return node.op() == op; }, function.node_def()); - - return indices.empty() ? -1 : indices.front(); } void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc new file mode 100644 index 0000000000..b2eec7220e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" + +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { +namespace grappler { +namespace graph_tests_utils { + +NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name) { + return test::function::NDef( + name, "MapDataset", {string(input_node_name)}, + {{"f", FunctionDefHelper::FunctionRef(string(function_name))}, + {"Targuments", {}}, + {"output_shapes", gtl::ArraySlice<TensorShape>{}}, + {"output_types", gtl::ArraySlice<DataType>{}}}); +} + +NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name) { + return test::function::NDef( + name, "FilterDataset", {string(input_node_name)}, + {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))}, + {"Targuments", {}}, + {"output_shapes", gtl::ArraySlice<TensorShape>{}}, + {"output_types", gtl::ArraySlice<TensorShape>{}}}); +} + +} // end namespace graph_tests_utils +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h new file mode 100644 index 0000000000..ca0fde997d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace grappler { +namespace graph_tests_utils { + +NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name = "XTimesTwo"); + +NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name = "IsZero"); + +} // end namespace graph_tests_utils +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_ diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index 2dd9ee822e..b863a25dc5 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -201,25 +202,22 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) { int FindGraphFunctionWithName(StringPiece name, const FunctionDefLibrary& library) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return GetFirstElementIndexWithPredicate( [&name](const FunctionDef& function) { return function.signature().name() == name; }, library.function()); - return indices.empty() ? -1 : indices.front(); } int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return GetFirstElementIndexWithPredicate( [&name](const NodeDef& node) { return node.name() == name; }, graph.node()); - return indices.empty() ? -1 : indices.front(); } int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return GetFirstElementIndexWithPredicate( [&op](const NodeDef& node) { return node.op() == op; }, graph.node()); - return indices.empty() ? -1 : indices.front(); } std::vector<int> FindAllGraphNodesWithOp(const string& op, @@ -260,6 +258,41 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library, } function->mutable_signature()->set_name(std::move(name)); } + +void CopyAttribute(const string& attribute_name, const NodeDef& from, + NodeDef* to_node) { + (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name); +} + +void ConcatAttributeList(const string& attribute_name, const NodeDef& first, + const NodeDef& second, NodeDef* to_node) { + CopyAttribute(attribute_name, first, to_node); + (*to_node->mutable_attr()) + .at(attribute_name) + .mutable_list() + ->MergeFrom(second.attr().at(attribute_name).list()); +} + +Status EnsureNodeNamesUnique(Graph* g) { + // Modeled after Scope::Impl::GetUniqueName + std::unordered_map<string, int> name_map; + + for (auto node : g->op_nodes()) { + const string& prefix = node->name(); + if (auto entry = gtl::FindOrNull(name_map, prefix)) { + string unique_name; + do { + unique_name = strings::StrCat(prefix, "_", ++(*entry)); + } while (name_map.find(unique_name) != name_map.end()); + name_map.insert({unique_name, 0}); + node->set_name(std::move(unique_name)); + } else { + name_map.insert({node->name(), 0}); + } + } + + return Status::OK(); +} } // end namespace graph_utils } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index b117482db2..d130fee204 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/errors.h" @@ -31,6 +32,21 @@ namespace tensorflow { namespace grappler { namespace graph_utils { +// Returns the index of the first element in collection that fulfills predicate. +// If no such element exists, returns -1. +template <typename Predicate, typename Collection> +int GetFirstElementIndexWithPredicate(const Predicate& predicate, + const Collection& collection) { + unsigned idx = 0; + for (auto&& element : collection) { + if (predicate(element)) { + return idx; + } + idx++; + } + return -1; +} + // Adds a node to the graph. NodeDef* AddNode(StringPiece name, StringPiece op, const std::vector<string>& inputs, @@ -101,11 +117,29 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op, // is unique across the graph. void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node); -// Sets the node name using the `prefix` name as a prefix while guaranteeing the -// name is unique across the graph. +// Sets the function name using the `prefix` name as a prefix while guaranteeing +// the name is unique across the function library. void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library, FunctionDef* function); +// Copies attribute having name `attribute_name` from node `from` to node +// `to_node`. +void CopyAttribute(const string& attribute_name, const NodeDef& from, + NodeDef* to_node); + +// Concatenates list attribute having name `attribute_name` from `first` and +// `second` node, setting it to `to_node`. +void ConcatAttributeList(const string& attribute_name, const NodeDef& first, + const NodeDef& second, NodeDef* to_node); + +// Checks that all nodes in the graphs have unique names, and sets their names +// to be unique if they are not already. This is necessary as Graph does not +// have the provisions to deduplicate names, and name deduplication elsewhere +// in tensorflow happens in other layers (for example, in the Scope class of the +// C++ API). Note that the nodes in the graph are identified by their id, +// and renaming nodes does not mutate any edges. +Status EnsureNodeNamesUnique(Graph* g); + } // end namespace graph_utils } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index 6877c207c4..4ab6d71532 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -24,6 +25,18 @@ namespace grappler { namespace graph_utils { namespace { +TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) { + std::vector<int> vec({1, 2, 3, 4, 5, 6}); + auto result = GetFirstElementIndexWithPredicate( + [](int elem) { return elem % 3 == 0; }, vec); + + EXPECT_EQ(result, 2); + + result = GetFirstElementIndexWithPredicate( + [](int elem) { return elem % 7 == 0; }, vec); + EXPECT_EQ(result, -1); +} + TEST(GraphUtilsTest, AddScalarConstNodeBool) { GraphDef graph_def; MutableGraphView graph(&graph_def); @@ -217,6 +230,33 @@ TEST(GraphUtilsTest, GetInputNode) { EXPECT_EQ(GetInputNode(*node1, graph), nullptr); } +TEST(GraphUtilsTest, EnsureNodeNamesUnique) { + Graph g(OpRegistry::Global()); + + Node *const_0, *const_1, *const_2; + + // Arbitrary const + Tensor tensor(DT_INT32, {}); + tensor.scalar<int32>()() = 5; + + for (auto node : {&const_0, &const_1}) { + TF_EXPECT_OK(NodeBuilder("Const", "Const") + .Attr("value", tensor) + .Attr("dtype", DT_INT32) + .Finalize(&g, node)); + } + // Make sure generated name doesn't clash with existing name either + TF_EXPECT_OK(NodeBuilder("Const_1", "Const") + .Attr("value", tensor) + .Attr("dtype", DT_INT32) + .Finalize(&g, &const_2)); + + TF_EXPECT_OK(EnsureNodeNamesUnique(&g)); + EXPECT_NE(const_0->name(), const_1->name()); + EXPECT_NE(const_1->name(), const_2->name()); + EXPECT_NE(const_0->name(), const_2->name()); +} + } // namespace } // namespace graph_utils } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc new file mode 100644 index 0000000000..ce0b2db039 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc @@ -0,0 +1,289 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace grappler { +namespace { + +NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node, + const FunctionDef& stateless_function, + MutableGraphView* graph) { + NodeDef stateless_map; + graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(), + &stateless_map); + + stateless_map.set_op("MapDataset"); + stateless_map.add_input(zip_node.name()); + // Add placeholders. + for (int i = 1; i < map_node.input_size(); i++) + stateless_map.add_input(map_node.input(i)); + + auto attr = map_node.attr().at("f"); + *attr.mutable_func()->mutable_name() = stateless_function.signature().name(); + *attr.mutable_func()->mutable_attr() = stateless_function.attr(); + (*stateless_map.mutable_attr())["f"] = std::move(attr); + + graph_utils::CopyAttribute("Targuments", map_node, &stateless_map); + for (auto key : {"output_shapes", "output_types"}) + graph_utils::CopyAttribute(key, map_node, &stateless_map); + + if (const auto* attr = + gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism")) + (*stateless_map.mutable_attr())["use_inter_op_parallelism"] = *attr; + + return stateless_map; +} + +NodeDef MakeRandomDataset(const NodeDef& random_uniform_node, + MutableGraphView* graph) { + NodeDef random_dataset; + random_dataset.set_op("RandomDataset"); + graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(), + &random_dataset); + + const auto* seed = graph_utils::AddScalarConstNode<int64>( + random_uniform_node.attr().at("seed").i(), graph); + const auto* seed2 = graph_utils::AddScalarConstNode<int64>( + random_uniform_node.attr().at("seed2").i(), graph); + + random_dataset.add_input(seed->name()); + random_dataset.add_input(seed2->name()); + + (*random_dataset.mutable_attr())["output_shapes"].mutable_list()->add_shape(); + (*random_dataset.mutable_attr())["output_types"].mutable_list()->add_type( + DT_INT64); + + return random_dataset; +} + +NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) { + NodeDef batch_dataset; + batch_dataset.set_op("BatchDatasetV2"); + graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(), + &batch_dataset); + const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph); + const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph); + batch_dataset.add_input(random_dataset.name()); + batch_dataset.add_input(batch_size->name()); + batch_dataset.add_input(drop_reminder->name()); + + (*batch_dataset.mutable_attr())["output_shapes"] + .mutable_list() + ->add_shape() + ->mutable_dim() + ->Add() + ->set_size(-1); + (*batch_dataset.mutable_attr())["output_types"].mutable_list()->add_type( + DT_INT64); + + return batch_dataset; +} + +NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node, + MutableGraphView* graph) { + NodeDef zip_node; + graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(), + &zip_node); + + zip_node.set_op("ZipDataset"); + zip_node.add_input(first_node.name()); + zip_node.add_input(second_node.name()); + + for (auto key : {"output_shapes", "output_types"}) + graph_utils::ConcatAttributeList(key, first_node, second_node, &zip_node); + + (*zip_node.mutable_attr())["N"].set_i(2); + + return zip_node; +} + +// We need to insert our argument before the placeholders, which are the last +// arguments. +OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) { + int new_argument_idx = signature->input_arg_size() - num_placeholders; + signature->add_input_arg(); + for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) { + signature->mutable_input_arg()->SwapElements(i - 1, i); + } + auto* seed_arg = signature->mutable_input_arg(new_argument_idx); + seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx)); + seed_arg->set_type(DT_INT64); + + return seed_arg; +} + +// Make function that uses `StatelessRandomUniform` instead of `RandomUniform` +// to make it less statefull. The function can still be stateful, but in when +// other stateful ops are e.g. `Assert`, then it will be parallelizable. +const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function, + bool is_stateful, + int num_placeholders, + FunctionDefLibrary* library) { + FunctionDef* stateless_function = library->add_function(); + *stateless_function = map_function; + if (is_stateful) + stateless_function->mutable_signature()->set_is_stateful(is_stateful); + graph_utils::SetUniqueGraphFunctionName("stateless_function", library, + stateless_function); + + auto* seed_arg = InsertSeedArgument(stateless_function->mutable_signature(), + num_placeholders); + + auto* const random_uniform = stateless_function->mutable_node_def( + function_utils::FindFunctionNodeWithOp("RandomUniform", + *stateless_function)); + + // Replace RandomUniform node with StatelessRandomUniform. + random_uniform->set_op("StatelessRandomUniform"); + random_uniform->add_input(seed_arg->name()); + (*random_uniform->mutable_attr())["Tseed"].set_type(DT_INT64); + random_uniform->mutable_attr()->erase("seed"); + random_uniform->mutable_attr()->erase("seed2"); + + return stateless_function; +} +// This function returns true if function is stateful and has single +// RandomUniform op and no other stateful ops except Assert. +// `is_stateful_after_hoisting` is set to true if RandomUniform is the only +// stateful op and hoisting can be performed. +bool CanHoistRandomUniform(const FunctionDef& map_function, + const FunctionLibraryDefinition& library, + bool* is_stateful_after_hoisting, + const NodeDef** random_uniform_op) { + if (!map_function.signature().is_stateful()) return false; + *is_stateful_after_hoisting = true; + + bool have_other_stateful_ops = false; + + for (const auto& node : map_function.node_def()) { + const OpDef* op_def; + TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def)); + // Skip stateless nodes and assert, as it does not actually have a state. + if (!op_def->is_stateful()) continue; + + if (op_def->name() == "Assert") { + have_other_stateful_ops = true; + continue; + } + + // TODO(prazek): For now we only handle RandomUniform, we should handle + // RandomUniformInt as well. + if (op_def->name() != "RandomUniform") return false; + + // TODO(prazek): For now we can only hoist single RandomUniform. + if (*random_uniform_op != nullptr) return false; + + *random_uniform_op = &node; + } + + if (!have_other_stateful_ops) *is_stateful_after_hoisting = false; + + // Have we found single RandomUniform? + return *random_uniform_op != nullptr; +} + +int NumberOfPlaceholders(const NodeDef& map_node) { + // First input of MapDataset is the argument to the function. Rest of the + // inputs are placeholders. + return map_node.input_size() - 1; +} + +} // namespace + +Status HoistRandomUniform::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + *output = item.graph; + + MutableGraphView graph(output); + std::set<string> nodes_to_delete; + FunctionLibraryDefinition function_library(OpRegistry::Global(), + item.graph.library()); + + auto get_map_node = [](const NodeDef& node) -> const NodeDef* { + // TODO(prazek): we could also handle ParallelMapDataset and + // MapAndBatchDataset. + if (node.op() == "MapDataset") return &node; + return nullptr; + }; + + for (const NodeDef& node : item.graph.node()) { + const NodeDef* map_node = get_map_node(node); + if (!map_node) continue; + + const auto& fun = map_node->attr().at("f"); + const FunctionDef* func = function_library.Find(fun.func().name()); + + const NodeDef* random_uniform_op = nullptr; + bool is_stateful_after_hoisting = true; + if (!CanHoistRandomUniform(*func, function_library, + &is_stateful_after_hoisting, &random_uniform_op)) + continue; + const auto* random_seed_dataset = + graph.AddNode(MakeRandomDataset(*random_uniform_op, &graph)); + + const auto* batch_dataset = + graph.AddNode(MakeBatchTwo(*random_seed_dataset, &graph)); + + const NodeDef& parent_node = *graph_utils::GetInputNode(*map_node, graph); + + const auto* zip_node = + graph.AddNode(MakeZipNode(parent_node, *batch_dataset, &graph)); + + const auto* stateless_func = MakeLessStatefulFunction( + *func, is_stateful_after_hoisting, NumberOfPlaceholders(*map_node), + output->mutable_library()); + + const auto* stateless_map = graph.AddNode( + MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph)); + + graph.ReplaceInput(*map_node, *stateless_map); + + // TODO(b/116285210): we could also remove map functions from library if + // they are not used anymore. + nodes_to_delete.insert(map_node->name()); + } + + graph.DeleteNodes(nodes_to_delete); + return Status::OK(); +} + +void HoistRandomUniform::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, + double result) { + // no-op +} + +REGISTER_GRAPH_OPTIMIZER_AS(HoistRandomUniform, "hoist_random_uniform"); + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h new file mode 100644 index 0000000000..d1bcf6782d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// This optimization hoists instances of `random_uniform` out of a function +// with the aim of making it stateless. It creates a new function that takes a +// random seed as an extra argument and uses `stateless_random_uniform` instead +// of `random_uniform` to make it stateless. +// It also creates RandomDataset(seed).batch(2), which is zipped with old input +// to the map. The batching in RandomDataset is because we need 2 seeds for +// `stateless_random_uniform`. +// TODO(prazek): for now only `RandomUniform` is handled, but we could handle +// `RandomUniformInt` similarly. +class HoistRandomUniform : public CustomGraphOptimizer { + public: + HoistRandomUniform() = default; + ~HoistRandomUniform() override = default; + + string name() const override { return "hoist_random_uniform"; }; + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_ diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc new file mode 100644 index 0000000000..455459e3f6 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h" + +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +TEST(HoistRandomUniform, SimpleHoisting) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, + {{"output_shapes", gtl::ArraySlice<TensorShape>{}}, + {"output_types", gtl::ArraySlice<DataType>{}}}), + graph_tests_utils::MakeMapNode("map1", "range", "RandomUniform"), + NDef("cache", "CacheDataset", {"map1", "filename"}, {})}, + // FunctionLib + { + test::function::RandomUniform(), + }); + + HoistRandomUniform optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output)); + const int new_map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output); + const int zip_dataset_id = + graph_utils::FindGraphNodeWithOp("ZipDataset", output); + const int random_dataset_id = + graph_utils::FindGraphNodeWithOp("RandomDataset", output); + const int batch_random_id = + graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output); + ASSERT_NE(random_dataset_id, -1); + ASSERT_NE(zip_dataset_id, -1); + ASSERT_NE(new_map_id, -1); + ASSERT_NE(batch_random_id, -1); + + const auto& new_map = output.node(new_map_id); + const auto& zip = output.node(zip_dataset_id); + const auto& random = output.node(random_dataset_id); + const auto& batch = output.node(batch_random_id); + + ASSERT_EQ(new_map.input_size(), 1); + EXPECT_EQ(new_map.input(0), zip.name()); + + ASSERT_EQ(zip.input_size(), 2); + EXPECT_EQ(zip.input(0), "range"); + EXPECT_EQ(zip.input(1), batch.name()); + + ASSERT_EQ(batch.input_size(), 3); + EXPECT_EQ(batch.input(0), random.name()); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc index 63945b8b9e..e66766eb23 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc @@ -80,11 +80,12 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node, // Set `f` and `Targuments` attributes. for (auto key : {"f", "Targuments"}) { - (*new_node.mutable_attr())[key] = map_node.attr().at(key); + graph_utils::CopyAttribute(key, map_node, &new_node); } + // Set `output_types` and `output_shapes` attributes. for (auto key : {"output_shapes", "output_types"}) { - (*new_node.mutable_attr())[key] = batch_node.attr().at(key); + graph_utils::CopyAttribute(key, batch_node, &new_node); } return new_node; } diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc index f1844a141c..c4868eacbb 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -41,19 +42,18 @@ NodeDef MakeFusedNode(const NodeDef& map_node, fused_node.set_op("MapDataset"); fused_node.add_input(map_node.input(0)); - auto copy_attribute = [](const string& attribute_name, const NodeDef& from, - NodeDef* to) { - (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name); - }; - auto attr = map_node.attr().at("f"); attr.mutable_func()->set_name(fused_function.signature().name()); (*fused_node.mutable_attr())["f"] = std::move(attr); - copy_attribute("Targuments", map_node, &fused_node); + graph_utils::CopyAttribute("Targuments", map_node, &fused_node); for (auto key : {"output_shapes", "output_types"}) - copy_attribute(key, map_node, &fused_node); + graph_utils::CopyAttribute(key, map_node, &fused_node); + + if (const auto* attr = + gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism")) + (*fused_node.mutable_attr())["use_inter_op_parallelism"] = *attr; // Add the predicate output attributes. (*fused_node.mutable_attr())["output_types"] diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc index f029a093fa..6e6da37d7c 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -27,24 +28,8 @@ limitations under the License. namespace tensorflow { namespace grappler { namespace { - -NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) { - return test::function::NDef( - name, "MapDataset", {string(input_node_name)}, - {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} - -NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) { - return test::function::NDef( - name, "FilterDataset", {string(input_node_name)}, - {{"predicate", FunctionDefHelper::FunctionRef("IsZero")}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} +using graph_tests_utils::MakeFilterNode; +using graph_tests_utils::MakeMapNode; TEST(MapAndFilterFusionTest, FuseMapAndFilter) { using test::function::NDef; diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc index a78ecb09f7..bd943342e8 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -40,24 +41,31 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node, NodeDef fused_node; graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(), &fused_node); - fused_node.set_op("MapDataset"); fused_node.add_input(parent_map_node.input(0)); - auto copy_attribute = [](const string& attribute_name, const NodeDef& from, - NodeDef* to) { - (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name); - }; - auto attr = parent_map_node.attr().at("f"); *attr.mutable_func()->mutable_name() = fused_function.signature().name(); (*fused_node.mutable_attr())["f"] = std::move(attr); - copy_attribute("Targuments", parent_map_node, &fused_node); - + graph_utils::CopyAttribute("Targuments", parent_map_node, &fused_node); for (auto key : {"output_shapes", "output_types"}) - copy_attribute(key, map_node, &fused_node); + graph_utils::CopyAttribute(key, map_node, &fused_node); + auto value_or_false = [](const AttrValue* attr) { + if (!attr) return false; + return attr->b(); + }; + + const auto* first_parallelism = + gtl::FindOrNull(parent_map_node.attr(), "use_inter_op_parallelism"); + const auto* second_parallelism = + gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"); + // Some graphs cannot execute with use_inter_op_parallelism=False, so we need + // to set it to true if one of the ops have it set to true. + if (value_or_false(first_parallelism) || value_or_false(second_parallelism)) { + (*fused_node.mutable_attr())["use_inter_op_parallelism"].set_b(true); + } return fused_node; } @@ -123,8 +131,8 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item, // fusion. TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function)); - // TODO(prazek): we could also remove map functions from library if they - // are not used anymore. + // TODO(b/116285210): we could also remove map functions from library if + // they are not used anymore. nodes_to_delete.insert(parent_map_node->name()); nodes_to_delete.insert(map_node->name()); } diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc index b25dfbd0b8..8889f9dddd 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -28,14 +29,7 @@ namespace tensorflow { namespace grappler { namespace { -NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) { - return test::function::NDef( - name, "MapDataset", {string(input_node_name)}, - {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} +using graph_tests_utils::MakeMapNode; TEST(MapFusionTest, FuseTwoMapNodesIntoOne) { using test::function::NDef; diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc index 305325e434..782c9f48b7 100644 --- a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc @@ -84,9 +84,6 @@ Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item, auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph)); graph.ReplaceInput(*map_node, *parallel_map); - - // TODO(prazek): we could also remove map functions from library if they - // are not used anymore. nodes_to_delete.insert(map_node->name()); } diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc index b2a5d9b6af..9fdfe8af30 100644 --- a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" - #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -28,16 +28,7 @@ namespace tensorflow { namespace grappler { namespace { -NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, - StringPiece function_name) { - return test::function::NDef( - name, "MapDataset", {string(input_node_name)}, - {{"f", FunctionDefHelper::FunctionRef(string(function_name))}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} - +using graph_tests_utils::MakeMapNode; const char stateless_fun_name[] = "XTimesTwo"; const char stateful_fun_name[] = "RandomUniform"; diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index 7a2f1910da..a9254ed58b 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -35,10 +35,6 @@ namespace tensorflow { namespace grappler { namespace { -void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) { - (*to->mutable_attr())[attr_name] = from.attr().at(attr_name); -} - // Returns a FunctionDef containing a MapDefun op that wraps the original // function. FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, @@ -48,7 +44,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, // Function inputs and outputs are the same as original, just // with different shapes. *vectorized_func->mutable_signature() = orig_func.signature(); - graph_utils::SetUniqueGraphFunctionName("vectorized_function", library, + graph_utils::SetUniqueGraphFunctionName("naively_vectorized_fn", library, vectorized_func); // Add MapDefun node @@ -61,7 +57,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, for (const string& k : {"f", "output_types", "output_shapes"}) { // Function, output types and (unbatched) shapes are the same as the // original map node. - CopyAttribute(k, map_node, map_defun_node); + graph_utils::CopyAttribute(k, map_node, map_defun_node); } // Get types of input arguments from original map function @@ -71,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, map_defun_node->add_input(input.name()); } (*map_defun_node->mutable_attr())["Targuments"] = t_args; + AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node); // Set return values to match output names string output_prefix = strings::StrCat(map_defun_node->name(), ":output:"); @@ -90,21 +87,19 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node, // efficient vectorization with VectorizeMapDefun. FunctionDef* vectorized_func = CreateMapDefunWrapper(map_node, orig_func, library); - NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0); - DCHECK_EQ(map_defun_node->op(), "MapDefun"); - - // Create a copy of the original function so that we can mutate it, and - // attach that to the map defun node. - FunctionDef* map_defun_fn = library->add_function(); - *map_defun_fn = orig_func; - graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library, - map_defun_fn); - (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name( - map_defun_fn->signature().name()); - - vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn, - map_defun_node); - return vectorized_func; + const NodeDef& map_defun_node = vectorized_func->node_def(0); + DCHECK_EQ(map_defun_node.op(), "MapDefun"); + + // TODO(b/116285210): Unreferenced functions should get cleaned up later + FunctionDef* result; + Status s = vectorization_utils::VectorizeMapDefun( + *vectorized_func, map_defun_node, library, &result); + + if (!s.ok()) { + LOG(ERROR) << "VectorizeMapDefun failed: " << s; + return vectorized_func; + } + return result; } bool IsOutputShapesFullyDefined(const NodeDef& node) { @@ -195,13 +190,16 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node, } // Set attrs - CopyAttribute("Targuments", old_map_node, &map_node); + graph_utils::CopyAttribute("Targuments", old_map_node, &map_node); auto& func_attr = (*map_node.mutable_attr())["f"]; func_attr.mutable_func()->set_name(vectorized_func.signature().name()); for (auto key : {"output_shapes", "output_types"}) { - CopyAttribute(key, old_batch_node, &map_node); + graph_utils::CopyAttribute(key, old_batch_node, &map_node); } + + (*map_node.mutable_attr())["use_inter_op_parallelism"].set_b(true); + return map_node; } diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc index ed1bd6bc97..f4faf41549 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc @@ -30,72 +30,51 @@ namespace { using test::function::GDef; using test::function::NDef; -void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims, - TensorShapeProto* t) { - for (size_t i = 0; i < dims.size(); ++i) { - auto* d = t->add_dim(); - d->set_size(dims[i]); - } -} - -AttrValue MakeShapeListAttr( - const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) { - AttrValue shapes_attr; - for (size_t i = 0; i < shapes.size(); ++i) { - MakeTensorShapeProtoHelper(shapes[i], - shapes_attr.mutable_list()->add_shape()); - } - - return shapes_attr; -} - -NodeDef MakeMapNodeHelper( - StringPiece name, StringPiece input_node_name, StringPiece function_name, - StringPiece map_op_name, - const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, - const gtl::ArraySlice<DataType>& output_types) { +NodeDef MakeMapNodeHelper(StringPiece name, StringPiece input_node_name, + StringPiece function_name, StringPiece map_op_name, + gtl::ArraySlice<PartialTensorShape> output_shapes, + gtl::ArraySlice<DataType> output_types) { return test::function::NDef( name, map_op_name, {string(input_node_name)}, {{"f", FunctionDefHelper::FunctionRef(string(function_name))}, {"Targuments", {}}, - {"output_shapes", MakeShapeListAttr(output_shapes)}, + {"output_shapes", output_shapes}, {"output_types", output_types}}); } -NodeDef MakeMapNode( - StringPiece name, StringPiece input_node_name, StringPiece function_name, - const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, - const gtl::ArraySlice<DataType>& output_types) { +NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name, + gtl::ArraySlice<PartialTensorShape> output_shapes, + gtl::ArraySlice<DataType> output_types) { return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset", output_shapes, output_types); } -NodeDef MakeBatchNode( - StringPiece name, StringPiece input_node_name, - StringPiece input_batch_size_name, - const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, - const gtl::ArraySlice<DataType>& output_types) { - return NDef(name, "BatchDataset", - {string(input_node_name), string(input_batch_size_name)}, - {{"output_types", output_types}, - {"output_shapes", MakeShapeListAttr(output_shapes)}}); +NodeDef MakeBatchNode(StringPiece name, StringPiece input_node_name, + StringPiece input_batch_size_name, + gtl::ArraySlice<PartialTensorShape> output_shapes, + gtl::ArraySlice<DataType> output_types) { + return NDef( + name, "BatchDataset", + {string(input_node_name), string(input_batch_size_name)}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); } -NodeDef MakeBatchV2Node( - StringPiece name, StringPiece input_node_name, - StringPiece input_batch_size_name, StringPiece input_drop_remainder_name, - const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, - const gtl::ArraySlice<DataType>& output_types) { - return NDef(name, "BatchDatasetV2", - {string(input_node_name), string(input_batch_size_name), - string(input_drop_remainder_name)}, - {{"output_types", output_types}, - {"output_shapes", MakeShapeListAttr(output_shapes)}}); +NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name, + StringPiece input_batch_size_name, + StringPiece input_drop_remainder_name, + gtl::ArraySlice<PartialTensorShape> output_shapes, + gtl::ArraySlice<DataType> output_types) { + return NDef( + name, "BatchDatasetV2", + {string(input_node_name), string(input_batch_size_name), + string(input_drop_remainder_name)}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); } -NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) { +NodeDef MakeRangeNode(StringPiece name, gtl::ArraySlice<string> inputs) { return NDef(name, "RangeDataset", inputs, - {{"output_shapes", MakeShapeListAttr({{}})}, + {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}, {"output_types", gtl::ArraySlice<DataType>({DT_INT64})}}); } @@ -184,7 +163,7 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) { item.graph = GDef( {NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), NDef("input", "InputDataset", {}, - {{"output_shapes", MakeShapeListAttr({{}})}}), + {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}}), MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}), MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})}, // FunctionLib @@ -196,6 +175,37 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) { TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); } +TEST(MapVectorizationTest, VectorizeWithFullyDefinedFunction) { + GrapplerItem item; + item.graph = GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + MakeRangeNode("range", {"start", "stop", "step"}), + MakeMapNode("map", "range", "Func", {{}}, {DT_INT32}), + MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})}, + // FunctionLib + {FunctionDefHelper::Create( + "Func", {"x: int64", "y: int64"}, {"res: int64", "res2: int64"}, {}, + {{{"o"}, "Mul", {"x", "x"}, {{"T", DT_INT64}}}}, + {{"res", "o:z"}, {"res2", "o:z"}})}); + MapVectorization optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(), + 1); + EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(), + 1); + const NodeDef& map_node = + output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output)); + const NodeDef& batch_node = + output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output)); + EXPECT_EQ(map_node.input(0), batch_node.name()); + EXPECT_EQ(batch_node.input(0), "range"); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc index cb0ff670e8..99c4afa634 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc @@ -64,7 +64,7 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster, // Set `output_types` and `output_shapes` attributes. for (auto key : {"output_shapes", "output_types"}) { - (*new_node.mutable_attr())[key] = repeat_node.attr().at(key); + graph_utils::CopyAttribute(key, repeat_node, &new_node); } return new_node; }; diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD index 1462cb234d..985d6c6c3a 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD +++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD @@ -9,13 +9,24 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all") VECTORIZER_DEPS = [ ":vectorizer_registry", - "//tensorflow/core/grappler/optimizers/data:function_utils", + "//tensorflow/core/grappler/optimizers/data:graph_utils", ] + tf_protos_all() cc_library( + name = "wrapped_tensor", + hdrs = ["wrapped_tensor.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + ], +) + +cc_library( name = "vectorizer", hdrs = ["vectorizer.h"], deps = [ + ":wrapped_tensor", + "//tensorflow/core:core_cpu", "//tensorflow/core:lib", ] + tf_protos_all(), ) diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc index c1739737a0..f445157531 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc @@ -14,41 +14,38 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" namespace tensorflow { namespace grappler { -namespace vectorization_utils { +namespace { class CastVectorizer : public Vectorizer { public: - Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, - FunctionDef* outer_scope, - std::map<string, string>* conversion_map) override { - if (inputs.size() != 1) { + Status Vectorize(const Node& node, Graph* outer_scope, + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { + Status s; + if (node.num_inputs() != 1) { return errors::Internal("Cast op should only have one input."); } - // Add new Cast node - NodeDef* new_cast_node = outer_scope->add_node_def(); - *new_cast_node = node; - new_cast_node->clear_name(); - function_utils::SetUniqueFunctionNodeName( - strings::StrCat("vectorized/", node.name()), outer_scope, - new_cast_node); - new_cast_node->set_input(0, inputs[0]); + // Add new Cast node with the same op and attrs as the original node + auto new_cast_node = outer_scope->AddNode(node.def(), &s); + TF_RETURN_IF_ERROR(s); - // Add the output mapping to conversion map - (*conversion_map)[strings::StrCat(node.name(), ":y:0")] = - strings::StrCat(new_cast_node->name(), ":y:0"); + outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node, + 0); + // Add output mappings + outputs->push_back({new_cast_node, 0, true}); return Status::OK(); } }; REGISTER_VECTORIZER("Cast", CastVectorizer); -} // namespace vectorization_utils +} // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc index 776d3179c5..f1ba741821 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc @@ -14,40 +14,38 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" namespace tensorflow { namespace grappler { -namespace vectorization_utils { +namespace { class UnpackVectorizer : public Vectorizer { public: - Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, - FunctionDef* outer_scope, - std::map<string, string>* conversion_map) override { - if (inputs.size() != 1) { + Status Vectorize(const Node& node, Graph* outer_scope, + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { + Status s; + if (node.num_inputs() != 1 || inputs.size() != 1) { return errors::Internal("Unpack op should only have one input."); } - // Add new Unpack node - NodeDef* new_unpack_node = outer_scope->add_node_def(); - *new_unpack_node = node; - new_unpack_node->clear_name(); - function_utils::SetUniqueFunctionNodeName( - strings::StrCat("vectorized/", node.name()), outer_scope, - new_unpack_node); + // Add new Unpack node with the same op and attrs as the original node + auto new_unpack_node = outer_scope->AddNode(node.def(), &s); + TF_RETURN_IF_ERROR(s); // Increment "axis" attr by 1: - (*new_unpack_node->mutable_attr())["axis"].set_i( - node.attr().at("axis").i() + 1); - new_unpack_node->set_input(0, inputs[0]); + int new_axis = node.def().attr().at("axis").i() + 1; + new_unpack_node->AddAttr("axis", new_axis); - // Add the output mappings to conversion map - int num = new_unpack_node->attr().at("num").i(); + outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, + new_unpack_node, 0); + + // Add the output mappings + int num = node.def().attr().at("num").i(); for (int i = 0; i < num; ++i) { - (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] = - strings::StrCat(new_unpack_node->name(), ":output:", i); + outputs->push_back({new_unpack_node, i, true}); } return Status::OK(); @@ -56,6 +54,6 @@ class UnpackVectorizer : public Vectorizer { REGISTER_VECTORIZER("Unpack", UnpackVectorizer); -} // namespace vectorization_utils +} // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h index d341dbba7d..8d4676aae0 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h @@ -17,13 +17,13 @@ limitations under the License. #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ #include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { namespace grappler { -namespace vectorization_utils { // Interface for vectorization of TensorFlow operations. See `CastVectorizer` // for an example. @@ -31,19 +31,19 @@ class Vectorizer { public: virtual ~Vectorizer() {} - // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope` + // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope` // that produce the same vector output(s) as executing `node`'s op - // on elements of the vector inputs, and adding mappings to `conversion_map` - // from old output tensor names to new (vectorized) output tensor names. - // The new node(s) collectively have the same number of inputs and outputs as - // the node being converted, and use the tensor names in `inputs` as their - // inputs. - virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, - FunctionDef* outer_scope, - std::map<string, string>* conversion_map) = 0; + // on elements of `inputs`. The new Node(s) collectively have the + // same number of input and output ports as the node being converted. + // Adds edges between the newly created nodes and nodes in `inputs`, and adds + // mappings to the new nodes' output ports to `outputs`, where the i'th + // value in `outputs` corresponds to the i'th output port of the node + // to be converted. + virtual Status Vectorize(const Node& node, Graph* outer_scope, + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) = 0; }; -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc index a6551e36ac..e1cf77a7d5 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc @@ -19,7 +19,6 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { VectorizerRegistry* VectorizerRegistry::Global() { static VectorizerRegistry* registry = new VectorizerRegistry; @@ -42,6 +41,5 @@ void VectorizerRegistry::Register(const string& op_type, vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>( op_type, std::move(vectorizer))); } -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h index 16159d47ca..ad54c74933 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h @@ -23,7 +23,6 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { // A global VectorizerRegistry is used to hold all the vectorizers. class VectorizerRegistry { @@ -59,16 +58,12 @@ class VectorizerRegistration { #define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \ REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) -#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \ - static ::tensorflow::grappler::vectorization_utils:: \ - vectorizer_registration::VectorizerRegistration \ - vectorizer_registration_##ctr( \ - op_type, \ - ::std::unique_ptr< \ - ::tensorflow::grappler::vectorization_utils::Vectorizer>( \ - new vectorizer())) +#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \ + static ::tensorflow::grappler::vectorizer_registration:: \ + VectorizerRegistration vectorizer_registration_##ctr( \ + op_type, ::std::unique_ptr<::tensorflow::grappler::Vectorizer>( \ + new vectorizer())) -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc index 86e303564b..054aeb9a8f 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc @@ -20,13 +20,12 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { class TestVectorizer : public Vectorizer { public: - Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, - FunctionDef* outer_scope, - std::map<string, string>* conversion_map) override { + Status Vectorize(const Node& node, Graph* outer_scope, + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { return Status::OK(); } }; @@ -39,12 +38,14 @@ TEST(TestVectorizer, TestTestVectorizer) { auto vectorizer = VectorizerRegistry::Global()->Get("test_op"); EXPECT_NE(vectorizer, nullptr); - FunctionDef function; - NodeDef node; - std::map<string, string> conversion_map; - EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok()); + Graph g(OpRegistry::Global()); + NodeDef node_def; + Status s; + Node* node = g.AddNode(node_def, &s); + std::vector<WrappedTensor> inputs, outputs; + EXPECT_TRUE( + vectorizer->Vectorize(*node, &g, std::move(inputs), &outputs).ok()); } -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h new file mode 100644 index 0000000000..4439b4ab4e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace grappler { + +// Represents a tensor that has been vectorized. +struct WrappedTensor { + Node* const node; + const int output_index; + + // Whether the tensor is stacked, i.e. represents the results of applying + // the operation on all slices of the input, where each row i of the + // tensor corresponds to the op's output on slice i of the input. False + // if the tensor is not stacked, i.e. represents the result of the op on + // a single slice of the input, where the result does not vary between + // slices. + bool stacked; + + WrappedTensor(Node* node, int output_index, bool stacked) + : node(node), output_index(output_index), stacked(stacked) {} +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index cb56b65985..ba857ab5d9 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -17,274 +17,588 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" #include "absl/strings/str_join.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/scanner.h" namespace tensorflow { namespace grappler { namespace vectorization_utils { -using function_utils::FunctionDefTensorDesc; - namespace { -void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node, - const string& output_retval, const DataType t) { - // Set to unknown shape - TensorShapeProto tensor_shape_proto; - PartialTensorShape().AsProto(&tensor_shape_proto); +// Describes a tensor with its operation Node and output position +typedef std::pair<Node*, int> TensorDesc; - function_utils::AddFunctionOutputWithUniqueName( - "vectorized_out", output_retval, map_defun_fn, t); +const char* const kRetValOp = "_Retval"; - *(*map_defun_node->mutable_attr())["output_shapes"] - .mutable_list() - ->add_shape() = tensor_shape_proto; - (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t); +void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src, + Graph* graph) { + // NOTE: We need two for loops here because we can't mutate the set of output + // edges as we iterate over them. + std::vector<const Edge*> edges_to_replace; + for (auto edge : old_src.first->out_edges()) { + if (edge->src_output() == old_src.second) { + edges_to_replace.push_back(edge); + } + } + for (auto edge : edges_to_replace) { + graph->AddEdge(new_src.first, new_src.second, edge->dst(), + edge->dst_input()); + graph->RemoveEdge(edge); + } } -void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node, int output_position) { - DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size()) - << "Trying to remove output that doesn't exist. Output number: " - << output_position; +Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node, + const TensorDesc& output) { + // Note that we don't update MapDefun attrs as we go, only when we are done + DataType type = output.first->output_type(output.second); + int index = map_defun_fn->ret_nodes.size(); - int num_later_outputs = - map_defun_fn->signature().output_arg_size() - output_position - 1; + NodeDef ret_node_def; + ret_node_def.set_name("map_out"); + ret_node_def.set_op(kRetValOp); + AddNodeAttr("T", type, &ret_node_def); + AddNodeAttr("index", index, &ret_node_def); - // Remove from map_defun_fn's ret dict and output args - map_defun_fn->mutable_ret()->erase( - map_defun_fn->signature().output_arg(output_position).name()); - map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange( - output_position, 1); + Status s; + Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s); + TF_RETURN_IF_ERROR(s); - // Renumber outputs that come after - for (int i = 0; i < num_later_outputs; ++i) { - function_utils::ReplaceReferences( - strings::StrCat(map_defun_node->name(), - ":output:", output_position + i + 1), - strings::StrCat(map_defun_node->name(), - ":output:", output_position + i), - outer_scope); - } - map_defun_node->mutable_attr() - ->at("output_shapes") - .mutable_list() - ->mutable_shape() - ->DeleteSubrange(output_position, 1); - map_defun_node->mutable_attr() - ->at("output_types") - .mutable_list() - ->mutable_type() - ->ExtractSubrange(output_position, 1, nullptr); + map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0); + map_defun_fn->ret_nodes.push_back(ret_node); + map_defun_fn->ret_types.push_back(type); + + return s; } -int FindOutputToConvert(const FunctionDef& function, - const std::set<string>& unconvertible, - FunctionDefTensorDesc* f) { - for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) { - const string& ret_key = function.signature().output_arg(i).name(); - *f = FunctionDefTensorDesc(function.ret().at(ret_key)); +void RemoveMapDefunOutput(int output_position, Graph* outer_scope, + FunctionBody* map_defun_fn, Node* map_defun_node) { + // Note that we don't update MapDefun attrs as we go, only when we are done + DCHECK_LT(output_position, map_defun_fn->ret_nodes.size()) + << "Trying to remove output that doesn't exist. Output number: " + << output_position; - if (unconvertible.find(f->node_name) == unconvertible.end()) { - return i; - } + int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1; + + // Modify map_defun_fn's signature and remove the output node from its graph + map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]); + map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() + + output_position); + map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() + + output_position); + + // Renumber the nodes and edges that come after + for (int i = 0; i < num_later_outputs; ++i) { + ReplaceEdgeSources({map_defun_node, output_position + i + 1}, + {map_defun_node, output_position + i}, outer_scope); + // Each ret node has an "index" attr that has to be updated + map_defun_fn->ret_nodes[output_position + i]->AddAttr("index", + output_position + i); } - return -1; } // Helper class that vectorizes the body of a MapDefun node, adding new // operations to the graph that collectively compute the same value as what // running the MapDefun function on slices of the input would produce. -// Each instance of the class encapsulates all the data necessary to vectorize a -// MapDefun op in place. +// This class transforms the input FunctionDefs into their corresponding +// Graph objects and works on the graphs directly, then converts them back +// to FunctionDefs when GetResult is called. class Vectorization { public: - Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node) - : outer_scope_(outer_scope), - map_defun_fn_(map_defun_fn), - map_defun_node_(map_defun_node) {} + explicit Vectorization(FunctionDefLibrary* lib) + : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {} - // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in - // the outer_scope_, until there are no convertible outputs remaining. - // This method is idempotent. - void Vectorize(); + // Adds the vectorized function and new map_defun_fn to lib, and points + // vectorized_function to the former. Returns an error status if + // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere + // along the way. + Status Vectorize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, FunctionDef** result); private: - // Vectorizes the map defun function's output at output_position - Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc); - // Given a descriptor of the original output tensor, gets a string - // corresponding to the converted output tensor. - Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc, - string* converted); - Status AddConversionMappingFromInput( - const FunctionDefTensorDesc& output_desc); + // Converts FunctionDefs to Graphs and adds mappings from + // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_. + Status Initialize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node); + + // Converts Graphs back to FunctionDefs and adds them to `lib_`. + Status GetResult(FunctionDef** vectorized_function); + + // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in + // `outer_scope_`, until there are no convertible outputs remaining. + void VectorizeHelper(); + + // Vectorizes map_defun_fn's output at output_position. + Status ConvertOutput(int output_position); // Adds mappings from node's outputs tensors to converted output tensors, // creating the necessary new node(s). Generally, the steps to convert an op // are: - // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_, - // and modify map_defun_node_ attrs accordingly - // 2) Create new node(s) in outer_scope_ that act on batched input tensors. + // 1) Create new node(s) in `outer_scope_` that act on batched input tensors. // These operations collectively compute the same value as what running // the original operation on slices of the input tensors would produce. // For example, a Cast op in MapDefun translates to a Cast op in - // outer_scope_, since the vectorized version of Cast is itself. - // 3) Set inputs of new node(s) to the corresponding converted inputs (that - // are now outputs of map_defun_node_) - // 4) For each output of the old node, add the mapping of output strings to - // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0") - Status AddConversionMappingFromOp(const NodeDef& node, - const FunctionDefTensorDesc& output_desc); - - // Maps a tensor name to the name of the corresponding vectorized tensor. For - // example, "Cast:y:0" -> "Vectorize/Cast:y:0" - std::map<string, string> conversion_map_; - // Unconvertible node names - std::set<string> unconvertible_; - - FunctionDef* outer_scope_; - FunctionDef* map_defun_fn_; - NodeDef* map_defun_node_; + // `outer_scope_`, since the vectorized version of Cast is itself. + // 2) Promote the inputs of the op inputs to outputs of the + // `map_defun_node_` and `map_defun_fn_`. + // 3) Add edges between the promoted inputs (that are now outputs of + // `map_defun_node`) and the inputs ports of the new node(s). + // 4) For each output of the old node, add the mapping of output tensors to + // the conversion map. + Status AddConversionMapping(Node* op_node); + + // Given a tensor t in `unstacked`, stacks it by doing the equivalent of + // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of + // inputs to `map_defun_node_`. This stacked tensor will be compatible with + // the expected output shape of `map_defun_node_`. + // This is equivalent to the _stack function in python Pfor. + Status StackTensor(WrappedTensor* unstacked, TensorDesc* result); + + // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by + // doing a depth-first search from the ret nodes. Lifts nodes that are + // unstacked (i.e. don't derive from arg nodes) into `outer_scope_` directly + // and add mappings to `conversion_map_`. + Status AddUnstackedNodeMappings(); + + // Recursive helper for `AddUnstackedNodeMappings`, returns true if tensor + // is unstacked. + bool AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, Status* status); + + // Add mappings from `map_defun_fn_` arg nodes to `map_defun_node_` input + // nodes to `conversion_map_`. + Status AddArgNodeMappings(); + + // Maps a tensor to the corresponding WrappedTensor. For example, + // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true) + std::map<TensorDesc, WrappedTensor> conversion_map_; + + // Unconvertible ret nodes + std::set<Node*> unconvertible_; + + FunctionDefLibrary* lib_; // Not owned + FunctionLibraryDefinition lib_def_; + // Note that FunctionBody has a pointer to a Graph object that corresponds + // to the function's subgraph, with additional kArgOp and kRetValOp nodes + // that denote that function arguments and return values. These nodes have the + // attrs "T" for the type, and "index" for the argument / retval index + // respectively. FunctionBody also keeps track of arg/ret_nodes and + // arg/ret_types, that should be ordered according to argument/output indices. + std::unique_ptr<Graph> outer_scope_; + std::unique_ptr<FunctionBody> map_defun_fn_; + Node* map_defun_node_ = nullptr; // Owned by `outer_scope` + + // Caches the loop_len_node_ needed for tiling unstacked output. This + // corresponds to a vector with one element. + Node* loop_len_node_ = nullptr; // Owned by `outer_scope` + Status status_; }; -Status Vectorization::AddConversionMappingFromOp( - const NodeDef& node, const FunctionDefTensorDesc& output_desc) { - for (const string& input_name : node.input()) { - if (IsControlInput(input_name)) { +Status Vectorization::AddConversionMapping(Node* op_node) { + for (auto edge : op_node->in_edges()) { + if (edge->IsControlEdge()) { return errors::InvalidArgument( "Vectorizing outputs with control inputs is currently not " "supported."); } } - // TODO(rachelim): Have some mechanism for registering converters and some - // uniform, simpler way to represent them. + auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string()); + if (vectorizer == nullptr) { + return errors::Unimplemented("No vectorizer registered for op: ", + op_node->type_string()); + } + std::vector<WrappedTensor> inputs, outputs; + inputs.reserve(op_node->num_inputs()); + outputs.reserve(op_node->num_outputs()); + + std::vector<const Edge*> input_edges; + TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges)); + + // The inputs for the node to be converted may already have been converted + // themselves. For those that are not, we promote them to MapDefun outputs. + for (size_t i = 0; i < op_node->num_inputs(); ++i) { + auto edge = input_edges[i]; + if (auto found = gtl::FindOrNull(conversion_map_, + {edge->src(), edge->src_output()})) { + inputs.push_back(*found); + } else { + // TODO(rachelim): Handle the case where unconverted inputs are unstacked. + // We assume that all unconverted inputs will be stacked, since we + // converted all unstacked nodes in `Initialize`. However, it's actually + // possible that yet-unconverted nodes may produce unstacked outputs after + // they are vectorized. (For example, see the "Shape" converter in + // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects + // an unstacked input but receives a stacked one, vectorizer->Vectorize + // will return an error. + TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_, + {edge->src(), edge->src_output()})); + int output_index = map_defun_fn_->ret_nodes.size() - 1; + inputs.push_back({map_defun_node_, output_index, true}); + } + } - DataTypeVector types; - const OpDef* op_def = nullptr; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def)); - TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types)); + TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(), + std::move(inputs), &outputs)); - std::vector<string> promoted_inputs; - promoted_inputs.reserve(node.input_size()); - for (int i = 0; i < node.input_size(); ++i) { - promoted_inputs.push_back(strings::StrCat( - map_defun_node_->name(), - ":output:", map_defun_fn_->signature().output_arg_size() + i)); + if (op_node->num_outputs() != outputs.size()) { + return errors::Internal( + "Number of vectorizer outputs does not match. Expected: ", + op_node->num_outputs(), " Actual: ", outputs.size()); } - auto vectorizer = VectorizerRegistry::Global()->Get(node.op()); - if (vectorizer == nullptr) { - return errors::Unimplemented("No vectorizer registered for op: ", - node.op()); + // Add output mappings. + for (size_t i = 0; i < op_node->num_outputs(); ++i) { + conversion_map_.insert({{op_node, i}, outputs[i]}); } - TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_, - &conversion_map_)); + return Status::OK(); +} + +Status Vectorization::ConvertOutput(int output_position) { + // ret_edge->src() is the actual op that generated the retval, and + // ret_edge->dst() is the retval node whose op is "_Retval" + const Edge* ret_edge; + TF_RETURN_IF_ERROR( + map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge)); + + TensorDesc output({ret_edge->src(), ret_edge->src_output()}); + TensorDesc converted_output; + + // It's possible the output already has a mapping, if it comes from a node + // that has already been converted. + auto found = gtl::FindOrNull(conversion_map_, output); + if (!found) { + TF_RETURN_IF_ERROR(AddConversionMapping(output.first)); + found = &conversion_map_.at(output); + } - // If we get here, the conversion was successful, so we promote the inputs - // of the ops to MapDefun outputs. - for (int i = 0; i < types.size(); ++i) { - AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]); + if (found->stacked) { + converted_output = {found->node, found->output_index}; + } else { + // Some outputs may be unstacked if they don't derive from arg nodes + // (for example, if a function returns a constant). For these, we + // have to add extra nodes to tile it in the 0th dimension. + TF_RETURN_IF_ERROR(StackTensor(found, &converted_output)); } + ReplaceEdgeSources({map_defun_node_, output_position}, converted_output, + outer_scope_.get()); + RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(), + map_defun_node_); + return Status::OK(); } -Status Vectorization::AddConversionMappingFromInput( - const FunctionDefTensorDesc& output_desc) { - int input_index = function_utils::FindFunctionInputWithName( - output_desc.node_name, *map_defun_fn_); - if (input_index == -1) { - return errors::Internal("Cannot convert non-existent input."); +Status Vectorization::Vectorize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, + FunctionDef** result) { + TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node)); + VectorizeHelper(); + return GetResult(result); +} + +void Vectorization::VectorizeHelper() { + while (true) { + int output_position = graph_utils::GetFirstElementIndexWithPredicate( + [this](Node* n) { + return this->unconvertible_.find(n) == this->unconvertible_.end(); + }, + map_defun_fn_->ret_nodes); + + // No outputs left to convert + if (output_position == -1) break; + + Status s = ConvertOutput(output_position); + if (!s.ok()) { + Node* output_node = map_defun_fn_->ret_nodes.at(output_position); + VLOG(2) << "Could not convert the output at node: " + << output_node->DebugString() << "\nError: " << s; + unconvertible_.insert(output_node); + } + } + + // If we've converted all the outputs of the MapDefun function, we no longer + // need the MapDefun node and can delete it. + if (map_defun_fn_->ret_nodes.empty()) { + outer_scope_->RemoveNode(map_defun_node_); + } else { + // Update MapDefun node attrs accordingly + DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size()); + map_defun_node_->AddAttr( + "output_shapes", + std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size())); + map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types); + } +} + +Status Vectorization::Initialize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node) { + // Convert outer_scope and map_defun_fn to FunctionBodys so we can + // work on Graphs directly. + const FunctionDef* map_defun_fn = + lib_def_.Find(map_defun_node.attr().at("f").func().name()); + + if (map_defun_fn == nullptr) { + return errors::NotFound("Could not find function with name ", + map_defun_node.attr().at("f").func().name(), + " in function library."); + } + + auto get_func_sig = [this](const string& op, const OpDef** sig) { + return this->lib_def_.LookUpOpDef(op, sig); + }; + + FunctionBody* outer_fn; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_, + get_func_sig, &outer_fn)); + // We don't need outer_fn, just the graph + outer_scope_.reset(outer_fn->graph); + outer_fn->graph = nullptr; + delete outer_fn; + + FunctionBody* tmp; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_, + get_func_sig, &tmp)); + map_defun_fn_.reset(tmp); + + // Find the MapDefun node in outer_scope_ + int node_id = graph_utils::GetFirstElementIndexWithPredicate( + [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); }, + outer_scope_->nodes()); + if (node_id == -1) { + return errors::NotFound("Could not find node with name ", + map_defun_node.name(), " in outer_scope."); } + map_defun_node_ = outer_scope_->FindNodeId(node_id); + + TF_RETURN_IF_ERROR(AddArgNodeMappings()); + + TF_RETURN_IF_ERROR(AddUnstackedNodeMappings()); + loop_len_node_ = nullptr; - conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index); return Status::OK(); } -Status Vectorization::ConvertOutputHelper( - const FunctionDefTensorDesc& output_desc, string* converted) { - // It's possible the output already has a mapping, if it comes from a node - // that has already been converted. - if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) { - *converted = *found; - return Status::OK(); +// TODO(rachelim): It might be profitable to use the C++ API for this instead of +// NodeBuilder +Status Vectorization::StackTensor(WrappedTensor* unstacked, + TensorDesc* result) { + // Note that all these nodes are necessary as the size of the batch may not be + // constant. + if (unstacked->stacked) { + return errors::Internal("Can only stack unstacked tensor."); } - int index = function_utils::FindFunctionNodeWithName(output_desc.node_name, - *map_defun_fn_); - if (index == -1) { // The output comes from an input - TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc)); - } else { - TF_RETURN_IF_ERROR(AddConversionMappingFromOp( - map_defun_fn_->node_def(index), output_desc)); + Graph* g = outer_scope_.get(); + auto node_builder = [](StringPiece op) { + return NodeBuilder(strings::StrCat("vectorized/stack/", op), op); + }; + + auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph, + Node** result) { + TF_RETURN_IF_ERROR(val.status); + return node_builder("Const") + .Attr("value", val.tensor) + .Attr("dtype", val.tensor.dtype()) + .Finalize(graph, result); + }; + + // If loop_len_node_ hasn't been created yet, add the node and cache it. + if (loop_len_node_ == nullptr) { + Node* input_node; + TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node)); + + Node* shape_node; + TF_RETURN_IF_ERROR( + node_builder("Shape").Input(input_node).Finalize(g, &shape_node)); + + Node* const_vec_0; + TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0)); + Node* const_vec_1; + TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1)); + + Node* strided_slice_node; + TF_RETURN_IF_ERROR(node_builder("StridedSlice") + .Input(shape_node) // input + .Input(const_vec_0) // begin + .Input(const_vec_1) // end + .Input(const_vec_1) // strides + .Finalize(g, &strided_slice_node)); + + // Produces a vector of length 1 + TF_RETURN_IF_ERROR(node_builder("Reshape") + .Input(strided_slice_node) // tensor + .Input(const_vec_1) // shape + .Finalize(g, &loop_len_node_)); } - *converted = conversion_map_.at(output_desc.full_str); + + Node* ones_shape; + TF_RETURN_IF_ERROR(node_builder("Shape") + .Input(unstacked->node) // input + .Finalize(g, &ones_shape)); + + Node* ones; + TF_RETURN_IF_ERROR( + node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones)); + + Node* const_0; + TF_RETURN_IF_ERROR(make_const(0, g, &const_0)); + + Node* multiples; + TF_RETURN_IF_ERROR(node_builder("Concat") + .Input(const_0) // concat_dim + .Input({{loop_len_node_, 0}, {ones, 0}}) // values + .Finalize(g, &multiples)); + + Node* expand_dims; + TF_RETURN_IF_ERROR(node_builder("ExpandDims") + .Input(unstacked->node) // input + .Input(const_0) // dim + .Finalize(g, &expand_dims)); + + TF_RETURN_IF_ERROR(node_builder("Tile") + .Input(expand_dims) // input + .Input(multiples) // multiples + .Finalize(g, &result->first)); + result->second = 0; return Status::OK(); } -Status Vectorization::ConvertOutput(int output_position, - const FunctionDefTensorDesc& output_desc) { - string converted_output_name; - TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name)); +Status Vectorization::AddArgNodeMappings() { + for (auto arg_node : map_defun_fn_->arg_nodes) { + Node* input_node; + TF_RETURN_IF_ERROR(map_defun_node_->input_node( + arg_node->attrs().Find("index")->i(), &input_node)); - // Remove the old output and make everything that referenced it point - // to the new string - function_utils::ReplaceReferences( - strings::StrCat(map_defun_node_->name(), ":output:", output_position), - converted_output_name, outer_scope_); - RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_, - output_position); + conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}}); + // Control inputs + conversion_map_.insert({{arg_node, Graph::kControlSlot}, + {input_node, Graph::kControlSlot, true}}); + } return Status::OK(); } -void Vectorization::Vectorize() { - while (true) { - FunctionDefTensorDesc desc; - int output_position = - FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc); - if (output_position == -1) break; +bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, + Status* status) { + if (auto found = gtl::FindOrNull(conversion_map_, tensor)) { + return !found->stacked; + } - if (!ConvertOutput(output_position, desc).ok()) { - unconvertible_.insert(desc.node_name); + if (tensor.first->op_def().is_stateful()) { + // We don't lift stateful nodes directly out of the MapDefun, since they may + // have to be executed N times. + return false; + } + + bool is_unstacked = true; + for (auto edge : tensor.first->in_edges()) { + // Ignore Source nodes. Note that these are also ignored in the + // GraphToFunctionDef conversion. + if (edge->src()->IsSource()) continue; + + // A node is unstacked if all of its inputs are unstacked + is_unstacked &= AddUnstackedNodeMappingsHelper( + {edge->src(), edge->src_output()}, status); + } + + if (!is_unstacked) { + return false; + } + + // If the node is unstacked, we copy it into outer_scope_ and + // add it to the map. Note that we don't clean up the nodes that are copied + // in map_defun_fn_, and rely on them being pruned out later. + Node* node = outer_scope_->AddNode(tensor.first->def(), status); + if (!status->ok()) return true; + + // Add input edges to nodes that should already have been lifted. + for (auto edge : tensor.first->in_edges()) { + // Ignore Source nodes. Note that these are also ignored in the + // GraphToFunctionDef conversion. + if (edge->src()->IsSource()) continue; + + if (auto found = gtl::FindOrNull(conversion_map_, + {edge->src(), edge->src_output()})) { + outer_scope_->AddEdge(found->node, found->output_index, node, + edge->dst_input()); + } else { + status->Update(errors::Internal( + "Could not find input conversion even though we did depth first " + "conversion.")); } } - // If we've converted all the outputs of the MapDefun function, we no longer - // need the MapDefun node and can delete it. - if (map_defun_fn_->signature().output_arg_size() == 0) { - outer_scope_->mutable_node_def()->DeleteSubrange( - function_utils::FindFunctionNodeWithName(map_defun_node_->name(), - *outer_scope_), - 1); + // Add output mappings + for (int i = 0; i < tensor.first->num_outputs(); ++i) { + conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)}); + } + conversion_map_.insert({{tensor.first, Graph::kControlSlot}, + WrappedTensor(node, Graph::kControlSlot, false)}); + + return true; +} + +Status Vectorization::AddUnstackedNodeMappings() { + SetVector<Node*> unstacked_nodes; + Status s; + for (const auto& ret_node : map_defun_fn_->ret_nodes) { + const Edge* in_edge = nullptr; + TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge)); + AddUnstackedNodeMappingsHelper({in_edge->src(), in_edge->src_output()}, &s); + TF_RETURN_IF_ERROR(s); } + return Status::OK(); +} - if (!unconvertible_.empty()) { - VLOG(2) << "The following nodes could not be converted: [" - << absl::StrJoin(unconvertible_, ", ") << "]."; +Status Vectorization::GetResult(FunctionDef** vectorized_function) { + TF_RETURN_IF_ERROR(status_); + TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(outer_scope_.get())); + TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(map_defun_fn_->graph)); + + if (!map_defun_fn_->ret_nodes.empty()) { + FunctionDef* map_defun_fn = lib_->add_function(); + graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn); + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn)); + + AttrValue func_attr; + func_attr.mutable_func()->set_name(map_defun_fn->signature().name()); + map_defun_node_->AddAttr("f", func_attr); } + + *vectorized_function = lib_->add_function(); + graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_, + *vectorized_function); + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *outer_scope_, (*vectorized_function)->signature().name(), + *vectorized_function)); + return Status::OK(); } + } // namespace -void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node) { - Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize(); +Status VectorizeMapDefun(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, FunctionDefLibrary* lib, + FunctionDef** result) { + *result = nullptr; + return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result); } } // end namespace vectorization_utils diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h index bb405faa77..bd7d390900 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h @@ -24,22 +24,28 @@ namespace tensorflow { namespace grappler { namespace vectorization_utils { -// Given a function, `map_defun_fn`, that is mapped across some input vector -// elements via a MapDefun operation, `VectorizeMapDefun` attempts to -// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the -// `outer_scope`; that is, replacing `map_defun_fn` operations with new -// `outer_scope` operations that produce the same vector output(s) as executing -// the `map_defun_fn` operations on elements of vector input(s) would. If all -// `map_defun_fn` operations are successfully lifted, `map_defun_node` is -// eliminated from `outer_scope` altogether. However, if some operations cannot -// be lifted, and this vectorization only succeeds partially, `map_defun_node` -// remains to be used for operations that were not lifted. +// Given a MapDefun node (`map_defun_node`) in a FunctionDef (`outer_scope`) +// that maps a function in lib across some input vector elements, +// `VectorizeMapDefun` attempts to create a vectorized version of `outer_scope` +// by "lifting" operations from the MapDefun function to the new function +// (`result`); that is, replacing operations in the MapDefun function with +// operations that produce the same vector output(s) as executing the original +// operations on elements of vector input(s) would. If all operations in the +// MapDefun function are successfully lifted, `result` has no MapDefun node +// altogether. However, if some operations cannot be lifted, and this +// vectorization only succeeds partially, a MapDefun node remains in `result` to +// be used for operations that were not lifted, and the modified MapDefun +// function is added to `lib`. The newly vectorized function `result` is also +// added to `lib`. +// +// Returns Status::OK() if the vectorization is completely or partially +// successful. Otherwise, returns an error, and sets `result` to nullptr. // // Example: // If the input to the `VectorizeMapDefun` function is a MapDefun // whose `map_defun_fn` performs the Cast operation, the vectorization will // eliminate the MapDefun. This is because the Cast operation supports -// any tensor shape and can thus be lifted to the `outer_scope`. +// any tensor shape and can thus be lifted to `result`. // // Before: // @@ -68,7 +74,7 @@ namespace vectorization_utils { // // After: // -// outer_scope +------+ +// result +------+ // +---------------+ Arg0 +---------+ // | +---+--+ | // | | | @@ -80,8 +86,9 @@ namespace vectorization_utils { // +---------------+ Ret0 +---------+ // +------+ // -void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node); +Status VectorizeMapDefun(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, FunctionDefLibrary* lib, + FunctionDef** result); } // end namespace vectorization_utils } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc index e129fa9237..a6020e36bb 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" @@ -54,12 +55,18 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs, func.set_name(function_name); NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn); graph_transforms::SetNodeAttr("Targuments", t_arguments, node); + graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node); graph_transforms::SetNodeAttr("output_types", output_types, node); graph_transforms::SetNodeAttr("output_shapes", output_shapes, node); graph_transforms::SetNodeAttr("f", func, node); return node; } +string GetRetval(const FunctionDef& function_def, int index) { + return function_def.ret().at( + function_def.signature().output_arg(index).name()); +} + // TODO(rachelim): Use FunctionDefHelper::Create instead FunctionDef CreateFunction( StringPiece name, const std::vector<std::pair<string, DataType>>& inputs, @@ -85,7 +92,6 @@ FunctionDef CreateFunction( return func; } -TEST(FunctionDefInputDescTest, ConstructedCorrectly) {} // Before: // @@ -133,10 +139,17 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { {{}, {}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - EXPECT_EQ(outer.ret().at("mapdefun"), "ret0"); - EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1"); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized); + LOG(ERROR) << s; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + EXPECT_EQ(GetRetval(*vectorized, 0), "ret0"); + EXPECT_EQ(GetRetval(*vectorized, 1), "ret1"); } // Before: @@ -149,12 +162,12 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { // | +-----------+ Arg0 +---+ Arg1 +----+ | // | | +---+--+ +---+--+ | | // | | | | | | -// | | +------+ | +---v--+ | | -// | | |Const | | | Op0 | | | -// | | +---v--+ | +---+--+ | | +// | | +------+ | | | | +// | | |Const | | | | | +// | | +---v--+ | | | | // | | | | | | | // | | | +---v--+ +---v--+ | | -// | | +---| XOp1 | | XOp2 | | | +// | | +---| XOp1 | | Cast | | | // | | +---+--+ +---+--+ | | // | | | | | | // | | MapDefun +---v--+ +---v--+ | | @@ -165,23 +178,50 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { // +---------------+ Ret0 +---+ Ret1 +--------+ // +------+ +------+ // -// where XOp1 and XOp2 are not convertible. +// where XOp1 is not convertible. // // After: // -// No change because the ops are not convertible. +// +// +------+ +------+ +// +---------------+ Arg0 +---+ Arg1 +--------+ +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ | | +// | +-----------+ Arg0 +-+ | | +// | | +---+--+ | | | +// | | | | | | +// | | +------+ | | | | +// | | |Const | | | | | +// | | +---v--+ | | | | +// | | | | | | | +// | | | +---v--+ | +---v--+ | +// | | +---| XOp1 | | | Cast | | +// | | +---+--+ | +---+--+ | +// | | | | | | +// | | MapDefun +---v--+ | | | +// | +-----------+ Ret0 +-+ | | +// | +---+--+ | | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ // TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) { FunctionDef inner = CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}}, {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, - {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}}); + {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}}); + // TODO(rachelim): If we ever write a converter for MatMul, we have to + // change this test. NodeDef* x_op1 = - function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner); + function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner); CHECK_NOTNULL(x_op1); + graph_transforms::SetNodeAttr("T", DT_INT32, x_op1); - NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner); - CHECK_NOTNULL(x_op2); + NodeDef* cast_node = + AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner); + CHECK_NOTNULL(cast_node); FunctionDef outer = CreateFunction( "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}}, @@ -193,12 +233,22 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) { {{}, {}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - FunctionDef outer_copy(outer); - FunctionDef inner_copy(inner); - VectorizeMapDefun(&outer, &inner, map_defun); - // They should be unchanged - EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer)); - EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + + auto map_defun_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized)); + // The Cast node should be converted just fine. + EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0"); + + // The inner function should only have one retval. + FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib); + const FunctionDef* map_defun_fn = + lib_def.Find(map_defun_node.attr().at("f").func().name()); + EXPECT_EQ(map_defun_fn->signature().output_arg_size(), 1); } // Before: @@ -257,14 +307,19 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) { inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& cast_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); EXPECT_EQ(cast_node.input(0), "x"); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(cast_node.name(), ":y:0")); - EXPECT_EQ(outer.node_def_size(), 1); + EXPECT_EQ(vectorized->node_def_size(), 1); } // Before: @@ -330,16 +385,21 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) { {{}, {}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& cast_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); EXPECT_EQ(cast_node.input(0), "x"); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(cast_node.name(), ":y:0")); - EXPECT_EQ(outer.ret().at("mapdefun_0"), + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(cast_node.name(), ":y:0")); - EXPECT_EQ(outer.node_def_size(), 1); + EXPECT_EQ(vectorized->node_def_size(), 1); } // Before: @@ -411,21 +471,26 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) { {{1}, {1}, {1}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& unpack_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& unpack_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Unpack", *vectorized)); EXPECT_EQ(unpack_node.input(0), "x"); EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); EXPECT_EQ(unpack_node.attr().at("num").i(), 3); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(unpack_node.name(), ":output:0")); - EXPECT_EQ(outer.ret().at("mapdefun_0"), + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(unpack_node.name(), ":output:1")); - EXPECT_EQ(outer.ret().at("mapdefun_1"), + EXPECT_EQ(GetRetval(*vectorized, 2), strings::StrCat(unpack_node.name(), ":output:2")); - EXPECT_EQ(outer.node_def_size(), 1); + EXPECT_EQ(vectorized->node_def_size(), 1); } // Before: @@ -486,7 +551,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) { {"ret1", "MyUnstack:output:1"}, {"ret2", "MyUnstack:output:2"}}); NodeDef* cast_op = - AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner); + AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner); CHECK_NOTNULL(cast_op); NodeDef* unstack_op = AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner); @@ -505,25 +570,30 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) { {{1}, {1}, {1}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& cast_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); EXPECT_EQ(cast_node.input(0), "x"); - const NodeDef& unpack_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer)); + const NodeDef& unpack_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Unpack", *vectorized)); EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0")); EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); EXPECT_EQ(unpack_node.attr().at("num").i(), 3); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(unpack_node.name(), ":output:0")); - EXPECT_EQ(outer.ret().at("mapdefun_0"), + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(unpack_node.name(), ":output:1")); - EXPECT_EQ(outer.ret().at("mapdefun_1"), + EXPECT_EQ(GetRetval(*vectorized, 2), strings::StrCat(unpack_node.name(), ":output:2")); - EXPECT_EQ(outer.node_def_size(), 2); + EXPECT_EQ(vectorized->node_def_size(), 2); } // Before: @@ -561,9 +631,11 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { FunctionDef inner = CreateFunction("inner_function", {{"arg0", DT_INT32}}, {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}}); - // The attrs aren't relevant - NodeDef* print_op = - function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner); + NodeDef* print_op = function_utils::AddNode( + "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner); + graph_transforms::SetNodeAttr("T", DT_INT32, print_op); + graph_transforms::SetNodeAttr("U", gtl::ArraySlice<DataType>({DT_INT32}), + print_op); CHECK_NOTNULL(print_op); NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64, false, &inner); @@ -578,11 +650,278 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - FunctionDef outer_copy(outer); - FunctionDef inner_copy(inner); - VectorizeMapDefun(&outer, &inner, map_defun); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); // They should be unchanged - EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer)); + // We check this somewhat manually as the names of nodes may have changed + EXPECT_EQ(vectorized->node_def_size(), 1); + const NodeDef& map_defun_node = vectorized->node_def(0); + EXPECT_EQ(map_defun_node.op(), "MapDefun"); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib); + const FunctionDef* map_defun_fn = + lib_def.Find(map_defun_node.attr().at("f").func().name()); + + const NodeDef& print_node = map_defun_fn->node_def( + function_utils::FindFunctionNodeWithOp("Print", *map_defun_fn)); + const NodeDef& cast_node = map_defun_fn->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *map_defun_fn)); + string control_input = strings::StrCat("^", print_node.name()); + EXPECT_TRUE(cast_node.input(0) == control_input || + cast_node.input(1) == control_input); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | | | +// | | +------+ | | +// | | |Const | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | +------+ | +// | |Const | | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// | |Stack*| | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeConst) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2)}, + {{"ret0", "Const:output:0"}}); + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized)); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | | | +// | | +------+ | | +// | | |Const | | | +// | | +---+--+ | | +// | | | | | +// | | +---v--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | +------+ | +// | |Const | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | |Stack*| | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2)}, + {{"ret0", "Cast:y:0"}}); + AddCastNode("Cast", {"Const:output:0"}, DT_INT32, DT_INT64, false, &inner); + + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + auto const_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Const", *vectorized)); + auto cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); + EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')), + const_node.name()); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | +------+ +------+ | | +// | | |Const | |Const | | | +// | | +---+--+ +---+--+ | | +// | | : +---v--+ | | +// | | ::::::> Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | | +// | +------+ | +// | +------+ |Const | | +// | |Const | +---+--+ | +// | +---+--+ | | +// | : +---v--+ | +// | ::::::> Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | +Stack*+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2), + FunctionDefHelper::Const("ConstDep", 3)}, + {{"ret0", "Cast:y:0"}}); + AddCastNode("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64, + false, &inner); + + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + + auto find_const = [vectorized](int val) -> const NodeDef* { + for (const auto& n : vectorized->node_def()) { + if (n.attr().at("value").tensor().int_val(0) == val) { + return &n; + } + } + return nullptr; + }; + + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + auto const_node = find_const(2); + auto const_dep_node = find_const(3); + auto cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); + EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')), + const_node->name()); + EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name())); } // TODO(rachelim): More test cases when we get around to implementing them: diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index c59645e5f2..3f33b16ba8 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -106,7 +107,8 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer( MK_OPT("scoped_allocator", new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts())); - MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization())); + MK_OPT("pin_to_host", + new PinToHostOptimizer(cfg_.pin_to_host_optimization())); return std::unique_ptr<GraphOptimizer>(); } @@ -115,6 +117,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer( Status MetaOptimizer::InitializeOptimizers( std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const { + if (cfg_.disable_meta_optimizer()) { + return Status::OK(); + } if (!cfg_.disable_model_pruning()) { optimizers->push_back(MakeUnique<ModelPruner>()); } @@ -172,11 +177,12 @@ Status MetaOptimizer::InitializeOptimizers( optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>( cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts())); } - return InitializeCustomGraphOptimizers(optimizers); + return InitializeCustomGraphOptimizers(std::set<string>(), optimizers); } Status MetaOptimizer::InitializeOptimizersByName( std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const { + std::set<string> initialized_custom_optimizers; for (const string& optimizer_name : cfg_.optimizers()) { auto optimizer = MakeNewOptimizer(optimizer_name); if (optimizer) { @@ -190,18 +196,26 @@ Status MetaOptimizer::InitializeOptimizersByName( if (custom_optimizer) { VLOG(2) << "Registered custom graph optimizer: " << optimizer_name; - TF_RETURN_IF_ERROR(custom_optimizer->Init()); + TF_RETURN_IF_ERROR(custom_optimizer->Init( + GetCustomGraphOptimizerConfig(optimizer_name))); optimizers->push_back(std::move(custom_optimizer)); + initialized_custom_optimizers.insert(optimizer_name); } else { VLOG(2) << "Can't register an optimizer by name: " << optimizer_name; } } - return InitializeCustomGraphOptimizers(optimizers); + return InitializeCustomGraphOptimizers(initialized_custom_optimizers, + optimizers); } Status MetaOptimizer::InitializeCustomGraphOptimizers( + const std::set<string>& pre_initialized_optimizers, std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const { for (const auto& optimizer_config : cfg_.custom_optimizers()) { + if (pre_initialized_optimizers.find(optimizer_config.name()) != + pre_initialized_optimizers.end()) { + continue; + } // Initialize the ExperimentalImplementationSelector here instead of // CustomizeOptimizer registry, due the static link issue in TensorRT for // double registry. @@ -237,6 +251,16 @@ Status MetaOptimizer::InitializeCustomGraphOptimizers( return Status::OK(); } +const RewriterConfig::CustomGraphOptimizer* +MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const { + for (const auto& config : cfg_.custom_optimizers()) { + if (config.name() == name) { + return &config; + } + } + return nullptr; +} + Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes @@ -391,6 +415,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, FunctionLibraryDefinition flib(OpRegistry::Global(), optimized_graph->library()); + // Find functions for which we might need to compute a gradient at runtime. + gtl::FlatSet<string> differentiable_functions; + for (const NodeDef& node : optimized_graph->node()) { + if (IsSymbolicGradient(node)) { + const auto* f_attr = gtl::FindOrNull(node.attr(), "f"); + if (f_attr) differentiable_functions.insert(f_attr->func().name()); + } + } + // Optimize each function only once. std::unordered_set<string> optimized_funcs; bool optimize_function_library = true; @@ -406,6 +439,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Skip parametrized functions (function type or body is defined only at // function call time by caller node attributes). + // They should be specialized to their instantiation type parameters by + // the function optimizer, before we can optimize function body. if (IsParametrized(func)) continue; VLOG(3) << "Optimize function: function=" << func_name; @@ -420,6 +455,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem( func, flib, item.graph.versions().producer(), &func_item)); + // If we need to compute the gradient of optimized function at runtime, we + // can't perform non-differentiable rewrites. + if (differentiable_functions.find(func_name) != + differentiable_functions.end()) { + func_item.allowed_optimizations.non_differentiable_rewrites = false; + } + // Optimize function body graph. GraphDef optimized_func_graph; TF_RETURN_IF_ERROR( @@ -470,6 +512,9 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, } bool MetaOptimizerEnabled(const RewriterConfig& cfg) { + if (cfg.disable_meta_optimizer()) { + return false; + } return !cfg.disable_model_pruning() || cfg.layout_optimizer() != RewriterConfig::OFF || cfg.function_optimization() != RewriterConfig::OFF || diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index 831c5e37c0..99a0a33ffa 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -54,7 +54,11 @@ class MetaOptimizer : public GraphOptimizer { std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; // Initialize active optimizers from RewriterConfig.custom_optimizers. Status InitializeCustomGraphOptimizers( + const std::set<string>& pre_initialized_optimizers, std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; + // Returns the config for a custom graph optimizer. Null if none was found. + const RewriterConfig::CustomGraphOptimizer* GetCustomGraphOptimizerConfig( + const string& name) const; // Run optimization pass over a single GrapplerItem. Meta optimizer might run // multiple such passes: 1) for the main graph 2) for the function library diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index e74e0f7501..3f3f43382f 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -71,6 +72,59 @@ class TestGraphOptimizer : public TestOptimizer { REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer); +class TestOptimizerWithParams : public TestOptimizer { + public: + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + CHECK(config != nullptr); + return Status::OK(); + } +}; + +REGISTER_GRAPH_OPTIMIZER(TestOptimizerWithParams); + +// Record various properties of the GrapplerItems passed for optimization. +class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer { + public: + static void SetAllowedOptimizations( + gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>* + allowed_optimizations) { + allowed_optimizations_ = allowed_optimizations; + } + static void ResetAllowedOptimizations() { allowed_optimizations_ = nullptr; } + + GrapplerItemPropertiesAccumulator() {} + string name() const override { + return "grappler_item_properties_accumulator"; + } + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override { + *optimized_graph = item.graph; + if (allowed_optimizations_) { + allowed_optimizations_->insert({item.id, item.allowed_optimizations}); + } + return Status::OK(); + } + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} + + private: + static gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>* + allowed_optimizations_; +}; + +gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>* + GrapplerItemPropertiesAccumulator::allowed_optimizations_; + +REGISTER_GRAPH_OPTIMIZER(GrapplerItemPropertiesAccumulator); + class MetaOptimizerTest : public GrapplerTest {}; TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { @@ -90,6 +144,25 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { EXPECT_TRUE(TestOptimizer::IsOptimized()); } +TEST_F(MetaOptimizerTest, RunsCustomOptimizerWithParams) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TestOptimizer::SetOptimized(false); + RewriterConfig rewriter_config; + rewriter_config.add_optimizers("TestOptimizerWithParams"); + auto* custom_config = rewriter_config.add_custom_optimizers(); + custom_config->set_name("TestOptimizerWithParams"); + (*custom_config->mutable_parameter_map())["foo"] = AttrValue(); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_TRUE(TestOptimizer::IsOptimized()); +} + TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) { TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); GrapplerItem item; @@ -305,6 +378,89 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]); } +TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) { + using test::function::NDef; + using FDH = FunctionDefHelper; + + // We will record what type of optimizations meta optimizer allows for each + // GrapplerItem (main graph and graphs for each function). + gtl::FlatMap<string, GrapplerItem::AllowedOptimizations> + allowed_optimizations; + GrapplerItemPropertiesAccumulator::SetAllowedOptimizations( + &allowed_optimizations); + + // Just record properties of optimized Grappler items. + RewriterConfig rewriter_config; + rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); + rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator"); + rewriter_config.set_min_graph_nodes(-1); + + MetaOptimizer optimizer(nullptr, rewriter_config); + + // Define simple function library with two identical mul functions. + FunctionDef mul_func_1 = FunctionDefHelper::Create( + "MyMul1", {"x:float", "y:float"}, {"z:float"}, {}, + {{{"mul"}, "Mul", {"x", "y"}, {}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "mul:z:0"}}); + + FunctionDef mul_func_2 = FunctionDefHelper::Create( + "MyMul2", {"x:float", "y:float"}, {"z:float"}, {}, + {{{"mul"}, "Mul", {"x", "y"}, {}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "mul:z:0"}}); + + // Tensorflow graph: + // + // x0 = tf.Placeholder(tf.float); + // x1 = tf.Placeholder(tf.float); + // dy = tf.Placeholder(tf.float); + // + // mul_1 = MyMul1(x0, x1); + // mul_2 = MyMul2(x0, x1); + // dx = SymbolicGradient({x0, x1, dy}, f=MyMul2) + GrapplerItem item; + item.id = "main"; + item.graph = test::function::GDef( + {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + // Calls into function library + NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice), + NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice), + // Symbolic gradient of a MyMul2 + NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"}, + {{"f", FDH::FunctionRef("MyMul2", {})}, + {"Tin", DataTypeSlice{DT_FLOAT}}, + {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}, + kDevice)}, + // FunctionLib + {mul_func_1, mul_func_2}); + item.fetch = {"mul_1", "mul_2", "dx"}; + + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + // Our custom optimizer must be called for the main graph and for the two + // functions. + ASSERT_EQ(allowed_optimizations.size(), 3); + + auto allowed_optimizations_main = + gtl::FindOrNull(allowed_optimizations, "main"); + ASSERT_NE(allowed_optimizations_main, nullptr); + EXPECT_TRUE(allowed_optimizations_main->non_differentiable_rewrites); + + auto allowed_optimizations_my_mul_1 = + gtl::FindOrNull(allowed_optimizations, "MyMul1"); + ASSERT_NE(allowed_optimizations_my_mul_1, nullptr); + EXPECT_TRUE(allowed_optimizations_my_mul_1->non_differentiable_rewrites); + + auto allowed_optimizations_my_mul_2 = + gtl::FindOrNull(allowed_optimizations, "MyMul2"); + ASSERT_NE(allowed_optimizations_my_mul_2, nullptr); + EXPECT_FALSE(allowed_optimizations_my_mul_2->non_differentiable_rewrites); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc index 2190d38937..29a3b2b74c 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc @@ -25,23 +25,67 @@ limitations under the License. #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace grappler { + namespace internal { +namespace { // TODO(williamchan): Change this constant to be something smarter, maybe // dynamically determined. constexpr int64 kTensorMaxSize = 64; -// Find KernelDef for `node`. -Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) { - // Try find KernelDef for node.device, else GPU or CPU. - for (const DeviceType& device : - {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) { - Status s = FindKernelDef(device, node, kdef, nullptr); +struct OpDevicePortHasher { + std::size_t operator()(const std::tuple<string, string, int>& x) const { + uint64 code = Hash64Combine(Hash64(std::get<0>(x)), Hash64(std::get<1>(x))); + + return Hash64Combine(code, hash<int>()(std::get<2>(x))); + } +}; +using OpDevicePortOnHostMap = + gtl::FlatMap<std::tuple<string, string, int>, bool, OpDevicePortHasher>; + +// All the nodes that should be blacklisted and not swapped. +bool IsBlacklisted(const NodeDef& node) { + return + // Collective ops should not be swapped. + IsCollective(node) || + // ControlFlow ops should not be swapped. + IsControlFlow(node) || + // NoOp ops should not be swapped (due to group dependencies). + IsNoOp(node); +} + +// Check if Tensor is integer and small size. +bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) { + // Check type to be int32 or int64. + if (prop.dtype() != DataType::DT_INT32 && + prop.dtype() != DataType::DT_INT64) { + return false; + } + + // Check size known and small. + const int64 size = NumCoefficients(prop.shape()); + if (size < 0 || size > kTensorMaxSize) { + return false; + } + + return true; +} + +// Find KernelDef for `node`, greedily return first found from `devices`. +Status TryFindKernelDef(const std::vector<DeviceType>& devices, + const NodeDef& node, const KernelDef** kdef) { + for (const DeviceType& device : devices) { + const KernelDef* kernel = nullptr; + Status s = FindKernelDef(device, node, &kernel, nullptr); if (s.ok()) { + if (kdef) { + *kdef = kernel; + } return Status::OK(); } } @@ -49,96 +93,239 @@ Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) { return errors::NotFound("Could not find KernelDef for op: ", node.op()); } -// Check if all node's inputs are pinned to CPU memory. -bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) { - // Loop through all the inputs excluding the controlling nodes. - for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) { - // Check if (the fanin) op's device is on CPU. - if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) { - continue; - } - - // Check if (the fanin) op's output port is pinned to HostMemory. - const OpDef* fanin_odef = nullptr; - Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef); - if (!s.ok()) { - LOG(INFO) << "Could not find OpDef for : " << fanin.node->op(); - return false; - } +// Checks if a node's output port is host friendly. +// Roughly this means checking if the output port is on Host memory. +Status IsNodeOutputPortHostFriendly( + const GraphView& graph, GraphProperties* properties, const NodeDef& node, + int port_id, OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache, + bool* is_candidate) { + *is_candidate = false; - const int output_arg_id = - OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id); - if (output_arg_id < 0) { - LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n" - << node.DebugString() << "\n" - << fanin.node->DebugString() << "\n" - << fanin_odef->DebugString(); - return false; - } + // Make sure we are not a blacklisted op. + if (IsBlacklisted(node)) { + return Status::OK(); + } - const KernelDef* fanin_kdef = nullptr; - s = TryFindKernelDef(*fanin.node, &fanin_kdef); - if (!s.ok()) { - LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op(); - return false; - } + // Check to make sure we have the right properties (i.e., statically shaped). + if (!properties->has_properties()) { + // This is an expensive call, call it lazily. + TF_RETURN_IF_ERROR(properties->InferStatically( + /*assume_valid_feeds=*/false)); + } + const auto& output_properties = properties->GetOutputProperties(node.name()); + if (port_id >= output_properties.size()) { + LOG(WARNING) << "port_id=" << port_id + << " but output_properties.size()=" << output_properties.size() + << "\n" + << node.DebugString(); + return Status::OK(); + } + if (!IsTensorIntegerAndSmall(output_properties[port_id])) { + return Status::OK(); + } - bool fanin_pinned = false; - for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) { - if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) { - fanin_pinned = true; - break; + // These nodes may be optimized away downstream (even if pinned to Host), we + // should (recusively) check their source. + if (IsIdentity(node)) { + for (const auto& fanin : graph.GetFanins(node, false)) { + bool fanin_candidate = false; + TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( + graph, properties, *fanin.node, fanin.port_id, + op_device_outport_pinned_to_host_cache, &fanin_candidate)); + if (!fanin_candidate) { + return Status::OK(); } } + *is_candidate = true; + return Status::OK(); + } + + // Check if op's device is on CPU. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + *is_candidate = true; + return Status::OK(); + } + + // Check `op_device_outport_pinned_to_host_cache` for our + // {op, device, port_id} combo to see if the arg is pinned on Host. + const std::tuple<string, string, int> cache_key(node.op(), node.device(), + port_id); + auto it = op_device_outport_pinned_to_host_cache->find(cache_key); + if (it != op_device_outport_pinned_to_host_cache->end()) { + *is_candidate = it->second; + return Status::OK(); + } + + // Check if op's output port is pinned to HostMemory. + const OpDef* op = nullptr; + Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); + if (!s.ok()) { + LOG(WARNING) << "Could not find OpDef for : " << node.op(); + op_device_outport_pinned_to_host_cache->emplace(cache_key, false); + return Status::OK(); + } + + // Map the port_id to output_arg_id. + const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id); + if (output_arg_id < 0) { + LOG(WARNING) << "Invalid port: " << port_id << "!\n" + << node.DebugString() << "\n" + << op->DebugString(); + op_device_outport_pinned_to_host_cache->emplace(cache_key, false); + return Status::OK(); + } - if (!fanin_pinned) { - return false; + // Find the kernel. + const KernelDef* kernel = nullptr; + s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, + &kernel); + if (!s.ok()) { + LOG(INFO) << "Could not find KernelDef for: " << node.op(); + op_device_outport_pinned_to_host_cache->emplace(cache_key, false); + return Status::OK(); + } + + // Check if the output_arg is pinned to Host. + for (const string& host_memory_arg : kernel->host_memory_arg()) { + if (op->output_arg(output_arg_id).name() == host_memory_arg) { + *is_candidate = true; + break; } } - return true; + op_device_outport_pinned_to_host_cache->emplace(cache_key, *is_candidate); + + return Status::OK(); } -bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) { - // Check if Tensor is integer and small size. +// Checks if a node's input port is Host friendly. +// Roughly this means checking if the input port is on Host memory. +bool IsNodeInputPortHostFriendly( + const NodeDef& node, int port_id, + OpDevicePortOnHostMap* op_device_inport_pinned_to_host_cache) { + // If node is on Host, assume its inputs are Host friendly. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + return true; + } - // Check type to be int32 or int64. - if (prop.dtype() != DataType::DT_INT32 && - prop.dtype() != DataType::DT_INT64) { - return false; + // Check `op_device_inport_pinned_to_host_cache` for our + // {op, device, port_id} combo to see if the arg is pinned on Host. + std::tuple<string, string, int> cache_key(node.op(), node.device(), port_id); + auto it = op_device_inport_pinned_to_host_cache->find(cache_key); + if (it != op_device_inport_pinned_to_host_cache->end()) { + return it->second; } - // Check size known and small. - const int64 size = NumCoefficients(prop.shape()); - if (size < 0 || size > kTensorMaxSize) { + // Check if op's input port is pinned to HostMemory. + const OpDef* op = nullptr; + Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); + if (!s.ok()) { + LOG(WARNING) << "Could not find OpDef for : " << node.op(); + op_device_inport_pinned_to_host_cache->emplace(cache_key, false); + return false; + } + const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id); + + // Find the kernel. + const KernelDef* kernel = nullptr; + s = internal::TryFindKernelDef( + {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel); + if (!s.ok()) { + LOG(INFO) << "Could not find KernelDef for: " << node.op(); + op_device_inport_pinned_to_host_cache->emplace(cache_key, false); return false; } - return true; + // Check if the input_arg is pinned to Host. + for (const string& host_memory_arg : kernel->host_memory_arg()) { + if (op->input_arg(input_arg_id).name() == host_memory_arg) { + op_device_inport_pinned_to_host_cache->emplace(cache_key, true); + return true; + } + } + + op_device_inport_pinned_to_host_cache->emplace(cache_key, false); + + return false; } -bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties, - const NodeDef& node) { - for (const auto& prop : properties.GetInputProperties(node.name())) { +// Checks if a node is a candidate to pin to Host. +// The rough algorithm is as follows: +// 1] Check if node is blacklisted. +// 2] Check if node can run on Host. +// 3] Check all input/outputs are Host "friendly" (atm, friendly means small, +// ints, and pinned to Host). +Status IsNodeHostCandidate( + const GraphView& graph, GraphProperties* properties, const NodeDef& node, + OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache, + bool* is_candidate) { + *is_candidate = false; + + // Skip these node types. + if (IsBlacklisted(node)) { + return Status::OK(); + } + + // Check if node already on CPU. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + *is_candidate = true; + return Status::OK(); + } + + // Check the node can be run on CPU. + Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr); + if (!s.ok()) { + return Status::OK(); + } + + // Check all outputs are Host friendly. + if (!properties->has_properties()) { + // This is an expensive call, call it lazily. + TF_RETURN_IF_ERROR(properties->InferStatically( + /*assume_valid_feeds=*/false)); + } + for (const auto& prop : properties->GetOutputProperties(node.name())) { if (!IsTensorIntegerAndSmall(prop)) { - return false; + return Status::OK(); } } - for (const auto& prop : properties.GetOutputProperties(node.name())) { - if (!IsTensorIntegerAndSmall(prop)) { - return false; + // Check all inputs are Host friendly. + for (const GraphView::OutputPort& fanin : + graph.GetFanins(node, /*include_controlling_nodes=*/false)) { + bool fanin_candidate = false; + TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( + graph, properties, *fanin.node, fanin.port_id, + op_device_outport_pinned_to_host_cache, &fanin_candidate)); + if (!fanin_candidate) { + return Status::OK(); } } - return true; + + *is_candidate = true; + return Status::OK(); } -string TryFindHostDevice(const gtl::FlatSet<string>& devices, - bool has_device_cpu, const string& device) { +bool IsTPUGraphDef(const GraphDef& def) { + for (const auto& node : def.node()) { + if (node.op() == "TPUCompile" || node.op() == "TPUExecute" || + node.op() == "TPUPartitionedCall") { + return true; + } + } + return false; +} +} // end namespace + +// Tries to swap `device` to a Host device from `devices`. Returns true iff +// there was a swap. +bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices, + bool has_device_cpu, string* device) { // Force this node onto the CPU. - if (device.empty() && has_device_cpu) { - return "/device:CPU:0"; - } else if (str_util::StrContains(device, DEVICE_GPU)) { + if (device->empty() && has_device_cpu) { + *device = "/device:CPU:0"; + return true; + } else if (str_util::StrContains(*device, DEVICE_GPU)) { // Sometimes the cluster can have: // devices = {"/device:CPU:0", "/device:XLA_GPU:0"} // and we need to handle them properly. @@ -146,30 +333,19 @@ string TryFindHostDevice(const gtl::FlatSet<string>& devices, {std::pair<string, string>("GPU", "CPU:0"), std::pair<string, string>("/device", "/device:CPU:0")}) { const string device_host = - strings::StrCat(device.substr(0, device.rfind(device_match.first)), + strings::StrCat(device->substr(0, device->rfind(device_match.first)), device_match.second); if (devices.find(device_host) != devices.end()) { - return device_host; + *device = device_host; + return true; } } } - // We couldn't find an appropriate Host device, return original device. - return device; -} - -bool IsTPUGraphDef(const GraphDef& def) { - for (const auto& node : def.node()) { - if (node.op() == "TPUCompile" || node.op() == "TPUExecute" || - node.op() == "TPUPartitionedCall") { - return true; - } - } + // We couldn't find an appropriate Host device, return false. return false; } -// All the nodes that should be blacklisted and not swapped. -bool IsBlacklisted(const NodeDef& node) { return IsCollective(node); } } // end namespace internal Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -182,7 +358,6 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } GraphProperties properties(item); - bool has_properties = false; GraphView graph(optimized_graph); gtl::FlatSet<string> devices; @@ -202,45 +377,26 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // All the Const nodes, and their original devices in topological order. std::vector<std::pair<NodeDef*, string>> const_nodes; - for (auto& node : *optimized_graph->mutable_node()) { - // Check if node already on CPU. - if (str_util::StrContains(node.device(), DEVICE_CPU)) { - continue; - } - - // Skip these node types. - if (internal::IsBlacklisted(node)) { - continue; - } + // Cache to map {op, device, port} -> bool on whether it is pinned to host. + internal::OpDevicePortOnHostMap op_device_outport_pinned_to_host_cache; + internal::OpDevicePortOnHostMap op_device_inport_pinned_to_host_cache; - // Check the node can be run on CPU. - Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr); - if (!s.ok()) { - continue; - } - - // Check all input's are pinned to CPU. - if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) { - continue; - } - - if (!has_properties) { - // This is an expensive call, call it lazily. - TF_RETURN_IF_ERROR(properties.InferStatically(false)); - has_properties = true; - } - - // Check all inputs and outputs are integers and small. - if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) { + for (auto& node : *optimized_graph->mutable_node()) { + bool is_candidate = false; + TF_RETURN_IF_ERROR(internal::IsNodeHostCandidate( + graph, &properties, node, &op_device_outport_pinned_to_host_cache, + &is_candidate)); + if (!is_candidate) { continue; } - if (IsConstant(node)) { - const_nodes.emplace_back(&node, node.device()); + const string original_device = node.device(); + const bool swapped = internal::TrySwapToHostDevice(devices, has_device_cpu, + node.mutable_device()); + // Keep track of all Const nodes that we swapped. + if (swapped && IsConstant(node)) { + const_nodes.emplace_back(&node, original_device); } - // Try and swap the device to Host. - node.set_device( - internal::TryFindHostDevice(devices, has_device_cpu, node.device())); } // Traverse all `const_nodes`, and map them back to GPU greedily. @@ -248,10 +404,13 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, NodeDef* node = it.first; const string& device = it.second; - // Check all the consumers of this node, if any of them are on the original - // device, swap this node back onto the original device. + // Check all the consumers of this node, if any of them are not on CPU, swap + // this node back onto the original device. for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) { - if (fanout.node->device() == device) { + // The consumer is not Host friendly, swap it back to the original device. + if (!internal::IsNodeInputPortHostFriendly( + *fanout.node, fanout.port_id, + &op_device_inport_pinned_to_host_cache)) { node->set_device(device); break; } diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h index d557a03463..bed4a9ef95 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h @@ -26,8 +26,8 @@ namespace tensorflow { namespace grappler { namespace internal { // Try and find an appropriate Host device in `devices` given `device`. -string TryFindHostDevice(const gtl::FlatSet<string>& devices, - bool has_device_cpu, const string& device); +bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices, + bool has_device_cpu, string* device); } // end namespace internal // Optimize TensorFlow ops that should be swapped into the CPU to avoid diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc index 173cb3fe3c..9bb030b220 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc @@ -28,30 +28,60 @@ namespace { class PinToHostOptimizerTest : public GrapplerTest {}; -TEST_F(PinToHostOptimizerTest, TryFindHostDevice) { +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) { gtl::FlatSet<string> devices = {}; - EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC")); - - devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"), - "/device:CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"), - "/device:CPU:0"); - - devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), - "/device:XLA_CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), - "/device:XLA_CPU:0"); - - devices = {"/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), - "/device:XLA_GPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), - "/device:XLA_GPU:*"); + + string device = "ABC"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "ABC"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) { + gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); + + device = "/device:XLA_GPU:0"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) { + gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_TRUE(device.empty()); + + device = "/device:XLA_GPU:0"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_CPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_CPU:0"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) { + gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_TRUE(device.empty()); + + device = "/device:XLA_GPU:0"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_GPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_GPU:*"); } TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) { @@ -160,6 +190,48 @@ TEST_F(PinToHostOptimizerTest, NoSwap) { EXPECT_EQ(found, 3); } +TEST_F(PinToHostOptimizerTest, Identity) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + // `a,c` is on GPU, `e` is on CPU, consequently `e` should not be swapped. + // `b` should be placed onto Host since `c` pins the input to Host memory. + Output a = + ops::Const(s.WithOpName("a").WithDevice("/device:GPU:0"), 1, {64, 64}); + Output b = ops::Const(s.WithOpName("b"), {0, 1}, {2}); + Output c = + ops::ReduceProd(s.WithOpName("c").WithDevice("/device:GPU:0"), a, b); + Output d = ops::Identity(s.WithDevice("/device:CPU:0").WithOpName("d"), c); + Output e = ops::Multiply(s.WithOpName("e"), d, d); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + PinToHostOptimizer optimizer(RewriterConfig::ON); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "a" || node.name() == "c") { + EXPECT_EQ(node.device(), "/device:GPU:0"); + } else if (node.name() == "b") { + // If CUDA, then there is a GPU kernel registration that is pinned to Host + // memory. Consequently, `b` will be mapped to Host correct if there is + // a GPU kernel registered. +#if GOOGLE_CUDA + EXPECT_EQ(node.device(), "/device:CPU:0"); +#else + EXPECT_TRUE(node.device().empty()); +#endif + } else if (node.name() == "d") { + EXPECT_EQ(node.device(), "/device:CPU:0"); + } else if (node.name() == "e") { + EXPECT_TRUE(node.device().empty()); + } + ++found; + } + EXPECT_EQ(found, 5); +} + TEST_F(PinToHostOptimizerTest, PortIdToArgId) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3}); diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 008a289cfd..9ada8b7ff9 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -168,11 +168,12 @@ void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) { Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, GraphDef* optimized_graph) { GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically(false)); + bool inferred_properties = false; GraphView graph(const_cast<GraphDef*>(&item.graph)); // During inference, most of the inputs to FusedBatchNorm are constant, and we // can therefore replace the op with a much cheaper set of primitives. + optimized_graph->mutable_node()->Reserve(item.graph.node_size()); for (const NodeDef& node : item.graph.node()) { if (node.op() == "FusedBatchNorm" || node.op() == "FusedBatchNormV2") { bool optimizable = (node.attr().count("T") == 0 || @@ -181,6 +182,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, !node.attr().at("is_training").b()); if (optimizable) { int const_inputs = 0; + if (!inferred_properties) { + // Infer properties lazily in case they are not needed. + TF_RETURN_IF_ERROR(properties.InferStatically(false)); + inferred_properties = true; + } const auto& props = properties.GetInputProperties(node.name()); for (const auto& prop : props) { if (prop.has_value()) { diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc index 4542d17ccc..6ccb1cd783 100644 --- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc @@ -33,7 +33,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, *optimized_graph = item.graph; GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically(false)); + bool inferred_properties = false; GraphView graph(optimized_graph); // The product of all the dimensions in a tensor shape can be expressed more @@ -55,6 +55,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } const GraphView::OutputPort reduce_indices = graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1)); + if (!inferred_properties) { + // Infer properties lazily in case they are not needed. + TF_RETURN_IF_ERROR(properties.InferStatically(false)); + inferred_properties = true; + } const auto& prop = properties.GetOutputProperties(reduce_indices.node->name()); if (prop.size() < reduce_indices.port_id) { @@ -92,6 +97,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (!IsSize(*input1.node) || !IsSize(*input2.node)) { continue; } + if (!inferred_properties) { + // Infer properties lazily in case they are not needed. + TF_RETURN_IF_ERROR(properties.InferStatically(false)); + inferred_properties = true; + } const auto& prop1 = properties.GetInputProperties(input1.node->name()); const auto& prop2 = properties.GetInputProperties(input2.node->name()); if (prop1.size() != 1 || prop2.size() != 1) { diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index db6e4e6852..5867d01324 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -156,45 +156,6 @@ bool IsControlInput(const string& name) { return !name.empty() && name[0] == '^'; } -string NodeName(const string& name) { - int position; - return ParseNodeName(name, &position); -} - -int NodePosition(const string& name) { - int position; - ParseNodeNameAsStringPiece(name, &position); - return position; -} - -int NodePositionIfSameNode(const string& input_name, const string& node_name) { - const bool is_ctrl = input_name[0] == '^'; - auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin(); - auto node_it = node_name.begin(); - if (node_name.empty() || - std::distance(input_it, input_name.end()) < node_name.size()) { - return -2; - } - while (node_it != node_name.end()) { - if (*input_it++ != *node_it++) { - return -2; - } - } - if (input_it == input_name.end()) { - return is_ctrl ? -1 : 0; - } else if (*input_it++ == ':') { - StringPiece remaining(&(*input_it), - std::distance(input_it, input_name.end())); - int position; - if (!strings::safe_strto32(remaining, &position)) { - return -2; - } - return is_ctrl ? -1 : position; - } else { - return -2; - } -} - string AddPrefixToNodeName(const string& name, const string& prefix, const string& delimiter) { if (!name.empty()) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 296ee1678e..95126d470c 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/strings/scanner.h" namespace tensorflow { namespace grappler { @@ -102,40 +101,92 @@ bool IsControlInput(const string& name); // True iff 'name1' and 'name2' refer to the same input. bool IsSameInput(const string& name1, const string& name2); +// Returns the trailing position number (or zero if no number is present) if +// NodeName(input_name) is equal to node_name. Returns -1 for control inputs. +// Returns -2 if NodeName(input_name) is not equal to node_name. +// Note: This function is used very heavily, and this hand-optimized +// version is 3-4x faster than the version using Scanner, which it replaced. +// This is worth the reduction in readability. +inline int NodePositionIfSameNode(const string& input_name, + const string& node_name) { + if (input_name.empty()) return -2; + const bool is_ctrl = input_name[0] == '^'; + auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin(); + auto node_it = node_name.begin(); + if (node_name.empty() || + std::distance(input_it, input_name.end()) < node_name.size()) { + return -2; + } + while (node_it != node_name.end()) { + if (*input_it++ != *node_it++) { + return -2; + } + } + if (input_it == input_name.end()) { + return is_ctrl ? -1 : 0; + } else if (*input_it++ == ':') { + StringPiece remaining(&(*input_it), + std::distance(input_it, input_name.end())); + int position; + if (!strings::safe_strto32(remaining, &position)) { + return -2; + } + return is_ctrl ? -1 : position; + } else { + return -2; + } +} + // Return the node name corresponding to 'name' if name is valid, or the empty // string otherwise. -string NodeName(const string& name); +inline StringPiece NodeNameAsStringPiece(const string& name) { + static const string empty; + if (name.empty()) return StringPiece(empty); + const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin(); + auto end_it = begin_it; + while (end_it != name.end() && *end_it != ':') { + ++end_it; + } + if (end_it != name.end() && *end_it != ':') { + return StringPiece(empty); + } + return StringPiece(&(*begin_it), std::distance(begin_it, end_it)); +} -// Get the trailing position number ":{digits}" (if any) of a node name. -// Returns -1 for control inputs. -int NodePosition(const string& name); +// Return the node name corresponding to 'name' if name is valid, or the empty +// string otherwise. +inline string NodeName(const string& name) { + return string(NodeNameAsStringPiece(name)); +} +// Returns the node name and position in a single call. inline StringPiece ParseNodeNameAsStringPiece(const string& name, int* position) { - // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any) - // to get a node name. - strings::Scanner scan(name); - scan.ZeroOrOneLiteral("^") - .RestartCapture() - .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE) - .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); - StringPiece capture; - StringPiece remaining; - if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) { + static const string empty; + if (name.empty()) { *position = 0; - static const string empty; return StringPiece(empty); - } else { - if (name[0] == '^') { - *position = -1; - } else if (remaining.empty()) { - *position = 0; - } else { - // Skip the first ':' character. - CHECK(strings::safe_strto32(remaining.substr(1), position)); + } + const bool is_ctrl = name[0] == '^'; + const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin(); + *position = is_ctrl ? -1 : 0; + auto end_it = begin_it; + while (end_it != name.end() && *end_it != ':') { + ++end_it; + } + const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it)); + if (end_it != name.end()) { + if (*end_it != ':') { + return StringPiece(empty); + } else if (!is_ctrl) { + ++end_it; + StringPiece remaining(&(*end_it), std::distance(end_it, name.end())); + if (!strings::safe_strto32(remaining, position)) { + return StringPiece(empty); + } } - return capture; } + return node_name; } // Returns the node name and position in a single call. @@ -143,10 +194,11 @@ inline string ParseNodeName(const string& name, int* position) { return string(ParseNodeNameAsStringPiece(name, position)); } -// Returns NodePosition(input_name) if NodeName(input_name) == node_name. -// Otherwise returns -2; -// REQUIRES: inputs_name.size() > 0 && node_name.size() > 0. -int NodePositionIfSameNode(const string& input_name, const string& node_name); +inline int NodePosition(const string& name) { + int position; + ParseNodeNameAsStringPiece(name, &position); + return position; +} // Add a prefix to a node name with a custom delimiter. string AddPrefixToNodeName(const string& name, const string& prefix, diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index a428aea7f5..6861fb423c 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -41,7 +41,8 @@ Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration, tensorflow::NameRangeMap outputs_range_map; TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode( node, registration.op_def, nullptr, &outputs_range_map)); - connectivity->RegisterFunctionBodyOutputs(node.name(), outputs_range_map); + connectivity->RegisterFunctionBodyOutputs(node.name(), + std::move(outputs_range_map)); return Status::OK(); } @@ -75,20 +76,22 @@ Status ResolveFunctionBodyNodeAttrPlaceholders( } // namespace void GrapplerFunctionConnectivity::RegisterInputArgExpansion( - const InputArgExpansion& input_arg_expansion) { - const auto& input_name = input_arg_expansion.input_name; + InputArgExpansion input_arg_expansion) { + string input_name = input_arg_expansion.input_name; const auto& placeholders = input_arg_expansion.placeholders; - input_arg_expansions_.emplace(input_name, input_arg_expansion); + for (int i = 0; i < placeholders.size(); ++i) { const string& placeholder = input_arg_expansion.placeholders[i]; - input_arg_placeholders_.emplace( - placeholder, InputArgPlaceholder{input_name, /*position=*/i}); + input_arg_placeholders_.insert( + {placeholder, InputArgPlaceholder{input_name, /*position=*/i}}); } + input_arg_expansions_.insert( + {std::move(input_name), std::move(input_arg_expansion)}); } void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs( - const string& node_name, const tensorflow::NameRangeMap& outputs) { - function_body_outputs_[node_name] = outputs; + const string& node_name, tensorflow::NameRangeMap&& outputs) { + function_body_outputs_[node_name] = std::move(outputs); } Status GrapplerFunctionConnectivity::ExpandFunctionDefInput( @@ -174,11 +177,12 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput( const auto& output_range = output->second; if (position == -1) { + graph_def_inputs->reserve(graph_def_inputs->size() + + output_range.second - output_range.first); // If position is not defined expand node output range for (int i = output_range.first; i < output_range.second; ++i) { - i == 0 ? graph_def_inputs->push_back(node_name) - : graph_def_inputs->push_back( - strings::StrCat(node_name, ":", i)); + graph_def_inputs->push_back( + i == 0 ? node_name : strings::StrCat(node_name, ":", i)); } } else { if (position > (output_range.second - output_range.first)) { @@ -187,9 +191,8 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput( " position: ", position, " (out of range)"); } int pos = output_range.first + position; - pos == 0 ? graph_def_inputs->push_back(node_name) - : graph_def_inputs->push_back( - strings::StrCat(node_name, ":", pos)); + graph_def_inputs->push_back( + pos == 0 ? node_name : strings::StrCat(node_name, ":", pos)); } return Status::OK(); @@ -211,8 +214,8 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs( } function_body_node->clear_input(); - for (const string& expanded_input : expanded_inputs) - function_body_node->add_input(expanded_input); + for (string& expanded_input : expanded_inputs) + function_body_node->add_input(std::move(expanded_input)); return Status::OK(); } @@ -323,7 +326,7 @@ GrapplerFunctionItem::GrapplerFunctionItem( // Fill the feed nodes with input placeholders. for (const InputArgExpansion& input_arg : input_arg_expansions_) { for (const string& placeholder : input_arg.placeholders) { - feed.emplace_back(placeholder, Tensor()); + feed.push_back({placeholder, Tensor()}); input_arg_placeholders_.insert(placeholder); } } @@ -460,7 +463,7 @@ Status InstantiationBodyParameters( auto it = func_instantiation_attr.find(placeholder); if (it != func_instantiation_attr.end()) { - body_parameters->emplace(placeholder, it->second); + body_parameters->insert({placeholder, it->second}); } else { return errors::InvalidArgument("Can't resolve placeholder: ", placeholder); @@ -498,10 +501,6 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, // GraphDef input format (name[:position]) GrapplerFunctionConnectivity connectivity; - std::vector<InputArgExpansion> inputs; - std::vector<OutputArgExpansion> outputs; - std::vector<string> keep_nodes; - // Function body shares the library with the graph that instantiated it. GraphDef function_body; *function_body.mutable_library() = flib.ToProto(); @@ -518,6 +517,9 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, } } + std::vector<InputArgExpansion> inputs; + inputs.reserve(signature.input_arg_size()); + // For each input argument create a placeholder in function body. for (const OpDef::ArgDef& input : signature.input_arg()) { if (!input.type_list_attr().empty() || !input.number_attr().empty()) { @@ -542,9 +544,10 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, /*is_ref*/ input.is_ref(), /*placeholders=*/{input.name()}}; connectivity.RegisterInputArgExpansion(input_expansion); - inputs.push_back(input_expansion); + inputs.push_back(std::move(input_expansion)); } + std::vector<string> keep_nodes; // Add all function nodes to the function body for (const NodeDef& func_def_node : func.node_def()) { NodeDef* new_node = function_body.add_node(); @@ -572,6 +575,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node)); } + std::vector<OutputArgExpansion> outputs; + outputs.reserve(signature.output_arg_size()); // Add function outputs for (const OpDef::ArgDef& out : signature.output_arg()) { std::vector<string> output_tensors; @@ -589,8 +594,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, OutputArgExpansion output{/*output_name=*/out.name(), /*data_type=*/output_data_type, /*is_ref=*/out.is_ref(), - /*output_tensors=*/output_tensors}; - outputs.push_back(output); + /*output_tensors=*/std::move(output_tensors)}; + outputs.push_back(std::move(output)); } bool is_stateful = signature.is_stateful(); diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index 733caf325f..ef944ced09 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -18,6 +18,7 @@ limitations under the License. #include <memory> #include <string> +#include <unordered_map> #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" @@ -70,9 +71,9 @@ struct OutputArgExpansion { // and fold it back when doing backward conversion. class GrapplerFunctionConnectivity { public: - void RegisterInputArgExpansion(const InputArgExpansion& input_arg_expansion); + void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion); void RegisterFunctionBodyOutputs(const string& node_name, - const tensorflow::NameRangeMap& outputs); + tensorflow::NameRangeMap&& outputs); // Expand input encoded in FunctionDef format (name[:output][:position]) into // multiple inputs in GraphDef format (name[:position]). diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 6b787a6910..9b6c1f690b 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -371,6 +371,25 @@ BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl); BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0); BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end); +#define BM_ParseNodeNameAsStringPiece(I, NAME) \ + static void BM_ParseNodeNameAsStringPiece_##NAME(int iters) { \ + string input = I; \ + for (int i = 0; i < iters; ++i) { \ + int position; \ + const StringPiece name = ParseNodeNameAsStringPiece(input, &position); \ + CHECK_GE(position, -1); \ + CHECK(!name.empty()); \ + } \ + } \ + BENCHMARK(BM_ParseNodeNameAsStringPiece_##NAME) + +BM_ParseNodeNameAsStringPiece("foo", foo); +BM_ParseNodeNameAsStringPiece("foo/bar/baz", foo_bar_baz); +BM_ParseNodeNameAsStringPiece("^foo/bar/baz", foo_bar_baz_ctrl); +BM_ParseNodeNameAsStringPiece("foo:123", foo123); +BM_ParseNodeNameAsStringPiece("foo/bar/baz:123", foo_bar_baz_123); +BM_ParseNodeNameAsStringPiece("^foo/bar/baz:123", foo_bar_baz_123_ctrl); + } // namespace } // namespace grappler } // namespace tensorflow |