From 9037e241de1e64044ff55ab539ccc1fb013c178a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 14 Mar 2018 20:39:10 -0700 Subject: Enable Add/AddN tree rewrite for symbolically equal shapes. 1) Rewrite a tree of Add/AddN ops with a single AddN, if all shapes are symbolically equal 2) Lookup shape properties using GraphProperties instead of direct access to Node attributes PiperOrigin-RevId: 189131726 --- .../grappler/optimizers/arithmetic_optimizer.cc | 173 ++++++++++++--------- .../grappler/optimizers/arithmetic_optimizer.h | 3 + .../optimizers/arithmetic_optimizer_test.cc | 61 +++++++- tensorflow/core/grappler/utils.cc | 26 ++++ tensorflow/core/grappler/utils.h | 4 + tensorflow/core/grappler/utils_test.cc | 41 +++++ 6 files changed, 231 insertions(+), 77 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index c0fcfaf428..675cd8f072 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -197,35 +197,39 @@ bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); } const char kOutputShapesAttr[] = "_output_shapes"; -PartialTensorShape GetInputShape(const string& input, const NodeMap& node_map) { - int output_pos; - string node_name = ParseNodeName(input, &output_pos); - const NodeDef* input_node = node_map.GetNode(node_name); - auto attr = input_node->attr(); - if (attr.find(kOutputShapesAttr) == attr.end()) { - return PartialTensorShape(); // unknown shape - } else { - return attr.at(kOutputShapesAttr).list().shape(output_pos); - } +// Shape is symbolically defined if it has a known rank, and each dimension is +// defined, or is an unknown symbol (dim.size <= -2). +bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape) { + return !shape.unknown_rank() && + std::all_of( + shape.dim().begin(), shape.dim().end(), + [](const TensorShapeProto::Dim& dim) { return dim.size() != -1; }); +} + +bool ShapeIsSymbolicallyDefined(const OpInfo::TensorProperties& properties) { + return ShapeIsSymbolicallyDefined(properties.shape()); } -bool ShapesEqual(const string& input_x, const string& input_y, - const NodeMap& node_map) { - PartialTensorShape x_shape = GetInputShape(input_x, node_map); - PartialTensorShape y_shape = GetInputShape(input_y, node_map); - if (x_shape.unknown_rank() || y_shape.unknown_rank() || - x_shape.dims() != y_shape.dims()) { +bool ShapesSymbolicallyEqual(const TensorShapeProto& left, + const TensorShapeProto& right) { + if (left.unknown_rank() || right.unknown_rank() || + left.dim_size() != right.dim_size()) { return false; } - for (int i = 0; i < x_shape.dims(); ++i) { - if (x_shape.dim_size(i) == -1 || y_shape.dim_size(i) == -1 || - x_shape.dim_size(i) != y_shape.dim_size(i)) { + for (int i = 0; i < left.dim_size(); ++i) { + if (left.dim(i).size() == -1 || right.dim(i).size() == -1 || + left.dim(i).size() != right.dim(i).size()) { return false; } } return true; } +bool ShapesSymbolicallyEqual(const OpInfo::TensorProperties& left, + const OpInfo::TensorProperties& right) { + return ShapesSymbolicallyEqual(left.shape(), right.shape()); +} + // Returns whether `reshape` is an identity op. The tensor that `reshape` // reshapes is the `output_pos`-th output of node `input`. bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input, @@ -290,16 +294,19 @@ NodeDef* GetTailOfValuePreservingChain( struct ArithmeticOptimizerContext { ArithmeticOptimizerContext( const std::unordered_set* nodes_to_preserve, - GraphDef* optimized_graph, NodeMap* node_map, FrameMap* frame_map, + GraphDef* optimized_graph, GraphProperties* graph_properties, + NodeMap* node_map, FrameMap* frame_map, SetVector* nodes_to_simplify) : nodes_to_preserve(nodes_to_preserve), optimized_graph(optimized_graph), + graph_properties(graph_properties), node_map(node_map), frame_map(frame_map), nodes_to_simplify(nodes_to_simplify) {} const std::unordered_set* nodes_to_preserve; GraphDef* optimized_graph; + GraphProperties* graph_properties; NodeMap* node_map; FrameMap* frame_map; SetVector* nodes_to_simplify; @@ -388,7 +395,7 @@ class ArithmeticOptimizerStage { ctx_.nodes_to_simplify->PushBack(node); } - // Get a node by input name from a node map. Return a error if node was not + // Get a node by input name from a node map. Return an error if node was not // found. Status GetInputNode(const string& input, NodeDef** node) const { string node_name = NodeName(input); @@ -401,22 +408,31 @@ class ArithmeticOptimizerStage { return Status::OK(); } - // Get input shape from a node map. If node doesn't exists return unknown - // shape. - PartialTensorShape GetInputShape(const string& input) const { - int position; - string node_name = ParseNodeName(input, &position); - NodeDef* node; - Status node_status = GetInputNode(node_name, &node); - if (!node_status.ok()) { - return PartialTensorShape(); // unknown shape + // Lookup tensor properties by name. Tensor name might have non-zero port + // number. Return an error if tensor node doesn't exists in a graph, or it + // doesn't have properties defined for requested port. + Status GetTensorProperties(const string& tensor, + OpInfo::TensorProperties* properties) const { + int port; + string tensor_node_name = ParseNodeName(tensor, &port); + if (port < 0) { + return errors::InvalidArgument( + "Can't get tensor properties of control dependency ", tensor); } - auto attr = node->attr(); - if (attr.find(kOutputShapesAttr) == attr.end()) { - return PartialTensorShape(); // unknown shape - } else { - return attr.at(kOutputShapesAttr).list().shape(position); + + const auto& output_properties = + ctx_.graph_properties->GetOutputProperties(tensor_node_name); + auto num_outputs = output_properties.size(); + + if (num_outputs == 0 || port > num_outputs - 1) { + return errors::InvalidArgument( + "Node ", tensor_node_name, + " is missing output properties at position :", port, + " (num_outputs=", num_outputs, ")"); } + + properties->CopyFrom(output_properties[port]); + return Status::OK(); } NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) { @@ -509,8 +525,8 @@ class ArithmeticOptimizerStage { // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the // original inputs of absorbed nodes. // -// All nodes in a Add/AddN subgraph must have fully specified and identical -// shape. All nodes must have the same device placement. +// All nodes in a Add/AddN subgraph must have symbolically equal shape. All +// nodes must have the same device placement. // // Example: // AddN_1 @@ -533,16 +549,12 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { if (!IsRewritable(node)) { return false; } - // and must have fully defined shape - // TODO(ezhulenev): support partially defined shapes, when we can prove that - // unknown dimensions in the rewritten subgraph are the same. - PartialTensorShape shape = GetInputShape(node->name()); - if (!shape.IsFullyDefined()) { - return false; - } - // and must have inputs of fully defined shape identical to the output - // TODO(ezhulenev): relax this condition to support equal unknown dimensions - return HasAllInputsOfIdenticalShape(*node, shape); + + // shape must be symbolically defined and all inputs compatible with it + OpInfo::TensorProperties properties; + Status has_properties = GetTensorProperties(node->name(), &properties); + return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) && + HasAllInputsOfSymbolicallyEqualShape(*node, properties); } Status TrySimplify(const NodeDef* node, @@ -567,23 +579,26 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { // input_nodes: [x, y, z, w, q, e] struct AddOpsGroup { const NodeDef* root_node; - PartialTensorShape root_shape; + TensorShapeProto root_shape; // Add/AddN operations below the root level that were absorbed by this group std::vector absorbed_nodes; // Inputs of absorbed nodes that will be forwarded to rewritten AddN node std::vector inputs; }; - // Check if all inputs are fully defined and identical to expected shape - bool HasAllInputsOfIdenticalShape(const NodeDef& node, - const PartialTensorShape& shape) const { + // Check if all inputs have symbolically equal shapes + bool HasAllInputsOfSymbolicallyEqualShape( + const NodeDef& node, const OpInfo::TensorProperties& properties) const { const AddOpsRewriteStage* self = this; - return std::all_of(node.input().begin(), node.input().end(), - [self, &shape](const string& input) { - auto input_shape = self->GetInputShape(input); - return input_shape.IsFullyDefined() && - input_shape.IsIdenticalTo(shape); - }); + return std::all_of( + node.input().begin(), node.input().end(), + [self, &properties](const string& input) { + OpInfo::TensorProperties input_properties; + Status has_input_properties = + self->GetTensorProperties(input, &input_properties); + return has_input_properties.ok() && + ShapesSymbolicallyEqual(properties, input_properties); + }); } // TODO(ezhulenev): use GraphRewriter? @@ -614,27 +629,25 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { if (!node_status.ok()) { return false; } - - PartialTensorShape shape = GetInputShape(name); - CHECK(shape.IsIdenticalTo(group.root_shape)) - << "Cannot absorb a node of incompatible shape"; - // check basic preconditions if (!IsRewritable(node)) { return false; } - // with a single output consumer (presumably if we reach this node from + // with a single output data consumer (presumably if we reach this node from // previously absorbed or a root node, it means that this node is not used // as an input to any other op, outside of the group) - if (ctx_.node_map->GetOutputs(node->name()).size() != 1) { + if (NumNonControlDataOutputs(*node, *ctx_.node_map) != 1) { return false; } // must be on the same device as a root node if (node->device() != group.root_node->device()) { return false; } - // All input shapes must be fully defined and equal to the node shape - return HasAllInputsOfIdenticalShape(*node, shape); + // All input shapes must be symbolically defined and equal to the node shape + OpInfo::TensorProperties properties; + Status has_properties = GetTensorProperties(name, &properties); + return has_properties.ok() && + HasAllInputsOfSymbolicallyEqualShape(*node, properties); } // Node requirements both for a root node and an absorbed node @@ -660,15 +673,19 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { } // Check that optimized group node name doesn't exists. It might happen if - // graph optimized multiple times without pruning beween invocations. + // graph optimized multiple times without pruning between invocations. bool IsRewritten(const AddOpsGroup& group) const { return ctx_.node_map->NodeExists(AddOpsGroupName(group)); } // Create an AddOpsGroup with a root in a given node Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) { + OpInfo::TensorProperties root_node_output_properties; + TF_RETURN_IF_ERROR( + GetTensorProperties(root_node->name(), &root_node_output_properties)); + group->root_node = root_node; - group->root_shape = GetInputShape(root_node->name()); + group->root_shape = root_node_output_properties.shape(); group->absorbed_nodes.reserve(root_node->input_size()); for (int i = 0; i < root_node->input_size(); ++i) { @@ -737,6 +754,9 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { added_node->add_input(input); } + // Add frame dependencies that the original node might have had. + AddFrameControlDeps(group.root_node, {added_node}, "", {}); + VLOG(1) << "Absorbed " << group.absorbed_nodes.size() << " Add/AddN nodes from the graph"; @@ -891,8 +911,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { mul_node->input(0) == common_factor ? 1 : 0; unique_factors->push_back(mul_node->input(unique_factor_index)); if (i > 0 && !IsAdd(*node)) { - *shapes_match = ShapesEqual(unique_factors->front(), - unique_factors->back(), *ctx_.node_map); + OpInfo::TensorProperties lhs; + OpInfo::TensorProperties rhs; + TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs)); + TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs)); + *shapes_match = ShapesSymbolicallyEqual(lhs, rhs); } } return Status::OK(); @@ -1627,8 +1650,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { } const ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_, - node_map_.get(), &frame_map_, - &nodes_to_simplify); + graph_properties_.get(), node_map_.get(), + &frame_map_, &nodes_to_simplify); std::vector> stages; @@ -1660,8 +1683,10 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { const NodeDef* node = nodes_to_simplify.PopBack(); // TODO(ezhulenev): move all rewrites into separate stages - string simplified_tensor = - TrySimplifyAndReplaceUses(node, &nodes_to_simplify); + string simplified_tensor = ""; + if (options_.enable_try_simplify_and_replace) { + simplified_tensor = TrySimplifyAndReplaceUses(node, &nodes_to_simplify); + } // if it was not simplified try to run it through all configured stages if (simplified_tensor.empty()) { diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index d5a7af5ba6..2c6b52c072 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -55,6 +55,9 @@ class ArithmeticOptimizer : public GraphOptimizer { // Granular control for arithmetic optimizer stages struct ArithmeticOptimizerOptions { + // TODO(ezhulenev): flag do disable TrySimplifyAndReplaceUses in tests. + // Remove when all optimizers will be migrated to separate stages. + bool enable_try_simplify_and_replace = true; bool combine_add_to_addn = true; bool hoist_common_factor_out_of_aggregation = true; bool remove_inverse_transpose = true; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index e1f47625c1..d677aee589 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -89,6 +89,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { // should explicitly enable required optimization for tests isolation void DisableAllStages(ArithmeticOptimizer* optimizer) { ArithmeticOptimizer::ArithmeticOptimizerOptions options; + options.enable_try_simplify_and_replace = false; options.combine_add_to_addn = false; options.hoist_common_factor_out_of_aggregation = false; options.remove_inverse_transpose = false; @@ -1270,7 +1271,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) { EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs")); } -TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) { +TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope sx = s.NewSubScope("x"); tensorflow::Scope sy = s.NewSubScope("y"); @@ -1322,7 +1323,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) { EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0)); } -TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) { +TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT); @@ -1395,7 +1396,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) { EXPECT_EQ(collapsed_right->name(), updated_mul->input(1)); } -TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) { +TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT); @@ -1440,5 +1441,59 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) { EXPECT_EQ("c", collapsed_add->input(3)); } +TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + // unknown input shape propagated symbolically through the graph + auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT); + + // [a, b, c] have symbolically equal shapes + auto a = ops::Sqrt(s.WithOpName("a"), input); + auto b = ops::Square(s.WithOpName("b"), input); + auto c = ops::Round(s.WithOpName("c"), input); + + // [add_ab, add_abc] shape must be inferred from inputs + auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b); + auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c); + + auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc); + + GrapplerItem item; + item.fetch = {"outputs"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyAddToAddNCombining(&optimizer); + + OptimizeAndPrune(&optimizer, &item, &output); + + // We expect the following rewrite(s) to occur: + // + // + + // / \ + // + c --> AddN(a, b, c) + // / \ + // a b + EXPECT_EQ(6, output.node_size()); + + NodeMap node_map(&output); + + // check add tree was replaced with AddN + const NodeDef* collapsed_add = + node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab"); + ASSERT_TRUE(collapsed_add != nullptr); + EXPECT_EQ("AddN", collapsed_add->op()); + EXPECT_EQ(3, collapsed_add->input_size()); + EXPECT_EQ("a", collapsed_add->input(0)); + EXPECT_EQ("b", collapsed_add->input(1)); + EXPECT_EQ("c", collapsed_add->input(2)); + + // check output was re-wired to new node + const NodeDef* updated_outputs = node_map.GetNode("outputs"); + ASSERT_TRUE(updated_outputs != nullptr); + EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0)); +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index eb1f882ff1..829bfe9e31 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -40,6 +40,16 @@ bool SafeSetScalarTensorValue(double value, Tensor* tensor) { tensor->flat()(0) = static_cast(value); return true; } + +// Is 'node' an operator that consumes only the shape of its input, not the +// data itself? +// TODO(ezhulenev): move to op_types.h. Requires to break circular dependency. +// TODO(ezhulenev): what about Identity passing tensor to Shape consumer? +bool IsShapeConsumer(const NodeDef& node) { + const string& op = node.op(); + return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size"; +} + } // namespace NodeMap::NodeMap(GraphDef* graph) { @@ -270,6 +280,22 @@ int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) { return num_outputs; } +int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) { + int num_data_outputs = 0; + for (const NodeDef* output : node_map.GetOutputs(node.name())) { + if (IsShapeConsumer(*output)) continue; + + for (int i = 0; i < output->input_size(); ++i) { + const string& input = output->input(i); + if (!IsControlInput(input) && NodeName(input) == node.name()) { + ++num_data_outputs; + break; + } + } + } + return num_data_outputs; +} + // Returns the data type in attribute `attr_name` of `node`. If that attribute // doesn't exist, returns DT_INVALID. DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index fbd38c1531..7aa31939f5 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -144,6 +144,10 @@ int NumNonControlInputs(const NodeDef& node); // Number of connected non-control outputs. int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map); +// Number of connected non-control data outputs (Ops that consume output tensor +// data, not just it's shape). +int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map); + // Removes redundant control inputs from node. void DedupControlInputs(NodeDef* node); diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index eabce5b5ee..49a1996d25 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -292,6 +292,47 @@ TEST_F(UtilsTest, DedupControlInputs) { EXPECT_EQ("gnu", foo.input(1)); } +TEST_F(UtilsTest, NumNonControlOutputs) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + // *) Round node has control dependency edge from Add, which + // is not on this scheme (ASCII graphics limitation). + // + // *Round [Sqrt, Shape] + // | | + // | ctrl | + // Mul ------> Add + // / \ / \ + // x y a b + auto x = ops::Variable(s.WithOpName("x"), {1, 2}, DT_FLOAT); + auto y = ops::Variable(s.WithOpName("y"), {1, 2}, DT_FLOAT); + auto a = ops::Variable(s.WithOpName("a"), {1, 2}, DT_FLOAT); + auto b = ops::Variable(s.WithOpName("b"), {1, 2}, DT_FLOAT); + + auto mul = ops::Multiply(s.WithOpName("mul"), x, y); + auto add = ops::Add(s.WithOpName("add").WithControlDependencies(mul), a, b); + + auto shape = ops::Shape(s.WithOpName("shape"), add); + auto sqrt = ops::Sqrt(s.WithOpName("sqrt"), add); + + auto round = + ops::Round(s.WithOpName("round").WithControlDependencies(add), mul); + + GraphDef graph; + TF_CHECK_OK(s.ToGraphDef(&graph)); + NodeMap node_map(&graph); + + const NodeDef* add_node = node_map.GetNode("add"); + ASSERT_TRUE(add_node != nullptr); + + // [a, b] are only non-control inputs + EXPECT_EQ(2, NumNonControlInputs(*add_node)); + // [sqrt, shape] are non control outputs + EXPECT_EQ(2, NumNonControlOutputs(*add_node, node_map)); + // sqrt is the only data output + EXPECT_EQ(1, NumNonControlDataOutputs(*add_node, node_map)); +} + TEST_F(UtilsTest, DeleteNodes) {} } // namespace -- cgit v1.2.3