#ifndef TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_ #define TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_ // See docs in ../ops/math_ops.cc. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/cwise_ops.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/bcast.h" namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; class BinaryOpShared : public OpKernel { public: explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in); protected: struct BinaryOpState { // Sets up bcast with the shape of in0 and in1, ensures that the bcast // is valid, and if so, allocates out using ctx->output(...). // Caller must check ctx->status() upon return for non-ok status. // If ctx->status().ok() is true, then out is guaranteed to be allocated. BinaryOpState(OpKernelContext* ctx); BCast bcast; Tensor* out = nullptr; }; template static Eigen::array ToIndexArray( const BCast::Vec& vec) { CHECK_EQ(vec.size(), NDIMS); Eigen::array ret; for (int i = 0; i < NDIMS; ++i) ret[i] = vec[i]; return ret; } void SetUnimplementedError(OpKernelContext* ctx); }; // Coefficient-wise binary operations: // Device: E.g., CPUDevice, GPUDevice. // Functor: defined in cwise_functors.h. E.g., functor::add2. template class BinaryOp : public BinaryOpShared { public: typedef typename Functor::in_type Tin; // Input scalar data type. typedef typename Functor::out_type Tout; // Output scalar data type. explicit BinaryOp(OpKernelConstruction* ctx) : BinaryOpShared(ctx, DataTypeToEnum::v(), DataTypeToEnum::v()) {} void Compute(OpKernelContext* ctx) override { const Tensor& in0 = ctx->input(0); const Tensor& in1 = ctx->input(1); // 'state': Shared helper not dependent on T to reduce code size BinaryOpState state(ctx); if (!ctx->status().ok()) return; Tensor* out = state.out; BCast* bcast = &state.bcast; if (out->NumElements() == 0) { return; } const int ndims = bcast->x_reshape().size(); if (ndims <= 1) { if (in1.NumElements() == 1) { // tensor op scalar functor::BinaryFunctor().Right( ctx->eigen_device(), out->flat(), in0.flat(), in1.scalar()); return; } if (in0.NumElements() == 1) { // scalar op tensor functor::BinaryFunctor().Left( ctx->eigen_device(), out->flat(), in0.scalar(), in1.flat()); return; } functor::BinaryFunctor()( ctx->eigen_device(), out->flat(), in0.flat(), in1.flat()); return; } if (ndims == 2) { functor::BinaryFunctor().BCast( ctx->eigen_device(), out->shaped(bcast->result_shape()), in0.shaped(bcast->x_reshape()), ToIndexArray<2>(bcast->x_bcast()), in1.shaped(bcast->y_reshape()), ToIndexArray<2>(bcast->y_bcast())); return; } if (ndims == 3) { functor::BinaryFunctor().BCast( ctx->eigen_device(), out->shaped(bcast->result_shape()), in0.shaped(bcast->x_reshape()), ToIndexArray<3>(bcast->x_bcast()), in1.shaped(bcast->y_reshape()), ToIndexArray<3>(bcast->y_bcast())); return; } SetUnimplementedError(ctx); } private: }; // Coefficient-wise unary operations: // Device: E.g., CPUDevice, GPUDevice. // Functor: defined in cwise_functors.h. E.g., functor::sqrt. template class UnaryOp : public OpKernel { public: typedef typename Functor::in_type Tin; // Input scalar data type. typedef typename Functor::out_type Tout; // Output scalar data type. // Tin may be different from Tout. E.g., abs: complex64 -> float explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { auto in = DataTypeToEnum::v(); auto out = DataTypeToEnum::v(); OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out})); } void Compute(OpKernelContext* ctx) override { const Tensor& inp = ctx->input(0); Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); functor::UnaryFunctor()( ctx->eigen_device(), out->flat(), inp.flat()); } }; // Coefficient-wise select operation. // Device: E.g., CPUDevice, GPUDevice. template class SelectOp : public OpKernel { public: explicit SelectOp(OpKernelConstruction* ctx) : OpKernel(ctx) { auto dt = DataTypeToEnum::v(); OP_REQUIRES_OK(ctx, ctx->MatchSignature({DT_BOOL, dt, dt}, {dt})); } void Compute(OpKernelContext* ctx) override { const Tensor& in0 = ctx->input(0); const Tensor& in1 = ctx->input(1); const Tensor& in2 = ctx->input(2); if (!ctx->ValidateInputsAreSameShape(this)) return; Tensor* out = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out)); functor::SelectFunctor func; func(ctx->eigen_device(), out->flat(), in0.flat(), in1.flat(), in2.flat()); } }; namespace functor { // For CPUDevice, we do operations inline if the resulting tensor is // modestly sized. static bool DoInline(size_t size) { return size <= 32768; } template void Assign(const D& d, OUT out, RHS rhs) { if (DoInline(out.size())) { out = rhs; } else { out.device(d) = rhs; } } // Partial specialization of BinaryFunctor. template struct BinaryFunctor { void operator()(const CPUDevice& d, typename Functor::tout_type out, typename Functor::tin_type in0, typename Functor::tin_type in1) { Assign(d, out, in0.binaryExpr(in1, typename Functor::func())); } void Left(const CPUDevice& d, typename Functor::tout_type out, typename Functor::tscalar_type scalar, typename Functor::tin_type in) { typedef typename Functor::out_type Tout; typedef typename Functor::in_type Tin; typedef typename Functor::func Binary; typedef typename Eigen::internal::scalar_left Unary; Assign(d, out, in.unaryExpr(Unary(scalar.data()))); } void Right(const CPUDevice& d, typename Functor::tout_type out, typename Functor::tin_type in, typename Functor::tscalar_type scalar) { typedef typename Functor::out_type Tout; typedef typename Functor::in_type Tin; typedef typename Functor::func Binary; typedef typename Eigen::internal::scalar_right Unary; Assign(d, out, in.unaryExpr(Unary(scalar.data()))); } #if !defined(EIGEN_HAS_INDEX_LIST) inline Eigen::DSizes NByOne(int n) { return Eigen::DSizes(n, 1); } inline Eigen::DSizes OneByM(int m) { return Eigen::DSizes(1, m); } #else inline Eigen::IndexList> NByOne(int n) { Eigen::IndexList> ret; ret.set(0, n); return ret; } inline Eigen::IndexList, int> OneByM(int m) { Eigen::IndexList, int> ret; ret.set(1, m); return ret; } #endif void BCast(const CPUDevice& dev, typename TTypes::Tensor out, typename TTypes::ConstTensor in0, typename Eigen::array bcast0, typename TTypes::ConstTensor in1, typename Eigen::array bcast1) { typedef typename Functor::in_type T; typename Functor::func func; if ((NDIMS == 2) && Functor::use_bcast_optimization && use_bcast_optimization::value) { // Optimize for speed by using Eigen::type2index and avoid // .broadcast() when we know its a no-op. // // Here, we need to handle 6 cases depending on how many "1" // exist in in0 and in1's shapes (4 numbers in total). It's not // possible that two shapes have more than 2 1s because those // are simplified to NDIMS==1 case. // // Because this optimization increases the binary size for each // Functor (+, -, *, /, <, <=, etc.), type and ndim combination. // we only apply such optimization for selected ops/types/ndims. // // Because NDIMS, Functor::use_broadcast_optimization and // use_broadcast_optimization are compile-time constant, gcc // does a decent job avoiding generating code when conditions // are not met. const int a = in0.dimension(0); // in0 is shape [a, b] const int b = in0.dimension(1); const int c = in1.dimension(0); // in1 is shape [c, d] const int d = in1.dimension(1); if ((a == 1) && (d == 1)) { auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c)); auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b)); Assign(dev, out, lhs.binaryExpr(rhs, func)); return; } if ((b == 1) && (c == 1)) { auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d)); auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a)); Assign(dev, out, lhs.binaryExpr(rhs, func)); return; } if (a == 1) { auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c)); auto rhs = in1; Assign(dev, out, lhs.binaryExpr(rhs, func)); return; } if (b == 1) { auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d)); auto rhs = in1; Assign(dev, out, lhs.binaryExpr(rhs, func)); return; } if (c == 1) { auto lhs = in0; auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a)); Assign(dev, out, lhs.binaryExpr(rhs, func)); return; } if (d == 1) { auto lhs = in0; auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b)); Assign(dev, out, lhs.binaryExpr(rhs, func)); return; } const bool bcast0_all_one = AllOne(bcast0); const bool bcast1_all_one = AllOne(bcast1); if (bcast0_all_one && !bcast1_all_one) { auto lhs = in0; // No need to do broadcast for in0 auto rhs = in1.broadcast(bcast1); Assign(dev, out, lhs.binaryExpr(rhs, func)); return; } if (!bcast0_all_one && bcast1_all_one) { auto lhs = in0.broadcast(bcast0); auto rhs = in1; // No need to do broadcast for in1 Assign(dev, out, lhs.binaryExpr(rhs, func)); return; } } // Fallback path. Always work and probably slower. auto lhs = in0.broadcast(bcast0); auto rhs = in1.broadcast(bcast1); Assign(dev, out, lhs.binaryExpr(rhs, func)); } }; // Partial specialization of UnaryFunctor. template struct UnaryFunctor { void operator()(const CPUDevice& d, typename Functor::tout_type out, typename Functor::tin_type in) { Assign(d, out, in.unaryExpr(typename Functor::func())); } }; template struct SelectFunctor { void operator()(const CPUDevice& d, typename TTypes::Flat out, typename TTypes::ConstFlat cond_flat, typename TTypes::ConstFlat then_flat, typename TTypes::ConstFlat else_flat) { Assign(d, out, cond_flat.select(then_flat, else_flat)); } }; } // end namespace functor #define REGISTER_SELECT(D, N, F, T) \ REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint("T"), \ SelectOp) #define REGISTER(OP, D, N, F, T) \ REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint("T"), \ OP>); // Macros to register kernels for multiple types (T0, T1, etc.) on // device type "D" (CPU or GPU) for operatin "N" (e.g., sqrt) using // the functor "F" (e.g., functor:sqrt). #ifdef __ANDROID__ // On Android, only register the first type (float) #define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0) #define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0) #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0) #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0) #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0) #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \ REGISTER(OP, D, N, F, T0) #else // !__ANDROID__ #define REGISTER2(OP, D, N, F, T0, T1) \ REGISTER(OP, D, N, F, T0) \ REGISTER(OP, D, N, F, T1) #define REGISTER3(OP, D, N, F, T0, T1, T2) \ REGISTER2(OP, D, N, F, T0, T1) \ REGISTER(OP, D, N, F, T2) #define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ REGISTER2(OP, D, N, F, T0, T1) \ REGISTER2(OP, D, N, F, T2, T3) #define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \ REGISTER3(OP, D, N, F, T0, T1, T2) \ REGISTER2(OP, D, N, F, T3, T4) #define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \ REGISTER3(OP, D, N, F, T0, T1, T2) \ REGISTER3(OP, D, N, F, T3, T4, T5) #define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \ REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ REGISTER3(OP, D, N, F, T4, T5, T6) #endif // __ANDROID__ } // end namespace tensorflow #endif // TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_