From 44848ac39bba2ba25514c6c897f5dc7bba1c76ae Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 23 Nov 2015 15:58:47 -0800 Subject: Fixed a bug in TensorArgMax.h --- unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h | 29 ++++++++++------------- 1 file changed, 12 insertions(+), 17 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h b/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h index d4f9a725d..c783aab97 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h @@ -215,10 +215,17 @@ struct TensorEvaluator, Devi EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_orig_impl(op.expression(), device), m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device), - m_return_dim(op.return_dim()), - m_stride_mod(gen_stride_mod(m_orig_impl.dimensions())), - m_stride_div(gen_stride_div()) { + m_return_dim(op.return_dim()) { + gen_strides(m_orig_impl.dimensions(), m_strides); + if (Layout == static_cast(ColMajor)) { + const Index total_size = internal::array_prod(m_orig_impl.dimensions()); + m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size; + } else { + const Index total_size = internal::array_prod(m_orig_impl.dimensions()); + m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size; + } + m_stride_div = m_strides[m_return_dim]; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { @@ -263,25 +270,13 @@ struct TensorEvaluator, Devi } } - EIGEN_DEVICE_FUNC Index gen_stride_mod(const InputDimensions& dims) { - if (Layout == static_cast(ColMajor)) { - return (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : dims.TotalSize(); - } else { - return (m_return_dim > 0) ? m_strides[m_return_dim - 1] : dims.TotalSize(); - } - } - - EIGEN_DEVICE_FUNC Index gen_stride_div() { - return m_strides[m_return_dim]; - } - protected: TensorEvaluator, Device> m_orig_impl; TensorEvaluator >, Device> m_impl; const int m_return_dim; StrideDims m_strides; - const Index m_stride_mod; - const Index m_stride_div; + Index m_stride_mod; + Index m_stride_div; }; } // end namespace Eigen -- cgit v1.2.3