blob: 7cf4bdb6d03a89890346e25018027175d32f6290 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/port.h"
namespace tensorflow {
template <class T>
class CountUpToOp : public OpKernel {
public:
explicit CountUpToOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("limit", &limit_));
}
void Compute(OpKernelContext* context) override {
T before_increment;
{
mutex_lock l(*context->input_ref_mutex(0));
Tensor tensor = context->mutable_input(0, true);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(tensor.shape()),
errors::InvalidArgument("input is not a scalar: ",
tensor.shape().DebugString()));
T* ptr = &tensor.scalar<T>()();
before_increment = *ptr;
if (*ptr >= limit_) {
context->SetStatus(errors::OutOfRange("Reached limit of ", limit_));
return;
}
++*ptr;
}
// Output if no error.
Tensor* out_tensor;
OP_REQUIRES_OK(context, context->allocate_output("output", TensorShape({}),
&out_tensor));
out_tensor->scalar<T>()() = before_increment;
}
private:
T limit_;
};
#define REGISTER(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("CountUpTo").TypeConstraint<TYPE>("T").Device(DEVICE_CPU), \
CountUpToOp<TYPE>)
REGISTER(int32);
REGISTER(int64);
#undef REGISTER
} // namespace tensorflow
|