aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/count_up_to_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/count_up_to_op.cc')
-rw-r--r--tensorflow/core/kernels/count_up_to_op.cc51
1 files changed, 51 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/count_up_to_op.cc b/tensorflow/core/kernels/count_up_to_op.cc
new file mode 100644
index 0000000000..7cf4bdb6d0
--- /dev/null
+++ b/tensorflow/core/kernels/count_up_to_op.cc
@@ -0,0 +1,51 @@
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+template <class T>
+class CountUpToOp : public OpKernel {
+ public:
+ explicit CountUpToOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("limit", &limit_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ T before_increment;
+ {
+ mutex_lock l(*context->input_ref_mutex(0));
+ Tensor tensor = context->mutable_input(0, true);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(tensor.shape()),
+ errors::InvalidArgument("input is not a scalar: ",
+ tensor.shape().DebugString()));
+ T* ptr = &tensor.scalar<T>()();
+ before_increment = *ptr;
+ if (*ptr >= limit_) {
+ context->SetStatus(errors::OutOfRange("Reached limit of ", limit_));
+ return;
+ }
+ ++*ptr;
+ }
+ // Output if no error.
+ Tensor* out_tensor;
+ OP_REQUIRES_OK(context, context->allocate_output("output", TensorShape({}),
+ &out_tensor));
+ out_tensor->scalar<T>()() = before_increment;
+ }
+
+ private:
+ T limit_;
+};
+
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("CountUpTo").TypeConstraint<TYPE>("T").Device(DEVICE_CPU), \
+ CountUpToOp<TYPE>)
+
+REGISTER(int32);
+REGISTER(int64);
+
+#undef REGISTER
+
+} // namespace tensorflow