aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/numeric_op.h
blob: 8413d18f336b55f8dd6b08c3d38f033b852a53d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
#define TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/status.h"

namespace tensorflow {

// One input and one output, both the same type.
template <class T>
class UnaryOp : public OpKernel {
 public:
  explicit UnaryOp(OpKernelConstruction* context) : OpKernel(context) {
    const DataType dt = DataTypeToEnum<T>::v();
    OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt}));
  }
};

// Two inputs and one output, all the same type.
template <class T>
class BinaryOp : public OpKernel {
 public:
  explicit BinaryOp(OpKernelConstruction* context) : OpKernel(context) {
    const DataType dt = DataTypeToEnum<T>::v();
    OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt}));
  }
};

// For operations where the input and output are the same shape.
//
// For usage, see ../framework/elementwise_ops.cc.
template <class T, class CHILD>
class UnaryElementWiseOp : public UnaryOp<T> {
 public:
  using UnaryOp<T>::UnaryOp;

  void Compute(OpKernelContext* context) override {
    // Output shape is the same as input shape.
    const Tensor& input = context->input(0);
    Tensor* output;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input.shape(), &output));
    static_cast<CHILD*>(this)->Operate(context, input, output);
  }
};

// For binary elementwise operations.
template <class T, class CHILD>
class BinaryElementWiseOp : public BinaryOp<T> {
 public:
  using BinaryOp<T>::BinaryOp;

  void Compute(OpKernelContext* context) override {
    const Tensor& a = context->input(0);
    const Tensor& b = context->input(1);

    if (!context->ValidateInputsAreSameShape(this)) {
      return;
    }

    Tensor* output;
    OP_REQUIRES_OK(context, context->allocate_output(0, a.shape(), &output));

    // Dispatch to the descendant's Operate() function.
    switch (a.dims()) {
#define NDIM_CASE(NDIMS)                                                       \
  case NDIMS: {                                                                \
    static_cast<CHILD*>(this)->template Operate<NDIMS>(context, a, b, output); \
    break;                                                                     \
  }

      NDIM_CASE(1);
      NDIM_CASE(2);
      NDIM_CASE(3);
      NDIM_CASE(4);
      NDIM_CASE(5);
      NDIM_CASE(6);
      NDIM_CASE(7);
      NDIM_CASE(8);
#undef NDIM_CASE

      default:
        context->SetStatus(errors::OutOfRange(
            "We only handle up to Tensor::dims() up to 8, not ", a.dims()));
        break;
    }
  }
};

}  // namespace tensorflow

#endif  // TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_