aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2017-10-21 23:01:42 -0700
committerGravatar Vijay Vasudevan <vrv@google.com>2017-10-21 23:01:42 -0700
commit1c1dad105a57bb13711492a8ba5ab9d10c91b5df (patch)
tree5e440f8890d176b05db9a8e4bd9f7b00f1496c60
parent17096081eed7881c0b8ce3c32b5e9795619e27bb (diff)
Add int64 axis support for reduction ops. (#13891)
* Add int64 axis support for reduction ops. This fix is a follow up to PR 13863. In PR 13863 the program crash is fixed if int64 axis is passed to reduction ops, e.g. reduce_sum, reduce_max, etc. However, 13863 does not process the case of int64 support, it merely fixes the crash. This fix adds the support for int64 axis of reduction ops. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add int64 axis support for mean, prod, sum Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add int64 axis support for min and max. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add int64 axis support for reduce_all and reduce_any Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for int64 axis support of reduce_any and reduce_all Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
-rw-r--r--tensorflow/core/kernels/reduction_ops_all.cc16
-rw-r--r--tensorflow/core/kernels/reduction_ops_any.cc16
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.cc22
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.h27
-rw-r--r--tensorflow/core/kernels/reduction_ops_max.cc90
-rw-r--r--tensorflow/core/kernels/reduction_ops_mean.cc68
-rw-r--r--tensorflow/core/kernels/reduction_ops_min.cc90
-rw-r--r--tensorflow/core/kernels/reduction_ops_prod.cc68
-rw-r--r--tensorflow/core/kernels/reduction_ops_sum.cc90
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test.py52
10 files changed, 391 insertions, 148 deletions
diff --git a/tensorflow/core/kernels/reduction_ops_all.cc b/tensorflow/core/kernels/reduction_ops_all.cc
index 41abc2b957..4a34c4ef51 100644
--- a/tensorflow/core/kernels/reduction_ops_all.cc
+++ b/tensorflow/core/kernels/reduction_ops_all.cc
@@ -22,7 +22,13 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<int32>("Tidx")
.Device(DEVICE_CPU)
.HostMemory("reduction_indices"),
- ReductionOp<CPUDevice, bool, Eigen::internal::AndReducer>);
+ ReductionOp<CPUDevice, bool, int32, Eigen::internal::AndReducer>);
+REGISTER_KERNEL_BUILDER(
+ Name("All")
+ .TypeConstraint<int64>("Tidx")
+ .Device(DEVICE_CPU)
+ .HostMemory("reduction_indices"),
+ ReductionOp<CPUDevice, bool, int64, Eigen::internal::AndReducer>);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
@@ -30,7 +36,13 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<int32>("Tidx")
.Device(DEVICE_GPU)
.HostMemory("reduction_indices"),
- ReductionOp<GPUDevice, bool, Eigen::internal::AndReducer>);
+ ReductionOp<GPUDevice, bool, int32, Eigen::internal::AndReducer>);
+REGISTER_KERNEL_BUILDER(
+ Name("All")
+ .TypeConstraint<int64>("Tidx")
+ .Device(DEVICE_GPU)
+ .HostMemory("reduction_indices"),
+ ReductionOp<GPUDevice, bool, int64, Eigen::internal::AndReducer>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_any.cc b/tensorflow/core/kernels/reduction_ops_any.cc
index a2087cc3b7..6c0519de95 100644
--- a/tensorflow/core/kernels/reduction_ops_any.cc
+++ b/tensorflow/core/kernels/reduction_ops_any.cc
@@ -22,7 +22,13 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<int32>("Tidx")
.Device(DEVICE_CPU)
.HostMemory("reduction_indices"),
- ReductionOp<CPUDevice, bool, Eigen::internal::OrReducer>);
+ ReductionOp<CPUDevice, bool, int32, Eigen::internal::OrReducer>);
+REGISTER_KERNEL_BUILDER(
+ Name("Any")
+ .TypeConstraint<int64>("Tidx")
+ .Device(DEVICE_CPU)
+ .HostMemory("reduction_indices"),
+ ReductionOp<CPUDevice, bool, int64, Eigen::internal::OrReducer>);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
@@ -30,7 +36,13 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<int32>("Tidx")
.Device(DEVICE_GPU)
.HostMemory("reduction_indices"),
- ReductionOp<GPUDevice, bool, Eigen::internal::OrReducer>);
+ ReductionOp<GPUDevice, bool, int32, Eigen::internal::OrReducer>);
+REGISTER_KERNEL_BUILDER(
+ Name("Any")
+ .TypeConstraint<int64>("Tidx")
+ .Device(DEVICE_GPU)
+ .HostMemory("reduction_indices"),
+ ReductionOp<GPUDevice, bool, int64, Eigen::internal::OrReducer>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_common.cc b/tensorflow/core/kernels/reduction_ops_common.cc
index 5eba4288ac..8daab0d6be 100644
--- a/tensorflow/core/kernels/reduction_ops_common.cc
+++ b/tensorflow/core/kernels/reduction_ops_common.cc
@@ -57,13 +57,12 @@ gtl::InlinedVector<int32, 8> ReductionHelper::permutation() {
return perm;
}
-Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
- const bool keep_dims) {
- // bitmap[i] indicates whether to reduce data along i-th axis.
- gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
- auto axis_vec = axis.flat<int32>();
+template <typename Tperm>
+Status SimplifyHelper(const Tensor& data, const Tensor& axis,
+ gtl::InlinedVector<bool, 4>& bitmap) {
+ auto axis_vec = axis.flat<Tperm>();
for (int64 i = 0; i < axis.NumElements(); ++i) {
- int32 index = axis_vec(i);
+ Tperm index = axis_vec(i);
if (index < -data.dims() || index >= data.dims()) {
return errors::InvalidArgument("Invalid reduction dimension (", index,
" for input with ", data.dims(),
@@ -72,7 +71,18 @@ Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
index = (index + data.dims()) % data.dims();
bitmap[index] = true;
}
+ return Status::OK();
+}
+Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
+ const bool keep_dims) {
+ // bitmap[i] indicates whether to reduce data along i-th axis.
+ gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
+ if (axis.dtype() == DT_INT32) {
+ TF_RETURN_IF_ERROR(SimplifyHelper<int32>(data, axis, bitmap));
+ } else {
+ TF_RETURN_IF_ERROR(SimplifyHelper<int64>(data, axis, bitmap));
+ }
// Output tensor's dim sizes.
out_shape_.clear();
for (int i = 0; i < data.dims(); ++i) {
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
index 71af9d88dc..9da992ccd1 100644
--- a/tensorflow/core/kernels/reduction_ops_common.h
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -42,7 +43,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
template <typename Device>
struct Constants {
@@ -68,11 +69,13 @@ struct ConstantsBase {
const Eigen::IndexList<Eigen::type2index<1>> kOne;
const Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<2>> kZeroTwo;
};
-template<> struct Constants<CPUDevice> : ConstantsBase{};
+template <>
+struct Constants<CPUDevice> : ConstantsBase {};
#ifdef TENSORFLOW_USE_SYCL
-template<> struct Constants<SYCLDevice> : ConstantsBase{};
-#endif // TENSORFLOW_USE_SYCL
-#endif // EIGEN_HAS_INDEX_LIST
+template <>
+struct Constants<SYCLDevice> : ConstantsBase {};
+#endif // TENSORFLOW_USE_SYCL
+#endif // EIGEN_HAS_INDEX_LIST
class ReductionHelper {
public:
@@ -131,12 +134,13 @@ class ReductionHelper {
// For operations where the output is a reduction function along some
// dimensions of the input.
-template <typename Device, class T, typename Reducer>
+template <typename Device, class T, typename Tperm, typename Reducer>
class ReductionOp : public OpKernel {
public:
explicit ReductionOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
const DataType dt = DataTypeToEnum<T>::v();
- OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
+ const DataType pt = DataTypeToEnum<Tperm>::v();
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, pt}, {dt}));
OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
}
@@ -266,20 +270,19 @@ struct ReduceFunctorBase {
}
template <typename OUT_T>
- static void FillIdentity(const Device& d, OUT_T out,
- const Reducer& reducer) {
+ static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer) {
FillIdentityEigenImpl(d, out, reducer);
}
};
template <typename Reducer>
struct ReduceFunctor<CPUDevice, Reducer>
- : ReduceFunctorBase<CPUDevice, Reducer>{};
+ : ReduceFunctorBase<CPUDevice, Reducer> {};
#if TENSORFLOW_USE_SYCL
template <typename Reducer>
struct ReduceFunctor<SYCLDevice, Reducer>
- : ReduceFunctorBase<SYCLDevice, Reducer>{};
-#endif // TENSORFLOW_USE_SYCL
+ : ReduceFunctorBase<SYCLDevice, Reducer> {};
+#endif // TENSORFLOW_USE_SYCL
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc
index 4ca5c11a48..9cf953f4bf 100644
--- a/tensorflow/core/kernels/reduction_ops_max.cc
+++ b/tensorflow/core/kernels/reduction_ops_max.cc
@@ -17,26 +17,39 @@ limitations under the License.
namespace tensorflow {
-#define REGISTER_CPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Max") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx"), \
- ReductionOp<CPUDevice, type, Eigen::internal::MaxReducer<type>>);
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Max") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ ReductionOp<CPUDevice, type, int32, Eigen::internal::MaxReducer<type>>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Max") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx"), \
+ ReductionOp<CPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
-#define REGISTER_GPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Max") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("reduction_indices"), \
- ReductionOp<GPUDevice, type, Eigen::internal::MaxReducer<type>>);
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Max") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, int32, Eigen::internal::MaxReducer<type>>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Max") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
REGISTER_GPU_KERNELS(int64);
@@ -52,21 +65,37 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("output")
.TypeConstraint<int32>("T")
.TypeConstraint<int32>("Tidx"),
- ReductionOp<CPUDevice, int32, Eigen::internal::MaxReducer<int32>>);
+ ReductionOp<CPUDevice, int32, int32, Eigen::internal::MaxReducer<int32>>);
+REGISTER_KERNEL_BUILDER(
+ Name("Max")
+ .Device(DEVICE_GPU)
+ .HostMemory("reduction_indices")
+ .HostMemory("input")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("Tidx"),
+ ReductionOp<CPUDevice, int32, int64, Eigen::internal::MaxReducer<int32>>);
#undef REGISTER_GPU_KERNELS
#endif
#ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Max") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("reduction_indices"), \
- ReductionOp<SYCLDevice, type, Eigen::internal::MaxReducer<type>>);
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("Max") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, int32, \
+ Eigen::internal::MaxReducer<type>>); \
+ REGISTER_KERNEL_BUILDER(Name("Max") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, int64, \
+ Eigen::internal::MaxReducer<type>>);
REGISTER_SYCL_KERNELS(float);
REGISTER_SYCL_KERNELS(double);
@@ -78,8 +107,17 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("output")
.TypeConstraint<int32>("T")
.TypeConstraint<int32>("Tidx"),
- ReductionOp<CPUDevice, int32, Eigen::internal::MaxReducer<int32>>);
+ ReductionOp<CPUDevice, int32, int32, Eigen::internal::MaxReducer<int32>>);
+REGISTER_KERNEL_BUILDER(
+ Name("Max")
+ .Device(DEVICE_SYCL)
+ .HostMemory("reduction_indices")
+ .HostMemory("input")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("Tidx"),
+ ReductionOp<CPUDevice, int32, int64, Eigen::internal::MaxReducer<int32>>);
#undef REGISTER_SYCL_KERNELS
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_mean.cc b/tensorflow/core/kernels/reduction_ops_mean.cc
index 5b01de8ddb..f61589f913 100644
--- a/tensorflow/core/kernels/reduction_ops_mean.cc
+++ b/tensorflow/core/kernels/reduction_ops_mean.cc
@@ -17,26 +17,39 @@ limitations under the License.
namespace tensorflow {
-#define REGISTER_CPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Mean") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx"), \
- ReductionOp<CPUDevice, type, Eigen::internal::MeanReducer<type>>);
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("Mean") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ ReductionOp<CPUDevice, type, int32, \
+ Eigen::internal::MeanReducer<type>>); \
+ REGISTER_KERNEL_BUILDER(Name("Mean") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx"), \
+ ReductionOp<CPUDevice, type, int64, \
+ Eigen::internal::MeanReducer<type>>);
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
-#define REGISTER_GPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Mean") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("reduction_indices"), \
- ReductionOp<GPUDevice, type, Eigen::internal::MeanReducer<type>>);
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("Mean") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, int32, \
+ Eigen::internal::MeanReducer<type>>); \
+ REGISTER_KERNEL_BUILDER(Name("Mean") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, int64, \
+ Eigen::internal::MeanReducer<type>>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_complex64(REGISTER_GPU_KERNELS);
TF_CALL_complex128(REGISTER_GPU_KERNELS);
@@ -45,17 +58,24 @@ TF_CALL_complex128(REGISTER_GPU_KERNELS);
#endif
#ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Mean") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("reduction_indices"), \
- ReductionOp<SYCLDevice, type, Eigen::internal::MeanReducer<type>>);
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("Mean") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, int32, \
+ Eigen::internal::MeanReducer<type>>); \
+ REGISTER_KERNEL_BUILDER(Name("Mean") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, int64, \
+ Eigen::internal::MeanReducer<type>>);
REGISTER_SYCL_KERNELS(float);
REGISTER_SYCL_KERNELS(double);
#undef REGISTER_SYCL_KERNELS
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_min.cc b/tensorflow/core/kernels/reduction_ops_min.cc
index 1e394bea41..807ac0a456 100644
--- a/tensorflow/core/kernels/reduction_ops_min.cc
+++ b/tensorflow/core/kernels/reduction_ops_min.cc
@@ -17,26 +17,39 @@ limitations under the License.
namespace tensorflow {
-#define REGISTER_CPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Min") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx"), \
- ReductionOp<CPUDevice, type, Eigen::internal::MinReducer<type>>);
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Min") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ ReductionOp<CPUDevice, type, int32, Eigen::internal::MinReducer<type>>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Min") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx"), \
+ ReductionOp<CPUDevice, type, int64, Eigen::internal::MinReducer<type>>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
-#define REGISTER_GPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Min") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("reduction_indices"), \
- ReductionOp<GPUDevice, type, Eigen::internal::MinReducer<type>>);
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Min") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, int32, Eigen::internal::MinReducer<type>>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Min") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, int64, Eigen::internal::MinReducer<type>>);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
@@ -51,21 +64,37 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("output")
.TypeConstraint<int32>("T")
.TypeConstraint<int32>("Tidx"),
- ReductionOp<CPUDevice, int32, Eigen::internal::MinReducer<int32>>);
+ ReductionOp<CPUDevice, int32, int32, Eigen::internal::MinReducer<int32>>);
+REGISTER_KERNEL_BUILDER(
+ Name("Min")
+ .Device(DEVICE_GPU)
+ .HostMemory("reduction_indices")
+ .HostMemory("input")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("Tidx"),
+ ReductionOp<CPUDevice, int32, int64, Eigen::internal::MinReducer<int32>>);
#undef REGISTER_GPU_KERNELS
#endif
#ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Min") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("reduction_indices"), \
- ReductionOp<SYCLDevice, type, Eigen::internal::MinReducer<type>>);
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("Min") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, int32, \
+ Eigen::internal::MinReducer<type>>); \
+ REGISTER_KERNEL_BUILDER(Name("Min") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, int64, \
+ Eigen::internal::MinReducer<type>>);
REGISTER_SYCL_KERNELS(float);
REGISTER_SYCL_KERNELS(double);
@@ -77,8 +106,17 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("output")
.TypeConstraint<int32>("T")
.TypeConstraint<int32>("Tidx"),
- ReductionOp<CPUDevice, int32, Eigen::internal::MinReducer<int32>>);
+ ReductionOp<CPUDevice, int32, int32, Eigen::internal::MinReducer<int32>>);
+REGISTER_KERNEL_BUILDER(
+ Name("Min")
+ .Device(DEVICE_SYCL)
+ .HostMemory("reduction_indices")
+ .HostMemory("input")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("Tidx"),
+ ReductionOp<CPUDevice, int32, int64, Eigen::internal::MinReducer<int32>>);
#undef REGISTER_SYCL_KERNELS
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_prod.cc b/tensorflow/core/kernels/reduction_ops_prod.cc
index 33f6ae6bae..e9b23df746 100644
--- a/tensorflow/core/kernels/reduction_ops_prod.cc
+++ b/tensorflow/core/kernels/reduction_ops_prod.cc
@@ -17,26 +17,39 @@ limitations under the License.
namespace tensorflow {
-#define REGISTER_CPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Prod") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx"), \
- ReductionOp<CPUDevice, type, Eigen::internal::ProdReducer<type>>);
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("Prod") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ ReductionOp<CPUDevice, type, int32, \
+ Eigen::internal::ProdReducer<type>>); \
+ REGISTER_KERNEL_BUILDER(Name("Prod") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx"), \
+ ReductionOp<CPUDevice, type, int64, \
+ Eigen::internal::ProdReducer<type>>);
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
-#define REGISTER_GPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Prod") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("reduction_indices"), \
- ReductionOp<GPUDevice, type, Eigen::internal::ProdReducer<type>>);
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("Prod") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, int32, \
+ Eigen::internal::ProdReducer<type>>); \
+ REGISTER_KERNEL_BUILDER(Name("Prod") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, int64, \
+ Eigen::internal::ProdReducer<type>>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int32(REGISTER_GPU_KERNELS);
TF_CALL_complex64(REGISTER_GPU_KERNELS);
@@ -46,18 +59,25 @@ TF_CALL_complex128(REGISTER_GPU_KERNELS);
#endif
#ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Prod") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("reduction_indices"), \
- ReductionOp<SYCLDevice, type, Eigen::internal::ProdReducer<type>>);
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("Prod") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, int32, \
+ Eigen::internal::ProdReducer<type>>); \
+ REGISTER_KERNEL_BUILDER(Name("Prod") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, int64, \
+ Eigen::internal::ProdReducer<type>>);
REGISTER_SYCL_KERNELS(int32);
REGISTER_SYCL_KERNELS(float);
REGISTER_SYCL_KERNELS(double);
#undef REGISTER_SYCL_KERNELS
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index c1f4f3475a..5318d8c133 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -17,26 +17,39 @@ limitations under the License.
namespace tensorflow {
-#define REGISTER_CPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Sum") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx"), \
- ReductionOp<CPUDevice, type, Eigen::internal::SumReducer<type>>);
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Sum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ ReductionOp<CPUDevice, type, int32, Eigen::internal::SumReducer<type>>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Sum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx"), \
+ ReductionOp<CPUDevice, type, int64, Eigen::internal::SumReducer<type>>);
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS
#if GOOGLE_CUDA
-#define REGISTER_GPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Sum") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("reduction_indices"), \
- ReductionOp<GPUDevice, type, Eigen::internal::SumReducer<type>>);
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Sum") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, int32, Eigen::internal::SumReducer<type>>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Sum") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<GPUDevice, type, int64, Eigen::internal::SumReducer<type>>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_complex64(REGISTER_GPU_KERNELS);
TF_CALL_complex128(REGISTER_GPU_KERNELS);
@@ -53,19 +66,35 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("input")
.HostMemory("output")
.HostMemory("reduction_indices"),
- ReductionOp<CPUDevice, int32, Eigen::internal::SumReducer<int32>>);
+ ReductionOp<CPUDevice, int32, int32, Eigen::internal::SumReducer<int32>>);
+REGISTER_KERNEL_BUILDER(
+ Name("Sum")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("Tidx")
+ .HostMemory("input")
+ .HostMemory("output")
+ .HostMemory("reduction_indices"),
+ ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
#endif
#ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_SYCL_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Sum") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx") \
- .HostMemory("reduction_indices"), \
- ReductionOp<SYCLDevice, type, Eigen::internal::SumReducer<type>>);
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("Sum") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, int32, \
+ Eigen::internal::SumReducer<type>>); \
+ REGISTER_KERNEL_BUILDER(Name("Sum") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("Tidx") \
+ .HostMemory("reduction_indices"), \
+ ReductionOp<SYCLDevice, type, int64, \
+ Eigen::internal::SumReducer<type>>);
REGISTER_SYCL_KERNELS(float);
REGISTER_SYCL_KERNELS(double);
@@ -77,8 +106,17 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("input")
.HostMemory("output")
.HostMemory("reduction_indices"),
- ReductionOp<CPUDevice, int32, Eigen::internal::SumReducer<int32>>);
+ ReductionOp<CPUDevice, int32, int32, Eigen::internal::SumReducer<int32>>);
+REGISTER_KERNEL_BUILDER(
+ Name("Sum")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int64>("Tidx")
+ .HostMemory("input")
+ .HostMemory("output")
+ .HostMemory("reduction_indices"),
+ ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
#undef REGISTER_SYCL_KERNELS
-#endif // TENSORFLOW_USE_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index c794351fe9..2dc65b1384 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -163,6 +163,13 @@ class SumReductionTest(BaseReductionTest):
reduction_axes = tuple(reduction_axes)
return np.sum(x, axis=reduction_axes, keepdims=keep_dims)
+ def testAxesType(self):
+ for dtype in [dtypes.int64, dtypes.int32]:
+ with self.test_session(use_gpu=True) as sess:
+ v = math_ops.reduce_sum([0, 0], constant_op.constant(0, dtype=dtype))
+ tf_v = sess.run(v)
+ self.assertAllEqual(tf_v, 0)
+
def testInfinity(self):
for dtype in [np.float32, np.float64]:
for special_value_x in [-np.inf, np.inf]:
@@ -193,6 +200,7 @@ class SumReductionTest(BaseReductionTest):
tf_out_mean = sess.run(tf_mean)
self.assertAllClose(tf_out_mean, 1.)
+
def testFloat32(self):
for rank in range(1, _MAX_RANK + 1):
np_arr = self._makeIncremental((2,) * rank, dtypes.float32)
@@ -369,6 +377,13 @@ class MeanReductionTest(BaseReductionTest):
return np_sum // count
return np_sum / count
+ def testAxesType(self):
+ for dtype in [dtypes.int64, dtypes.int32]:
+ with self.test_session(use_gpu=True) as sess:
+ v = math_ops.reduce_mean([0, 0], constant_op.constant(0, dtype=dtype))
+ tf_v = sess.run(v)
+ self.assertAllEqual(tf_v, 0)
+
def testInfinity(self):
for dtype in [np.float32, np.float64]:
for special_value_x in [-np.inf, np.inf]:
@@ -435,6 +450,13 @@ class ProdReductionTest(BaseReductionTest):
reduction_axes = tuple(reduction_axes)
return np.prod(x, axis=reduction_axes, keepdims=keep_dims)
+ def testAxesType(self):
+ for dtype in [dtypes.int64, dtypes.int32]:
+ with self.test_session(use_gpu=True) as sess:
+ v = math_ops.reduce_prod([0, 0], constant_op.constant(0, dtype=dtype))
+ tf_v = sess.run(v)
+ self.assertAllEqual(tf_v, 0)
+
def testInfinity(self):
for dtype in [np.float32, np.float64]:
for special_value_x in [-np.inf, np.inf]:
@@ -531,6 +553,13 @@ class MinReductionTest(test.TestCase):
self._compare(x, reduction_axes, True, use_gpu=True)
self._compare(x, reduction_axes, True, use_gpu=False)
+ def testAxesType(self):
+ for dtype in [dtypes.int64, dtypes.int32]:
+ with self.test_session(use_gpu=True) as sess:
+ v = math_ops.reduce_min([0, 0], constant_op.constant(0, dtype=dtype))
+ tf_v = sess.run(v)
+ self.assertAllEqual(tf_v, 0)
+
def testInfinity(self):
for dtype in [np.float32, np.float64]:
for special_value_x in [-np.inf, np.inf]:
@@ -637,6 +666,13 @@ class MaxReductionTest(test.TestCase):
self._compare(x, reduction_axes, True, use_gpu=True)
self._compare(x, reduction_axes, True, use_gpu=False)
+ def testAxesType(self):
+ for dtype in [dtypes.int64, dtypes.int32]:
+ with self.test_session(use_gpu=True) as sess:
+ v = math_ops.reduce_max([0, 0], constant_op.constant(0, dtype=dtype))
+ tf_v = sess.run(v)
+ self.assertAllEqual(tf_v, 0)
+
def testInfinity(self):
for dtype in [np.float32, np.float64]:
for special_value_x in [-np.inf, np.inf]:
@@ -757,6 +793,14 @@ class AllReductionTest(test.TestCase):
self._compare(x, reduction_axes, True, use_gpu=True)
self._compare(x, reduction_axes, True, use_gpu=False)
+ def testAxesType(self):
+ for dtype in [dtypes.int64, dtypes.int32]:
+ with self.test_session(use_gpu=True) as sess:
+ v = math_ops.reduce_all([True, True],
+ constant_op.constant(0, dtype=dtype))
+ tf_v = sess.run(v)
+ self.assertAllEqual(tf_v, True)
+
def testAll3D(self):
# Create a 3D array of bools and reduce across all possible
# dimensions
@@ -798,6 +842,14 @@ class AnyReductionTest(test.TestCase):
self._compare(x, reduction_axes, True, use_gpu=True)
self._compare(x, reduction_axes, True, use_gpu=False)
+ def testAxesType(self):
+ for dtype in [dtypes.int64, dtypes.int32]:
+ with self.test_session(use_gpu=True) as sess:
+ v = math_ops.reduce_any([True, True],
+ constant_op.constant(0, dtype=dtype))
+ tf_v = sess.run(v)
+ self.assertAllEqual(tf_v, True)
+
def testAll3D(self):
# Create a 3D array of bools and reduce across all possible
# dimensions