aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/count_up_to_op.cc
blob: 7cf4bdb6d03a89890346e25018027175d32f6290 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/port.h"

namespace tensorflow {

template <class T>
class CountUpToOp : public OpKernel {
 public:
  explicit CountUpToOp(OpKernelConstruction* context) : OpKernel(context) {
    OP_REQUIRES_OK(context, context->GetAttr("limit", &limit_));
  }

  void Compute(OpKernelContext* context) override {
    T before_increment;
    {
      mutex_lock l(*context->input_ref_mutex(0));
      Tensor tensor = context->mutable_input(0, true);
      OP_REQUIRES(context, TensorShapeUtils::IsScalar(tensor.shape()),
                  errors::InvalidArgument("input is not a scalar: ",
                                          tensor.shape().DebugString()));
      T* ptr = &tensor.scalar<T>()();
      before_increment = *ptr;
      if (*ptr >= limit_) {
        context->SetStatus(errors::OutOfRange("Reached limit of ", limit_));
        return;
      }
      ++*ptr;
    }
    // Output if no error.
    Tensor* out_tensor;
    OP_REQUIRES_OK(context, context->allocate_output("output", TensorShape({}),
                                                     &out_tensor));
    out_tensor->scalar<T>()() = before_increment;
  }

 private:
  T limit_;
};

#define REGISTER(TYPE)                                                \
  REGISTER_KERNEL_BUILDER(                                            \
      Name("CountUpTo").TypeConstraint<TYPE>("T").Device(DEVICE_CPU), \
      CountUpToOp<TYPE>)

REGISTER(int32);
REGISTER(int64);

#undef REGISTER

}  // namespace tensorflow