diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 15:38:48 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 15:38:48 -0800 |
commit | f697df723798779bc29d9f7299bb5398767d5db0 (patch) | |
tree | c155c21ad9ef0e6269f6af83fe2f29f97a0c0e21 /unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h | |
parent | 6559d09c60fb4acfc7ee5197284f576ac14926f1 (diff) |
Improved support for RowMajor tensors
Misc fixes and API cleanups.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h | 75 |
1 files changed, 58 insertions, 17 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h index 74485b15b..fb4e7fb11 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h @@ -35,6 +35,8 @@ struct traits<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> > typedef typename RhsXprType::Nested RhsNested; typedef typename remove_reference<LhsNested>::type _LhsNested; typedef typename remove_reference<RhsNested>::type _RhsNested; + static const int NumDimensions = traits<LhsXprType>::NumDimensions; + static const int Layout = traits<LhsXprType>::Layout; enum { Flags = 0 }; }; @@ -103,11 +105,13 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy enum { IsAligned = false, PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess, + Layout = TensorEvaluator<LeftArgType, Device>::Layout, }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_axis(op.axis()) { + EIGEN_STATIC_ASSERT((TensorEvaluator<LeftArgType, Device>::Layout == TensorEvaluator<RightArgType, Device>::Layout || NumDims == 1), YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(NumDims == RightNumDims, YOU_MADE_A_PROGRAMMING_MISTAKE) eigen_assert(0 <= m_axis && m_axis < NumDims); const Dimensions& lhs_dims = m_leftImpl.dimensions(); @@ -127,13 +131,26 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy m_dimensions[i] = lhs_dims[i]; } - m_leftStrides[0] = 1; - m_rightStrides[0] = 1; - m_outputStrides[0] = 1; - for (int i = 1; i < NumDims; ++i) { - m_leftStrides[i] = m_leftStrides[i-1] * lhs_dims[i-1]; - m_rightStrides[i] = m_rightStrides[i-1] * rhs_dims[i-1]; - m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1]; + if (Layout == ColMajor) { + m_leftStrides[0] = 1; + m_rightStrides[0] = 1; + m_outputStrides[0] = 1; + + for (int i = 1; i < NumDims; ++i) { + m_leftStrides[i] = m_leftStrides[i-1] * lhs_dims[i-1]; + m_rightStrides[i] = m_rightStrides[i-1] * rhs_dims[i-1]; + m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1]; + } + } else { + m_leftStrides[NumDims - 1] = 1; + m_rightStrides[NumDims - 1] = 1; + m_outputStrides[NumDims - 1] = 1; + + for (int i = NumDims - 2; i >= 0; --i) { + m_leftStrides[i] = m_leftStrides[i+1] * lhs_dims[i+1]; + m_rightStrides[i] = m_rightStrides[i+1] * rhs_dims[i+1]; + m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1]; + } } } @@ -159,25 +176,49 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy { // Collect dimension-wise indices (subs). array<Index, NumDims> subs; - for (int i = NumDims - 1; i > 0; --i) { - subs[i] = index / m_outputStrides[i]; - index -= subs[i] * m_outputStrides[i]; + if (Layout == ColMajor) { + for (int i = NumDims - 1; i > 0; --i) { + subs[i] = index / m_outputStrides[i]; + index -= subs[i] * m_outputStrides[i]; + } + subs[0] = index; + } else { + for (int i = 0; i < NumDims - 1; ++i) { + subs[i] = index / m_outputStrides[i]; + index -= subs[i] * m_outputStrides[i]; + } + subs[NumDims - 1] = index; } - subs[0] = index; const Dimensions& left_dims = m_leftImpl.dimensions(); if (subs[m_axis] < left_dims[m_axis]) { - Index left_index = subs[0]; - for (int i = 1; i < NumDims; ++i) { - left_index += (subs[i] % left_dims[i]) * m_leftStrides[i]; + Index left_index; + if (Layout == ColMajor) { + left_index = subs[0]; + for (int i = 1; i < NumDims; ++i) { + left_index += (subs[i] % left_dims[i]) * m_leftStrides[i]; + } + } else { + left_index = subs[NumDims - 1]; + for (int i = NumDims - 2; i >= 0; --i) { + left_index += (subs[i] % left_dims[i]) * m_leftStrides[i]; + } } return m_leftImpl.coeff(left_index); } else { subs[m_axis] -= left_dims[m_axis]; const Dimensions& right_dims = m_rightImpl.dimensions(); - Index right_index = subs[0]; - for (int i = 1; i < NumDims; ++i) { - right_index += (subs[i] % right_dims[i]) * m_rightStrides[i]; + Index right_index; + if (Layout == ColMajor) { + right_index = subs[0]; + for (int i = 1; i < NumDims; ++i) { + right_index += (subs[i] % right_dims[i]) * m_rightStrides[i]; + } + } else { + right_index = subs[NumDims - 1]; + for (int i = NumDims - 2; i >= 0; --i) { + right_index += (subs[i] % right_dims[i]) * m_rightStrides[i]; + } } return m_rightImpl.coeff(right_index); } |