diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-27 09:06:25 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-27 10:31:05 -0700 |
commit | 1f486662bf6626239c115bc4ef82c9aa45429c40 (patch) | |
tree | 29b846f0d0c4bd159b07abf30677bdfe3cea0162 /tensorflow/core/grappler/optimizers | |
parent | 5e2cef7155fea04469578e48123ec6925e998c2f (diff) |
Automated rollback of change 151331705
Change: 151335029
Diffstat (limited to 'tensorflow/core/grappler/optimizers')
5 files changed, 7 insertions, 116 deletions
diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.cc b/tensorflow/core/grappler/optimizers/graph_rewriter.cc index fbb7e849ba..b554baea95 100644 --- a/tensorflow/core/grappler/optimizers/graph_rewriter.cc +++ b/tensorflow/core/grappler/optimizers/graph_rewriter.cc @@ -27,19 +27,6 @@ GraphRewriter::GraphRewriter(const GrapplerItem& item) { for (auto& node : item.graph.node()) { nodes_[node.name()] = &node; } - - for (auto& node : item.graph.node()) { - for (const auto& input : node.input()) { - int position = 0; - string input_node_name = ParseNodeName(input, &position); - if (position < 0) { - // This is a control edge - auto itr = nodes_.find(input_node_name); - CHECK(itr != nodes_.end()); - control_dependency_drivers_.insert(itr->second); - } - } - } } void GraphRewriter::ForwardInputs( @@ -59,10 +46,5 @@ void GraphRewriter::ForwardInputs( } } -bool GraphRewriter::DrivesControlDependency(const NodeDef& node) const { - return control_dependency_drivers_.find(&node) != - control_dependency_drivers_.end(); -} - } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.h b/tensorflow/core/grappler/optimizers/graph_rewriter.h index a9cc777809..f373b49fee 100644 --- a/tensorflow/core/grappler/optimizers/graph_rewriter.h +++ b/tensorflow/core/grappler/optimizers/graph_rewriter.h @@ -39,13 +39,8 @@ class GraphRewriter { const std::unordered_set<const NodeDef*>& nodes_to_delete, NodeDef* new_node); - // Returns true if at least one of the edges in the direct fanout of 'node' is - // a control dependency edge. - bool DrivesControlDependency(const NodeDef& node) const; - private: std::unordered_map<string, const NodeDef*> nodes_; - std::unordered_set<const NodeDef*> control_dependency_drivers_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index 26d62a0fcc..e53411df7c 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -29,18 +29,11 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, std::unordered_set<const NodeDef*> nodes_to_delete; for (auto& node : item.graph.node()) { // Remove the stop gradient nodes since they serve no purpose once the graph - // is built. Also remove Identity ops. - if (node.op() != "StopGradient" && node.op() != "Identity") { + // is built. + if (node.op() != "StopGradient") { continue; } - // Don't prune nodes that are explicitely placed. - if (!node.device().empty()) { - continue; - } - // Don't remove nodes that drive control dependencies. - if (!rewriter.DrivesControlDependency(node)) { - nodes_to_delete.insert(&node); - } + nodes_to_delete.insert(&node); } for (auto& node : item.graph.node()) { @@ -53,15 +46,11 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, rewriter.ForwardInputs(node, nodes_to_delete, new_node); } - LOG(INFO) << "Pruned " << nodes_to_delete.size() - << " nodes from the graph. The graph now contains " - << pruned_graph->node_size() " nodes."; - return Status::OK(); } void ModelPruner::Feedback(Cluster* cluster, const GrapplerItem& item, - const GraphDef& pruned_graph, double result) { + const GraphDef& optimize_output, double result) { // Nothing to do for ModelPruner. } diff --git a/tensorflow/core/grappler/optimizers/model_pruner.h b/tensorflow/core/grappler/optimizers/model_pruner.h index 3956d33961..3bc6dd5706 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.h +++ b/tensorflow/core/grappler/optimizers/model_pruner.h @@ -32,10 +32,10 @@ class ModelPruner : public GraphOptimizer { string name() const override { return "model_pruner"; }; Status Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* pruned_graph) override; + GraphDef* output) override; void Feedback(Cluster* cluster, const GrapplerItem& item, - const GraphDef& pruned_graph, double result) override; + const GraphDef& optimize_output, double result) override; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc index 53f868a6dd..320ba86147 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc @@ -52,7 +52,7 @@ TEST_F(ModelPrunerTest, NoPruning) { } } -TEST_F(ModelPrunerTest, StopGradientPruning) { +TEST_F(ModelPrunerTest, SimplePruning) { // Build a simple graph with a few trivially prunable ops. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -82,81 +82,6 @@ TEST_F(ModelPrunerTest, StopGradientPruning) { EXPECT_EQ(NodeName(b.name()), new_e.input(0)); } -TEST_F(ModelPrunerTest, IdentityPruning) { - // Build a simple graph with a few trivially prunable ops. - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - - Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); - Output b = ops::AddN(s.WithOpName("b"), {a}); - Output c = ops::Identity(s.WithOpName("c"), b); - Output d = ops::Identity(s.WithOpName("d"), c); - Output e = ops::AddN(s.WithOpName("e"), {d}); - - GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - - // Force the placement of c. This should ensure it is preserved. - EXPECT_EQ("c", item.graph.node(2).name()); - item.graph.mutable_node(2)->set_device("CPU"); - - ModelPruner pruner; - GraphDef output; - Status status = pruner.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - - EXPECT_EQ(4, output.node_size()); - const NodeDef& new_a = output.node(0); - EXPECT_EQ(NodeName(a.name()), new_a.name()); - const NodeDef& new_b = output.node(1); - EXPECT_EQ(NodeName(b.name()), new_b.name()); - const NodeDef& new_c = output.node(2); - EXPECT_EQ(NodeName(c.name()), new_c.name()); - const NodeDef& new_e = output.node(3); - EXPECT_EQ(NodeName(e.name()), new_e.name()); - - EXPECT_EQ(1, new_e.input_size()); - EXPECT_EQ(NodeName(c.name()), new_e.input(0)); -} - -TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { - // Build a simple graph with a few trivially prunable ops. - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - - Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); - Output b = ops::AddN(s.WithOpName("b"), {a}); - Output c = ops::Identity(s.WithOpName("c"), b); - Output d = ops::Identity(s.WithOpName("d"), c); - Output e = ops::AddN(s.WithOpName("e"), {d}); - - GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - - // Add a control dependency between c and e. This should ensure c is - // preserved. - EXPECT_EQ("c", item.graph.node(2).name()); - EXPECT_EQ("e", item.graph.node(4).name()); - *item.graph.mutable_node(4)->add_input() = "^c"; - - ModelPruner pruner; - GraphDef output; - Status status = pruner.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - - EXPECT_EQ(4, output.node_size()); - const NodeDef& new_a = output.node(0); - EXPECT_EQ(NodeName(a.name()), new_a.name()); - const NodeDef& new_b = output.node(1); - EXPECT_EQ(NodeName(b.name()), new_b.name()); - const NodeDef& new_c = output.node(2); - EXPECT_EQ(NodeName(c.name()), new_c.name()); - const NodeDef& new_e = output.node(3); - EXPECT_EQ(NodeName(e.name()), new_e.name()); - - EXPECT_EQ(2, new_e.input_size()); - EXPECT_EQ(NodeName(c.name()), new_e.input(0)); - EXPECT_EQ("^c", new_e.input(1)); -} - } // namespace } // namespace grappler } // namespace tensorflow |