diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-10-13 16:01:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-13 16:05:44 -0700 |
commit | 40d5bf33829249404f935441bac0fa1615a58c13 (patch) | |
tree | f556fac48b6929d1f2cc2b786a0ccedb9e48128a /tensorflow/core/graph/graph.cc | |
parent | 7679a2ec746bec36191087feaf9ec8371180669c (diff) |
Enable Operation._add_control_inputs() with the C API and related improvements
This change:
- Implements the C API logic for Operation._add_control_inputs()
- Adds type-checking to Operation._add_control_input()
- Makes Graph::AddControlEdge() update the node def if necessary
- Makes Graph::AddControlEdge() a no-op if the control edge already exists
The AddControlEdge() changes may have a performance impact if anything
is sensitive to AddControlEdge(), but nothing is to my knowledge. I'm
not sure what benchmarks would confirm this.
PiperOrigin-RevId: 172158589
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r-- | tensorflow/core/graph/graph.cc | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index daefb6b1fb..87c41186d5 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -424,6 +424,35 @@ void Graph::RemoveEdge(const Edge* e) { --num_edges_; } +const Edge* Graph::AddControlEdge(Node* source, Node* dest, + bool allow_duplicates) { + if (!allow_duplicates) { + for (const Edge* edge : dest->in_edges()) { + if (edge->IsControlEdge() && edge->src() == source) { + // The requested edge already exists. + return nullptr; + } + } + } + // Modify dest's NodeDef if necessary. + if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) { + // Check if this input is already in dest's NodeDef. + const string new_input = strings::StrCat("^", source->name()); + bool input_exists = false; + for (const string& input : dest->props_->node_def.input()) { + if (input == new_input) { + input_exists = true; + break; + } + } + if (!input_exists) { + dest->MaybeCopyOnWrite(); + dest->props_->node_def.add_input(new_input); + } + } + return AddEdge(source, kControlSlot, dest, kControlSlot); +} + Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index) { TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index)); |