aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-10-13 16:01:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-13 16:05:44 -0700
commit40d5bf33829249404f935441bac0fa1615a58c13 (patch)
treef556fac48b6929d1f2cc2b786a0ccedb9e48128a /tensorflow/core/graph/graph.cc
parent7679a2ec746bec36191087feaf9ec8371180669c (diff)
Enable Operation._add_control_inputs() with the C API and related improvements
This change: - Implements the C API logic for Operation._add_control_inputs() - Adds type-checking to Operation._add_control_input() - Makes Graph::AddControlEdge() update the node def if necessary - Makes Graph::AddControlEdge() a no-op if the control edge already exists The AddControlEdge() changes may have a performance impact if anything is sensitive to AddControlEdge(), but nothing is to my knowledge. I'm not sure what benchmarks would confirm this. PiperOrigin-RevId: 172158589
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r--tensorflow/core/graph/graph.cc29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index daefb6b1fb..87c41186d5 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -424,6 +424,35 @@ void Graph::RemoveEdge(const Edge* e) {
--num_edges_;
}
+const Edge* Graph::AddControlEdge(Node* source, Node* dest,
+ bool allow_duplicates) {
+ if (!allow_duplicates) {
+ for (const Edge* edge : dest->in_edges()) {
+ if (edge->IsControlEdge() && edge->src() == source) {
+ // The requested edge already exists.
+ return nullptr;
+ }
+ }
+ }
+ // Modify dest's NodeDef if necessary.
+ if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
+ // Check if this input is already in dest's NodeDef.
+ const string new_input = strings::StrCat("^", source->name());
+ bool input_exists = false;
+ for (const string& input : dest->props_->node_def.input()) {
+ if (input == new_input) {
+ input_exists = true;
+ break;
+ }
+ }
+ if (!input_exists) {
+ dest->MaybeCopyOnWrite();
+ dest->props_->node_def.add_input(new_input);
+ }
+ }
+ return AddEdge(source, kControlSlot, dest, kControlSlot);
+}
+
Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
int dst_index) {
TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));