diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-02-08 14:10:51 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-08 14:13:53 -0800 |
commit | 037895185e970e3bdc789fbf6aad643c271415d4 (patch) | |
tree | b5bd3b7c88f2930e3a32c0dc629d66ba6ad03125 /tensorflow/c/c_api_function.cc | |
parent | fa3fb289ba6a1718f9c76b2277a58f95f5e878ab (diff) |
Don't fail if control dependency is on an input of the function.
PiperOrigin-RevId: 185049319
Diffstat (limited to 'tensorflow/c/c_api_function.cc')
-rw-r--r-- | tensorflow/c/c_api_function.cc | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 46271e0514..384e6c8cb9 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -44,8 +44,12 @@ class NodeNameMapping { public: NodeNameMapping() = default; - // Normalize the input/output name and make it unique. - string GetIOName(const string& name); + // Normalize the input name and make it unique. This is the same as the + // function for output, expect that it adds a name mapping for the name. + string GetInputName(const string& name); + + // Normalize the output name and make it unique. + string GetOutputName(const string& name); // Make the node name unique. string Uniquify(const string& name); @@ -107,7 +111,13 @@ string NodeNameMapping::UniquifyHelper(const string& name) const { } } -string NodeNameMapping::GetIOName(const string& name) { +string NodeNameMapping::GetInputName(const string& name) { + const string& input_name = GetOutputName(name); + name_mapping_[name] = input_name; + return input_name; +} + +string NodeNameMapping::GetOutputName(const string& name) { const string& input_name = UniquifyHelper(Normalize(name)); // Record that we used this name, but don't add it to name_mapping_ // since this name is not for a node. @@ -214,10 +224,11 @@ Status FillFunctionBody( // Add control inputs. for (const Edge* edge : control_edges) { - // Add this control input only if the src node is in the body. + // Add this control input only if the src node is in the body or a part of + // the inputs. const string normalized = node_names.Lookup(edge->src()->name()); // If we did not find a name for the source of control edge, this - // source must be outside of the body. Raise an error. + // source must be outside of the body, and not an input. Raise an error. if (normalized.empty()) { return InvalidArgument( "The source of control edge ", edge->DebugString(), @@ -279,7 +290,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i])); argdef->set_name(output_names[i]); } else { - argdef->set_name(node_names.GetIOName(node->name())); + argdef->set_name(node_names.GetOutputName(node->name())); } } @@ -289,7 +300,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, int idx = inputs[i].index; OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg(); argdef->set_type(node->output_type(idx)); - const string& input_name = node_names.GetIOName(node->name()); + const string& input_name = node_names.GetInputName(node->name()); argdef->set_name(input_name); tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name; } @@ -467,7 +478,7 @@ Status ComputeBodyNodes( return Status::OK(); } -} // anonymous namespace +} // namespace } // namespace tensorflow using tensorflow::Node; |