aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar James Keeling <jtkeeling@google.com>2018-07-23 07:56:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 08:00:32 -0700
commit8647db865ce41361413a2eb4c3b4d0ba404dd4e0 (patch)
tree97eafd2fd44fa322dc5b91bd676c324059612be8 /tensorflow/c
parent7aef462279657517377e0da15b90b3f3f5be16e1 (diff)
Add C API for kernel info
This is part of the work to make available kernels possible to query, to support Swift For TensorFlow. There are also a number of github issues asking for the functionality. PiperOrigin-RevId: 205660862
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/c_api.cc24
-rw-r--r--tensorflow/c/c_api.h12
-rw-r--r--tensorflow/c/c_api_test.cc53
3 files changed, 89 insertions, 0 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index a3003953a3..1b937883c8 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eval_const_tensor.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -2729,4 +2730,27 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
return ret;
#endif // __ANDROID__
}
+
+TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
+ tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
+ TF_Buffer* ret = TF_NewBuffer();
+ status->status = MessageToBuffer(kernel_list, ret);
+ if (!status->status.ok()) {
+ TF_DeleteBuffer(ret);
+ return nullptr;
+ }
+ return ret;
+}
+
+TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
+ tensorflow::KernelList kernel_list =
+ tensorflow::GetRegisteredKernelsForOp(name);
+ TF_Buffer* ret = TF_NewBuffer();
+ status->status = MessageToBuffer(kernel_list, ret);
+ if (!status->status.ok()) {
+ TF_DeleteBuffer(ret);
+ return nullptr;
+ }
+ return ret;
+}
} // end extern "C"
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index fddc09d45e..c5035e0e41 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -1610,6 +1610,18 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map,
size_t name_len,
TF_Status* status);
+// --------------------------------------------------------------------------
+// Kernel definition information.
+
+// Returns a serialized KernelList protocol buffer containing KernelDefs for all
+// registered kernels.
+TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
+
+// Returns a serialized KernelList protocol buffer containing KernelDefs for all
+// kernels registered for the operation named `name`.
+TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
+ const char* name, TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index bc04b53fbb..c470ab5649 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -29,9 +29,11 @@ limitations under the License.
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/graph.pb_text.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb_text.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
@@ -2312,6 +2314,57 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
TF_DeleteLibraryHandle(lib);
}
+class DummyKernel : public tensorflow::OpKernel {
+ public:
+ explicit DummyKernel(tensorflow::OpKernelConstruction* context)
+ : OpKernel(context) {}
+ void Compute(tensorflow::OpKernelContext* context) override {}
+};
+
+// Test we can query kernels
+REGISTER_OP("TestOpWithSingleKernel")
+ .Input("a: float")
+ .Input("b: float")
+ .Output("o: float");
+REGISTER_KERNEL_BUILDER(
+ Name("TestOpWithSingleKernel").Device(tensorflow::DEVICE_CPU), DummyKernel);
+
+TEST(TestKernel, TestGetAllRegisteredKernels) {
+ TF_Status* status = TF_NewStatus();
+ TF_Buffer* kernel_list_buf = TF_GetAllRegisteredKernels(status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ KernelList kernel_list;
+ kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
+ ASSERT_GT(kernel_list.kernel_size(), 0);
+ TF_DeleteBuffer(kernel_list_buf);
+ TF_DeleteStatus(status);
+}
+
+TEST(TestKernel, TestGetRegisteredKernelsForOp) {
+ TF_Status* status = TF_NewStatus();
+ TF_Buffer* kernel_list_buf =
+ TF_GetRegisteredKernelsForOp("TestOpWithSingleKernel", status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ KernelList kernel_list;
+ kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
+ ASSERT_EQ(kernel_list.kernel_size(), 1);
+ EXPECT_EQ(kernel_list.kernel(0).op(), "TestOpWithSingleKernel");
+ EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
+ TF_DeleteBuffer(kernel_list_buf);
+ TF_DeleteStatus(status);
+}
+
+TEST(TestKernel, TestGetRegisteredKernelsForOpNoKernels) {
+ TF_Status* status = TF_NewStatus();
+ TF_Buffer* kernel_list_buf = TF_GetRegisteredKernelsForOp("Unknown", status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ KernelList kernel_list;
+ kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
+ ASSERT_EQ(kernel_list.kernel_size(), 0);
+ TF_DeleteBuffer(kernel_list_buf);
+ TF_DeleteStatus(status);
+}
+
#undef EXPECT_TF_META
} // namespace