aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-03-27 16:24:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-27 18:42:24 -0700
commit164b6a88f56a8e491b315f2747303c46f04b5c76 (patch)
treef5430ee8a1d8597610cdcc9c88717c76f5cd14b6
parent08b145b26ea3c8389197cddc3ddcee4531b8b08b (diff)
Avoid pruning fetch nodes
Change: 151394592
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/model_pruner.cc15
-rw-r--r--tensorflow/core/grappler/optimizers/model_pruner_test.cc26
3 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index f5af4d959b..02716b3f78 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -62,6 +62,7 @@ cc_library(
":graph_rewriter",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
],
)
diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc
index f371363657..a89831b6e6 100644
--- a/tensorflow/core/grappler/optimizers/model_pruner.cc
+++ b/tensorflow/core/grappler/optimizers/model_pruner.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
+#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
@@ -26,6 +27,14 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* pruned_graph) {
GraphRewriter rewriter(item);
+ std::unordered_set<string> nodes_to_preserve;
+ for (const auto& node : item.fetch) {
+ nodes_to_preserve.insert(NodeName(node));
+ }
+ for (const auto& node : item.init_ops) {
+ nodes_to_preserve.insert(NodeName(node));
+ }
+
std::unordered_set<const NodeDef*> nodes_to_delete;
for (auto& node : item.graph.node()) {
// Remove the stop gradient nodes since they serve no purpose once the graph
@@ -33,7 +42,11 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
if (node.op() != "StopGradient" && node.op() != "Identity") {
continue;
}
- // Don't prune nodes that are explicitely placed.
+ // Don't remove nodes that must be preserved.
+ if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
+ continue;
+ }
+ // Don't remove nodes that are explicitely placed.
if (!node.device().empty()) {
continue;
}
diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc
index 6b90dba864..47d45a6f49 100644
--- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc
+++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc
@@ -199,6 +199,32 @@ TEST_F(ModelPrunerTest, PruningForwardsCtrlDependencies) {
EXPECT_EQ("^c", new_f.input(2));
}
+TEST_F(ModelPrunerTest, PruningPerservesFetch) {
+ // Build a simple graph with a few trivially prunable ops.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
+ Output b = ops::AddN(s.WithOpName("b"), {a});
+ Output c = ops::Identity(s.WithOpName("c"), b);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back("c");
+
+ ModelPruner pruner;
+ GraphDef output;
+ Status status = pruner.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(3, output.node_size());
+ const NodeDef& new_a = output.node(0);
+ EXPECT_EQ(NodeName(a.name()), new_a.name());
+ const NodeDef& new_b = output.node(1);
+ EXPECT_EQ(NodeName(b.name()), new_b.name());
+ const NodeDef& new_c = output.node(2);
+ EXPECT_EQ(NodeName(c.name()), new_c.name());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow