aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/numeric_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/numeric_op.h')
-rw-r--r--tensorflow/core/framework/numeric_op.h96
1 files changed, 96 insertions, 0 deletions
diff --git a/tensorflow/core/framework/numeric_op.h b/tensorflow/core/framework/numeric_op.h
new file mode 100644
index 0000000000..8413d18f33
--- /dev/null
+++ b/tensorflow/core/framework/numeric_op.h
@@ -0,0 +1,96 @@
+#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
+#define TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// One input and one output, both the same type.
+template <class T>
+class UnaryOp : public OpKernel {
+ public:
+ explicit UnaryOp(OpKernelConstruction* context) : OpKernel(context) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt}));
+ }
+};
+
+// Two inputs and one output, all the same type.
+template <class T>
+class BinaryOp : public OpKernel {
+ public:
+ explicit BinaryOp(OpKernelConstruction* context) : OpKernel(context) {
+ const DataType dt = DataTypeToEnum<T>::v();
+ OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt}));
+ }
+};
+
+// For operations where the input and output are the same shape.
+//
+// For usage, see ../framework/elementwise_ops.cc.
+template <class T, class CHILD>
+class UnaryElementWiseOp : public UnaryOp<T> {
+ public:
+ using UnaryOp<T>::UnaryOp;
+
+ void Compute(OpKernelContext* context) override {
+ // Output shape is the same as input shape.
+ const Tensor& input = context->input(0);
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+ static_cast<CHILD*>(this)->Operate(context, input, output);
+ }
+};
+
+// For binary elementwise operations.
+template <class T, class CHILD>
+class BinaryElementWiseOp : public BinaryOp<T> {
+ public:
+ using BinaryOp<T>::BinaryOp;
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& a = context->input(0);
+ const Tensor& b = context->input(1);
+
+ if (!context->ValidateInputsAreSameShape(this)) {
+ return;
+ }
+
+ Tensor* output;
+ OP_REQUIRES_OK(context, context->allocate_output(0, a.shape(), &output));
+
+ // Dispatch to the descendant's Operate() function.
+ switch (a.dims()) {
+#define NDIM_CASE(NDIMS) \
+ case NDIMS: { \
+ static_cast<CHILD*>(this)->template Operate<NDIMS>(context, a, b, output); \
+ break; \
+ }
+
+ NDIM_CASE(1);
+ NDIM_CASE(2);
+ NDIM_CASE(3);
+ NDIM_CASE(4);
+ NDIM_CASE(5);
+ NDIM_CASE(6);
+ NDIM_CASE(7);
+ NDIM_CASE(8);
+#undef NDIM_CASE
+
+ default:
+ context->SetStatus(errors::OutOfRange(
+ "We only handle up to Tensor::dims() up to 8, not ", a.dims()));
+ break;
+ }
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_