diff options
Diffstat (limited to 'tensorflow/core/kernels/reduction_ops_common.h')
-rw-r--r-- | tensorflow/core/kernels/reduction_ops_common.h | 27 |
1 files changed, 15 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h index 71af9d88dc..9da992ccd1 100644 --- a/tensorflow/core/kernels/reduction_ops_common.h +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -25,6 +25,7 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -42,7 +43,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template <typename Device> struct Constants { @@ -68,11 +69,13 @@ struct ConstantsBase { const Eigen::IndexList<Eigen::type2index<1>> kOne; const Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<2>> kZeroTwo; }; -template<> struct Constants<CPUDevice> : ConstantsBase{}; +template <> +struct Constants<CPUDevice> : ConstantsBase {}; #ifdef TENSORFLOW_USE_SYCL -template<> struct Constants<SYCLDevice> : ConstantsBase{}; -#endif // TENSORFLOW_USE_SYCL -#endif // EIGEN_HAS_INDEX_LIST +template <> +struct Constants<SYCLDevice> : ConstantsBase {}; +#endif // TENSORFLOW_USE_SYCL +#endif // EIGEN_HAS_INDEX_LIST class ReductionHelper { public: @@ -131,12 +134,13 @@ class ReductionHelper { // For operations where the output is a reduction function along some // dimensions of the input. -template <typename Device, class T, typename Reducer> +template <typename Device, class T, typename Tperm, typename Reducer> class ReductionOp : public OpKernel { public: explicit ReductionOp(OpKernelConstruction* ctx) : OpKernel(ctx) { const DataType dt = DataTypeToEnum<T>::v(); - OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); + const DataType pt = DataTypeToEnum<Tperm>::v(); + OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, pt}, {dt})); OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); } @@ -266,20 +270,19 @@ struct ReduceFunctorBase { } template <typename OUT_T> - static void FillIdentity(const Device& d, OUT_T out, - const Reducer& reducer) { + static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer) { FillIdentityEigenImpl(d, out, reducer); } }; template <typename Reducer> struct ReduceFunctor<CPUDevice, Reducer> - : ReduceFunctorBase<CPUDevice, Reducer>{}; + : ReduceFunctorBase<CPUDevice, Reducer> {}; #if TENSORFLOW_USE_SYCL template <typename Reducer> struct ReduceFunctor<SYCLDevice, Reducer> - : ReduceFunctorBase<SYCLDevice, Reducer>{}; -#endif // TENSORFLOW_USE_SYCL + : ReduceFunctorBase<SYCLDevice, Reducer> {}; +#endif // TENSORFLOW_USE_SYCL } // namespace functor } // namespace tensorflow |