diff options
author | 2018-03-13 11:36:28 -0700 | |
---|---|---|
committer | 2018-03-13 11:46:31 -0700 | |
commit | 74938be9aabe057fd7f779b6cd023f98e5a4bba2 (patch) | |
tree | f49b6dee57bbdaf979634c74b80a33fa8eda32f5 | |
parent | 9dcf033873007b48033b38b428af45abdef97ee7 (diff) |
Change back TFE_Execute logic to set '*num_retvals' on return.
PiperOrigin-RevId: 188903892
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 9 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.h | 3 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_test.cc | 5 | ||||
-rw-r--r-- | tensorflow/python/eager/core_test.py | 16 |
4 files changed, 17 insertions, 16 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 56cec2d668..0811bd363f 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -714,10 +714,7 @@ tensorflow::Status Execute( *dev_stats->add_node_stats() = *maybe_stats; } } - if (num_retvals != outputs.size()) { - return tensorflow::errors::InvalidArgument( - "Expecting ", num_retvals, " outputs but got ", outputs.size()); - } + DCHECK_EQ(num_retvals, outputs.size()); tensorflow::Device* op_device = IsCPU(device) ? nullptr : device; for (int i = 0; i < num_retvals; ++i) { tensorflow::Device* d = op_device; @@ -1154,7 +1151,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); } const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes(); - if (output_dtypes.size() != *num_retvals) { + const int output_dtypes_size = output_dtypes.size(); + if (output_dtypes_size > *num_retvals) { TF_SetStatus(status, TF_INVALID_ARGUMENT, tensorflow::strings::StrCat("Expecting ", output_dtypes.size(), " outputs, but *num_retvals is ", @@ -1162,6 +1160,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, .c_str()); return; } + *num_retvals = output_dtypes_size; if (device == nullptr) { // TODO(apassos) debug how the assignment below might return a different // device from the one requested above. diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 316006bafb..a5029bf211 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -285,7 +285,8 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op, // // 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and // '*num_retvals' should be set to the size of this array. It is an error if -// the number of outputs is different from *num_retvals. +// the size of 'retvals' is less than the number of outputs. This call sets +// *num_retvals to the number of outputs. // // If async execution is enabled, the call may simply enqueue the execution // and return "non-ready" handles in `retvals`. Note that any handles contained diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 927d119389..2268aba90d 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -553,9 +553,10 @@ void Execute_MatMul_CPU(bool async) { TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); - TFE_TensorHandle* retvals[1] = {nullptr}; - int num_retvals = 1; + TFE_TensorHandle* retvals[2] = {nullptr, nullptr}; + int num_retvals = 2; TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(1, num_retvals); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 012c68f68e..61c5526d48 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -250,16 +250,16 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteTooManyNumOutputs(self): # num_outputs provided is 50, but only one output is produced. - with self.assertRaises(errors.InvalidArgumentError): - _ = execute( - b'Mul', - num_outputs=50, - inputs=[constant_op.constant(3), - constant_op.constant(5)], - attrs=('T', dtypes.int32.as_datatype_enum))[0] + product = execute( + b'Mul', + num_outputs=50, + inputs=[constant_op.constant(3), + constant_op.constant(5)], + attrs=('T', dtypes.int32.as_datatype_enum))[0] + self.assertAllEqual(15, product) def testExecuteTooFewNumOutputs(self): - # num_outputs provided is 50, but only one output is produced. + # num_outputs provided is 0, but one output is produced. with self.assertRaises(errors.InvalidArgumentError): _ = execute( b'Mul', |