diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-03-27 16:24:07 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-27 18:42:24 -0700 |
commit | 164b6a88f56a8e491b315f2747303c46f04b5c76 (patch) | |
tree | f5430ee8a1d8597610cdcc9c88717c76f5cd14b6 | |
parent | 08b145b26ea3c8389197cddc3ddcee4531b8b08b (diff) |
Avoid pruning fetch nodes
Change: 151394592
-rw-r--r-- | tensorflow/core/grappler/optimizers/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/model_pruner.cc | 15 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/model_pruner_test.cc | 26 |
3 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index f5af4d959b..02716b3f78 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -62,6 +62,7 @@ cc_library( ":graph_rewriter", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", ], ) diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index f371363657..a89831b6e6 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/graph_rewriter.h" +#include "tensorflow/core/grappler/utils.h" namespace tensorflow { namespace grappler { @@ -26,6 +27,14 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* pruned_graph) { GraphRewriter rewriter(item); + std::unordered_set<string> nodes_to_preserve; + for (const auto& node : item.fetch) { + nodes_to_preserve.insert(NodeName(node)); + } + for (const auto& node : item.init_ops) { + nodes_to_preserve.insert(NodeName(node)); + } + std::unordered_set<const NodeDef*> nodes_to_delete; for (auto& node : item.graph.node()) { // Remove the stop gradient nodes since they serve no purpose once the graph @@ -33,7 +42,11 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, if (node.op() != "StopGradient" && node.op() != "Identity") { continue; } - // Don't prune nodes that are explicitely placed. + // Don't remove nodes that must be preserved. + if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) { + continue; + } + // Don't remove nodes that are explicitely placed. if (!node.device().empty()) { continue; } diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc index 6b90dba864..47d45a6f49 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc @@ -199,6 +199,32 @@ TEST_F(ModelPrunerTest, PruningForwardsCtrlDependencies) { EXPECT_EQ("^c", new_f.input(2)); } +TEST_F(ModelPrunerTest, PruningPerservesFetch) { + // Build a simple graph with a few trivially prunable ops. + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); + Output b = ops::AddN(s.WithOpName("b"), {a}); + Output c = ops::Identity(s.WithOpName("c"), b); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("c"); + + ModelPruner pruner; + GraphDef output; + Status status = pruner.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(3, output.node_size()); + const NodeDef& new_a = output.node(0); + EXPECT_EQ(NodeName(a.name()), new_a.name()); + const NodeDef& new_b = output.node(1); + EXPECT_EQ(NodeName(b.name()), new_b.name()); + const NodeDef& new_c = output.node(2); + EXPECT_EQ(NodeName(c.name()), new_c.name()); +} + } // namespace } // namespace grappler } // namespace tensorflow |