diff options
author | karl@kubx.ca <karl@kubx.ca> | 2018-07-27 18:04:43 -0400 |
---|---|---|
committer | karl@kubx.ca <karl@kubx.ca> | 2018-07-27 18:04:43 -0400 |
commit | c01cfe7ced91dabc19b2392696cf0598a5df70f9 (patch) | |
tree | f2e748e6b1c216f4c97054025257d206ba980fcf /tensorflow/c/c_api.cc | |
parent | a278365e8848f5fcbccb42f95a3c523367c1602f (diff) |
2nd review: Cover more prefix conflict cases
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r-- | tensorflow/c/c_api.cc | 32 |
1 files changed, 25 insertions, 7 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index c1f4745e56..bcecbb0bc6 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -53,6 +53,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" @@ -2411,20 +2412,25 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, const int first_new_node_id = g->graph.num_node_ids(); - const char* child_scope_name = prefix; - if (child_scope_name != nullptr) { + string prefix_cmp; + const char* child_scope_name; + if (prefix == nullptr) { + child_scope_name = "gradients"; + } else { + prefix_cmp = string(prefix) + "/"; // The operation should fail if the provided name prefix has already been // used in this graph for (const auto& pair: g->name_map) { const string& name = pair.first; - if (name.compare(0, name.find_last_of('/'), prefix) == 0) { - status->status = - InvalidArgument("Duplicate node name in graph: '", prefix, "'"); + if (name.compare(prefix) == 0 || + tensorflow::str_util::StartsWith(name, prefix_cmp)) { + status->status = InvalidArgument("prefix [", prefix, + "] conflicts with existing node in the graph named [", + name, "]"); return; } } - } else { - child_scope_name = "gradients"; + child_scope_name = prefix; } tensorflow::Scope scope = NewInternalScope(&g->graph, &status->status, &g->refiner) @@ -2443,6 +2449,18 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { Node* n = g->graph.FindNodeId(i); if (n == nullptr) continue; + + // Adding the gradients to the graph can alter the prefix to prevent + // name collisions only if this prefix has not been provided explicitly + // by the user. If it was provided, assert that it remained intact. + if (prefix != nullptr && + !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) { + status->status = tensorflow::errors::Internal( + "BUG: The gradients prefix have been unexpectedly altered when " + "adding the nodes to the graph. This is a bug. Please file an " + "issue at https://github.com/tensorflow/tensorflow/issues."); + return; + } // We have a convoluted scheme here: Using the C++ graph construction API // to add potentially many nodes to the graph without running the checks // (such as uniqueness of the names of nodes) we run with other functions |