aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/reduction_ops_common.cc
diff options
context:
space:
mode:
authorGravatar Yong Tang <yong.tang.github@outlook.com>2017-10-21 23:01:42 -0700
committerGravatar Vijay Vasudevan <vrv@google.com>2017-10-21 23:01:42 -0700
commit1c1dad105a57bb13711492a8ba5ab9d10c91b5df (patch)
tree5e440f8890d176b05db9a8e4bd9f7b00f1496c60 /tensorflow/core/kernels/reduction_ops_common.cc
parent17096081eed7881c0b8ce3c32b5e9795619e27bb (diff)
Add int64 axis support for reduction ops. (#13891)
* Add int64 axis support for reduction ops. This fix is a follow up to PR 13863. In PR 13863 the program crash is fixed if int64 axis is passed to reduction ops, e.g. reduce_sum, reduce_max, etc. However, 13863 does not process the case of int64 support, it merely fixes the crash. This fix adds the support for int64 axis of reduction ops. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add int64 axis support for mean, prod, sum Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add int64 axis support for min and max. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add int64 axis support for reduce_all and reduce_any Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test cases for int64 axis support of reduce_any and reduce_all Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/core/kernels/reduction_ops_common.cc')
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.cc22
1 files changed, 16 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/reduction_ops_common.cc b/tensorflow/core/kernels/reduction_ops_common.cc
index 5eba4288ac..8daab0d6be 100644
--- a/tensorflow/core/kernels/reduction_ops_common.cc
+++ b/tensorflow/core/kernels/reduction_ops_common.cc
@@ -57,13 +57,12 @@ gtl::InlinedVector<int32, 8> ReductionHelper::permutation() {
return perm;
}
-Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
- const bool keep_dims) {
- // bitmap[i] indicates whether to reduce data along i-th axis.
- gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
- auto axis_vec = axis.flat<int32>();
+template <typename Tperm>
+Status SimplifyHelper(const Tensor& data, const Tensor& axis,
+ gtl::InlinedVector<bool, 4>& bitmap) {
+ auto axis_vec = axis.flat<Tperm>();
for (int64 i = 0; i < axis.NumElements(); ++i) {
- int32 index = axis_vec(i);
+ Tperm index = axis_vec(i);
if (index < -data.dims() || index >= data.dims()) {
return errors::InvalidArgument("Invalid reduction dimension (", index,
" for input with ", data.dims(),
@@ -72,7 +71,18 @@ Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
index = (index + data.dims()) % data.dims();
bitmap[index] = true;
}
+ return Status::OK();
+}
+Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
+ const bool keep_dims) {
+ // bitmap[i] indicates whether to reduce data along i-th axis.
+ gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
+ if (axis.dtype() == DT_INT32) {
+ TF_RETURN_IF_ERROR(SimplifyHelper<int32>(data, axis, bitmap));
+ } else {
+ TF_RETURN_IF_ERROR(SimplifyHelper<int64>(data, axis, bitmap));
+ }
// Output tensor's dim sizes.
out_shape_.clear();
for (int i = 0; i < data.dims(); ++i) {