aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 15:38:48 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 15:38:48 -0800
commitf697df723798779bc29d9f7299bb5398767d5db0 (patch)
treec155c21ad9ef0e6269f6af83fe2f29f97a0c0e21 /unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h
parent6559d09c60fb4acfc7ee5197284f576ac14926f1 (diff)
Improved support for RowMajor tensors
Misc fixes and API cleanups.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h46
1 files changed, 41 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h b/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h
index e2fe32d67..1c03d202f 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h
@@ -24,11 +24,14 @@ template<typename PatchDim, typename XprType>
struct traits<TensorPatchOp<PatchDim, XprType> > : public traits<XprType>
{
typedef typename XprType::Scalar Scalar;
- typedef typename internal::packet_traits<Scalar>::type Packet;
- typedef typename traits<XprType>::StorageKind StorageKind;
- typedef typename traits<XprType>::Index Index;
+ typedef traits<XprType> XprTraits;
+ typedef typename packet_traits<Scalar>::type Packet;
+ typedef typename XprTraits::StorageKind StorageKind;
+ typedef typename XprTraits::Index Index;
typedef typename XprType::Nested Nested;
typedef typename remove_reference<Nested>::type _Nested;
+ static const int NumDimensions = XprTraits::NumDimensions + 1;
+ static const int Layout = XprTraits::Layout;
};
template<typename PatchDim, typename XprType>
@@ -89,11 +92,16 @@ struct TensorEvaluator<const TensorPatchOp<PatchDim, ArgType>, Device>
enum {
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
- };
+ Layout = TensorEvaluator<ArgType, Device>::Layout,
+ CoordAccess = true,
+ };
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
: m_impl(op.expression(), device)
{
+ // Only column major tensors are supported for now.
+ EIGEN_STATIC_ASSERT((Layout == ColMajor), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
Index num_patches = 1;
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
const PatchDim& patch_dims = op.patch_dims();
@@ -195,6 +203,35 @@ struct TensorEvaluator<const TensorPatchOp<PatchDim, ArgType>, Device>
}
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<Index, NumDims>& coords) const
+ {
+ // Location of the first element of the patch.
+ const Index patchIndex = coords[NumDims - 1];
+
+ if (TensorEvaluator<ArgType, Device>::CoordAccess) {
+ array<Index, NumDims-1> inputCoords;
+ for (int i = NumDims - 2; i > 0; --i) {
+ const Index patchIdx = patchIndex / m_patchStrides[i];
+ patchIndex -= patchIdx * m_patchStrides[i];
+ const Index offsetIdx = coords[i];
+ inputCoords[i] = coords[i] + patchIdx;
+ }
+ inputCoords[0] = (patchIndex + coords[0]);
+ return m_impl.coeff(inputCoords);
+ }
+ else {
+ Index inputIndex = 0;
+ for (int i = NumDims - 2; i > 0; --i) {
+ const Index patchIdx = patchIndex / m_patchStrides[i];
+ patchIndex -= patchIdx * m_patchStrides[i];
+ const Index offsetIdx = coords[i];
+ inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
+ }
+ inputIndex += (patchIndex + coords[0]);
+ return m_impl.coeff(inputIndex);
+ }
+ }
+
Scalar* data() const { return NULL; }
protected:
@@ -206,7 +243,6 @@ struct TensorEvaluator<const TensorPatchOp<PatchDim, ArgType>, Device>
TensorEvaluator<ArgType, Device> m_impl;
};
-
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_PATCH_H