aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-11-23 15:58:47 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-11-23 15:58:47 -0800
commit44848ac39bba2ba25514c6c897f5dc7bba1c76ae (patch)
treed5049542b597626737fba54d72e7fa5a5bc4a960 /unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h
parent547a8608e5ff329c0f4e2da38c6eae023fc75647 (diff)
Fixed a bug in TensorArgMax.h
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h29
1 files changed, 12 insertions, 17 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h b/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h
index d4f9a725d..c783aab97 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h
@@ -215,10 +215,17 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
: m_orig_impl(op.expression(), device),
m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device),
- m_return_dim(op.return_dim()),
- m_stride_mod(gen_stride_mod(m_orig_impl.dimensions())),
- m_stride_div(gen_stride_div()) {
+ m_return_dim(op.return_dim()) {
+
gen_strides(m_orig_impl.dimensions(), m_strides);
+ if (Layout == static_cast<int>(ColMajor)) {
+ const Index total_size = internal::array_prod(m_orig_impl.dimensions());
+ m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
+ } else {
+ const Index total_size = internal::array_prod(m_orig_impl.dimensions());
+ m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
+ }
+ m_stride_div = m_strides[m_return_dim];
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
@@ -263,25 +270,13 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi
}
}
- EIGEN_DEVICE_FUNC Index gen_stride_mod(const InputDimensions& dims) {
- if (Layout == static_cast<int>(ColMajor)) {
- return (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : dims.TotalSize();
- } else {
- return (m_return_dim > 0) ? m_strides[m_return_dim - 1] : dims.TotalSize();
- }
- }
-
- EIGEN_DEVICE_FUNC Index gen_stride_div() {
- return m_strides[m_return_dim];
- }
-
protected:
TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl;
TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl;
const int m_return_dim;
StrideDims m_strides;
- const Index m_stride_mod;
- const Index m_stride_div;
+ Index m_stride_mod;
+ Index m_stride_div;
};
} // end namespace Eigen