aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-13 11:36:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-13 11:46:31 -0700
commit74938be9aabe057fd7f779b6cd023f98e5a4bba2 (patch)
treef49b6dee57bbdaf979634c74b80a33fa8eda32f5
parent9dcf033873007b48033b38b428af45abdef97ee7 (diff)
Change back TFE_Execute logic to set '*num_retvals' on return.
PiperOrigin-RevId: 188903892
-rw-r--r--tensorflow/c/eager/c_api.cc9
-rw-r--r--tensorflow/c/eager/c_api.h3
-rw-r--r--tensorflow/c/eager/c_api_test.cc5
-rw-r--r--tensorflow/python/eager/core_test.py16
4 files changed, 17 insertions, 16 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 56cec2d668..0811bd363f 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -714,10 +714,7 @@ tensorflow::Status Execute(
*dev_stats->add_node_stats() = *maybe_stats;
}
}
- if (num_retvals != outputs.size()) {
- return tensorflow::errors::InvalidArgument(
- "Expecting ", num_retvals, " outputs but got ", outputs.size());
- }
+ DCHECK_EQ(num_retvals, outputs.size());
tensorflow::Device* op_device = IsCPU(device) ? nullptr : device;
for (int i = 0; i < num_retvals; ++i) {
tensorflow::Device* d = op_device;
@@ -1154,7 +1151,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
}
const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes();
- if (output_dtypes.size() != *num_retvals) {
+ const int output_dtypes_size = output_dtypes.size();
+ if (output_dtypes_size > *num_retvals) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat("Expecting ", output_dtypes.size(),
" outputs, but *num_retvals is ",
@@ -1162,6 +1160,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
.c_str());
return;
}
+ *num_retvals = output_dtypes_size;
if (device == nullptr) {
// TODO(apassos) debug how the assignment below might return a different
// device from the one requested above.
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 316006bafb..a5029bf211 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -285,7 +285,8 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op,
//
// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and
// '*num_retvals' should be set to the size of this array. It is an error if
-// the number of outputs is different from *num_retvals.
+// the size of 'retvals' is less than the number of outputs. This call sets
+// *num_retvals to the number of outputs.
//
// If async execution is enabled, the call may simply enqueue the execution
// and return "non-ready" handles in `retvals`. Note that any handles contained
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 927d119389..2268aba90d 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -553,9 +553,10 @@ void Execute_MatMul_CPU(bool async) {
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
- TFE_TensorHandle* retvals[1] = {nullptr};
- int num_retvals = 1;
+ TFE_TensorHandle* retvals[2] = {nullptr, nullptr};
+ int num_retvals = 2;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(1, num_retvals);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index 012c68f68e..61c5526d48 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -250,16 +250,16 @@ class TFETest(test_util.TensorFlowTestCase):
def testExecuteTooManyNumOutputs(self):
# num_outputs provided is 50, but only one output is produced.
- with self.assertRaises(errors.InvalidArgumentError):
- _ = execute(
- b'Mul',
- num_outputs=50,
- inputs=[constant_op.constant(3),
- constant_op.constant(5)],
- attrs=('T', dtypes.int32.as_datatype_enum))[0]
+ product = execute(
+ b'Mul',
+ num_outputs=50,
+ inputs=[constant_op.constant(3),
+ constant_op.constant(5)],
+ attrs=('T', dtypes.int32.as_datatype_enum))[0]
+ self.assertAllEqual(15, product)
def testExecuteTooFewNumOutputs(self):
- # num_outputs provided is 50, but only one output is produced.
+ # num_outputs provided is 0, but one output is produced.
with self.assertRaises(errors.InvalidArgumentError):
_ = execute(
b'Mul',