diff options
author | 2017-07-25 15:25:13 -0700 | |
---|---|---|
committer | 2017-07-25 15:28:32 -0700 | |
commit | 136494d3295a23e3ed0612773f224243915463b7 (patch) | |
tree | 8f22ce36f07c65c26ed64263d49922bd2862b6ff | |
parent | 07249f08867369899d39fc60442febdf1e36e6b5 (diff) |
Prune trivial ops (concatenation of a single tensor, AddN of a single tensor,
...)
PiperOrigin-RevId: 163131793
-rw-r--r-- | tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc | 14 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.h | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/model_pruner.cc | 27 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/model_pruner_test.cc | 66 |
5 files changed, 93 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc index 446ae2df64..b1ec35e268 100644 --- a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc +++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc @@ -48,9 +48,17 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size, for (int i = 0; i < num_stages; i++) { std::vector<Output> this_stage; for (int j = 0; j < width; j++) { - Output combine = AddN( - s.WithDevice(device_names[use_multiple_devices ? j : 0]), last_stage); - this_stage.push_back(combine); + if (last_stage.size() == 1) { + Output unary_op = + Square(s.WithDevice(device_names[use_multiple_devices ? j : 0]), + last_stage[0]); + this_stage.push_back(unary_op); + } else { + Output combine = + AddN(s.WithDevice(device_names[use_multiple_devices ? j : 0]), + last_stage); + this_stage.push_back(combine); + } } last_stage = this_stage; } diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 9b2584f970..8584681220 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -18,6 +18,11 @@ limitations under the License. namespace tensorflow { namespace grappler { +bool IsAddN(const NodeDef& node) { + const auto op = node.op(); + return op == "AddN"; +} + bool IsConcat(const NodeDef& node) { const auto op = node.op(); return op == "Concat" || op == "ConcatV2"; diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 9c9dd22e2c..d83cb777ed 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -21,6 +21,7 @@ limitations under the License. namespace tensorflow { namespace grappler { +bool IsAddN(const NodeDef& node); bool IsConcat(const NodeDef& node); bool IsConstant(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index df9aca8aa3..e313155563 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -26,6 +26,29 @@ limitations under the License. namespace tensorflow { namespace grappler { +int NumNonControlInputs(const NodeDef& node) { + int num_inputs = node.input_size(); + for (int i = 0; i < node.input_size(); ++i) { + if (!node.input(i).empty() && node.input(i)[0] == '^') { + num_inputs--; + } + } + return num_inputs; +} + +bool IsTrivialOp(const NodeDef& node) { + // Remove the stop gradient nodes since they serve no purpose once the graph + // is built. Also remove Identity ops. + if (IsStopGradient(node) || IsIdentity(node)) { + return true; + } + if (IsAddN(node) && NumNonControlInputs(node) <= 1) { + return true; + } + + return false; +} + Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* pruned_graph) { GraphRewriter rewriter(item); @@ -43,9 +66,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, std::unordered_set<const NodeDef*> nodes_to_delete; for (auto& node : item.graph.node()) { - // Remove the stop gradient nodes since they serve no purpose once the graph - // is built. Also remove Identity ops. - if (!IsStopGradient(node) && !IsIdentity(node)) { + if (!IsTrivialOp(node)) { continue; } // Don't remove nodes that must be preserved. diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc index fdfb3f41cf..72d9c7bf27 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc @@ -57,10 +57,10 @@ TEST_F(ModelPrunerTest, StopGradientPruning) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); - Output b = ops::AddN(s.WithOpName("b"), {a}); + Output b = ops::Sqrt(s.WithOpName("b"), {a}); Output c = ops::StopGradient(s.WithOpName("c"), b); Output d = ops::StopGradient(s.WithOpName("d"), c); - Output e = ops::AddN(s.WithOpName("e"), {d}); + Output e = ops::Sqrt(s.WithOpName("e"), {d}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -93,10 +93,10 @@ TEST_F(ModelPrunerTest, IdentityPruning) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); - Output b = ops::AddN(s.WithOpName("b"), {a}); + Output b = ops::Sqrt(s.WithOpName("b"), {a}); Output c = ops::Identity(s.WithOpName("c"), b); Output d = ops::Identity(s.WithOpName("d"), c); - Output e = ops::AddN(s.WithOpName("e"), {d}); + Output e = ops::Sqrt(s.WithOpName("e"), {d}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -126,15 +126,53 @@ TEST_F(ModelPrunerTest, IdentityPruning) { EXPECT_EQ(NodeName(b.name()), new_c.input(0)); } -TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { +TEST_F(ModelPrunerTest, NoOpPruning) { // Build a simple graph with a few trivially prunable ops. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); Output b = ops::AddN(s.WithOpName("b"), {a}); + Output c = ops::AddN(s.WithOpName("c"), {b}); + Output d = ops::AddN(s.WithOpName("d").WithControlDependencies(b), {c}); + Output e = ops::AddN(s.WithOpName("e"), {d}); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ModelPruner pruner; + GraphDef output; + Status status = pruner.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(5, output.node_size()); + const NodeDef& new_a = output.node(0); + EXPECT_EQ(NodeName(a.name()), new_a.name()); + const NodeDef& new_b = output.node(1); + EXPECT_EQ(NodeName(b.name()), new_b.name()); + const NodeDef& new_c = output.node(2); + EXPECT_EQ(NodeName(c.name()), new_c.name()); + const NodeDef& new_d = output.node(3); + EXPECT_EQ(NodeName(d.name()), new_d.name()); + const NodeDef& new_e = output.node(4); + EXPECT_EQ(NodeName(e.name()), new_e.name()); + + EXPECT_EQ(1, new_e.input_size()); + EXPECT_EQ(NodeName(d.name()), new_e.input(0)); + EXPECT_EQ(2, new_d.input_size()); + EXPECT_EQ(NodeName(b.name()), new_d.input(0)); + EXPECT_EQ(1, new_c.input_size()); + EXPECT_EQ(NodeName(b.name()), new_c.input(0)); +} + +TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { + // Build a simple graph with a few trivially prunable ops. + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); + Output b = ops::Sqrt(s.WithOpName("b"), {a}); Output c = ops::Identity(s.WithOpName("c"), b); Output d = ops::Identity(s.WithOpName("d"), c); - Output e = ops::AddN(s.WithOpName("e").WithControlDependencies(c), {d}); + Output e = ops::Sqrt(s.WithOpName("e").WithControlDependencies(c), {d}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -166,11 +204,11 @@ TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); - Output b = ops::AddN(s.WithOpName("b"), {a}); - Output c = ops::AddN(s.WithOpName("c"), {a}); + Output b = ops::Sqrt(s.WithOpName("b"), {a}); + Output c = ops::Sqrt(s.WithOpName("c"), {a}); Output d = ops::Identity(s.WithOpName("d"), c); Output e = ops::Identity(s.WithOpName("e"), d); - Output f = ops::AddN(s.WithOpName("f"), {e}); + Output f = ops::Sqrt(s.WithOpName("f"), {e}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -216,7 +254,7 @@ TEST_F(ModelPrunerTest, PruningPerservesFetch) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); - Output b = ops::AddN(s.WithOpName("b"), {a}); + Output b = ops::Sqrt(s.WithOpName("b"), {a}); Output c = ops::Identity(s.WithOpName("c"), b); GrapplerItem item; @@ -243,13 +281,13 @@ TEST_F(ModelPrunerTest, PruningPerservesCrossDeviceIdentity) { // Node i1 should be preserved. Output i1 = ops::Identity(s.WithOpName("i1").WithDevice("/gpu:0"), c); - Output a1 = ops::AddN(s.WithOpName("a1").WithDevice("/gpu:0"), {i1}); - Output a2 = ops::AddN(s.WithOpName("a2").WithDevice("/gpu:0"), {i1}); + Output a1 = ops::Sqrt(s.WithOpName("a1").WithDevice("/gpu:0"), {i1}); + Output a2 = ops::Sqrt(s.WithOpName("a2").WithDevice("/gpu:0"), {i1}); // Node i2 should be pruned since it resides on the sender's device. Output i2 = ops::Identity(s.WithOpName("i2").WithDevice("/cpu:0"), c); - Output a3 = ops::AddN(s.WithOpName("a3").WithDevice("/gpu:0"), {i2}); - Output a4 = ops::AddN(s.WithOpName("a4").WithDevice("/gpu:0"), {i2}); + Output a3 = ops::Sqrt(s.WithOpName("a3").WithDevice("/gpu:0"), {i2}); + Output a4 = ops::Sqrt(s.WithOpName("a4").WithDevice("/gpu:0"), {i2}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); |