aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc')
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc43
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc
new file mode 100644
index 0000000000..e960adc047
--- /dev/null
+++ b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_op_kernel_1.cc
@@ -0,0 +1,43 @@
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+using namespace tensorflow;
+
+REGISTER_OP("ZeroOut")
+ .Input("to_zero: int32")
+ .Output("zeroed: int32")
+ .Doc(R"doc(
+Zeros out all but the first value of a Tensor.
+
+zeroed: A Tensor whose first value is identical to `to_zero`, and 0
+ otherwise.
+
+)doc");
+
+class ZeroOutOp : public OpKernel {
+ public:
+ explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Grab the input tensor
+ const Tensor& input_tensor = context->input(0);
+ auto input = input_tensor.flat<int32>();
+
+ // Create an output tensor
+ Tensor* output_tensor = NULL;
+ OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
+ &output_tensor));
+ auto output = output_tensor->template flat<int32>();
+
+ // Set all but the first element of the output tensor to 0.
+ const int N = input.size();
+ for (int i = 1; i < N; i++) {
+ output(i) = 0;
+ }
+
+ // Preserve the first input value.
+ if (N > 0) output(0) = input(0);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);