aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-22 13:58:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-22 14:02:55 -0700
commit4e131d27354bc9be90e291f3ec4538c0e3bf06eb (patch)
tree2677eb5ff6b721721f0bf6b5ac2cd41f1afa1075 /tensorflow/core
parent89e09f6357863f05ffd3db1ac5f202559470bbfd (diff)
Many algorithms need to enumerate the set of nodes within a graph, while excluding the special Sink and Source nodes. The checks for skipping Source and Sink are duplicated in dozens of loops.
This CL adds a new Graph::op_nodes() method, which returns an enumerable range of all operation nodes, excluding Sink and Source. This allows many for loops to be simplified. This simplification is being done mainly for readability / reliability. There may be a tiny performance difference owing to this change (as well as making the Graph::nodes() and Graph::op_nodes() methods inlineable), but the measured difference is not reliably large enough to be significant. The changes to graph.h and graph.cc are quite minimal. I updated all of the uses of Graph::nodes() that I could reliably determine were unaffected by the change. Most uses immediately checked node->IsOp(). Some compared node->type_string() against literal strings, none of which were "_SINK" or "_SOURCE", and so using op_nodes() was more appropriate than nodes(). In some cases, it was not obvious whether an existing use of Graph::node() wanted to enumerate Sink / Source, so I left those uses unaffected. PiperOrigin-RevId: 156782112
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/function.cc10
-rw-r--r--tensorflow/core/common_runtime/parallel_concat_optimizer.cc2
-rw-r--r--tensorflow/core/common_runtime/resource_variable_read_optimizer.cc2
-rw-r--r--tensorflow/core/common_runtime/simple_placer.cc19
-rw-r--r--tensorflow/core/graph/costmodel.cc24
-rw-r--r--tensorflow/core/graph/graph.cc6
-rw-r--r--tensorflow/core/graph/graph.h33
-rw-r--r--tensorflow/core/graph/graph_constructor.cc4
-rw-r--r--tensorflow/core/graph/graph_partition.cc8
-rw-r--r--tensorflow/core/graph/quantize_training.cc4
10 files changed, 57 insertions, 55 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 407c20bbf2..6de848341d 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -848,9 +848,7 @@ void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
// remember 'y' in node_map[x->id()].
std::vector<Node*> node_map(fbody->graph->num_node_ids());
Status s;
- for (Node* n : fbody->graph->nodes()) {
- if (n->IsSource() || n->IsSink()) continue;
- CHECK(n->IsOp());
+ for (Node* n : fbody->graph->op_nodes()) {
NodeDef ndef = n->def();
ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name()));
Node* clone = g->AddNode(ndef, &s);
@@ -1077,7 +1075,7 @@ FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
ret_types(ret_t.begin(), ret_t.end()) {
this->arg_nodes.resize(arg_types.size());
this->ret_nodes.resize(ret_types.size());
- for (Node* n : this->graph->nodes()) {
+ for (Node* n : this->graph->op_nodes()) {
gtl::InlinedVector<Node*, 4>* node_vec;
if (n->type_string() == kRetOp) {
node_vec = &this->ret_nodes;
@@ -1124,9 +1122,7 @@ void SymbolicGradientHelper::Copy() {
// Copy the nodes.
node_map[src.source_node()->id()] = dst->source_node();
node_map[src.sink_node()->id()] = dst->sink_node();
- for (Node* n : src.nodes()) {
- if (n->IsSource() || n->IsSink()) continue;
- CHECK(n->IsOp());
+ for (Node* n : src.op_nodes()) {
node_map[n->id()] = dst->CopyNode(n);
}
diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
index bbd38a2e07..f9f36443a8 100644
--- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
+++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
@@ -43,7 +43,7 @@ class ParallelConcatRemovePass : public GraphOptimizationPass {
"graph should be available.");
}
gtl::InlinedVector<Node*, 2> matches;
- for (Node* n : g->nodes()) {
+ for (Node* n : g->op_nodes()) {
if (n->type_string() == "ParallelConcat") {
matches.push_back(n);
}
diff --git a/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc b/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc
index b40924ef3a..228c4b5406 100644
--- a/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc
+++ b/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc
@@ -39,7 +39,7 @@ class ResourceVariableReadPass : public GraphOptimizationPass {
"and a graph should be available.");
}
gtl::InlinedVector<Node*, 2> matches;
- for (Node* n : g->nodes()) {
+ for (Node* n : g->op_nodes()) {
if (n->type_string() == "ReadVariableOp") {
bool skip = false;
for (const Edge* e : n->out_edges()) {
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc
index ae225e8b35..59bf0544c1 100644
--- a/tensorflow/core/common_runtime/simple_placer.cc
+++ b/tensorflow/core/common_runtime/simple_placer.cc
@@ -638,22 +638,14 @@ Status SimplePlacer::Run() {
// 1. First add all of the nodes. Note that steps (1) and (2)
// requires two passes over the nodes because the graph (and hence
// the constraints) may not be acyclic.
- for (Node* node : graph_->nodes()) {
- // Skip the source and sink nodes.
- if (!node->IsOp()) {
- continue;
- }
+ for (Node* node : graph_->op_nodes()) {
status = colocation_graph.AddNode(*node);
if (!status.ok()) return AttachDef(status, *node);
}
// 2. Enumerate the constraint edges, and use them to update the disjoint
// node set.
- for (Node* node : graph_->nodes()) {
- if (!node->IsOp()) {
- continue;
- }
-
+ for (Node* node : graph_->op_nodes()) {
// If `node` has an input edge with reference type, add an
// edge from the source of that edge to `node`.
for (const auto& edge : node->in_edges()) {
@@ -717,12 +709,7 @@ Status SimplePlacer::Run() {
// disjoint node set.
std::vector<Device*> devices;
std::vector<Node*> second_pass;
- for (Node* node : graph_->nodes()) {
- // Skip the source and sink nodes.
- if (!node->IsOp()) {
- continue;
- }
-
+ for (Node* node : graph_->op_nodes()) {
// The graph may have come pre-populated by the framework with assigned
// devices (e.g., for stateful placements), so the placer should not try to
// place nodes that are already placed.
diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc
index 1809b35c84..727b201cc5 100644
--- a/tensorflow/core/graph/costmodel.cc
+++ b/tensorflow/core/graph/costmodel.cc
@@ -217,19 +217,17 @@ Microseconds CostModel::TimeEstimate(const Node* node) const {
}
void CostModel::CheckInitialized(const Graph& graph) const {
- for (const Node* n : graph.nodes()) {
- if (n->IsOp()) {
- CHECK(static_cast<size_t>(n->id()) < time_.size() &&
- time_[n->id()] >= Microseconds(0))
- << ": no time estimate for " << n->DebugString();
-
- CHECK(static_cast<size_t>(n->id()) < slot_bytes_.size())
- << ": no size estimate for " << n->DebugString();
- const auto& perslot = slot_bytes_[n->id()];
- for (size_t i = 0; i < perslot.size(); i++) {
- CHECK_GE(perslot[i], Bytes(0)) << ": no size estimate for output# " << i
- << " of " << n->DebugString();
- }
+ for (const Node* n : graph.op_nodes()) {
+ CHECK(static_cast<size_t>(n->id()) < time_.size() &&
+ time_[n->id()] >= Microseconds(0))
+ << ": no time estimate for " << n->DebugString();
+
+ CHECK(static_cast<size_t>(n->id()) < slot_bytes_.size())
+ << ": no size estimate for " << n->DebugString();
+ const auto& perslot = slot_bytes_[n->id()];
+ for (size_t i = 0; i < perslot.size(); i++) {
+ CHECK_GE(perslot[i], Bytes(0)) << ": no size estimate for output# " << i
+ << " of " << n->DebugString();
}
}
}
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 9066de5668..80161ceb56 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -488,12 +488,6 @@ string Graph::NewName(StringPiece prefix) {
return strings::StrCat(prefix, "/_", name_counter_++);
}
-gtl::iterator_range<NodeIter> Graph::nodes() const {
- // Note that NodeId 0 is always valid since we don't let the source
- // node be removed from the graph.
- return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
-}
-
bool Graph::IsValidNode(Node* node) const {
if (node == nullptr) return false;
const int id = node->id();
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 8554cb2f4b..e82580f204 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -417,6 +417,12 @@ class Graph {
// array's size.
int num_nodes() const { return num_nodes_; }
+ // The number of live nodes in the graph, excluding the Source and Sink nodes.
+ int num_op_nodes() const {
+ DCHECK_GE(num_nodes_, 2);
+ return num_nodes_ - 2;
+ }
+
// The number of live edges in the graph.
//
// Because edges can be removed from the graph, num_edges() is often
@@ -439,6 +445,9 @@ class Graph {
// for (Node* node : graph.nodes()) { ... }
gtl::iterator_range<NodeIter> nodes() const;
+ // Access to the list of all nodes, excluding the Source and Sink nodes.
+ gtl::iterator_range<NodeIter> op_nodes() const;
+
// Returns one more than the maximum id assigned to any node.
int num_node_ids() const { return nodes_.size(); }
@@ -633,6 +642,30 @@ inline bool Edge::IsControlEdge() const {
return src_output_ == Graph::kControlSlot;
}
+inline gtl::iterator_range<NodeIter> Graph::nodes() const {
+ // Note that NodeId 0 is always valid since we don't let the source
+ // node be removed from the graph.
+ return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
+}
+
+inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
+ // Note that NodeId 0 is always valid since we don't let the source
+ // node be removed from the graph.
+ //
+ // The current implementation of Graph maintains the invariant that the
+ // first two nodes are the source and sink nodes, and all other nodes are op
+ // nodes. This method (op_nodes()) relies on this invariant.
+ NodeIter begin(this, 0);
+ NodeIter end(this, num_node_ids());
+ if (begin != end) {
+ ++begin;
+ }
+ if (begin != end) {
+ ++begin;
+ }
+ return gtl::make_range(begin, end);
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_GRAPH_GRAPH_H_
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 70087b8fe1..1d7eea2206 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -898,9 +898,7 @@ void CopyGraph(const Graph& src, Graph* dest) {
node_map; // "Node in src" -> "Node in *dest"
node_map[src.source_node()] = dest->source_node();
node_map[src.sink_node()] = dest->sink_node();
- for (Node* n : src.nodes()) {
- if (n->IsSource() || n->IsSink()) continue;
- CHECK(n->IsOp());
+ for (Node* n : src.op_nodes()) {
node_map[n] = dest->CopyNode(n);
}
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 57a2f399e0..6036317559 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -519,9 +519,7 @@ Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
MemoryTypeVector output_memory_types;
info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
- for (const Node* node : g.nodes()) {
- if (!node->IsOp()) continue; // Skip Sink/Source nodes.
-
+ for (const Node* node : g.op_nodes()) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
&parsed)) {
@@ -831,9 +829,7 @@ Status Partition(const PartitionOptions& opts, Graph* g,
int32 num_data = 0;
int32 num_control = 0;
- for (const Node* dst : g->nodes()) {
- if (!dst->IsOp()) continue; // Skip Sink/Source nodes.
-
+ for (const Node* dst : g->op_nodes()) {
dstp = opts.node_to_loc(dst);
GraphDef* dst_graph = &(*partitions)[dstp];
NodeDef* dst_def = dst_graph->add_node();
diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc
index a0c3fbe2aa..48b6b2a497 100644
--- a/tensorflow/core/graph/quantize_training.cc
+++ b/tensorflow/core/graph/quantize_training.cc
@@ -139,7 +139,7 @@ bool FindType(const Graph* graph, const Node* node, bool* signed_input,
Status FindSaveOp(const Graph* graph, Node** save_op,
std::vector<const Edge*>* in_edges, bool* found) {
*found = false;
- for (Node* node : graph->nodes()) {
+ for (Node* node : graph->op_nodes()) {
if (node->type_string() == "SaveV2") {
// We found multiple save ops.
if (*found) {
@@ -154,7 +154,7 @@ Status FindSaveOp(const Graph* graph, Node** save_op,
}
Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) {
- for (Node* node : graph->nodes()) {
+ for (Node* node : graph->op_nodes()) {
// The restore_all op should have the same prefix of the save_op.
if (node->name() == strings::StrCat(save_prefix, "/restore_all")) {
return node;