diff options
author | 2018-03-09 10:22:16 -0800 | |
---|---|---|
committer | 2018-03-09 10:26:07 -0800 | |
commit | 48fc3bc388b09c67482db9751b6eab1d89ae140e (patch) | |
tree | 17fc5a2881129b33fe5aa03546b056f61171a27e | |
parent | 41a12df5de7d767a1a872348f3ba630350fcc78e (diff) |
Implement partial constant folding for Concat.
PiperOrigin-RevId: 188501394
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties.h | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties_test.cc | 27 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.h | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 143 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.h | 2 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding_test.cc | 100 |
8 files changed, 261 insertions, 28 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 243ca9121c..817247e379 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -1182,5 +1182,12 @@ GraphProperties::GetOutputProperties(const string& node_name) const { return missing_properties_; } +void GraphProperties::ClearInputProperties(const string& node_name) { + input_properties_.erase(node_name); +} +void GraphProperties::ClearOutputProperties(const string& node_name) { + output_properties_.erase(node_name); +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index 6fc53a7f2e..5aa4962072 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -64,6 +64,8 @@ class GraphProperties { const string& node_name) const; const std::vector<OpInfo::TensorProperties>& GetOutputProperties( const string& node_name) const; + void ClearInputProperties(const string& node_name); + void ClearOutputProperties(const string& node_name); static void FillTensorPropertiesFromContext( const shape_inference::ShapeHandle&, const DataType&, diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 5012069118..284d9d409b 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -113,6 +113,33 @@ TEST_F(GraphPropertiesTest, StaticProperties) { } } +TEST_F(GraphPropertiesTest, ClearProperties) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, + cluster_->GetDeviceNames()); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + GraphProperties properties(item); + Status s = properties.InferStatically(true); + TF_CHECK_OK(s); + + for (const auto& node : item.graph.node()) { + if (node.op() == "RandomStandardNormal") { + EXPECT_EQ(1, properties.GetInputProperties(node.name()).size()); + const auto props = properties.GetOutputProperties(node.name()); + properties.ClearOutputProperties(node.name()); + const auto cleared_props = properties.GetOutputProperties(node.name()); + EXPECT_TRUE(cleared_props.empty()); + } else if (node.op() == "AddN") { + const auto in_props = properties.GetInputProperties(node.name()); + EXPECT_EQ(1, in_props.size()); + properties.ClearInputProperties(node.name()); + const auto cleared_props = properties.GetInputProperties(node.name()); + EXPECT_TRUE(cleared_props.empty()); + } + } +} + TEST_F(GraphPropertiesTest, DynamicProperties) { TrivialTestGraphInputYielder fake_input(4, 1, 10, false, cluster_->GetDeviceNames()); diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 8cf1402ae8..ae71094079 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -72,6 +72,10 @@ bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; } bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; } +bool IsConcat(const NodeDef& node) { + return node.op() == "Concat" || node.op() == "ConcatV2"; +} + bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; } bool IsConstant(const NodeDef& node) { return node.op() == "Const"; } @@ -213,6 +217,8 @@ bool IsNextIteration(const NodeDef& node) { return op == "NextIteration" || op == "RefNextIteration"; } +bool IsPack(const NodeDef& node) { return node.op() == "Pack"; } + bool IsPad(const NodeDef& node) { const auto& op = node.op(); return op == "Pad" || op == "PadV2"; diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index a7c33ef97b..690275da7c 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -40,6 +40,7 @@ bool IsCast(const NodeDef& node); bool IsComplex(const NodeDef& node); bool IsComplexAbs(const NodeDef& node); bool IsConj(const NodeDef& node); +bool IsConcat(const NodeDef& node); bool IsConcatOffset(const NodeDef& node); bool IsConstant(const NodeDef& node); bool IsConv2D(const NodeDef& node); @@ -85,6 +86,7 @@ bool IsMul(const NodeDef& node); bool IsMatMul(const NodeDef& node); bool IsNextIteration(const NodeDef& node); bool IsPad(const NodeDef& node); +bool IsPack(const NodeDef& node); bool IsNoOp(const NodeDef& node); bool IsNotEqual(const NodeDef& node); bool IsPlaceholder(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 31dc1b73e1..4036ea3f16 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1510,7 +1510,7 @@ Status ConstantFolding::ReplaceOperationWithConstant( } Status ConstantFolding::SimplifyGraph(GraphDef* output, - const GraphProperties& properties, + GraphProperties* properties, bool use_shape_info) { const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE; for (int i = 0; i < output->node_size(); ++i) { @@ -1520,7 +1520,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, if (use_shape_info && (IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) { const auto& shape = - properties.GetInputProperties(node->name())[0].shape(); + properties->GetInputProperties(node->name())[0].shape(); // The node is replaceable iff // unknown_rank == false && (dim_size == 0 || all dims have size 1) bool replaceable = !shape.unknown_rank(); @@ -1649,7 +1649,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, graph_modified_ = true; continue; } - if (use_shape_info && IsSimplifiableReshape(*node, properties)) { + if (use_shape_info && IsSimplifiableReshape(*node, *properties)) { DataType output_type = node->attr().at("T").type(); node->set_op("Identity"); node->clear_attr(); @@ -1667,8 +1667,8 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, // Simplify arithmetic operations with ones or zeros. if (use_shape_info && (is_mul || is_matmul || is_add || is_sub || is_any_div) && - properties.HasInputProperties(node->name()) && - properties.HasOutputProperties(node->name())) { + properties->HasInputProperties(node->name()) && + properties->HasOutputProperties(node->name())) { const NodeDef* x = node_map_->GetNode(node->input(0)); const NodeDef* y = node_map_->GetNode(node->input(1)); if (x == nullptr || y == nullptr) { @@ -1676,12 +1676,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, node->DebugString()); } const TensorShapeProto& output_shape = - properties.GetOutputProperties(node->name())[0].shape(); + properties->GetOutputProperties(node->name())[0].shape(); // Simplify element-wise multiplication by ones or addition/subtraction // of zeros. const TensorShapeProto& y_shape = - properties.GetInputProperties(node->name())[1].shape(); + properties->GetInputProperties(node->name())[1].shape(); const bool x_is_zero = IsZeros(*x); const bool x_is_one = IsOnes(*x); const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape); @@ -1708,7 +1708,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, } const TensorShapeProto& x_shape = - properties.GetInputProperties(node->name())[0].shape(); + properties->GetInputProperties(node->name())[0].shape(); const bool y_is_zero = IsZeros(*y); const bool y_is_one = IsOnes(*y); const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); @@ -1921,13 +1921,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, // folding of ops when more than one but not all inputs are constant. // For AddN and AccumulateNV2, we may furthermore reorder inputs, since // addition is commutative. - // TODO(rmlarsen): Concat/Pack/ParallelConcat which are not commutative, so - // we have to preserve order and can only push consecutive runs of constant - // inputs into sub-nodes. + const int num_non_control_inputs = NumNonControlInputs(*node); if (IsAggregate(*node) && IsCommutative(*node) && - NumNonControlInputs(*node) > 2) { + num_non_control_inputs > 2) { const int num_control_inputs = - node->input_size() - NumNonControlInputs(*node); + node->input_size() - num_non_control_inputs; std::vector<int> const_inputs; std::vector<int> nonconst_inputs; for (int i = 0; i < node->input_size(); ++i) { @@ -1943,7 +1941,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, } // Promote AccumulateNV2 with all constant inputs to AddN, since it is // a fake node that cannot be constant folded by itself. - if (const_inputs.size() == NumNonControlInputs(*node) && + if (const_inputs.size() == num_non_control_inputs && node->op() == "AccumulateNV2") { node->set_op("AddN"); node->mutable_attr()->erase("shape"); @@ -1953,7 +1951,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, const string new_node_name = OptimizedNodeName( *node, strings::StrCat("_partial_split_", const_inputs.size())); if (1 < const_inputs.size() && - const_inputs.size() < NumNonControlInputs(*node) && + const_inputs.size() < num_non_control_inputs && !node_map_->NodeExists(new_node_name)) { NodeDef* added_node = output->add_node(); *added_node = *node; @@ -1987,8 +1985,121 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, const_inputs.size() - 1); (*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs); + properties->ClearInputProperties(node->name()); (*added_node->mutable_attr())["N"].set_i(const_inputs.size()); graph_modified_ = true; + continue; + } + } + + // Partial constant folding for Concat which is not commutative, so + // we have to preserve order and can only push consecutive runs of constant + // inputs into sub-nodes. + if (IsConcat(*node) && num_non_control_inputs > 3) { + bool already_optimized = false; + const string optimized = strings::StrCat(node->name(), "_partial_split_"); + for (const string& input : node->input()) { + if (input.rfind(optimized) != string::npos) { + already_optimized = true; + break; + } + } + if (already_optimized) { + continue; + } + int axis_arg = -1; + int begin = 0; + int end = num_non_control_inputs; + if (node->op() == "Concat") { + begin = 1; + axis_arg = 0; + } else if (node->op() == "ConcatV2") { + end = num_non_control_inputs - 1; + axis_arg = num_non_control_inputs - 1; + } else { + continue; + } + + const NodeDef* axis_arg_node = + node_map_->GetNode(NodeName(node->input(axis_arg))); + if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) { + // We cannot constant fold Concat unless we know the axis. + // Skip node. + continue; + } + + // We search for consecutive runs of constant inputs in the range + // [begin:end[ and push then down into child nodes. + std::vector<std::pair<int, int>> constant_input_runs; + int first = begin; + int last = begin; + while (last < end) { + while (first < end && !IsReallyConstant(*node_map_->GetNode( + NodeName(node->input(first))))) { + ++first; + } + // Invariant: node[first] is constant || first >= end. + last = first + 1; + while (last < end && IsReallyConstant(*node_map_->GetNode( + NodeName(node->input(last))))) { + ++last; + } + // Invariant: node[last] is not constant || last >= end + // Discard intervals shorter than 2 elements. + if (first < end && (last - first) > 1) { + constant_input_runs.emplace_back(first, last); + } + first = last; + } + + std::set<int> inputs_to_delete; + for (auto interval : constant_input_runs) { + // Push the constant inputs in the interval to a child node than can be + // constant folded. + const string new_node_name = OptimizedNodeName( + *node, strings::StrCat("_partial_split_", interval.first)); + if (node_map_->NodeExists(new_node_name)) { + break; + } + NodeDef* added_node = output->add_node(); + *added_node = *node; + added_node->set_name(new_node_name); + node_map_->AddNode(added_node->name(), added_node); + added_node->clear_input(); + for (int i = interval.first; i < interval.second; ++i) { + added_node->add_input(node->input(i)); + node_map_->UpdateOutput(NodeName(node->input(i)), node->name(), + added_node->name()); + if (i != interval.first) { + inputs_to_delete.insert(i); + } + } + added_node->add_input(node->input(axis_arg)); + (*added_node->mutable_attr())["N"].set_i(interval.second - + interval.first); + node_map_->AddOutput(NodeName(node->input(axis_arg)), + added_node->name()); + + // Overwrite the first constant input with the result of the added + // child node. + node->set_input(interval.first, added_node->name()); + node_map_->AddOutput(added_node->name(), node->name()); + } + if (!constant_input_runs.empty()) { + graph_modified_ = true; + if (!inputs_to_delete.empty()) { + // Fix up the inputs to the original node. + std::vector<string> tmp(node->input().begin(), node->input().end()); + node->clear_input(); + for (int i = 0; i < tmp.size(); ++i) { + if (inputs_to_delete.find(i) == inputs_to_delete.end()) { + node->add_input(tmp[i]); + } + } + (*node->mutable_attr())["N"].set_i(node->input_size() - 1); + properties->ClearInputProperties(node->name()); + } + continue; } } } @@ -2030,7 +2141,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, TF_RETURN_IF_ERROR(FoldGraph(output)); node_map_.reset(new NodeMap(output)); - TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info)); + TF_RETURN_IF_ERROR(SimplifyGraph(output, &properties, can_use_shape_info)); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 2fd59c7f9c..13ecfcd281 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -92,7 +92,7 @@ class ConstantFolding : public GraphOptimizer { bool IsSimplifiableReduction(const NodeDef& node) const; bool IsSimplifiableReshape(const NodeDef& node, const GraphProperties& properties) const; - Status SimplifyGraph(GraphDef* output, const GraphProperties& properties, + Status SimplifyGraph(GraphDef* output, GraphProperties* properties, bool use_shape_info); Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item, diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 4b9770889f..9214695eb6 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -188,20 +188,19 @@ TEST_F(ConstantFoldingTest, NeutralElement) { Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros); Output sub2 = ops::Sub(s.WithOpName("sub2"), zeros, y); Output concat = - ops::Concat(s.WithOpName("concat"), - {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1, - matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2}, - 0); + ops::Stack(s.WithOpName("stack"), + {mul1, mul2, mul3, mul4, mul5, mul6, div1, div2, matmul1, + matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch = {"concat", "matmul3", "matmul4"}; + item.fetch = {"stack", "matmul3", "matmul4"}; ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(28, output.node_size()); + EXPECT_EQ(27, output.node_size()); for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); const string& name = node.name(); @@ -1626,19 +1625,19 @@ TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) { Output acc4 = fun(s.WithOpName("acc4"), {c1, y, c2}); Output acc5 = fun(s.WithOpName("acc5"), {x, c1, c2}); Output acc6 = fun(s.WithOpName("acc6"), {x, c1, y, c2}); - Output concat = ops::Concat(s.WithOpName("concat"), - {acc0, acc1, acc2, acc3, acc4, acc5, acc6}, 0); + Output stack = ops::Stack(s.WithOpName("stack"), + {acc0, acc1, acc2, acc3, acc4, acc5, acc6}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch = {"concat"}; + item.fetch = {"stack"}; ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(17, output.node_size()); + EXPECT_EQ(16, output.node_size()); for (const NodeDef& node : output.node()) { if (node.name() == "acc0") { EXPECT_EQ("Const", node.op()); @@ -1696,7 +1695,86 @@ TEST_F(ConstantFoldingTest, PartialFolding_AssociativeAndCommutative) { } } -TEST_F(ConstantFoldingTest, IdenticalN) { +TEST_F(ConstantFoldingTest, PartialFolding_Concat) { + Scope s = Scope::NewRootScope(); + Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({2, 2}))); + Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({2, 2}))); + Output z = ops::Placeholder(s.WithOpName("z"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({2, 2}))); + Output axis = ops::Const(s.WithOpName("axis"), 0, {}); + Output c1 = ops::Const(s.WithOpName("c1"), 1.0f, {2, 2}); + Output c2 = ops::Const(s.WithOpName("c2"), 2.0f, {2, 2}); + Output concat0 = ops::Concat(s.WithOpName("concat0"), {c1, c2, c1}, axis); + Output concat1 = ops::Concat(s.WithOpName("concat1"), {x, y, z}, axis); + Output concat2 = ops::Concat(s.WithOpName("concat2"), {c1, x, y}, axis); + Output concat3 = ops::Concat(s.WithOpName("concat3"), {c1, c2, z}, axis); + Output concat4 = ops::Concat(s.WithOpName("concat4"), {c1, y, c2}, axis); + Output concat5 = ops::Concat(s.WithOpName("concat5"), {x, c1, c2}, axis); + Output concat6 = ops::Concat(s.WithOpName("concat6"), {x, c1, y, c2}, axis); + Output concat7 = ops::Concat(s.WithOpName("concat7"), {x, y, c1, c2}, axis); + Output concat8 = ops::Concat(s.WithOpName("concat8"), {x, c1, c2, y}, axis); + Output concat9 = ops::Concat(s.WithOpName("concat9"), {c1, c2, x, y}, axis); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"concat0", "concat1", "concat2", "concat3", "concat4", + "concat5", "concat6", "concat7", "concat8", "concat9"}; + + ConstantFolding optimizer(nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(21, output.node_size()); + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + if (node.name() == "concat0") { + EXPECT_EQ("Const", node.op()); + } else if (node.name() == "concat3") { + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("ConstantFolding/concat3_partial_split_0", node.input(0)); + EXPECT_EQ("z", node.input(1)); + EXPECT_EQ("axis", node.input(2)); + } else if (node.name() == "concat5") { + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("ConstantFolding/concat5_partial_split_1", node.input(1)); + EXPECT_EQ("axis", node.input(2)); + } else if (node.name() == "concat7") { + EXPECT_EQ(4, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("y", node.input(1)); + EXPECT_EQ("ConstantFolding/concat7_partial_split_2", node.input(2)); + EXPECT_EQ("axis", node.input(3)); + } else if (node.name() == "concat8") { + EXPECT_EQ(4, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("ConstantFolding/concat8_partial_split_1", node.input(1)); + EXPECT_EQ("y", node.input(2)); + EXPECT_EQ("axis", node.input(3)); + } else if (node.name() == "concat9") { + EXPECT_EQ(4, node.input_size()); + EXPECT_EQ("ConstantFolding/concat9_partial_split_0", node.input(0)); + EXPECT_EQ("x", node.input(1)); + EXPECT_EQ("y", node.input(2)); + EXPECT_EQ("axis", node.input(3)); + } else if (StringPiece(node.name()).starts_with("ConstantFolding/")) { + EXPECT_EQ("Const", node.op()); + } else { + EXPECT_EQ(item.graph.node(i).DebugString(), node.DebugString()); + } + } + + auto tensors_expected = EvaluateNodes(item.graph, {"concat0"}); + auto tensors = EvaluateNodes(output, {"concat0"}); + EXPECT_EQ(1, tensors_expected.size()); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); +} + +TEST_F(ConstantFoldingTest, PartialFolding_IdentityN) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); Output x = ops::Placeholder(scope.WithOpName("x"), DT_FLOAT, ops::Placeholder::Shape(TensorShape({}))); |