#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" using namespace tensorflow; REGISTER_OP("ZeroOut") .Input("to_zero: int32") .Output("zeroed: int32") .Doc(R"doc( Zeros out all but the first value of a Tensor. zeroed: A Tensor whose first value is identical to `to_zero`, and 0 otherwise. )doc"); class ZeroOutOp : public OpKernel { public: explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { // Grab the input tensor const Tensor& input_tensor = context->input(0); auto input = input_tensor.flat(); // Create an output tensor Tensor* output_tensor = NULL; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output = output_tensor->template flat(); // Set all but the first element of the output tensor to 0. const int N = input.size(); for (int i = 1; i < N; i++) { output(i) = 0; } // Preserve the first input value. if (N > 0) output(0) = input(0); } }; REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);