aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_function.cc
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-09-19 23:32:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-19 23:35:55 -0700
commit5f453c9959e12c7922500c9ae188a7afe41036e0 (patch)
treed9a56f4bb547dba29beb2ab132de5715836fb28e /tensorflow/c/c_api_function.cc
parent0de5f7cecb785b03d652479ac4f359b284e8c3a5 (diff)
Relax restriction on ref types in body
PiperOrigin-RevId: 169356209
Diffstat (limited to 'tensorflow/c/c_api_function.cc')
-rw-r--r--tensorflow/c/c_api_function.cc28
1 files changed, 10 insertions, 18 deletions
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index 92ee77935e..b713aa7645 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -137,17 +137,12 @@ string NodeNameMapping::Lookup(const string& name) const {
return iter->second;
}
-Status ValidateNoRefOutputs(const Node* node) {
- for (int i = 0; i < node->num_outputs(); ++i) {
- const DataType& dt = node->output_type(i);
- if (IsRefType(dt)) {
- return InvalidArgument("Output ", i, " of node '", node->name(),
- "' has a reference "
- "type ",
- DataTypeString(dt));
- }
- }
- return Status::OK();
+Status ValidateNonRefOutput(const Node* node, int idx) {
+ const DataType& dt = node->output_type(idx);
+ return IsRefType(dt)
+ ? InvalidArgument("Output ", idx, " of node '", node->name(),
+ "' has a reference type ", DataTypeString(dt))
+ : Status::OK();
}
Status FillFunctionBody(
@@ -366,7 +361,7 @@ Status ProcessInputs(
fn_body->graph.IsValidOutputTensor(&node, idx),
"Encountered while processing input ", i, " into function '", fn_name,
"'");
- TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(&node),
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx),
"Encountered while processing input ", i,
" into function '", fn_name, "'");
@@ -401,6 +396,9 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
fn_body->graph.IsValidOutputTensor(&node, idx),
"Encountered while processing output ", i, " from function '", fn_name,
"'");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx),
+ "Encountered while creating function '",
+ fn_name, "'");
output_tensors->emplace_back(&node, idx);
}
return Status::OK();
@@ -419,9 +417,6 @@ Status ComputeBodyNodes(
const auto& iter = input_nodes.find(node);
if (iter == input_nodes.end()) {
// This node is not referenced in inputs. Add it to the body.
- TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
- "Encountered while creating function '",
- fn_name, "'");
body_nodes->push_back(node);
} else {
// This node is referenced in inputs. Currently, we place an
@@ -440,9 +435,6 @@ Status ComputeBodyNodes(
body_nodes->reserve(num_opers);
for (int i = 0; i < num_opers; ++i) {
const Node* node = &opers[i]->node;
- TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node),
- "Encountered while creating function '",
- fn_name, "'");
body_nodes->push_back(node);
}
}