diff options
author | Asim Shankar <ashankar@google.com> | 2017-08-01 12:00:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-01 12:08:17 -0700 |
commit | 96675956ef17e609d1bd60591fc998890d505004 (patch) | |
tree | da9825ac24727f5c51869845f7f2ae35065db5a4 /tensorflow/c/c_api_test.cc | |
parent | 9593704b28e43b1a10a9c16317e1ba3cef2e1921 (diff) |
C API: Avoid converting uninitialized tensorflow::Tensor to TF_Tensor*
And return error messages instead of CHECK failing when the conversion
fails.
PiperOrigin-RevId: 163863981
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r-- | tensorflow/c/c_api_test.cc | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 25b6cbd8e7..1d191fc36d 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -45,7 +45,7 @@ limitations under the License. #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { -TF_Tensor* TF_TensorFromTensor(const Tensor& src); +TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { @@ -137,6 +137,7 @@ TEST(CAPI, LibraryLoadFunctions) { void TestEncodeDecode(int line, const std::vector<string>& data) { const tensorflow::int64 n = data.size(); + TF_Status* status = TF_NewStatus(); for (const std::vector<tensorflow::int64>& dims : std::vector<std::vector<tensorflow::int64>>{ {n}, {1, n}, {n, 1}, {n / 2, 2}}) { @@ -145,7 +146,8 @@ void TestEncodeDecode(int line, const std::vector<string>& data) { for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) { src.flat<string>()(i) = data[i]; } - TF_Tensor* dst = TF_TensorFromTensor(src); + TF_Tensor* dst = TF_TensorFromTensor(src, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); // Convert back to a C++ Tensor and ensure we get expected output. Tensor output; @@ -157,6 +159,7 @@ void TestEncodeDecode(int line, const std::vector<string>& data) { TF_DeleteTensor(dst); } + TF_DeleteStatus(status); } TEST(CAPI, TensorEncodeDecodeStrings) { @@ -914,7 +917,8 @@ TEST(CAPI, SavedModel) { TF_Operation* input_op = TF_GraphOperationByName(graph, input_op_name.c_str()); ASSERT_TRUE(input_op != nullptr); - csession.SetInputs({{input_op, TF_TensorFromTensor(input)}}); + csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}}); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); const tensorflow::string output_op_name = tensorflow::ParseTensorName(output_name).first.ToString(); |