aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api.cc
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-07-27 18:04:43 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-07-27 18:04:43 -0400
commitc01cfe7ced91dabc19b2392696cf0598a5df70f9 (patch)
treef2e748e6b1c216f4c97054025257d206ba980fcf /tensorflow/c/c_api.cc
parenta278365e8848f5fcbccb42f95a3c523367c1602f (diff)
2nd review: Cover more prefix conflict cases
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r--tensorflow/c/c_api.cc32
1 files changed, 25 insertions, 7 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index c1f4745e56..bcecbb0bc6 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -53,6 +53,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -2411,20 +2412,25 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
const int first_new_node_id = g->graph.num_node_ids();
- const char* child_scope_name = prefix;
- if (child_scope_name != nullptr) {
+ string prefix_cmp;
+ const char* child_scope_name;
+ if (prefix == nullptr) {
+ child_scope_name = "gradients";
+ } else {
+ prefix_cmp = string(prefix) + "/";
// The operation should fail if the provided name prefix has already been
// used in this graph
for (const auto& pair: g->name_map) {
const string& name = pair.first;
- if (name.compare(0, name.find_last_of('/'), prefix) == 0) {
- status->status =
- InvalidArgument("Duplicate node name in graph: '", prefix, "'");
+ if (name.compare(prefix) == 0 ||
+ tensorflow::str_util::StartsWith(name, prefix_cmp)) {
+ status->status = InvalidArgument("prefix [", prefix,
+ "] conflicts with existing node in the graph named [",
+ name, "]");
return;
}
}
- } else {
- child_scope_name = "gradients";
+ child_scope_name = prefix;
}
tensorflow::Scope scope =
NewInternalScope(&g->graph, &status->status, &g->refiner)
@@ -2443,6 +2449,18 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
Node* n = g->graph.FindNodeId(i);
if (n == nullptr) continue;
+
+ // Adding the gradients to the graph can alter the prefix to prevent
+ // name collisions only if this prefix has not been provided explicitly
+ // by the user. If it was provided, assert that it remained intact.
+ if (prefix != nullptr &&
+ !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) {
+ status->status = tensorflow::errors::Internal(
+ "BUG: The gradients prefix have been unexpectedly altered when "
+ "adding the nodes to the graph. This is a bug. Please file an "
+ "issue at https://github.com/tensorflow/tensorflow/issues.");
+ return;
+ }
// We have a convoluted scheme here: Using the C++ graph construction API
// to add potentially many nodes to the graph without running the checks
// (such as uniqueness of the names of nodes) we run with other functions