aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-27 14:46:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-27 14:50:34 -0700
commite5353c941c4cfd7f256d69cc50caf6c90e70dd4a (patch)
tree6e04c41be958e60a109461ccc9b0caff0de1bd5b
parent22651083406ca01ac9d481e3367a3510d25f88cd (diff)
Don't prune nodes that have reference inputs.
PiperOrigin-RevId: 163390862
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/graph_rewriter.cc39
-rw-r--r--tensorflow/core/grappler/optimizers/graph_rewriter.h15
-rw-r--r--tensorflow/core/grappler/optimizers/model_pruner.cc25
-rw-r--r--tensorflow/core/grappler/optimizers/model_pruner_test.cc40
5 files changed, 105 insertions, 15 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index edd3fae7b2..0eb5ecdc5e 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -134,6 +134,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.cc b/tensorflow/core/grappler/optimizers/graph_rewriter.cc
index 5273f11ca0..9e4247fd1a 100644
--- a/tensorflow/core/grappler/optimizers/graph_rewriter.cc
+++ b/tensorflow/core/grappler/optimizers/graph_rewriter.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
@@ -26,8 +28,24 @@ namespace tensorflow {
namespace grappler {
GraphRewriter::GraphRewriter(const GrapplerItem& item) {
+ OpRegistryInterface* op_registry = OpRegistry::Global();
for (auto& node : item.graph.node()) {
- nodes_[node.name()] = &node;
+ NodeInfo* info = new NodeInfo();
+ info->def = &node;
+
+ const OpRegistrationData* op_reg_data = nullptr;
+ Status s = op_registry->LookUp(node.op(), &op_reg_data);
+ // TODO(bsteiner): make this not a best-effort lookup and evaluation?
+ if (s.ok()) {
+ s = InOutTypesForNode(node, op_reg_data->op_def, &info->inputs,
+ &info->outputs);
+ if (!s.ok()) {
+ info->inputs.clear();
+ info->outputs.clear();
+ }
+ }
+
+ nodes_[node.name()].reset(info);
}
std::unordered_set<string> function_names;
@@ -73,11 +91,16 @@ bool GraphRewriter::IsDrivenByAnotherDevice(const NodeDef& node) const {
return cross_device_receivers_.find(&node) != cross_device_receivers_.end();
}
+bool GraphRewriter::ReceivesRefValue(const NodeDef& node) const {
+ return ref_receivers_.find(&node) != ref_receivers_.end();
+}
+
void GraphRewriter::RecordConnectivity(
const NodeDef& node, const std::unordered_set<string>& function_names) {
const bool is_function =
function_names.find(node.op()) != function_names.end();
+ bool ref_receiver = false;
for (const auto& input : node.input()) {
int position = 0;
string input_node_name = ParseNodeName(input, &position);
@@ -85,7 +108,8 @@ void GraphRewriter::RecordConnectivity(
if (itr == nodes_.end()) {
continue;
}
- const NodeDef* fanin = itr->second;
+ const NodeInfo* fanin_info = itr->second.get();
+ const NodeDef* fanin = fanin_info->def;
if (position < 0) {
// This is a control edge
control_dependency_drivers_.insert(fanin);
@@ -97,11 +121,20 @@ void GraphRewriter::RecordConnectivity(
if (is_function) {
function_neighbors_.insert(fanin);
}
+
+ if (position < fanin_info->outputs.size() &&
+ IsRefType(fanin_info->outputs[position])) {
+ ref_receiver = true;
+ }
}
if (fanin->device() != node.device()) {
cross_device_receivers_.insert(&node);
}
}
+
+ if (ref_receiver) {
+ ref_receivers_.insert(&node);
+ }
}
void GraphRewriter::ForwardInputsInternal(
@@ -125,7 +158,7 @@ void GraphRewriter::ForwardInputsInternal(
*new_node->add_input() = input;
continue;
}
- const NodeDef* input_node = itr->second;
+ const NodeDef* input_node = itr->second->def;
if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
ForwardInputsInternal(*input_node, nodes_to_delete, new_node);
} else {
diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.h b/tensorflow/core/grappler/optimizers/graph_rewriter.h
index 4bdb063d58..cdc246369f 100644
--- a/tensorflow/core/grappler/optimizers/graph_rewriter.h
+++ b/tensorflow/core/grappler/optimizers/graph_rewriter.h
@@ -55,6 +55,9 @@ class GraphRewriter {
// device.
bool IsDrivenByAnotherDevice(const NodeDef& node) const;
+ // Returns true if the node has input from a stateful op.
+ bool ReceivesRefValue(const NodeDef& node) const;
+
private:
void RecordConnectivity(const NodeDef& node,
const std::unordered_set<string>& function_names);
@@ -63,11 +66,21 @@ class GraphRewriter {
const std::unordered_set<const NodeDef*>& nodes_to_delete,
NodeDef* new_node);
- std::unordered_map<string, const NodeDef*> nodes_;
+ struct NodeInfo {
+ const NodeDef* def;
+
+ // These are filled in when the NodeInfo is built, but not that they
+ // may be empty - if the op could not be loaded from the registry.
+ DataTypeVector inputs;
+ DataTypeVector outputs;
+ };
+
+ std::unordered_map<string, std::unique_ptr<NodeInfo>> nodes_;
std::unordered_map<string, const NodeDef*> optimized_nodes_;
std::unordered_set<const NodeDef*> control_dependency_drivers_;
std::unordered_set<const NodeDef*> function_neighbors_;
std::unordered_set<const NodeDef*> cross_device_receivers_;
+ std::unordered_set<const NodeDef*> ref_receivers_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc
index e313155563..4e6218c0fb 100644
--- a/tensorflow/core/grappler/optimizers/model_pruner.cc
+++ b/tensorflow/core/grappler/optimizers/model_pruner.cc
@@ -74,20 +74,23 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
continue;
}
- // Don't remove nodes that drive control dependencies.
- // Don't remove nodes that are driven by control dependencies either since
- // we can't ensure (yet) that we won't increase the number of control
- // dependency edges by deleting them (for example, removing a node driven by
- // 10 control edges and driving 10 control edges would result in the
- // creation of 100 edges).
- // Don't modify nodes that are connected to functions since that can result
- // in inlining failures later on.
- // Don't prune nodes that are driven by another device since these could be
- // used to reduce cross device communication.
+ // - Don't remove nodes that drive control dependencies.
+ // - Don't remove nodes that are driven by control dependencies either since
+ // we can't ensure (yet) that we won't increase the number of control
+ // dependency edges by deleting them (for example, removing a node driven
+ // by 10 control edges and driving 10 control edges would result in the
+ // creation of 100 edges).
+ // - Don't modify nodes that are connected to functions since that can
+ // result in inlining failures later on.
+ // - Don't prune nodes that are driven by another device since these could
+ // be used to reduce cross device communication.
+ // - Don't remove nodes that receive reference values, as those can be
+ // converting references to non-references.
if (!rewriter.DrivesControlDependency(node) &&
!rewriter.IsDrivenByControlDependency(node) &&
!rewriter.IsConnectedToFunction(node) &&
- !rewriter.IsDrivenByAnotherDevice(node)) {
+ !rewriter.IsDrivenByAnotherDevice(node) &&
+ !rewriter.ReceivesRefValue(node)) {
nodes_to_delete.insert(&node);
}
}
diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc
index 72d9c7bf27..aea1fcd7c9 100644
--- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc
+++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc
@@ -199,6 +199,46 @@ TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) {
EXPECT_EQ("^c", new_e.input(1));
}
+TEST_F(ModelPrunerTest, PruningSkipsRefOutputs) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ // Make graph of Identity(Identity(Identity(Identity(Variable)))).
+ Output a = ops::Variable(s.WithOpName("a"), {}, DT_INT64);
+ Output b = ops::Identity(s.WithOpName("b"), a);
+ Output c = ops::Identity(s.WithOpName("c"), b);
+ Output d = ops::Identity(s.WithOpName("d"), c);
+ Output e = ops::Identity(s.WithOpName("e"), d);
+
+ // Run pruner.
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ ModelPruner pruner;
+ GraphDef output;
+ Status status = pruner.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ // Get the updated nodes.
+ ASSERT_EQ(5, output.node_size());
+ const NodeDef& new_a = output.node(0);
+ const NodeDef& new_b = output.node(1);
+ const NodeDef& new_c = output.node(2);
+ const NodeDef& new_d = output.node(3);
+ const NodeDef& new_e = output.node(4);
+ EXPECT_EQ("a", new_a.name());
+ EXPECT_EQ("b", new_b.name());
+ EXPECT_EQ("c", new_c.name());
+ EXPECT_EQ("d", new_d.name());
+ EXPECT_EQ("e", new_e.name());
+
+ // Verify the connections. Identity "b" can't be removed from the chain
+ // because it is converting a reference input to a non-reference, so c,d,e all
+ // refer to it as an input.
+ EXPECT_EQ("a", new_b.input(0));
+ EXPECT_EQ("b", new_c.input(0));
+ EXPECT_EQ("b", new_d.input(0));
+ EXPECT_EQ("b", new_e.input(0));
+}
+
TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();