diff options
author | 2017-10-21 23:01:42 -0700 | |
---|---|---|
committer | 2017-10-21 23:01:42 -0700 | |
commit | 1c1dad105a57bb13711492a8ba5ab9d10c91b5df (patch) | |
tree | 5e440f8890d176b05db9a8e4bd9f7b00f1496c60 /tensorflow/core/kernels/reduction_ops_common.cc | |
parent | 17096081eed7881c0b8ce3c32b5e9795619e27bb (diff) |
Add int64 axis support for reduction ops. (#13891)
* Add int64 axis support for reduction ops.
This fix is a follow up to PR 13863. In PR 13863 the
program crash is fixed if int64 axis is passed to reduction ops,
e.g. reduce_sum, reduce_max, etc. However, 13863 does not
process the case of int64 support, it merely fixes the crash.
This fix adds the support for int64 axis of reduction ops.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add int64 axis support for mean, prod, sum
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add int64 axis support for min and max.
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add int64 axis support for reduce_all and reduce_any
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
* Add test cases for int64 axis support of reduce_any and reduce_all
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
Diffstat (limited to 'tensorflow/core/kernels/reduction_ops_common.cc')
-rw-r--r-- | tensorflow/core/kernels/reduction_ops_common.cc | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/reduction_ops_common.cc b/tensorflow/core/kernels/reduction_ops_common.cc index 5eba4288ac..8daab0d6be 100644 --- a/tensorflow/core/kernels/reduction_ops_common.cc +++ b/tensorflow/core/kernels/reduction_ops_common.cc @@ -57,13 +57,12 @@ gtl::InlinedVector<int32, 8> ReductionHelper::permutation() { return perm; } -Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis, - const bool keep_dims) { - // bitmap[i] indicates whether to reduce data along i-th axis. - gtl::InlinedVector<bool, 4> bitmap(data.dims(), false); - auto axis_vec = axis.flat<int32>(); +template <typename Tperm> +Status SimplifyHelper(const Tensor& data, const Tensor& axis, + gtl::InlinedVector<bool, 4>& bitmap) { + auto axis_vec = axis.flat<Tperm>(); for (int64 i = 0; i < axis.NumElements(); ++i) { - int32 index = axis_vec(i); + Tperm index = axis_vec(i); if (index < -data.dims() || index >= data.dims()) { return errors::InvalidArgument("Invalid reduction dimension (", index, " for input with ", data.dims(), @@ -72,7 +71,18 @@ Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis, index = (index + data.dims()) % data.dims(); bitmap[index] = true; } + return Status::OK(); +} +Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis, + const bool keep_dims) { + // bitmap[i] indicates whether to reduce data along i-th axis. + gtl::InlinedVector<bool, 4> bitmap(data.dims(), false); + if (axis.dtype() == DT_INT32) { + TF_RETURN_IF_ERROR(SimplifyHelper<int32>(data, axis, bitmap)); + } else { + TF_RETURN_IF_ERROR(SimplifyHelper<int64>(data, axis, bitmap)); + } // Output tensor's dim sizes. out_shape_.clear(); for (int i = 0; i < data.dims(); ++i) { |