diff options
author | karl@kubx.ca <karl@kubx.ca> | 2018-07-16 00:49:06 -0400 |
---|---|---|
committer | karl@kubx.ca <karl@kubx.ca> | 2018-07-25 21:10:30 -0400 |
commit | a278365e8848f5fcbccb42f95a3c523367c1602f (patch) | |
tree | 368044bbb27312ea44e0856548b84d260fc304d0 /tensorflow/c/c_api.cc | |
parent | 7ebdc9834bbc583bcc42551b660c8ed256ea7416 (diff) |
Enforce uniqueness of custom prefixes for gradients
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r-- | tensorflow/c/c_api.cc | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 32b0b70620..c1f4745e56 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2411,9 +2411,24 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, const int first_new_node_id = g->graph.num_node_ids(); + const char* child_scope_name = prefix; + if (child_scope_name != nullptr) { + // The operation should fail if the provided name prefix has already been + // used in this graph + for (const auto& pair: g->name_map) { + const string& name = pair.first; + if (name.compare(0, name.find_last_of('/'), prefix) == 0) { + status->status = + InvalidArgument("Duplicate node name in graph: '", prefix, "'"); + return; + } + } + } else { + child_scope_name = "gradients"; + } tensorflow::Scope scope = NewInternalScope(&g->graph, &status->status, &g->refiner) - .NewSubScope(prefix != nullptr ? prefix : "gradients"); + .NewSubScope(child_scope_name); if (dx != nullptr) { std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny); |