diff options
Diffstat (limited to 'tensorflow/python/framework/test_kernel_label_op.cc')
-rw-r--r-- | tensorflow/python/framework/test_kernel_label_op.cc | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/python/framework/test_kernel_label_op.cc b/tensorflow/python/framework/test_kernel_label_op.cc new file mode 100644 index 0000000000..50f8522e1b --- /dev/null +++ b/tensorflow/python/framework/test_kernel_label_op.cc @@ -0,0 +1,47 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/public/status.h" + +namespace tensorflow { + +REGISTER_OP("KernelLabel").Output("result: string"); + +namespace { +enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL }; +} // namespace + +template <KernelLabel KL> +class KernelLabelOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* ctx) override { + Tensor* output; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("result", TensorShape({}), &output)); + switch (KL) { + case DEFAULT_LABEL: + output->scalar<string>()() = "My label is: default"; + break; + case OVERLOAD_1_LABEL: + output->scalar<string>()() = "My label is: overload_1"; + break; + case OVERLOAD_2_LABEL: + output->scalar<string>()() = "My label is: overload_2"; + break; + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("KernelLabel").Device(DEVICE_CPU), + KernelLabelOp<DEFAULT_LABEL>); +REGISTER_KERNEL_BUILDER(Name("KernelLabel") + .Device(DEVICE_CPU) + .Label("overload_1"), + KernelLabelOp<OVERLOAD_1_LABEL>); +REGISTER_KERNEL_BUILDER(Name("KernelLabel") + .Device(DEVICE_CPU) + .Label("overload_2"), + KernelLabelOp<OVERLOAD_2_LABEL>); + +} // end namespace tensorflow |