diff options
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 31 |
1 files changed, 16 insertions, 15 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 12f8ddc2ee..2d29f5176d 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -299,7 +299,23 @@ const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) { Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, OpKernel** kernel) { + // If a custom kernel creator is given, try that. + Status s; + if (custom_kernel_creator_) { + std::unique_ptr<OpKernel> ret; + s = custom_kernel_creator_(this, ndef, &ret); + if (s.ok()) { + *kernel = ret.release(); + return s; + } else { + VLOG(2) << "Custom creator error: " << s; + // Falls through. + s = Status::OK(); + } + } + if (lib_def_->Find(ndef.op()) == nullptr) { + // A primitive operation. Creates the registered kernel. return CreateNonCachedKernel(device_, this, ndef, graph_def_version_, kernel); } @@ -325,21 +341,6 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, output_memory_types.push_back(t == DT_INT32 ? HOST_MEMORY : DEVICE_MEMORY); } - // If a custom kernel creator is given, try that. - Status s; - if (custom_kernel_creator_) { - std::unique_ptr<OpKernel> ret; - s = custom_kernel_creator_(this, ndef, &ret); - if (s.ok()) { - *kernel = ret.release(); - return s; - } else { - VLOG(2) << "Custom creator error: " << s; - // Falls through. - s = Status::OK(); - } - } - // Constructs a CallOp kernel for running the instantiated function. auto device_type = DeviceType(device_->attributes().device_type()); OpKernelConstruction construction( |