diff options
Diffstat (limited to 'tensorflow/g3doc/how_tos/adding_an_op/register_kernels.cc')
-rw-r--r-- | tensorflow/g3doc/how_tos/adding_an_op/register_kernels.cc | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/register_kernels.cc b/tensorflow/g3doc/how_tos/adding_an_op/register_kernels.cc new file mode 100644 index 0000000000..3d2f50d16e --- /dev/null +++ b/tensorflow/g3doc/how_tos/adding_an_op/register_kernels.cc @@ -0,0 +1,64 @@ +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" + +using namespace tensorflow; + +template <typename T> +class ZeroOutOp : public OpKernel { + public: + explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + auto input = input_tensor.flat<T>(); + + // Create an output tensor + Tensor* output = NULL; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_tensor.shape(), &output)); + auto output_flat = output->template flat<T>(); + + // Set all the elements of the output tensor to 0 + const int N = input.size(); + for (int i = 0; i < N; i++) { + output_flat(i) = 0; + } + + // Preserve the first input value + if (N > 0) output_flat(0) = input(0); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ZeroOut") + .Device(DEVICE_CPU) + .TypeConstraint<float>("T"), + ZeroOutOp<float>); +REGISTER_KERNEL_BUILDER(Name("ZeroOut") + .Device(DEVICE_CPU) + .TypeConstraint<double>("T"), + ZeroOutOp<double>); +REGISTER_KERNEL_BUILDER(Name("ZeroOut") + .Device(DEVICE_CPU) + .TypeConstraint<int>("T"), + ZeroOutOp<int>); + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + ZeroOutOp<type>) + +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +REGISTER_KERNEL(int32); + +#undef REGISTER_KERNEL + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + ZeroOutOp<type>) + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); + +#undef REGISTER_KERNEL |