diff options
author | 2017-05-22 13:58:57 -0700 | |
---|---|---|
committer | 2017-05-22 14:02:55 -0700 | |
commit | 4e131d27354bc9be90e291f3ec4538c0e3bf06eb (patch) | |
tree | 2677eb5ff6b721721f0bf6b5ac2cd41f1afa1075 /tensorflow/core | |
parent | 89e09f6357863f05ffd3db1ac5f202559470bbfd (diff) |
Many algorithms need to enumerate the set of nodes within a graph, while excluding the special Sink and Source nodes. The checks for skipping Source and Sink are duplicated in dozens of loops.
This CL adds a new Graph::op_nodes() method, which returns an enumerable range of all operation nodes, excluding Sink and Source. This allows many for loops to be simplified.
This simplification is being done mainly for readability / reliability. There may be a tiny performance difference owing to this change (as well as making the Graph::nodes() and Graph::op_nodes() methods inlineable), but the measured difference is not reliably large enough to be significant.
The changes to graph.h and graph.cc are quite minimal. I updated all of the uses of Graph::nodes() that I could reliably determine were unaffected by the change. Most uses immediately checked node->IsOp(). Some compared node->type_string() against literal strings, none of which were "_SINK" or "_SOURCE", and so using op_nodes() was more appropriate than nodes(). In some cases, it was not obvious whether an existing use of Graph::node() wanted to enumerate Sink / Source, so I left those uses unaffected.
PiperOrigin-RevId: 156782112
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/parallel_concat_optimizer.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/resource_variable_read_optimizer.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/simple_placer.cc | 19 | ||||
-rw-r--r-- | tensorflow/core/graph/costmodel.cc | 24 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.h | 33 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_partition.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/graph/quantize_training.cc | 4 |
10 files changed, 57 insertions, 55 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 407c20bbf2..6de848341d 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -848,9 +848,7 @@ void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, // remember 'y' in node_map[x->id()]. std::vector<Node*> node_map(fbody->graph->num_node_ids()); Status s; - for (Node* n : fbody->graph->nodes()) { - if (n->IsSource() || n->IsSink()) continue; - CHECK(n->IsOp()); + for (Node* n : fbody->graph->op_nodes()) { NodeDef ndef = n->def(); ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name())); Node* clone = g->AddNode(ndef, &s); @@ -1077,7 +1075,7 @@ FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t, ret_types(ret_t.begin(), ret_t.end()) { this->arg_nodes.resize(arg_types.size()); this->ret_nodes.resize(ret_types.size()); - for (Node* n : this->graph->nodes()) { + for (Node* n : this->graph->op_nodes()) { gtl::InlinedVector<Node*, 4>* node_vec; if (n->type_string() == kRetOp) { node_vec = &this->ret_nodes; @@ -1124,9 +1122,7 @@ void SymbolicGradientHelper::Copy() { // Copy the nodes. node_map[src.source_node()->id()] = dst->source_node(); node_map[src.sink_node()->id()] = dst->sink_node(); - for (Node* n : src.nodes()) { - if (n->IsSource() || n->IsSink()) continue; - CHECK(n->IsOp()); + for (Node* n : src.op_nodes()) { node_map[n->id()] = dst->CopyNode(n); } diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc index bbd38a2e07..f9f36443a8 100644 --- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc +++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc @@ -43,7 +43,7 @@ class ParallelConcatRemovePass : public GraphOptimizationPass { "graph should be available."); } gtl::InlinedVector<Node*, 2> matches; - for (Node* n : g->nodes()) { + for (Node* n : g->op_nodes()) { if (n->type_string() == "ParallelConcat") { matches.push_back(n); } diff --git a/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc b/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc index b40924ef3a..228c4b5406 100644 --- a/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc +++ b/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc @@ -39,7 +39,7 @@ class ResourceVariableReadPass : public GraphOptimizationPass { "and a graph should be available."); } gtl::InlinedVector<Node*, 2> matches; - for (Node* n : g->nodes()) { + for (Node* n : g->op_nodes()) { if (n->type_string() == "ReadVariableOp") { bool skip = false; for (const Edge* e : n->out_edges()) { diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc index ae225e8b35..59bf0544c1 100644 --- a/tensorflow/core/common_runtime/simple_placer.cc +++ b/tensorflow/core/common_runtime/simple_placer.cc @@ -638,22 +638,14 @@ Status SimplePlacer::Run() { // 1. First add all of the nodes. Note that steps (1) and (2) // requires two passes over the nodes because the graph (and hence // the constraints) may not be acyclic. - for (Node* node : graph_->nodes()) { - // Skip the source and sink nodes. - if (!node->IsOp()) { - continue; - } + for (Node* node : graph_->op_nodes()) { status = colocation_graph.AddNode(*node); if (!status.ok()) return AttachDef(status, *node); } // 2. Enumerate the constraint edges, and use them to update the disjoint // node set. - for (Node* node : graph_->nodes()) { - if (!node->IsOp()) { - continue; - } - + for (Node* node : graph_->op_nodes()) { // If `node` has an input edge with reference type, add an // edge from the source of that edge to `node`. for (const auto& edge : node->in_edges()) { @@ -717,12 +709,7 @@ Status SimplePlacer::Run() { // disjoint node set. std::vector<Device*> devices; std::vector<Node*> second_pass; - for (Node* node : graph_->nodes()) { - // Skip the source and sink nodes. - if (!node->IsOp()) { - continue; - } - + for (Node* node : graph_->op_nodes()) { // The graph may have come pre-populated by the framework with assigned // devices (e.g., for stateful placements), so the placer should not try to // place nodes that are already placed. diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc index 1809b35c84..727b201cc5 100644 --- a/tensorflow/core/graph/costmodel.cc +++ b/tensorflow/core/graph/costmodel.cc @@ -217,19 +217,17 @@ Microseconds CostModel::TimeEstimate(const Node* node) const { } void CostModel::CheckInitialized(const Graph& graph) const { - for (const Node* n : graph.nodes()) { - if (n->IsOp()) { - CHECK(static_cast<size_t>(n->id()) < time_.size() && - time_[n->id()] >= Microseconds(0)) - << ": no time estimate for " << n->DebugString(); - - CHECK(static_cast<size_t>(n->id()) < slot_bytes_.size()) - << ": no size estimate for " << n->DebugString(); - const auto& perslot = slot_bytes_[n->id()]; - for (size_t i = 0; i < perslot.size(); i++) { - CHECK_GE(perslot[i], Bytes(0)) << ": no size estimate for output# " << i - << " of " << n->DebugString(); - } + for (const Node* n : graph.op_nodes()) { + CHECK(static_cast<size_t>(n->id()) < time_.size() && + time_[n->id()] >= Microseconds(0)) + << ": no time estimate for " << n->DebugString(); + + CHECK(static_cast<size_t>(n->id()) < slot_bytes_.size()) + << ": no size estimate for " << n->DebugString(); + const auto& perslot = slot_bytes_[n->id()]; + for (size_t i = 0; i < perslot.size(); i++) { + CHECK_GE(perslot[i], Bytes(0)) << ": no size estimate for output# " << i + << " of " << n->DebugString(); } } } diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 9066de5668..80161ceb56 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -488,12 +488,6 @@ string Graph::NewName(StringPiece prefix) { return strings::StrCat(prefix, "/_", name_counter_++); } -gtl::iterator_range<NodeIter> Graph::nodes() const { - // Note that NodeId 0 is always valid since we don't let the source - // node be removed from the graph. - return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids())); -} - bool Graph::IsValidNode(Node* node) const { if (node == nullptr) return false; const int id = node->id(); diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 8554cb2f4b..e82580f204 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -417,6 +417,12 @@ class Graph { // array's size. int num_nodes() const { return num_nodes_; } + // The number of live nodes in the graph, excluding the Source and Sink nodes. + int num_op_nodes() const { + DCHECK_GE(num_nodes_, 2); + return num_nodes_ - 2; + } + // The number of live edges in the graph. // // Because edges can be removed from the graph, num_edges() is often @@ -439,6 +445,9 @@ class Graph { // for (Node* node : graph.nodes()) { ... } gtl::iterator_range<NodeIter> nodes() const; + // Access to the list of all nodes, excluding the Source and Sink nodes. + gtl::iterator_range<NodeIter> op_nodes() const; + // Returns one more than the maximum id assigned to any node. int num_node_ids() const { return nodes_.size(); } @@ -633,6 +642,30 @@ inline bool Edge::IsControlEdge() const { return src_output_ == Graph::kControlSlot; } +inline gtl::iterator_range<NodeIter> Graph::nodes() const { + // Note that NodeId 0 is always valid since we don't let the source + // node be removed from the graph. + return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids())); +} + +inline gtl::iterator_range<NodeIter> Graph::op_nodes() const { + // Note that NodeId 0 is always valid since we don't let the source + // node be removed from the graph. + // + // The current implementation of Graph maintains the invariant that the + // first two nodes are the source and sink nodes, and all other nodes are op + // nodes. This method (op_nodes()) relies on this invariant. + NodeIter begin(this, 0); + NodeIter end(this, num_node_ids()); + if (begin != end) { + ++begin; + } + if (begin != end) { + ++begin; + } + return gtl::make_range(begin, end); +} + } // namespace tensorflow #endif // TENSORFLOW_GRAPH_GRAPH_H_ diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 70087b8fe1..1d7eea2206 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -898,9 +898,7 @@ void CopyGraph(const Graph& src, Graph* dest) { node_map; // "Node in src" -> "Node in *dest" node_map[src.source_node()] = dest->source_node(); node_map[src.sink_node()] = dest->sink_node(); - for (Node* n : src.nodes()) { - if (n->IsSource() || n->IsSink()) continue; - CHECK(n->IsOp()); + for (Node* n : src.op_nodes()) { node_map[n] = dest->CopyNode(n); } diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 57a2f399e0..6036317559 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -519,9 +519,7 @@ Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) { MemoryTypeVector output_memory_types; info->device_types.resize(g.num_node_ids(), DEVICE_CPU); - for (const Node* node : g.nodes()) { - if (!node->IsOp()) continue; // Skip Sink/Source nodes. - + for (const Node* node : g.op_nodes()) { DeviceNameUtils::ParsedName parsed; if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(), &parsed)) { @@ -831,9 +829,7 @@ Status Partition(const PartitionOptions& opts, Graph* g, int32 num_data = 0; int32 num_control = 0; - for (const Node* dst : g->nodes()) { - if (!dst->IsOp()) continue; // Skip Sink/Source nodes. - + for (const Node* dst : g->op_nodes()) { dstp = opts.node_to_loc(dst); GraphDef* dst_graph = &(*partitions)[dstp]; NodeDef* dst_def = dst_graph->add_node(); diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc index a0c3fbe2aa..48b6b2a497 100644 --- a/tensorflow/core/graph/quantize_training.cc +++ b/tensorflow/core/graph/quantize_training.cc @@ -139,7 +139,7 @@ bool FindType(const Graph* graph, const Node* node, bool* signed_input, Status FindSaveOp(const Graph* graph, Node** save_op, std::vector<const Edge*>* in_edges, bool* found) { *found = false; - for (Node* node : graph->nodes()) { + for (Node* node : graph->op_nodes()) { if (node->type_string() == "SaveV2") { // We found multiple save ops. if (*found) { @@ -154,7 +154,7 @@ Status FindSaveOp(const Graph* graph, Node** save_op, } Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) { - for (Node* node : graph->nodes()) { + for (Node* node : graph->op_nodes()) { // The restore_all op should have the same prefix of the save_op. if (node->name() == strings::StrCat(save_prefix, "/restore_all")) { return node; |