From f697df723798779bc29d9f7299bb5398767d5db0 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 14 Jan 2015 15:38:48 -0800 Subject: Improved support for RowMajor tensors Misc fixes and API cleanups. --- .../Eigen/CXX11/src/Tensor/TensorConcatenation.h | 75 +++++++++++++++++----- 1 file changed, 58 insertions(+), 17 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h index 74485b15b..fb4e7fb11 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h @@ -35,6 +35,8 @@ struct traits > typedef typename RhsXprType::Nested RhsNested; typedef typename remove_reference::type _LhsNested; typedef typename remove_reference::type _RhsNested; + static const int NumDimensions = traits::NumDimensions; + static const int Layout = traits::Layout; enum { Flags = 0 }; }; @@ -103,11 +105,13 @@ struct TensorEvaluator::PacketAccess & TensorEvaluator::PacketAccess, + Layout = TensorEvaluator::Layout, }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_axis(op.axis()) { + EIGEN_STATIC_ASSERT((TensorEvaluator::Layout == TensorEvaluator::Layout || NumDims == 1), YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(NumDims == RightNumDims, YOU_MADE_A_PROGRAMMING_MISTAKE) eigen_assert(0 <= m_axis && m_axis < NumDims); const Dimensions& lhs_dims = m_leftImpl.dimensions(); @@ -127,13 +131,26 @@ struct TensorEvaluator= 0; --i) { + m_leftStrides[i] = m_leftStrides[i+1] * lhs_dims[i+1]; + m_rightStrides[i] = m_rightStrides[i+1] * rhs_dims[i+1]; + m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1]; + } } } @@ -159,25 +176,49 @@ struct TensorEvaluator subs; - for (int i = NumDims - 1; i > 0; --i) { - subs[i] = index / m_outputStrides[i]; - index -= subs[i] * m_outputStrides[i]; + if (Layout == ColMajor) { + for (int i = NumDims - 1; i > 0; --i) { + subs[i] = index / m_outputStrides[i]; + index -= subs[i] * m_outputStrides[i]; + } + subs[0] = index; + } else { + for (int i = 0; i < NumDims - 1; ++i) { + subs[i] = index / m_outputStrides[i]; + index -= subs[i] * m_outputStrides[i]; + } + subs[NumDims - 1] = index; } - subs[0] = index; const Dimensions& left_dims = m_leftImpl.dimensions(); if (subs[m_axis] < left_dims[m_axis]) { - Index left_index = subs[0]; - for (int i = 1; i < NumDims; ++i) { - left_index += (subs[i] % left_dims[i]) * m_leftStrides[i]; + Index left_index; + if (Layout == ColMajor) { + left_index = subs[0]; + for (int i = 1; i < NumDims; ++i) { + left_index += (subs[i] % left_dims[i]) * m_leftStrides[i]; + } + } else { + left_index = subs[NumDims - 1]; + for (int i = NumDims - 2; i >= 0; --i) { + left_index += (subs[i] % left_dims[i]) * m_leftStrides[i]; + } } return m_leftImpl.coeff(left_index); } else { subs[m_axis] -= left_dims[m_axis]; const Dimensions& right_dims = m_rightImpl.dimensions(); - Index right_index = subs[0]; - for (int i = 1; i < NumDims; ++i) { - right_index += (subs[i] % right_dims[i]) * m_rightStrides[i]; + Index right_index; + if (Layout == ColMajor) { + right_index = subs[0]; + for (int i = 1; i < NumDims; ++i) { + right_index += (subs[i] % right_dims[i]) * m_rightStrides[i]; + } + } else { + right_index = subs[NumDims - 1]; + for (int i = NumDims - 2; i >= 0; --i) { + right_index += (subs[i] % right_dims[i]) * m_rightStrides[i]; + } } return m_rightImpl.coeff(right_index); } -- cgit v1.2.3