aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar James Keeling <jtkeeling@google.com>2018-07-23 07:12:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 07:17:03 -0700
commit7aef462279657517377e0da15b90b3f3f5be16e1 (patch)
tree41b332c62a0f3ade7d7a68a62edb3069ea56ef9f /tensorflow/core/framework
parent21d0205916eded7e2bf2f26e43dd41b2f86cba3f (diff)
Add GetFilteredRegisteredKernels and refactor
GetFilteredRegisteredKernels makes it easier for users to query at runtime which kernels are available which match some predicate. The most common usage will be querying which kernels are available for a given op, so we add the specialized GetRegisteredKernelsForOp. This is part of the work to make available kernels possible to query, to support Swift For TensorFlow. There are also a number of github issues asking for the functionality. I will add C API and Python API support in upcoming changes. PiperOrigin-RevId: 205656251
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/op_kernel.cc45
-rw-r--r--tensorflow/core/framework/op_kernel.h7
-rw-r--r--tensorflow/core/framework/op_kernel_test.cc18
3 files changed, 52 insertions, 18 deletions
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 58feec90f0..507aa9e447 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -1061,40 +1061,51 @@ Status SupportedDeviceTypesForNode(
}
void LogAllRegisteredKernels() {
- for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
- const KernelDef& kernel_def(key_registration.second.def);
+ KernelList kernel_list = GetAllRegisteredKernels();
+ for (const auto& kernel_def : kernel_list.kernel()) {
LOG(INFO) << "OpKernel ('" << ProtoShortDebugString(kernel_def) << "')";
}
}
KernelList GetAllRegisteredKernels() {
+ return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; });
+}
+
+KernelList GetFilteredRegisteredKernels(
+ const std::function<bool(const KernelDef&)>& predicate) {
const KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
KernelList kernel_list;
kernel_list.mutable_kernel()->Reserve(typed_registry->size());
for (const auto& p : *typed_registry) {
- *kernel_list.add_kernel() = p.second.def;
+ const KernelDef& kernel_def = p.second.def;
+ if (predicate(kernel_def)) {
+ *kernel_list.add_kernel() = kernel_def;
+ }
}
return kernel_list;
}
+KernelList GetRegisteredKernelsForOp(StringPiece op_name) {
+ auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; };
+ return GetFilteredRegisteredKernels(op_pred);
+}
+
string KernelsRegisteredForOp(StringPiece op_name) {
+ KernelList kernel_list = GetRegisteredKernelsForOp(op_name);
+ if (kernel_list.kernel_size() == 0) return " <no registered kernels>\n";
string ret;
- for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
- const KernelDef& kernel_def(key_registration.second.def);
- if (kernel_def.op() == op_name) {
- strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'");
- if (!kernel_def.label().empty()) {
- strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
- }
- for (int i = 0; i < kernel_def.constraint_size(); ++i) {
- strings::StrAppend(
- &ret, "; ", kernel_def.constraint(i).name(), " in ",
- SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
- }
- strings::StrAppend(&ret, "\n");
+ for (const auto& kernel_def : kernel_list.kernel()) {
+ strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'");
+ if (!kernel_def.label().empty()) {
+ strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
+ }
+ for (int i = 0; i < kernel_def.constraint_size(); ++i) {
+ strings::StrAppend(
+ &ret, "; ", kernel_def.constraint(i).name(), " in ",
+ SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
}
+ strings::StrAppend(&ret, "\n");
}
- if (ret.empty()) return " <no registered kernels>\n";
return ret;
}
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index d9fe42fcbb..1fc5e9908e 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -1304,6 +1304,13 @@ void LogAllRegisteredKernels();
// Gets a list of all registered kernels.
KernelList GetAllRegisteredKernels();
+// Gets a list of all registered kernels for which predicate returns true
+KernelList GetFilteredRegisteredKernels(
+ const std::function<bool(const KernelDef&)>& predicate);
+
+// Gets a list of all registered kernels for a given op
+KernelList GetRegisteredKernelsForOp(StringPiece op_name);
+
namespace kernel_factory {
class OpKernelRegistrar {
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc
index b76a3400a8..83dda6579b 100644
--- a/tensorflow/core/framework/op_kernel_test.cc
+++ b/tensorflow/core/framework/op_kernel_test.cc
@@ -965,7 +965,8 @@ BENCHMARK(BM_ConcatInputRange);
BENCHMARK(BM_SelectInputRange);
TEST(RegisteredKernels, CanCallGetAllRegisteredKernels) {
- auto all_registered_kernels = GetAllRegisteredKernels().kernel();
+ auto kernel_list = GetAllRegisteredKernels();
+ auto all_registered_kernels = kernel_list.kernel();
auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
// Verify we can find the "Test1" op registered above
@@ -986,5 +987,20 @@ TEST(RegisteredKernels, CanLogAllRegisteredKernels) {
tensorflow::LogAllRegisteredKernels();
}
+TEST(RegisteredKernels, GetFilteredRegisteredKernels) {
+ auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
+ auto kernel_list = GetFilteredRegisteredKernels(has_name_test1);
+ ASSERT_EQ(kernel_list.kernel_size(), 1);
+ EXPECT_EQ(kernel_list.kernel(0).op(), "Test1");
+ EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
+}
+
+TEST(RegisteredKernels, GetRegisteredKernelsForOp) {
+ auto kernel_list = GetRegisteredKernelsForOp("Test1");
+ ASSERT_EQ(kernel_list.kernel_size(), 1);
+ EXPECT_EQ(kernel_list.kernel(0).op(), "Test1");
+ EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
+}
+
} // namespace
} // namespace tensorflow