diff options
author | Igor Ganichev <iga@google.com> | 2017-09-19 23:32:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-19 23:35:55 -0700 |
commit | 5f453c9959e12c7922500c9ae188a7afe41036e0 (patch) | |
tree | d9a56f4bb547dba29beb2ab132de5715836fb28e /tensorflow/c/c_api_function.cc | |
parent | 0de5f7cecb785b03d652479ac4f359b284e8c3a5 (diff) |
Relax restriction on ref types in body
PiperOrigin-RevId: 169356209
Diffstat (limited to 'tensorflow/c/c_api_function.cc')
-rw-r--r-- | tensorflow/c/c_api_function.cc | 28 |
1 files changed, 10 insertions, 18 deletions
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 92ee77935e..b713aa7645 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -137,17 +137,12 @@ string NodeNameMapping::Lookup(const string& name) const { return iter->second; } -Status ValidateNoRefOutputs(const Node* node) { - for (int i = 0; i < node->num_outputs(); ++i) { - const DataType& dt = node->output_type(i); - if (IsRefType(dt)) { - return InvalidArgument("Output ", i, " of node '", node->name(), - "' has a reference " - "type ", - DataTypeString(dt)); - } - } - return Status::OK(); +Status ValidateNonRefOutput(const Node* node, int idx) { + const DataType& dt = node->output_type(idx); + return IsRefType(dt) + ? InvalidArgument("Output ", idx, " of node '", node->name(), + "' has a reference type ", DataTypeString(dt)) + : Status::OK(); } Status FillFunctionBody( @@ -366,7 +361,7 @@ Status ProcessInputs( fn_body->graph.IsValidOutputTensor(&node, idx), "Encountered while processing input ", i, " into function '", fn_name, "'"); - TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(&node), + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx), "Encountered while processing input ", i, " into function '", fn_name, "'"); @@ -401,6 +396,9 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, fn_body->graph.IsValidOutputTensor(&node, idx), "Encountered while processing output ", i, " from function '", fn_name, "'"); + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx), + "Encountered while creating function '", + fn_name, "'"); output_tensors->emplace_back(&node, idx); } return Status::OK(); @@ -419,9 +417,6 @@ Status ComputeBodyNodes( const auto& iter = input_nodes.find(node); if (iter == input_nodes.end()) { // This node is not referenced in inputs. Add it to the body. - TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node), - "Encountered while creating function '", - fn_name, "'"); body_nodes->push_back(node); } else { // This node is referenced in inputs. Currently, we place an @@ -440,9 +435,6 @@ Status ComputeBodyNodes( body_nodes->reserve(num_opers); for (int i = 0; i < num_opers; ++i) { const Node* node = &opers[i]->node; - TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node), - "Encountered while creating function '", - fn_name, "'"); body_nodes->push_back(node); } } |