From 75c333f94c51344d610d886858265eddfa5abbbd Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 27 Jun 2016 10:32:38 -0700 Subject: Don't store the scan axis in the evaluator of the tensor scan operation since it's only used in the constructor. Also avoid taking references to values that may becomes stale after a copy construction. --- unsupported/Eigen/CXX11/src/Tensor/TensorScan.h | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorScan.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h b/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h index 1aa196b84..ba165ad4d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h @@ -101,32 +101,31 @@ struct TensorEvaluator, Device> { const Device& device) : m_impl(op.expression(), device), m_device(device), - m_axis(op.axis()), m_exclusive(op.exclusive()), m_accumulator(op.accumulator()), - m_dimensions(m_impl.dimensions()), - m_size(m_dimensions[m_axis]), + m_size(m_impl.dimensions()[op.axis()]), m_stride(1), m_output(NULL) { // Accumulating a scalar isn't supported. EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE); - eigen_assert(m_axis >= 0 && m_axis < NumDims); + eigen_assert(op.axis() >= 0 && op.axis() < NumDims); // Compute stride of scan axis + const Dimensions& dims = m_impl.dimensions(); if (static_cast(Layout) == static_cast(ColMajor)) { - for (int i = 0; i < m_axis; ++i) { - m_stride = m_stride * m_dimensions[i]; + for (int i = 0; i < op.axis(); ++i) { + m_stride = m_stride * dims[i]; } } else { - for (int i = NumDims - 1; i > m_axis; --i) { - m_stride = m_stride * m_dimensions[i]; + for (int i = NumDims - 1; i > op.axis(); --i) { + m_stride = m_stride * dims[i]; } } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { - return m_dimensions; + return m_impl.dimensions(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { @@ -135,7 +134,8 @@ struct TensorEvaluator, Device> { accumulateTo(data); return false; } else { - m_output = static_cast(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); + const Index total_size = internal::array_prod(dimensions()); + m_output = static_cast(m_device.allocate(total_size * sizeof(Scalar))); accumulateTo(m_output); return true; } @@ -171,11 +171,9 @@ struct TensorEvaluator, Device> { protected: TensorEvaluator m_impl; const Device& m_device; - const Index m_axis; const bool m_exclusive; Op m_accumulator; - const Dimensions& m_dimensions; - const Index& m_size; + const Index m_size; Index m_stride; CoeffReturnType* m_output; -- cgit v1.2.3