diff options
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r-- | tensorflow/c/c_api_test.cc | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index bc04b53fbb..e674b1623c 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -29,9 +29,11 @@ limitations under the License. #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb_text.h" +#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb_text.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -1424,6 +1426,29 @@ TEST(CAPI, SavedModelNullArgsAreValid) { TF_DeleteStatus(s); } +TEST(CAPI, DeletingNullPointerIsSafe) { + TF_Status* status = TF_NewStatus(); + + TF_DeleteStatus(nullptr); + TF_DeleteBuffer(nullptr); + TF_DeleteTensor(nullptr); + TF_DeleteSessionOptions(nullptr); + TF_DeleteGraph(nullptr); + TF_DeleteImportGraphDefOptions(nullptr); + TF_DeleteImportGraphDefResults(nullptr); + TF_DeleteFunction(nullptr); + TF_DeleteSession(nullptr, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeletePRunHandle(nullptr); + TF_DeleteDeprecatedSession(nullptr, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteDeviceList(nullptr); + TF_DeleteLibraryHandle(nullptr); + TF_DeleteApiDefMap(nullptr); + + TF_DeleteStatus(status); +} + REGISTER_OP("TestOpWithNoGradient") .Input("x: T") .Output("y: T") @@ -2312,6 +2337,57 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) { TF_DeleteLibraryHandle(lib); } +class DummyKernel : public tensorflow::OpKernel { + public: + explicit DummyKernel(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(tensorflow::OpKernelContext* context) override {} +}; + +// Test we can query kernels +REGISTER_OP("TestOpWithSingleKernel") + .Input("a: float") + .Input("b: float") + .Output("o: float"); +REGISTER_KERNEL_BUILDER( + Name("TestOpWithSingleKernel").Device(tensorflow::DEVICE_CPU), DummyKernel); + +TEST(TestKernel, TestGetAllRegisteredKernels) { + TF_Status* status = TF_NewStatus(); + TF_Buffer* kernel_list_buf = TF_GetAllRegisteredKernels(status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + KernelList kernel_list; + kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length); + ASSERT_GT(kernel_list.kernel_size(), 0); + TF_DeleteBuffer(kernel_list_buf); + TF_DeleteStatus(status); +} + +TEST(TestKernel, TestGetRegisteredKernelsForOp) { + TF_Status* status = TF_NewStatus(); + TF_Buffer* kernel_list_buf = + TF_GetRegisteredKernelsForOp("TestOpWithSingleKernel", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + KernelList kernel_list; + kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length); + ASSERT_EQ(kernel_list.kernel_size(), 1); + EXPECT_EQ(kernel_list.kernel(0).op(), "TestOpWithSingleKernel"); + EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU"); + TF_DeleteBuffer(kernel_list_buf); + TF_DeleteStatus(status); +} + +TEST(TestKernel, TestGetRegisteredKernelsForOpNoKernels) { + TF_Status* status = TF_NewStatus(); + TF_Buffer* kernel_list_buf = TF_GetRegisteredKernelsForOp("Unknown", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + KernelList kernel_list; + kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length); + ASSERT_EQ(kernel_list.kernel_size(), 0); + TF_DeleteBuffer(kernel_list_buf); + TF_DeleteStatus(status); +} + #undef EXPECT_TF_META } // namespace |