aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_function.cc
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-02-08 14:10:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-08 14:13:53 -0800
commit037895185e970e3bdc789fbf6aad643c271415d4 (patch)
treeb5bd3b7c88f2930e3a32c0dc629d66ba6ad03125 /tensorflow/c/c_api_function.cc
parentfa3fb289ba6a1718f9c76b2277a58f95f5e878ab (diff)
Don't fail if control dependency is on an input of the function.
PiperOrigin-RevId: 185049319
Diffstat (limited to 'tensorflow/c/c_api_function.cc')
-rw-r--r--tensorflow/c/c_api_function.cc27
1 files changed, 19 insertions, 8 deletions
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index 46271e0514..384e6c8cb9 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -44,8 +44,12 @@ class NodeNameMapping {
public:
NodeNameMapping() = default;
- // Normalize the input/output name and make it unique.
- string GetIOName(const string& name);
+ // Normalize the input name and make it unique. This is the same as the
+ // function for output, expect that it adds a name mapping for the name.
+ string GetInputName(const string& name);
+
+ // Normalize the output name and make it unique.
+ string GetOutputName(const string& name);
// Make the node name unique.
string Uniquify(const string& name);
@@ -107,7 +111,13 @@ string NodeNameMapping::UniquifyHelper(const string& name) const {
}
}
-string NodeNameMapping::GetIOName(const string& name) {
+string NodeNameMapping::GetInputName(const string& name) {
+ const string& input_name = GetOutputName(name);
+ name_mapping_[name] = input_name;
+ return input_name;
+}
+
+string NodeNameMapping::GetOutputName(const string& name) {
const string& input_name = UniquifyHelper(Normalize(name));
// Record that we used this name, but don't add it to name_mapping_
// since this name is not for a node.
@@ -214,10 +224,11 @@ Status FillFunctionBody(
// Add control inputs.
for (const Edge* edge : control_edges) {
- // Add this control input only if the src node is in the body.
+ // Add this control input only if the src node is in the body or a part of
+ // the inputs.
const string normalized = node_names.Lookup(edge->src()->name());
// If we did not find a name for the source of control edge, this
- // source must be outside of the body. Raise an error.
+ // source must be outside of the body, and not an input. Raise an error.
if (normalized.empty()) {
return InvalidArgument(
"The source of control edge ", edge->DebugString(),
@@ -279,7 +290,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i]));
argdef->set_name(output_names[i]);
} else {
- argdef->set_name(node_names.GetIOName(node->name()));
+ argdef->set_name(node_names.GetOutputName(node->name()));
}
}
@@ -289,7 +300,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
int idx = inputs[i].index;
OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
argdef->set_type(node->output_type(idx));
- const string& input_name = node_names.GetIOName(node->name());
+ const string& input_name = node_names.GetInputName(node->name());
argdef->set_name(input_name);
tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
}
@@ -467,7 +478,7 @@ Status ComputeBodyNodes(
return Status::OK();
}
-} // anonymous namespace
+} // namespace
} // namespace tensorflow
using tensorflow::Node;