aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/c_api_function.cc27
-rw-r--r--tensorflow/c/c_api_function_test.cc24
2 files changed, 36 insertions, 15 deletions
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index 46271e0514..384e6c8cb9 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -44,8 +44,12 @@ class NodeNameMapping {
public:
NodeNameMapping() = default;
- // Normalize the input/output name and make it unique.
- string GetIOName(const string& name);
+ // Normalize the input name and make it unique. This is the same as the
+ // function for output, expect that it adds a name mapping for the name.
+ string GetInputName(const string& name);
+
+ // Normalize the output name and make it unique.
+ string GetOutputName(const string& name);
// Make the node name unique.
string Uniquify(const string& name);
@@ -107,7 +111,13 @@ string NodeNameMapping::UniquifyHelper(const string& name) const {
}
}
-string NodeNameMapping::GetIOName(const string& name) {
+string NodeNameMapping::GetInputName(const string& name) {
+ const string& input_name = GetOutputName(name);
+ name_mapping_[name] = input_name;
+ return input_name;
+}
+
+string NodeNameMapping::GetOutputName(const string& name) {
const string& input_name = UniquifyHelper(Normalize(name));
// Record that we used this name, but don't add it to name_mapping_
// since this name is not for a node.
@@ -214,10 +224,11 @@ Status FillFunctionBody(
// Add control inputs.
for (const Edge* edge : control_edges) {
- // Add this control input only if the src node is in the body.
+ // Add this control input only if the src node is in the body or a part of
+ // the inputs.
const string normalized = node_names.Lookup(edge->src()->name());
// If we did not find a name for the source of control edge, this
- // source must be outside of the body. Raise an error.
+ // source must be outside of the body, and not an input. Raise an error.
if (normalized.empty()) {
return InvalidArgument(
"The source of control edge ", edge->DebugString(),
@@ -279,7 +290,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i]));
argdef->set_name(output_names[i]);
} else {
- argdef->set_name(node_names.GetIOName(node->name()));
+ argdef->set_name(node_names.GetOutputName(node->name()));
}
}
@@ -289,7 +300,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
int idx = inputs[i].index;
OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
argdef->set_type(node->output_type(idx));
- const string& input_name = node_names.GetIOName(node->name());
+ const string& input_name = node_names.GetInputName(node->name());
argdef->set_name(input_name);
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
}
@@ -467,7 +478,7 @@ Status ComputeBodyNodes(
return Status::OK();
}
-} // anonymous namespace
+} // namespace
} // namespace tensorflow
using tensorflow::Node;
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index dbce66d231..7ca50119ea 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -331,6 +331,11 @@ class CApiFunctionTest : public ::testing::Test {
<< "Failed to find expected edge " << e.ToString()
<< " in fdef: " << fdef.DebugString();
}
+ for (const EdgeSpec& e : c_edges) {
+ ASSERT_TRUE(a_edges.find(e) != a_edges.end())
+ << "Failed to find expected control edge " << e.ToString()
+ << " in fdef: " << fdef.DebugString();
+ }
// If caller specified all edges, check that we have seen all
if (is_exact_edges) {
@@ -980,7 +985,7 @@ TEST_F(CApiFunctionTest, ControlDependency) {
VerifyFDef(
{"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
{{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
- {{"scalar", "add_0"}});
+ {{"^scalar", "add_0:2"}});
}
TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) {
@@ -1023,12 +1028,17 @@ TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) {
TF_Operation* add =
AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_);
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- Define(-1, {}, {feed1, feed2}, {add}, {}, true);
- EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
- EXPECT_EQ(string("The source of control edge [id=3 feed1:-1 -> add:-1] "
- "is not in the body. Encountered while creating "
- "function 'MyFunc'"),
- string(TF_Message(s_)));
+ Define(-1, {}, {feed1, feed2}, {add}, {});
+
+ // Use, run, and verify
+ TF_Operation* two = ScalarConst(2, host_graph_, s_);
+ TF_Operation* func_feed = Placeholder(host_graph_, s_);
+ TF_Operation* func_op = Use({two, func_feed});
+ Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
+ VerifyFDef(
+ {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
+ {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
+ {{"^feed1", "add_0:2"}});
}
TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) {