From c01cfe7ced91dabc19b2392696cf0598a5df70f9 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Fri, 27 Jul 2018 18:04:43 -0400 Subject: 2nd review: Cover more prefix conflict cases --- tensorflow/c/c_api.cc | 32 +++++++++++++++++----- tensorflow/c/c_api_test.cc | 66 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 82 insertions(+), 16 deletions(-) (limited to 'tensorflow/c') diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index c1f4745e56..bcecbb0bc6 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -53,6 +53,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" @@ -2411,20 +2412,25 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, const int first_new_node_id = g->graph.num_node_ids(); - const char* child_scope_name = prefix; - if (child_scope_name != nullptr) { + string prefix_cmp; + const char* child_scope_name; + if (prefix == nullptr) { + child_scope_name = "gradients"; + } else { + prefix_cmp = string(prefix) + "/"; // The operation should fail if the provided name prefix has already been // used in this graph for (const auto& pair: g->name_map) { const string& name = pair.first; - if (name.compare(0, name.find_last_of('/'), prefix) == 0) { - status->status = - InvalidArgument("Duplicate node name in graph: '", prefix, "'"); + if (name.compare(prefix) == 0 || + tensorflow::str_util::StartsWith(name, prefix_cmp)) { + status->status = InvalidArgument("prefix [", prefix, + "] conflicts with existing node in the graph named [", + name, "]"); return; } } - } else { - child_scope_name = "gradients"; + child_scope_name = prefix; } tensorflow::Scope scope = NewInternalScope(&g->graph, &status->status, &g->refiner) @@ -2443,6 +2449,18 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { Node* n = g->graph.FindNodeId(i); if (n == nullptr) continue; + + // Adding the gradients to the graph can alter the prefix to prevent + // name collisions only if this prefix has not been provided explicitly + // by the user. If it was provided, assert that it remained intact. + if (prefix != nullptr && + !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) { + status->status = tensorflow::errors::Internal( + "BUG: The gradients prefix have been unexpectedly altered when " + "adding the nodes to the graph. This is a bug. Please file an " + "issue at https://github.com/tensorflow/tensorflow/issues."); + return; + } // We have a convoluted scheme here: Using the C++ graph construction API // to add potentially many nodes to the graph without running the checks // (such as uniqueness of the names of nodes) we run with other functions diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 7094d5d32d..d8d2533c60 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -1708,6 +1708,20 @@ class CApiGradientsTest : public ::testing::Test { return op; } + void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1, + const char* prefix2 = nullptr) { + TF_Output inputs[2]; + TF_Output outputs[1]; + TF_Output grad_outputs[2]; + + BuildSuccessGraph(inputs, outputs); + + AddGradients(false, prefix1, inputs, 2, outputs, 1, grad_outputs); + if (prefix2 != nullptr) { + AddGradients(false, prefix2, inputs, 2, outputs, 1, grad_outputs); + } + } + TF_Status* s_; TF_Graph* graph_; TF_Graph* expected_graph_; @@ -1727,19 +1741,53 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } -TEST_F(CApiGradientsTest, Gradients_WithPrefix) { - TF_Output inputs[2]; - TF_Output outputs[1]; - TF_Output grad_outputs[2]; +TEST_F(CApiGradientsTest, GradientsPrefix_PrefixIsOk) { + BuildGraphAndAddGradientsWithPrefixes("gradients"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} - BuildSuccessGraph(inputs, outputs); - AddGradients(false, "mygrads", inputs, 2, outputs, 1, grad_outputs); - EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithDistinctPrefixes) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients_1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInSameScope) { + BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope/gradients_1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} - AddGradients(false, "mygrads_1", inputs, 2, outputs, 1, grad_outputs); +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInDifferentScopes) { + BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope_1/gradients"); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsSubScopeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/sub"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_PrefixMatchesExistingNodeName) { + BuildGraphAndAddGradientsWithPrefixes("Const_0"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithIdenticalPrefixes) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsMatchingNodeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/MatMul"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_1stGradientsMatchingNodeOf2nd) { + BuildGraphAndAddGradientsWithPrefixes("gradients/MatMul", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} - AddGradients(false, "mygrads_1", inputs, 2, outputs, 1, grad_outputs); +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsParentScopeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients/sub", "gradients"); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); } -- cgit v1.2.3