aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/op_kernel_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/op_kernel_test.cc')
-rw-r--r--tensorflow/core/framework/op_kernel_test.cc18
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc
index b76a3400a8..83dda6579b 100644
--- a/tensorflow/core/framework/op_kernel_test.cc
+++ b/tensorflow/core/framework/op_kernel_test.cc
@@ -965,7 +965,8 @@ BENCHMARK(BM_ConcatInputRange);
BENCHMARK(BM_SelectInputRange);
TEST(RegisteredKernels, CanCallGetAllRegisteredKernels) {
- auto all_registered_kernels = GetAllRegisteredKernels().kernel();
+ auto kernel_list = GetAllRegisteredKernels();
+ auto all_registered_kernels = kernel_list.kernel();
auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
// Verify we can find the "Test1" op registered above
@@ -986,5 +987,20 @@ TEST(RegisteredKernels, CanLogAllRegisteredKernels) {
tensorflow::LogAllRegisteredKernels();
}
+TEST(RegisteredKernels, GetFilteredRegisteredKernels) {
+ auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
+ auto kernel_list = GetFilteredRegisteredKernels(has_name_test1);
+ ASSERT_EQ(kernel_list.kernel_size(), 1);
+ EXPECT_EQ(kernel_list.kernel(0).op(), "Test1");
+ EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
+}
+
+TEST(RegisteredKernels, GetRegisteredKernelsForOp) {
+ auto kernel_list = GetRegisteredKernelsForOp("Test1");
+ ASSERT_EQ(kernel_list.kernel_size(), 1);
+ EXPECT_EQ(kernel_list.kernel(0).op(), "Test1");
+ EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
+}
+
} // namespace
} // namespace tensorflow