diff options
Diffstat (limited to 'tensorflow/core/framework/op_kernel.cc')
-rw-r--r-- | tensorflow/core/framework/op_kernel.cc | 55 |
1 files changed, 34 insertions, 21 deletions
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 8a332fa1d8..507aa9e447 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -263,11 +263,13 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs) outputs_(num_outputs), temp_memory_allocated_(0), persistent_memory_allocated_(0) { - Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); params_->ensure_eigen_gpu_device(); - params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device, - params_->op_device_context, - eigen_gpu_allocator); + if (params_->eigen_gpu_device != nullptr) { + Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); + params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device, + params_->op_device_context, + eigen_gpu_allocator); + } if (params_->record_tensor_accesses) { referenced_tensors_.Init(); } @@ -1059,40 +1061,51 @@ Status SupportedDeviceTypesForNode( } void LogAllRegisteredKernels() { - for (const auto& key_registration : *GlobalKernelRegistryTyped()) { - const KernelDef& kernel_def(key_registration.second.def); + KernelList kernel_list = GetAllRegisteredKernels(); + for (const auto& kernel_def : kernel_list.kernel()) { LOG(INFO) << "OpKernel ('" << ProtoShortDebugString(kernel_def) << "')"; } } KernelList GetAllRegisteredKernels() { + return GetFilteredRegisteredKernels([](const KernelDef& k) { return true; }); +} + +KernelList GetFilteredRegisteredKernels( + const std::function<bool(const KernelDef&)>& predicate) { const KernelRegistry* const typed_registry = GlobalKernelRegistryTyped(); KernelList kernel_list; kernel_list.mutable_kernel()->Reserve(typed_registry->size()); for (const auto& p : *typed_registry) { - *kernel_list.add_kernel() = p.second.def; + const KernelDef& kernel_def = p.second.def; + if (predicate(kernel_def)) { + *kernel_list.add_kernel() = kernel_def; + } } return kernel_list; } +KernelList GetRegisteredKernelsForOp(StringPiece op_name) { + auto op_pred = [op_name](const KernelDef& k) { return k.op() == op_name; }; + return GetFilteredRegisteredKernels(op_pred); +} + string KernelsRegisteredForOp(StringPiece op_name) { + KernelList kernel_list = GetRegisteredKernelsForOp(op_name); + if (kernel_list.kernel_size() == 0) return " <no registered kernels>\n"; string ret; - for (const auto& key_registration : *GlobalKernelRegistryTyped()) { - const KernelDef& kernel_def(key_registration.second.def); - if (kernel_def.op() == op_name) { - strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'"); - if (!kernel_def.label().empty()) { - strings::StrAppend(&ret, "; label='", kernel_def.label(), "'"); - } - for (int i = 0; i < kernel_def.constraint_size(); ++i) { - strings::StrAppend( - &ret, "; ", kernel_def.constraint(i).name(), " in ", - SummarizeAttrValue(kernel_def.constraint(i).allowed_values())); - } - strings::StrAppend(&ret, "\n"); + for (const auto& kernel_def : kernel_list.kernel()) { + strings::StrAppend(&ret, " device='", kernel_def.device_type(), "'"); + if (!kernel_def.label().empty()) { + strings::StrAppend(&ret, "; label='", kernel_def.label(), "'"); + } + for (int i = 0; i < kernel_def.constraint_size(); ++i) { + strings::StrAppend( + &ret, "; ", kernel_def.constraint(i).name(), " in ", + SummarizeAttrValue(kernel_def.constraint(i).allowed_values())); } + strings::StrAppend(&ret, "\n"); } - if (ret.empty()) return " <no registered kernels>\n"; return ret; } |