diff options
Diffstat (limited to 'tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc')
-rw-r--r-- | tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc new file mode 100644 index 0000000000..e960adc047 --- /dev/null +++ b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc @@ -0,0 +1,43 @@ +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +using namespace tensorflow; + +REGISTER_OP("ZeroOut") + .Input("to_zero: int32") + .Output("zeroed: int32") + .Doc(R"doc( +Zeros out all but the first value of a Tensor. + +zeroed: A Tensor whose first value is identical to `to_zero`, and 0 + otherwise. + +)doc"); + +class ZeroOutOp : public OpKernel { + public: + explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + auto input = input_tensor.flat<int32>(); + + // Create an output tensor + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), + &output_tensor)); + auto output = output_tensor->template flat<int32>(); + + // Set all but the first element of the output tensor to 0. + const int N = input.size(); + for (int i = 1; i < N; i++) { + output(i) = 0; + } + + // Preserve the first input value. + if (N > 0) output(0) = input(0); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp); |