diff options
author | 2018-02-08 14:10:51 -0800 | |
---|---|---|
committer | 2018-02-08 14:13:53 -0800 | |
commit | 037895185e970e3bdc789fbf6aad643c271415d4 (patch) | |
tree | b5bd3b7c88f2930e3a32c0dc629d66ba6ad03125 /tensorflow/c/c_api_function_test.cc | |
parent | fa3fb289ba6a1718f9c76b2277a58f95f5e878ab (diff) |
Don't fail if control dependency is on an input of the function.
PiperOrigin-RevId: 185049319
Diffstat (limited to 'tensorflow/c/c_api_function_test.cc')
-rw-r--r-- | tensorflow/c/c_api_function_test.cc | 24 |
1 files changed, 17 insertions, 7 deletions
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index dbce66d231..7ca50119ea 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -331,6 +331,11 @@ class CApiFunctionTest : public ::testing::Test { << "Failed to find expected edge " << e.ToString() << " in fdef: " << fdef.DebugString(); } + for (const EdgeSpec& e : c_edges) { + ASSERT_TRUE(a_edges.find(e) != a_edges.end()) + << "Failed to find expected control edge " << e.ToString() + << " in fdef: " << fdef.DebugString(); + } // If caller specified all edges, check that we have seen all if (is_exact_edges) { @@ -980,7 +985,7 @@ TEST_F(CApiFunctionTest, ControlDependency) { VerifyFDef( {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, - {{"scalar", "add_0"}}); + {{"^scalar", "add_0:2"}}); } TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) { @@ -1023,12 +1028,17 @@ TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) { TF_Operation* add = AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - Define(-1, {}, {feed1, feed2}, {add}, {}, true); - EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); - EXPECT_EQ(string("The source of control edge [id=3 feed1:-1 -> add:-1] " - "is not in the body. Encountered while creating " - "function 'MyFunc'"), - string(TF_Message(s_))); + Define(-1, {}, {feed1, feed2}, {add}, {}); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); + VerifyFDef( + {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), + {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, + {{"^feed1", "add_0:2"}}); } TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) { |