aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-07-25 15:25:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-25 15:28:32 -0700
commit136494d3295a23e3ed0612773f224243915463b7 (patch)
tree8f22ce36f07c65c26ed64263d49922bd2862b6ff
parent07249f08867369899d39fc60442febdf1e36e6b5 (diff)
Prune trivial ops (concatenation of a single tensor, AddN of a single tensor,
...) PiperOrigin-RevId: 163131793
-rw-r--r--tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc14
-rw-r--r--tensorflow/core/grappler/op_types.cc5
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/model_pruner.cc27
-rw-r--r--tensorflow/core/grappler/optimizers/model_pruner_test.cc66
5 files changed, 93 insertions, 20 deletions
diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
index 446ae2df64..b1ec35e268 100644
--- a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
+++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
@@ -48,9 +48,17 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
for (int i = 0; i < num_stages; i++) {
std::vector<Output> this_stage;
for (int j = 0; j < width; j++) {
- Output combine = AddN(
- s.WithDevice(device_names[use_multiple_devices ? j : 0]), last_stage);
- this_stage.push_back(combine);
+ if (last_stage.size() == 1) {
+ Output unary_op =
+ Square(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
+ last_stage[0]);
+ this_stage.push_back(unary_op);
+ } else {
+ Output combine =
+ AddN(s.WithDevice(device_names[use_multiple_devices ? j : 0]),
+ last_stage);
+ this_stage.push_back(combine);
+ }
}
last_stage = this_stage;
}
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 9b2584f970..8584681220 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -18,6 +18,11 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+bool IsAddN(const NodeDef& node) {
+ const auto op = node.op();
+ return op == "AddN";
+}
+
bool IsConcat(const NodeDef& node) {
const auto op = node.op();
return op == "Concat" || op == "ConcatV2";
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 9c9dd22e2c..d83cb777ed 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -21,6 +21,7 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+bool IsAddN(const NodeDef& node);
bool IsConcat(const NodeDef& node);
bool IsConstant(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc
index df9aca8aa3..e313155563 100644
--- a/tensorflow/core/grappler/optimizers/model_pruner.cc
+++ b/tensorflow/core/grappler/optimizers/model_pruner.cc
@@ -26,6 +26,29 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+int NumNonControlInputs(const NodeDef& node) {
+ int num_inputs = node.input_size();
+ for (int i = 0; i < node.input_size(); ++i) {
+ if (!node.input(i).empty() && node.input(i)[0] == '^') {
+ num_inputs--;
+ }
+ }
+ return num_inputs;
+}
+
+bool IsTrivialOp(const NodeDef& node) {
+ // Remove the stop gradient nodes since they serve no purpose once the graph
+ // is built. Also remove Identity ops.
+ if (IsStopGradient(node) || IsIdentity(node)) {
+ return true;
+ }
+ if (IsAddN(node) && NumNonControlInputs(node) <= 1) {
+ return true;
+ }
+
+ return false;
+}
+
Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* pruned_graph) {
GraphRewriter rewriter(item);
@@ -43,9 +66,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
std::unordered_set<const NodeDef*> nodes_to_delete;
for (auto& node : item.graph.node()) {
- // Remove the stop gradient nodes since they serve no purpose once the graph
- // is built. Also remove Identity ops.
- if (!IsStopGradient(node) && !IsIdentity(node)) {
+ if (!IsTrivialOp(node)) {
continue;
}
// Don't remove nodes that must be preserved.
diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc
index fdfb3f41cf..72d9c7bf27 100644
--- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc
+++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc
@@ -57,10 +57,10 @@ TEST_F(ModelPrunerTest, StopGradientPruning) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
- Output b = ops::AddN(s.WithOpName("b"), {a});
+ Output b = ops::Sqrt(s.WithOpName("b"), {a});
Output c = ops::StopGradient(s.WithOpName("c"), b);
Output d = ops::StopGradient(s.WithOpName("d"), c);
- Output e = ops::AddN(s.WithOpName("e"), {d});
+ Output e = ops::Sqrt(s.WithOpName("e"), {d});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -93,10 +93,10 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
- Output b = ops::AddN(s.WithOpName("b"), {a});
+ Output b = ops::Sqrt(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), b);
Output d = ops::Identity(s.WithOpName("d"), c);
- Output e = ops::AddN(s.WithOpName("e"), {d});
+ Output e = ops::Sqrt(s.WithOpName("e"), {d});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -126,15 +126,53 @@ TEST_F(ModelPrunerTest, IdentityPruning) {
EXPECT_EQ(NodeName(b.name()), new_c.input(0));
}
-TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
+TEST_F(ModelPrunerTest, NoOpPruning) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::AddN(s.WithOpName("b"), {a});
+ Output c = ops::AddN(s.WithOpName("c"), {b});
+ Output d = ops::AddN(s.WithOpName("d").WithControlDependencies(b), {c});
+ Output e = ops::AddN(s.WithOpName("e"), {d});
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ModelPruner pruner;
+ GraphDef output;
+ Status status = pruner.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(5, output.node_size());
+ const NodeDef& new_a = output.node(0);
+ EXPECT_EQ(NodeName(a.name()), new_a.name());
+ const NodeDef& new_b = output.node(1);
+ EXPECT_EQ(NodeName(b.name()), new_b.name());
+ const NodeDef& new_c = output.node(2);
+ EXPECT_EQ(NodeName(c.name()), new_c.name());
+ const NodeDef& new_d = output.node(3);
+ EXPECT_EQ(NodeName(d.name()), new_d.name());
+ const NodeDef& new_e = output.node(4);
+ EXPECT_EQ(NodeName(e.name()), new_e.name());
+
+ EXPECT_EQ(1, new_e.input_size());
+ EXPECT_EQ(NodeName(d.name()), new_e.input(0));
+ EXPECT_EQ(2, new_d.input_size());
+ EXPECT_EQ(NodeName(b.name()), new_d.input(0));
+ EXPECT_EQ(1, new_c.input_size());
+ EXPECT_EQ(NodeName(b.name()), new_c.input(0));
+}
+
+TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
+ // Build a simple graph with a few trivially prunable ops.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
+ Output b = ops::Sqrt(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), b);
Output d = ops::Identity(s.WithOpName("d"), c);
- Output e = ops::AddN(s.WithOpName("e").WithControlDependencies(c), {d});
+ Output e = ops::Sqrt(s.WithOpName("e").WithControlDependencies(c), {d});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -166,11 +204,11 @@ TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
- Output b = ops::AddN(s.WithOpName("b"), {a});
- Output c = ops::AddN(s.WithOpName("c"), {a});
+ Output b = ops::Sqrt(s.WithOpName("b"), {a});
+ Output c = ops::Sqrt(s.WithOpName("c"), {a});
Output d = ops::Identity(s.WithOpName("d"), c);
Output e = ops::Identity(s.WithOpName("e"), d);
- Output f = ops::AddN(s.WithOpName("f"), {e});
+ Output f = ops::Sqrt(s.WithOpName("f"), {e});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -216,7 +254,7 @@ TEST_F(ModelPrunerTest, PruningPerservesFetch) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
- Output b = ops::AddN(s.WithOpName("b"), {a});
+ Output b = ops::Sqrt(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), b);
GrapplerItem item;
@@ -243,13 +281,13 @@ TEST_F(ModelPrunerTest, PruningPerservesCrossDeviceIdentity) {
// Node i1 should be preserved.
Output i1 = ops::Identity(s.WithOpName("i1").WithDevice("/gpu:0"), c);
- Output a1 = ops::AddN(s.WithOpName("a1").WithDevice("/gpu:0"), {i1});
- Output a2 = ops::AddN(s.WithOpName("a2").WithDevice("/gpu:0"), {i1});
+ Output a1 = ops::Sqrt(s.WithOpName("a1").WithDevice("/gpu:0"), {i1});
+ Output a2 = ops::Sqrt(s.WithOpName("a2").WithDevice("/gpu:0"), {i1});
// Node i2 should be pruned since it resides on the sender's device.
Output i2 = ops::Identity(s.WithOpName("i2").WithDevice("/cpu:0"), c);
- Output a3 = ops::AddN(s.WithOpName("a3").WithDevice("/gpu:0"), {i2});
- Output a4 = ops::AddN(s.WithOpName("a4").WithDevice("/gpu:0"), {i2});
+ Output a3 = ops::Sqrt(s.WithOpName("a3").WithDevice("/gpu:0"), {i2});
+ Output a4 = ops::Sqrt(s.WithOpName("a4").WithDevice("/gpu:0"), {i2});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));