diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 15:38:48 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 15:38:48 -0800 |
commit | f697df723798779bc29d9f7299bb5398767d5db0 (patch) | |
tree | c155c21ad9ef0e6269f6af83fe2f29f97a0c0e21 /unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h | |
parent | 6559d09c60fb4acfc7ee5197284f576ac14926f1 (diff) |
Improved support for RowMajor tensors
Misc fixes and API cleanups.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h | 46 |
1 files changed, 41 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h b/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h index e2fe32d67..1c03d202f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h @@ -24,11 +24,14 @@ template<typename PatchDim, typename XprType> struct traits<TensorPatchOp<PatchDim, XprType> > : public traits<XprType> { typedef typename XprType::Scalar Scalar; - typedef typename internal::packet_traits<Scalar>::type Packet; - typedef typename traits<XprType>::StorageKind StorageKind; - typedef typename traits<XprType>::Index Index; + typedef traits<XprType> XprTraits; + typedef typename packet_traits<Scalar>::type Packet; + typedef typename XprTraits::StorageKind StorageKind; + typedef typename XprTraits::Index Index; typedef typename XprType::Nested Nested; typedef typename remove_reference<Nested>::type _Nested; + static const int NumDimensions = XprTraits::NumDimensions + 1; + static const int Layout = XprTraits::Layout; }; template<typename PatchDim, typename XprType> @@ -89,11 +92,16 @@ struct TensorEvaluator<const TensorPatchOp<PatchDim, ArgType>, Device> enum { IsAligned = false, PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, - }; + Layout = TensorEvaluator<ArgType, Device>::Layout, + CoordAccess = true, + }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_impl(op.expression(), device) { + // Only column major tensors are supported for now. + EIGEN_STATIC_ASSERT((Layout == ColMajor), YOU_MADE_A_PROGRAMMING_MISTAKE); + Index num_patches = 1; const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions(); const PatchDim& patch_dims = op.patch_dims(); @@ -195,6 +203,35 @@ struct TensorEvaluator<const TensorPatchOp<PatchDim, ArgType>, Device> } } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<Index, NumDims>& coords) const + { + // Location of the first element of the patch. + const Index patchIndex = coords[NumDims - 1]; + + if (TensorEvaluator<ArgType, Device>::CoordAccess) { + array<Index, NumDims-1> inputCoords; + for (int i = NumDims - 2; i > 0; --i) { + const Index patchIdx = patchIndex / m_patchStrides[i]; + patchIndex -= patchIdx * m_patchStrides[i]; + const Index offsetIdx = coords[i]; + inputCoords[i] = coords[i] + patchIdx; + } + inputCoords[0] = (patchIndex + coords[0]); + return m_impl.coeff(inputCoords); + } + else { + Index inputIndex = 0; + for (int i = NumDims - 2; i > 0; --i) { + const Index patchIdx = patchIndex / m_patchStrides[i]; + patchIndex -= patchIdx * m_patchStrides[i]; + const Index offsetIdx = coords[i]; + inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i]; + } + inputIndex += (patchIndex + coords[0]); + return m_impl.coeff(inputIndex); + } + } + Scalar* data() const { return NULL; } protected: @@ -206,7 +243,6 @@ struct TensorEvaluator<const TensorPatchOp<PatchDim, ArgType>, Device> TensorEvaluator<ArgType, Device> m_impl; }; - } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_PATCH_H |