diff options
-rw-r--r-- | tensorflow/c/c_api_function.cc | 27 | ||||
-rw-r--r-- | tensorflow/c/c_api_function_test.cc | 24 |
2 files changed, 36 insertions, 15 deletions
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 46271e0514..384e6c8cb9 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -44,8 +44,12 @@ class NodeNameMapping { public: NodeNameMapping() = default; - // Normalize the input/output name and make it unique. - string GetIOName(const string& name); + // Normalize the input name and make it unique. This is the same as the + // function for output, expect that it adds a name mapping for the name. + string GetInputName(const string& name); + + // Normalize the output name and make it unique. + string GetOutputName(const string& name); // Make the node name unique. string Uniquify(const string& name); @@ -107,7 +111,13 @@ string NodeNameMapping::UniquifyHelper(const string& name) const { } } -string NodeNameMapping::GetIOName(const string& name) { +string NodeNameMapping::GetInputName(const string& name) { + const string& input_name = GetOutputName(name); + name_mapping_[name] = input_name; + return input_name; +} + +string NodeNameMapping::GetOutputName(const string& name) { const string& input_name = UniquifyHelper(Normalize(name)); // Record that we used this name, but don't add it to name_mapping_ // since this name is not for a node. @@ -214,10 +224,11 @@ Status FillFunctionBody( // Add control inputs. for (const Edge* edge : control_edges) { - // Add this control input only if the src node is in the body. + // Add this control input only if the src node is in the body or a part of + // the inputs. const string normalized = node_names.Lookup(edge->src()->name()); // If we did not find a name for the source of control edge, this - // source must be outside of the body. Raise an error. + // source must be outside of the body, and not an input. Raise an error. if (normalized.empty()) { return InvalidArgument( "The source of control edge ", edge->DebugString(), @@ -279,7 +290,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i])); argdef->set_name(output_names[i]); } else { - argdef->set_name(node_names.GetIOName(node->name())); + argdef->set_name(node_names.GetOutputName(node->name())); } } @@ -289,7 +300,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, int idx = inputs[i].index; OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg(); argdef->set_type(node->output_type(idx)); - const string& input_name = node_names.GetIOName(node->name()); + const string& input_name = node_names.GetInputName(node->name()); argdef->set_name(input_name); tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name; } @@ -467,7 +478,7 @@ Status ComputeBodyNodes( return Status::OK(); } -} // anonymous namespace +} // namespace } // namespace tensorflow using tensorflow::Node; diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index dbce66d231..7ca50119ea 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -331,6 +331,11 @@ class CApiFunctionTest : public ::testing::Test { << "Failed to find expected edge " << e.ToString() << " in fdef: " << fdef.DebugString(); } + for (const EdgeSpec& e : c_edges) { + ASSERT_TRUE(a_edges.find(e) != a_edges.end()) + << "Failed to find expected control edge " << e.ToString() + << " in fdef: " << fdef.DebugString(); + } // If caller specified all edges, check that we have seen all if (is_exact_edges) { @@ -980,7 +985,7 @@ TEST_F(CApiFunctionTest, ControlDependency) { VerifyFDef( {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, - {{"scalar", "add_0"}}); + {{"^scalar", "add_0:2"}}); } TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) { @@ -1023,12 +1028,17 @@ TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) { TF_Operation* add = AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - Define(-1, {}, {feed1, feed2}, {add}, {}, true); - EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); - EXPECT_EQ(string("The source of control edge [id=3 feed1:-1 -> add:-1] " - "is not in the body. Encountered while creating " - "function 'MyFunc'"), - string(TF_Message(s_))); + Define(-1, {}, {feed1, feed2}, {add}, {}); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); + VerifyFDef( + {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), + {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, + {{"^feed1", "add_0:2"}}); } TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) { |