diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-08-02 15:27:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-02 15:34:43 -0700 |
commit | daaaab25fb8aa981fdf76740763714c9ea3f2879 (patch) | |
tree | b1a62c005b3b1fbd21e28b3335b5a69d7377bc1b /tensorflow/c | |
parent | f8afea0e1dc7226d2d4e6bc9fb75ba7094fa727e (diff) |
Check if the handle is nullptr, and fail early instead of segfaulting.
PiperOrigin-RevId: 207176253
Diffstat (limited to 'tensorflow/c')
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 20 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_test.cc | 36 |
2 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 7321b4b791..555dab3e89 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -348,6 +348,11 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } int result; status->status = h->handle->NumDims(&result); return result; @@ -355,12 +360,22 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return -1; + } tensorflow::int64 result; status->status = h->handle->Dim(dim_index, &result); return result; } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } tensorflow::Device* d = nullptr; status->status = h->handle->OpDevice(&d); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" @@ -368,6 +383,11 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { } TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } // TODO(agarwal): move this implementation inside TFE_TensorHandle. tensorflow::Device* d = nullptr; tensorflow::Device* op_device = nullptr; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 0bdea70fe6..6f2fbee884 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -615,6 +615,42 @@ void SetAndGetOpDevices(bool async) { TF_DeleteStatus(status); } +TEST(CAPI, TensorHandleNullptr) { + TFE_TensorHandle* h = nullptr; + std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( + TF_NewStatus(), TF_DeleteStatus); + + TF_Tensor* t = TFE_TensorHandleResolve(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(t, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + const char* device_name = TFE_TensorHandleDeviceName(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(device_name, nullptr); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int num_dims = TFE_TensorHandleNumDims(h, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(num_dims, -1); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); + + TF_SetStatus(status.get(), TF_OK, ""); + + int dim = TFE_TensorHandleDim(h, 0, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + ASSERT_EQ(dim, -1); + ASSERT_EQ("The passed in handle is a nullptr", + string(TF_Message(status.get()))); +} + void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); |