aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/op_kernel.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/op_kernel.cc')
-rw-r--r--tensorflow/core/framework/op_kernel.cc55
1 files changed, 34 insertions, 21 deletions
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 8a332fa1d8..507aa9e447 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -263,11 +263,13 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs)
outputs_(num_outputs),
temp_memory_allocated_(0),
persistent_memory_allocated_(0) {
- Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
params_->ensure_eigen_gpu_device();
- params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
- params_->op_device_context,
- eigen_gpu_allocator);
+ if (params_->eigen_gpu_device != nullptr) {
+ Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
+ params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
+ params_->op_device_context,
+ eigen_gpu_allocator);
+ }
if (params_->record_tensor_accesses) {
referenced_tensors_.Init();
}
@@ -1059,40 +1061,51 @@ Status SupportedDeviceTypesForNode(
}
void LogAllRegisteredKernels() {
- for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
- const KernelDef& kernel_def(key_registration.second.def);
+ KernelList kernel_list = GetAllRegisteredKernels();
+ for (const auto& kernel_def : kernel_list.kernel()) {
LOG(INFO) << "OpKernel ('" << ProtoShortDebugString(kernel_def) << "')";
}
}
KernelList GetAllRegisteredKernels() {
+ return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; });
+}
+
+KernelList GetFilteredRegisteredKernels(
+ const std::function<bool(const KernelDef&)>& predicate) {
const KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
KernelList kernel_list;
kernel_list.mutable_kernel()->Reserve(typed_registry->size());
for (const auto& p : *typed_registry) {
- *kernel_list.add_kernel() = p.second.def;
+ const KernelDef& kernel_def = p.second.def;
+ if (predicate(kernel_def)) {
+ *kernel_list.add_kernel() = kernel_def;
+ }
}
return kernel_list;
}
+KernelList GetRegisteredKernelsForOp(StringPiece op_name) {
+ auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; };
+ return GetFilteredRegisteredKernels(op_pred);
+}
+
string KernelsRegisteredForOp(StringPiece op_name) {
+ KernelList kernel_list = GetRegisteredKernelsForOp(op_name);
+ if (kernel_list.kernel_size() == 0) return " <no registered kernels>\n";
string ret;
- for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
- const KernelDef& kernel_def(key_registration.second.def);
- if (kernel_def.op() == op_name) {
- strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'");
- if (!kernel_def.label().empty()) {
- strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
- }
- for (int i = 0; i < kernel_def.constraint_size(); ++i) {
- strings::StrAppend(
- &ret, "; ", kernel_def.constraint(i).name(), " in ",
- SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
- }
- strings::StrAppend(&ret, "\n");
+ for (const auto& kernel_def : kernel_list.kernel()) {
+ strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'");
+ if (!kernel_def.label().empty()) {
+ strings::StrAppend(&ret, "; label='", kernel_def.label(), "'");
+ }
+ for (int i = 0; i < kernel_def.constraint_size(); ++i) {
+ strings::StrAppend(
+ &ret, "; ", kernel_def.constraint(i).name(), " in ",
+ SummarizeAttrValue(kernel_def.constraint(i).allowed_values()));
}
+ strings::StrAppend(&ret, "\n");
}
- if (ret.empty()) return " <no registered kernels>\n";
return ret;
}