diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h | 61 |
1 files changed, 42 insertions, 19 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h b/unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h index 7acdbfc72..ecfdb762c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h @@ -48,7 +48,7 @@ struct nested<TensorStridingOp<Strides, XprType>, 1, typename eval<TensorStridin template<typename Strides, typename XprType> -class TensorStridingOp : public TensorBase<TensorStridingOp<Strides, XprType>, WriteAccessors> +class TensorStridingOp : public TensorBase<TensorStridingOp<Strides, XprType> > { public: typedef typename Eigen::internal::traits<TensorStridingOp>::Scalar Scalar; @@ -97,7 +97,7 @@ struct TensorEvaluator<const TensorStridingOp<Strides, ArgType>, Device> enum { IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/false, - PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/false, + PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) @@ -109,28 +109,23 @@ struct TensorEvaluator<const TensorStridingOp<Strides, ArgType>, Device> } const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions(); - for (int i = 0; i < NumDims; ++i) { - if (i > 0) { - m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1]; - m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1]; - } else { - m_inputStrides[0] = 1; - m_outputStrides[0] = 1; - } - } - for (int i = 0; i < NumDims; ++i) { - m_inputStrides[i] *= op.strides()[i]; + m_outputStrides[0] = 1; + m_inputStrides[0] = 1; + for (int i = 1; i < NumDims; ++i) { + m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1]; + m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1]; + m_inputStrides[i-1] *= op.strides()[i-1]; } + m_inputStrides[NumDims-1] *= op.strides()[NumDims-1]; } - // typedef typename XprType::Index Index; typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) { m_impl.evalSubExprsIfNeeded(NULL); return true; } @@ -150,16 +145,44 @@ struct TensorEvaluator<const TensorStridingOp<Strides, ArgType>, Device> return m_impl.coeff(inputIndex); } - /* template<int LoadMode> + template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const { - return m_impl.template packet<LoadMode>(index); - }*/ + const int packetSize = internal::unpacket_traits<PacketReturnType>::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(index+packetSize-1 < dimensions().TotalSize()); + + Index inputIndices[] = {0, 0}; + Index indices[] = {index, index + packetSize - 1}; + for (int i = NumDims - 1; i > 0; --i) { + const Index idx0 = indices[0] / m_outputStrides[i]; + const Index idx1 = indices[1] / m_outputStrides[i]; + inputIndices[0] += idx0 * m_inputStrides[i]; + inputIndices[1] += idx1 * m_inputStrides[i]; + indices[0] -= idx0 * m_outputStrides[i]; + indices[1] -= idx1 * m_outputStrides[i]; + } + inputIndices[0] += indices[0] * m_inputStrides[0]; + inputIndices[1] += indices[1] * m_inputStrides[0]; + if (inputIndices[1] - inputIndices[0] == packetSize - 1) { + PacketReturnType rslt = m_impl.template packet<Unaligned>(inputIndices[0]); + return rslt; + } + else { + EIGEN_ALIGN_DEFAULT typename internal::remove_const<CoeffReturnType>::type values[packetSize]; + values[0] = m_impl.coeff(inputIndices[0]); + values[packetSize-1] = m_impl.coeff(inputIndices[1]); + for (int i = 1; i < packetSize-1; ++i) { + values[i] = coeff(index+i); + } + PacketReturnType rslt = internal::pload<PacketReturnType>(values); + return rslt; + } + } Scalar* data() const { return NULL; } protected: - // Strides m_strides; Dimensions m_dimensions; array<Index, NumDims> m_outputStrides; array<Index, NumDims> m_inputStrides; |