aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_test.cc
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-08-01 12:00:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-01 12:08:17 -0700
commit96675956ef17e609d1bd60591fc998890d505004 (patch)
treeda9825ac24727f5c51869845f7f2ae35065db5a4 /tensorflow/c/c_api_test.cc
parent9593704b28e43b1a10a9c16317e1ba3cef2e1921 (diff)
C API: Avoid converting uninitialized tensorflow::Tensor to TF_Tensor*
And return error messages instead of CHECK failing when the conversion fails. PiperOrigin-RevId: 163863981
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r--tensorflow/c/c_api_test.cc10
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 25b6cbd8e7..1d191fc36d 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -45,7 +45,7 @@ limitations under the License.
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
-TF_Tensor* TF_TensorFromTensor(const Tensor& src);
+TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
namespace {
@@ -137,6 +137,7 @@ TEST(CAPI, LibraryLoadFunctions) {
void TestEncodeDecode(int line, const std::vector<string>& data) {
const tensorflow::int64 n = data.size();
+ TF_Status* status = TF_NewStatus();
for (const std::vector<tensorflow::int64>& dims :
std::vector<std::vector<tensorflow::int64>>{
{n}, {1, n}, {n, 1}, {n / 2, 2}}) {
@@ -145,7 +146,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) {
src.flat<string>()(i) = data[i];
}
- TF_Tensor* dst = TF_TensorFromTensor(src);
+ TF_Tensor* dst = TF_TensorFromTensor(src, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
// Convert back to a C++ Tensor and ensure we get expected output.
Tensor output;
@@ -157,6 +159,7 @@ void TestEncodeDecode(int line, const std::vector<string>& data) {
TF_DeleteTensor(dst);
}
+ TF_DeleteStatus(status);
}
TEST(CAPI, TensorEncodeDecodeStrings) {
@@ -914,7 +917,8 @@ TEST(CAPI, SavedModel) {
TF_Operation* input_op =
TF_GraphOperationByName(graph, input_op_name.c_str());
ASSERT_TRUE(input_op != nullptr);
- csession.SetInputs({{input_op, TF_TensorFromTensor(input)}});
+ csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}});
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
const tensorflow::string output_op_name =
tensorflow::ParseTensorName(output_name).first.ToString();