diff options
Diffstat (limited to 'tensorflow/core/framework/op_kernel_test.cc')
-rw-r--r-- | tensorflow/core/framework/op_kernel_test.cc | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index b76a3400a8..83dda6579b 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -965,7 +965,8 @@ BENCHMARK(BM_ConcatInputRange); BENCHMARK(BM_SelectInputRange); TEST(RegisteredKernels, CanCallGetAllRegisteredKernels) { - auto all_registered_kernels = GetAllRegisteredKernels().kernel(); + auto kernel_list = GetAllRegisteredKernels(); + auto all_registered_kernels = kernel_list.kernel(); auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; }; // Verify we can find the "Test1" op registered above @@ -986,5 +987,20 @@ TEST(RegisteredKernels, CanLogAllRegisteredKernels) { tensorflow::LogAllRegisteredKernels(); } +TEST(RegisteredKernels, GetFilteredRegisteredKernels) { + auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; }; + auto kernel_list = GetFilteredRegisteredKernels(has_name_test1); + ASSERT_EQ(kernel_list.kernel_size(), 1); + EXPECT_EQ(kernel_list.kernel(0).op(), "Test1"); + EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU"); +} + +TEST(RegisteredKernels, GetRegisteredKernelsForOp) { + auto kernel_list = GetRegisteredKernelsForOp("Test1"); + ASSERT_EQ(kernel_list.kernel_size(), 1); + EXPECT_EQ(kernel_list.kernel(0).op(), "Test1"); + EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU"); +} + } // namespace } // namespace tensorflow |