aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/reduction_ops_common.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/reduction_ops_common.h')
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.h27
1 files changed, 15 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
index 71af9d88dc..9da992ccd1 100644
--- a/tensorflow/core/kernels/reduction_ops_common.h
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -42,7 +43,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
template <typename Device>
struct Constants {
@@ -68,11 +69,13 @@ struct ConstantsBase {
const Eigen::IndexList<Eigen::type2index<1>> kOne;
const Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<2>> kZeroTwo;
};
-template<> struct Constants<CPUDevice> : ConstantsBase{};
+template <>
+struct Constants<CPUDevice> : ConstantsBase {};
#ifdef TENSORFLOW_USE_SYCL
-template<> struct Constants<SYCLDevice> : ConstantsBase{};
-#endif // TENSORFLOW_USE_SYCL
-#endif // EIGEN_HAS_INDEX_LIST
+template <>
+struct Constants<SYCLDevice> : ConstantsBase {};
+#endif // TENSORFLOW_USE_SYCL
+#endif // EIGEN_HAS_INDEX_LIST
class ReductionHelper {
public:
@@ -131,12 +134,13 @@ class ReductionHelper {
// For operations where the output is a reduction function along some
// dimensions of the input.
-template <typename Device, class T, typename Reducer>
+template <typename Device, class T, typename Tperm, typename Reducer>
class ReductionOp : public OpKernel {
public:
explicit ReductionOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
const DataType dt = DataTypeToEnum<T>::v();
- OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
+ const DataType pt = DataTypeToEnum<Tperm>::v();
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, pt}, {dt}));
OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
}
@@ -266,20 +270,19 @@ struct ReduceFunctorBase {
}
template <typename OUT_T>
- static void FillIdentity(const Device& d, OUT_T out,
- const Reducer& reducer) {
+ static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer) {
FillIdentityEigenImpl(d, out, reducer);
}
};
template <typename Reducer>
struct ReduceFunctor<CPUDevice, Reducer>
- : ReduceFunctorBase<CPUDevice, Reducer>{};
+ : ReduceFunctorBase<CPUDevice, Reducer> {};
#if TENSORFLOW_USE_SYCL
template <typename Reducer>
struct ReduceFunctor<SYCLDevice, Reducer>
- : ReduceFunctorBase<SYCLDevice, Reducer>{};
-#endif // TENSORFLOW_USE_SYCL
+ : ReduceFunctorBase<SYCLDevice, Reducer> {};
+#endif // TENSORFLOW_USE_SYCL
} // namespace functor
} // namespace tensorflow