diff options
author | 2017-07-27 14:46:40 -0700 | |
---|---|---|
committer | 2017-07-27 14:50:34 -0700 | |
commit | e5353c941c4cfd7f256d69cc50caf6c90e70dd4a (patch) | |
tree | 6e04c41be958e60a109461ccc9b0caff0de1bd5b | |
parent | 22651083406ca01ac9d481e3367a3510d25f88cd (diff) |
Don't prune nodes that have reference inputs.
PiperOrigin-RevId: 163390862
5 files changed, 105 insertions, 15 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index edd3fae7b2..0eb5ecdc5e 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -134,6 +134,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.cc b/tensorflow/core/grappler/optimizers/graph_rewriter.cc index 5273f11ca0..9e4247fd1a 100644 --- a/tensorflow/core/grappler/optimizers/graph_rewriter.cc +++ b/tensorflow/core/grappler/optimizers/graph_rewriter.cc @@ -18,6 +18,8 @@ limitations under the License. #include <unordered_set> #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" @@ -26,8 +28,24 @@ namespace tensorflow { namespace grappler { GraphRewriter::GraphRewriter(const GrapplerItem& item) { + OpRegistryInterface* op_registry = OpRegistry::Global(); for (auto& node : item.graph.node()) { - nodes_[node.name()] = &node; + NodeInfo* info = new NodeInfo(); + info->def = &node; + + const OpRegistrationData* op_reg_data = nullptr; + Status s = op_registry->LookUp(node.op(), &op_reg_data); + // TODO(bsteiner): make this not a best-effort lookup and evaluation? + if (s.ok()) { + s = InOutTypesForNode(node, op_reg_data->op_def, &info->inputs, + &info->outputs); + if (!s.ok()) { + info->inputs.clear(); + info->outputs.clear(); + } + } + + nodes_[node.name()].reset(info); } std::unordered_set<string> function_names; @@ -73,11 +91,16 @@ bool GraphRewriter::IsDrivenByAnotherDevice(const NodeDef& node) const { return cross_device_receivers_.find(&node) != cross_device_receivers_.end(); } +bool GraphRewriter::ReceivesRefValue(const NodeDef& node) const { + return ref_receivers_.find(&node) != ref_receivers_.end(); +} + void GraphRewriter::RecordConnectivity( const NodeDef& node, const std::unordered_set<string>& function_names) { const bool is_function = function_names.find(node.op()) != function_names.end(); + bool ref_receiver = false; for (const auto& input : node.input()) { int position = 0; string input_node_name = ParseNodeName(input, &position); @@ -85,7 +108,8 @@ void GraphRewriter::RecordConnectivity( if (itr == nodes_.end()) { continue; } - const NodeDef* fanin = itr->second; + const NodeInfo* fanin_info = itr->second.get(); + const NodeDef* fanin = fanin_info->def; if (position < 0) { // This is a control edge control_dependency_drivers_.insert(fanin); @@ -97,11 +121,20 @@ void GraphRewriter::RecordConnectivity( if (is_function) { function_neighbors_.insert(fanin); } + + if (position < fanin_info->outputs.size() && + IsRefType(fanin_info->outputs[position])) { + ref_receiver = true; + } } if (fanin->device() != node.device()) { cross_device_receivers_.insert(&node); } } + + if (ref_receiver) { + ref_receivers_.insert(&node); + } } void GraphRewriter::ForwardInputsInternal( @@ -125,7 +158,7 @@ void GraphRewriter::ForwardInputsInternal( *new_node->add_input() = input; continue; } - const NodeDef* input_node = itr->second; + const NodeDef* input_node = itr->second->def; if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) { ForwardInputsInternal(*input_node, nodes_to_delete, new_node); } else { diff --git a/tensorflow/core/grappler/optimizers/graph_rewriter.h b/tensorflow/core/grappler/optimizers/graph_rewriter.h index 4bdb063d58..cdc246369f 100644 --- a/tensorflow/core/grappler/optimizers/graph_rewriter.h +++ b/tensorflow/core/grappler/optimizers/graph_rewriter.h @@ -55,6 +55,9 @@ class GraphRewriter { // device. bool IsDrivenByAnotherDevice(const NodeDef& node) const; + // Returns true if the node has input from a stateful op. + bool ReceivesRefValue(const NodeDef& node) const; + private: void RecordConnectivity(const NodeDef& node, const std::unordered_set<string>& function_names); @@ -63,11 +66,21 @@ class GraphRewriter { const std::unordered_set<const NodeDef*>& nodes_to_delete, NodeDef* new_node); - std::unordered_map<string, const NodeDef*> nodes_; + struct NodeInfo { + const NodeDef* def; + + // These are filled in when the NodeInfo is built, but not that they + // may be empty - if the op could not be loaded from the registry. + DataTypeVector inputs; + DataTypeVector outputs; + }; + + std::unordered_map<string, std::unique_ptr<NodeInfo>> nodes_; std::unordered_map<string, const NodeDef*> optimized_nodes_; std::unordered_set<const NodeDef*> control_dependency_drivers_; std::unordered_set<const NodeDef*> function_neighbors_; std::unordered_set<const NodeDef*> cross_device_receivers_; + std::unordered_set<const NodeDef*> ref_receivers_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index e313155563..4e6218c0fb 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -74,20 +74,23 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, continue; } - // Don't remove nodes that drive control dependencies. - // Don't remove nodes that are driven by control dependencies either since - // we can't ensure (yet) that we won't increase the number of control - // dependency edges by deleting them (for example, removing a node driven by - // 10 control edges and driving 10 control edges would result in the - // creation of 100 edges). - // Don't modify nodes that are connected to functions since that can result - // in inlining failures later on. - // Don't prune nodes that are driven by another device since these could be - // used to reduce cross device communication. + // - Don't remove nodes that drive control dependencies. + // - Don't remove nodes that are driven by control dependencies either since + // we can't ensure (yet) that we won't increase the number of control + // dependency edges by deleting them (for example, removing a node driven + // by 10 control edges and driving 10 control edges would result in the + // creation of 100 edges). + // - Don't modify nodes that are connected to functions since that can + // result in inlining failures later on. + // - Don't prune nodes that are driven by another device since these could + // be used to reduce cross device communication. + // - Don't remove nodes that receive reference values, as those can be + // converting references to non-references. if (!rewriter.DrivesControlDependency(node) && !rewriter.IsDrivenByControlDependency(node) && !rewriter.IsConnectedToFunction(node) && - !rewriter.IsDrivenByAnotherDevice(node)) { + !rewriter.IsDrivenByAnotherDevice(node) && + !rewriter.ReceivesRefValue(node)) { nodes_to_delete.insert(&node); } } diff --git a/tensorflow/core/grappler/optimizers/model_pruner_test.cc b/tensorflow/core/grappler/optimizers/model_pruner_test.cc index 72d9c7bf27..aea1fcd7c9 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner_test.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner_test.cc @@ -199,6 +199,46 @@ TEST_F(ModelPrunerTest, PruningSkipsCtrlDependencies) { EXPECT_EQ("^c", new_e.input(1)); } +TEST_F(ModelPrunerTest, PruningSkipsRefOutputs) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + // Make graph of Identity(Identity(Identity(Identity(Variable)))). + Output a = ops::Variable(s.WithOpName("a"), {}, DT_INT64); + Output b = ops::Identity(s.WithOpName("b"), a); + Output c = ops::Identity(s.WithOpName("c"), b); + Output d = ops::Identity(s.WithOpName("d"), c); + Output e = ops::Identity(s.WithOpName("e"), d); + + // Run pruner. + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + ModelPruner pruner; + GraphDef output; + Status status = pruner.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // Get the updated nodes. + ASSERT_EQ(5, output.node_size()); + const NodeDef& new_a = output.node(0); + const NodeDef& new_b = output.node(1); + const NodeDef& new_c = output.node(2); + const NodeDef& new_d = output.node(3); + const NodeDef& new_e = output.node(4); + EXPECT_EQ("a", new_a.name()); + EXPECT_EQ("b", new_b.name()); + EXPECT_EQ("c", new_c.name()); + EXPECT_EQ("d", new_d.name()); + EXPECT_EQ("e", new_e.name()); + + // Verify the connections. Identity "b" can't be removed from the chain + // because it is converting a reference input to a non-reference, so c,d,e all + // refer to it as an input. + EXPECT_EQ("a", new_b.input(0)); + EXPECT_EQ("b", new_c.input(0)); + EXPECT_EQ("b", new_d.input(0)); + EXPECT_EQ("b", new_e.input(0)); +} + TEST_F(ModelPrunerTest, PruningPerservesCtrlDependencies) { // Build a simple graph with a few trivially prunable ops. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); |