diff options
-rw-r--r-- | tensorflow/core/framework/op.h | 12 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 31 | ||||
-rw-r--r-- | tensorflow/core/kernels/function_ops.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/ops/function_ops.cc | 4 |
4 files changed, 47 insertions, 4 deletions
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index f047ddb12a..892ed9b60b 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -293,6 +293,18 @@ struct OpDefBuilderReceiver { ::tensorflow::register_op::OpDefBuilderWrapper<SHOULD_REGISTER_OP( \ name)>(name) +// The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except +// that the op is registered unconditionally even when selective +// registration is used. +#define REGISTER_SYSTEM_OP(name) \ + REGISTER_SYSTEM_OP_UNIQ_HELPER(__COUNTER__, name) +#define REGISTER_SYSTEM_OP_UNIQ_HELPER(ctr, name) \ + REGISTER_SYSTEM_OP_UNIQ(ctr, name) +#define REGISTER_SYSTEM_OP_UNIQ(ctr, name) \ + static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr \ + TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::register_op::OpDefBuilderWrapper<true>(name) + } // namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_OP_H_ diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 91e6a98304..48bb69cb4e 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -1194,6 +1194,17 @@ class Name : public KernelDefBuilder { : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {} }; +namespace system { + +class Name : public KernelDefBuilder { + public: + // For system kernels, we ignore selective registration and + // unconditionally register the kernel. + explicit Name(const char* op) : KernelDefBuilder(op) {} +}; + +} // namespace system + } // namespace register_kernel #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ @@ -1216,6 +1227,26 @@ class Name : public KernelDefBuilder { return new __VA_ARGS__(context); \ }); +// The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as +// `REGISTER_KERNEL_BUILDER()` except that the kernel is registered +// unconditionally even when selective registration is used. +#define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \ + REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, \ + __VA_ARGS__) + +#define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \ + REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__) + +#define REGISTER_SYSTEM_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ + static ::tensorflow::kernel_factory::OpKernelRegistrar \ + registrar__body__##ctr##__object( \ + ::tensorflow::register_kernel::system::kernel_builder.Build(), \ + #__VA_ARGS__, \ + [](::tensorflow::OpKernelConstruction* context) \ + -> ::tensorflow::OpKernel* { \ + return new __VA_ARGS__(context); \ + }); + void* GlobalKernelRegistry(); // If node_def has a corresponding kernel registered on device_type, diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index bcbf7424b1..ba408f3657 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -84,8 +84,8 @@ class RetvalOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp); -REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp); +REGISTER_SYSTEM_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp); +REGISTER_SYSTEM_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp); #if TENSORFLOW_USE_SYCL #define REGISTER(type) \ diff --git a/tensorflow/core/ops/function_ops.cc b/tensorflow/core/ops/function_ops.cc index 9fbebdb088..ada96fa1d2 100644 --- a/tensorflow/core/ops/function_ops.cc +++ b/tensorflow/core/ops/function_ops.cc @@ -18,7 +18,7 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("_Arg") +REGISTER_SYSTEM_OP("_Arg") .Output("output: T") .Attr("T: type") .Attr("index: int >= 0") @@ -34,7 +34,7 @@ output: The argument. index: This argument is the index-th argument of the function. )doc"); -REGISTER_OP("_Retval") +REGISTER_SYSTEM_OP("_Retval") .Input("input: T") .Attr("T: type") .Attr("index: int >= 0") |