aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-07-27 18:04:43 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-07-27 18:04:43 -0400
commitc01cfe7ced91dabc19b2392696cf0598a5df70f9 (patch)
treef2e748e6b1c216f4c97054025257d206ba980fcf /tensorflow/c
parenta278365e8848f5fcbccb42f95a3c523367c1602f (diff)
2nd review: Cover more prefix conflict cases
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/c_api.cc32
-rw-r--r--tensorflow/c/c_api_test.cc66
2 files changed, 82 insertions, 16 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index c1f4745e56..bcecbb0bc6 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -53,6 +53,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -2411,20 +2412,25 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
const int first_new_node_id = g->graph.num_node_ids();
- const char* child_scope_name = prefix;
- if (child_scope_name != nullptr) {
+ string prefix_cmp;
+ const char* child_scope_name;
+ if (prefix == nullptr) {
+ child_scope_name = "gradients";
+ } else {
+ prefix_cmp = string(prefix) + "/";
// The operation should fail if the provided name prefix has already been
// used in this graph
for (const auto& pair: g->name_map) {
const string& name = pair.first;
- if (name.compare(0, name.find_last_of('/'), prefix) == 0) {
- status->status =
- InvalidArgument("Duplicate node name in graph: '", prefix, "'");
+ if (name.compare(prefix) == 0 ||
+ tensorflow::str_util::StartsWith(name, prefix_cmp)) {
+ status->status = InvalidArgument("prefix [", prefix,
+ "] conflicts with existing node in the graph named [",
+ name, "]");
return;
}
}
- } else {
- child_scope_name = "gradients";
+ child_scope_name = prefix;
}
tensorflow::Scope scope =
NewInternalScope(&g->graph, &status->status, &g->refiner)
@@ -2443,6 +2449,18 @@ void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
Node* n = g->graph.FindNodeId(i);
if (n == nullptr) continue;
+
+ // Adding the gradients to the graph can alter the prefix to prevent
+ // name collisions only if this prefix has not been provided explicitly
+ // by the user. If it was provided, assert that it remained intact.
+ if (prefix != nullptr &&
+ !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) {
+ status->status = tensorflow::errors::Internal(
+ "BUG: The gradients prefix have been unexpectedly altered when "
+ "adding the nodes to the graph. This is a bug. Please file an "
+ "issue at https://github.com/tensorflow/tensorflow/issues.");
+ return;
+ }
// We have a convoluted scheme here: Using the C++ graph construction API
// to add potentially many nodes to the graph without running the checks
// (such as uniqueness of the names of nodes) we run with other functions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 7094d5d32d..d8d2533c60 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -1708,6 +1708,20 @@ class CApiGradientsTest : public ::testing::Test {
return op;
}
+ void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1,
+ const char* prefix2 = nullptr) {
+ TF_Output inputs[2];
+ TF_Output outputs[1];
+ TF_Output grad_outputs[2];
+
+ BuildSuccessGraph(inputs, outputs);
+
+ AddGradients(false, prefix1, inputs, 2, outputs, 1, grad_outputs);
+ if (prefix2 != nullptr) {
+ AddGradients(false, prefix2, inputs, 2, outputs, 1, grad_outputs);
+ }
+ }
+
TF_Status* s_;
TF_Graph* graph_;
TF_Graph* expected_graph_;
@@ -1727,19 +1741,53 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
TestGradientsError(false);
}
-TEST_F(CApiGradientsTest, Gradients_WithPrefix) {
- TF_Output inputs[2];
- TF_Output outputs[1];
- TF_Output grad_outputs[2];
+TEST_F(CApiGradientsTest, GradientsPrefix_PrefixIsOk) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
- BuildSuccessGraph(inputs, outputs);
- AddGradients(false, "mygrads", inputs, 2, outputs, 1, grad_outputs);
- EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithDistinctPrefixes) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients_1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInSameScope) {
+ BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope/gradients_1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
- AddGradients(false, "mygrads_1", inputs, 2, outputs, 1, grad_outputs);
+TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInDifferentScopes) {
+ BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope_1/gradients");
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsSubScopeOf1st) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/sub");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_PrefixMatchesExistingNodeName) {
+ BuildGraphAndAddGradientsWithPrefixes("Const_0");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithIdenticalPrefixes) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsMatchingNodeOf1st) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/MatMul");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_1stGradientsMatchingNodeOf2nd) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients/MatMul", "gradients");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
- AddGradients(false, "mygrads_1", inputs, 2, outputs, 1, grad_outputs);
+TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsParentScopeOf1st) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients/sub", "gradients");
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
}