aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-27 09:06:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-27 10:31:05 -0700
commit1f486662bf6626239c115bc4ef82c9aa45429c40 (patch)
tree29b846f0d0c4bd159b07abf30677bdfe3cea0162 /tensorflow/core/grappler/optimizers
parent5e2cef7155fea04469578e48123ec6925e998c2f (diff)
Automated rollback of change 151331705
Change: 151335029
Diffstat (limited to 'tensorflow/core/grappler/optimizers')
-rw-r--r--tensorflow/core/grappler/optimizers/graph_rewriter.cc18
-rw-r--r--tensorflow/core/grappler/optimizers/graph_rewriter.h5
-rw-r--r--tensorflow/core/grappler/optimizers/model_pruner.cc19
-rw-r--r--tensorflow/core/grappler/optimizers/model_pruner.h4
-rw-r--r--tensorflow/core/grappler/optimizers/model_pruner_test.cc77
5 files changed, 7 insertions, 116 deletions
diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.cc b/tensorflow/core/grappler/optimizers/graph_rewriter.cc
index fbb7e849ba..b554baea95 100644
--- a/tensorflow/core/grappler/optimizers/graph_rewriter.cc
+++ b/tensorflow/core/grappler/optimizers/graph_rewriter.cc
@@ -27,19 +27,6 @@ GraphRewriter::GraphRewriter(const GrapplerItem& item) {
for (auto& node : item.graph.node()) {
nodes_[node.name()] = &node;
}
-
- for (auto& node : item.graph.node()) {
- for (const auto& input : node.input()) {
- int position = 0;
- string input_node_name = ParseNodeName(input, &position);
- if (position < 0) {
- // This is a control edge
- auto itr = nodes_.find(input_node_name);
- CHECK(itr != nodes_.end());
- control_dependency_drivers_.insert(itr->second);
- }
- }
- }
}
void GraphRewriter::ForwardInputs(
@@ -59,10 +46,5 @@ void GraphRewriter::ForwardInputs(
}
}
-bool GraphRewriter::DrivesControlDependency(const NodeDef& node) const {
- return control_dependency_drivers_.find(&node) !=
- control_dependency_drivers_.end();
-}
-
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.h b/tensorflow/core/grappler/optimizers/graph_rewriter.h
index a9cc777809..f373b49fee 100644
--- a/tensorflow/core/grappler/optimizers/graph_rewriter.h
+++ b/tensorflow/core/grappler/optimizers/graph_rewriter.h
@@ -39,13 +39,8 @@ class GraphRewriter {
const std::unordered_set<const NodeDef*>& nodes_to_delete,
NodeDef* new_node);
- // Returns true if at least one of the edges in the direct fanout of 'node' is
- // a control dependency edge.
- bool DrivesControlDependency(const NodeDef& node) const;
-
private:
std::unordered_map<string, const NodeDef*> nodes_;
- std::unordered_set<const NodeDef*> control_dependency_drivers_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc
index 26d62a0fcc..e53411df7c 100644
--- a/tensorflow/core/grappler/optimizers/model_pruner.cc
+++ b/tensorflow/core/grappler/optimizers/model_pruner.cc
@@ -29,18 +29,11 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
std::unordered_set<const NodeDef*> nodes_to_delete;
for (auto& node : item.graph.node()) {
// Remove the stop gradient nodes since they serve no purpose once the graph
- // is built. Also remove Identity ops.
- if (node.op() != "StopGradient" && node.op() != "Identity") {
+ // is built.
+ if (node.op() != "StopGradient") {
continue;
}
- // Don't prune nodes that are explicitely placed.
- if (!node.device().empty()) {
- continue;
- }
- // Don't remove nodes that drive control dependencies.
- if (!rewriter.DrivesControlDependency(node)) {
- nodes_to_delete.insert(&node);
- }
+ nodes_to_delete.insert(&node);
}
for (auto& node : item.graph.node()) {
@@ -53,15 +46,11 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
rewriter.ForwardInputs(node, nodes_to_delete, new_node);
}
- LOG(INFO) << "Pruned " << nodes_to_delete.size()
- << " nodes from the graph. The graph now contains "
- << pruned_graph->node_size() " nodes.";
-
return Status::OK();
}
void ModelPruner::Feedback(Cluster* cluster, const GrapplerItem& item,
- const GraphDef& pruned_graph, double result) {
+ const GraphDef& optimize_output, double result) {
// Nothing to do for ModelPruner.
}
diff --git a/tensorflow/core/grappler/optimizers/model_pruner.h b/tensorflow/core/grappler/optimizers/model_pruner.h
index 3956d33961..3bc6dd5706 100644
--- a/tensorflow/core/grappler/optimizers/model_pruner.h
+++ b/tensorflow/core/grappler/optimizers/model_pruner.h
@@ -32,10 +32,10 @@ class ModelPruner : public GraphOptimizer {
string name() const override { return "model_pruner"; };
Status Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* pruned_graph) override;
+ GraphDef* output) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
- const GraphDef& pruned_graph, double result) override;
+ const GraphDef& optimize_output, double result) override;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc
index 53f868a6dd..320ba86147 100644
--- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc
+++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc
@@ -52,7 +52,7 @@ TEST_F(ModelPrunerTest, NoPruning) {
}
}
-TEST_F(ModelPrunerTest, StopGradientPruning) {
+TEST_F(ModelPrunerTest, SimplePruning) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -82,81 +82,6 @@ TEST_F(ModelPrunerTest, StopGradientPruning) {
EXPECT_EQ(NodeName(b.name()), new_e.input(0));
}
-TEST_F(ModelPrunerTest, IdentityPruning) {
- // Build a simple graph with a few trivially prunable ops.
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-
- Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
- Output b = ops::AddN(s.WithOpName("b"), {a});
- Output c = ops::Identity(s.WithOpName("c"), b);
- Output d = ops::Identity(s.WithOpName("d"), c);
- Output e = ops::AddN(s.WithOpName("e"), {d});
-
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
- // Force the placement of c. This should ensure it is preserved.
- EXPECT_EQ("c", item.graph.node(2).name());
- item.graph.mutable_node(2)->set_device("CPU");
-
- ModelPruner pruner;
- GraphDef output;
- Status status = pruner.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
-
- EXPECT_EQ(4, output.node_size());
- const NodeDef& new_a = output.node(0);
- EXPECT_EQ(NodeName(a.name()), new_a.name());
- const NodeDef& new_b = output.node(1);
- EXPECT_EQ(NodeName(b.name()), new_b.name());
- const NodeDef& new_c = output.node(2);
- EXPECT_EQ(NodeName(c.name()), new_c.name());
- const NodeDef& new_e = output.node(3);
- EXPECT_EQ(NodeName(e.name()), new_e.name());
-
- EXPECT_EQ(1, new_e.input_size());
- EXPECT_EQ(NodeName(c.name()), new_e.input(0));
-}
-
-TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
- // Build a simple graph with a few trivially prunable ops.
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-
- Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
- Output b = ops::AddN(s.WithOpName("b"), {a});
- Output c = ops::Identity(s.WithOpName("c"), b);
- Output d = ops::Identity(s.WithOpName("d"), c);
- Output e = ops::AddN(s.WithOpName("e"), {d});
-
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
- // Add a control dependency between c and e. This should ensure c is
- // preserved.
- EXPECT_EQ("c", item.graph.node(2).name());
- EXPECT_EQ("e", item.graph.node(4).name());
- *item.graph.mutable_node(4)->add_input() = "^c";
-
- ModelPruner pruner;
- GraphDef output;
- Status status = pruner.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
-
- EXPECT_EQ(4, output.node_size());
- const NodeDef& new_a = output.node(0);
- EXPECT_EQ(NodeName(a.name()), new_a.name());
- const NodeDef& new_b = output.node(1);
- EXPECT_EQ(NodeName(b.name()), new_b.name());
- const NodeDef& new_c = output.node(2);
- EXPECT_EQ(NodeName(c.name()), new_c.name());
- const NodeDef& new_e = output.node(3);
- EXPECT_EQ(NodeName(e.name()), new_e.name());
-
- EXPECT_EQ(2, new_e.input_size());
- EXPECT_EQ(NodeName(c.name()), new_e.input(0));
- EXPECT_EQ("^c", new_e.input(1));
-}
-
} // namespace
} // namespace grappler
} // namespace tensorflow