aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/g3doc/how_tos/adding_an_op/register_kernels.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/g3doc/how_tos/adding_an_op/register_kernels.cc')
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/register_kernels.cc64
1 files changed, 64 insertions, 0 deletions
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/register_kernels.cc b/tensorflow/g3doc/how_tos/adding_an_op/register_kernels.cc
new file mode 100644
index 0000000000..3d2f50d16e
--- /dev/null
+++ b/tensorflow/g3doc/how_tos/adding_an_op/register_kernels.cc
@@ -0,0 +1,64 @@
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+
+using namespace tensorflow;
+
+template <typename T>
+class ZeroOutOp : public OpKernel {
+ public:
+ explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // Grab the input tensor
+ const Tensor& input_tensor = context->input(0);
+ auto input = input_tensor.flat<T>();
+
+ // Create an output tensor
+ Tensor* output = NULL;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_tensor.shape(), &output));
+ auto output_flat = output->template flat<T>();
+
+ // Set all the elements of the output tensor to 0
+ const int N = input.size();
+ for (int i = 0; i < N; i++) {
+ output_flat(i) = 0;
+ }
+
+ // Preserve the first input value
+ if (N > 0) output_flat(0) = input(0);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ZeroOut")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T"),
+ ZeroOutOp<float>);
+REGISTER_KERNEL_BUILDER(Name("ZeroOut")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<double>("T"),
+ ZeroOutOp<double>);
+REGISTER_KERNEL_BUILDER(Name("ZeroOut")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<int>("T"),
+ ZeroOutOp<int>);
+
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ZeroOutOp<type>)
+
+REGISTER_KERNEL(float);
+REGISTER_KERNEL(double);
+REGISTER_KERNEL(int32);
+
+#undef REGISTER_KERNEL
+
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ZeroOutOp<type>)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
+
+#undef REGISTER_KERNEL