aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-05-16 15:08:05 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-05-16 15:08:05 -0700
commit7402fea0a8e63e3ea248257047c584afee8f8bde (patch)
tree429aee7ea314c579ed62c1c5e1ff84850b14370a /unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
parent0320f7e3a71406b9a03d1bab0d168fd76e63d457 (diff)
Vectorized the evaluation of tensor expression (using SSE, AVX, NEON, ...)
Added the ability to parallelize the evaluation of a tensor expression over multiple cpu cores. Added the ability to offload the evaluation of a tensor expression to a GPU.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMap.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMap.h158
1 files changed, 147 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
index bb0b39c5a..3fc9c5335 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h
@@ -22,16 +22,16 @@ template<int InnerStrideAtCompileTime, int OuterStrideAtCompileTime> class Strid
*
*/
-template<typename PlainObjectType> class TensorMap : public TensorBase<TensorMap<PlainObjectType> >
+template<typename PlainObjectType, int Options_> class TensorMap : public TensorBase<TensorMap<PlainObjectType, Options_> >
{
public:
- typedef TensorMap<PlainObjectType> Self;
+ typedef TensorMap<PlainObjectType, Options_> Self;
typedef typename PlainObjectType::Base Base;
typedef typename Eigen::internal::nested<Self>::type Nested;
typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
typedef typename internal::traits<PlainObjectType>::Index Index;
typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
- typedef typename internal::packet_traits<Scalar>::type PacketScalar;
+ typedef typename internal::packet_traits<Scalar>::type Packet;
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef typename Base::CoeffReturnType CoeffReturnType;
@@ -43,13 +43,12 @@ template<typename PlainObjectType> class TensorMap : public TensorBase<TensorMap
typedef Scalar* PointerType;
typedef PointerType PointerArgType;
- // Fixed size plain object type only
- /* EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr) : m_data(dataPtr) {
- // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
- //EIGEN_STATIC_ASSERT(1 == PlainObjectType::NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
- // todo: add assert to ensure we don't screw up here.
- }*/
+ static const int Options = Options_;
+
+ enum {
+ IsAligned = bool(EIGEN_ALIGN) && ((int(Options_)&Aligned)==Aligned),
+ PacketAccess = true,
+ };
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(array<DenseIndex, PlainObjectType::NumIndices>({{firstDimension}})) {
@@ -65,7 +64,7 @@ template<typename PlainObjectType> class TensorMap : public TensorBase<TensorMap
}
#endif
- inline TensorMap(PointerArgType dataPtr, const array<Index, PlainObjectType::NumIndices>& dimensions)
+ inline TensorMap(PointerArgType dataPtr, const array<Index, PlainObjectType::NumIndices>& dimensions)
: m_data(dataPtr), m_dimensions(dimensions)
{ }
@@ -81,11 +80,96 @@ template<typename PlainObjectType> class TensorMap : public TensorBase<TensorMap
EIGEN_STRONG_INLINE const Scalar* data() const { return m_data; }
EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, PlainObjectType::NumIndices>& indices) const
+ {
+ // eigen_assert(checkIndexRange(indices));
+ if (PlainObjectType::Options&RowMajor) {
+ const Index index = m_dimensions.IndexOfRowMajor(indices);
+ return m_data[index];
+ } else {
+ const Index index = m_dimensions.IndexOfColMajor(indices);
+ return m_data[index];
+ }
+ }
+
+#ifdef EIGEN_HAS_VARIADIC_TEMPLATES
+ template<typename... IndexTypes> EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, IndexTypes... otherIndices) const
+ {
+ static_assert(sizeof...(otherIndices) + 1 == PlainObjectType::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
+ if (PlainObjectType::Options&RowMajor) {
+ const Index index = m_dimensions.IndexOfRowMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}});
+ return m_data[index];
+ } else {
+ const Index index = m_dimensions.IndexOfColMajor(array<Index, PlainObjectType::NumIndices>{{firstIndex, otherIndices...}});
+ return m_data[index];
+ }
+ }
+#else
+ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar& operator()(Index index) const
{
eigen_internal_assert(index >= 0 && index < size());
return m_data[index];
}
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1) const
+ {
+ if (PlainObjectType::Options&RowMajor) {
+ const Index index = i1 + i0 * m_dimensions[0];
+ return m_data[index];
+ } else {
+ const Index index = i0 + i1 * m_dimensions[0];
+ return m_data[index];
+ }
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2) const
+ {
+ if (PlainObjectType::Options&RowMajor) {
+ const Index index = i2 + m_dimensions[1] * (i1 + m_dimensions[0] * i0);
+ return m_data[index];
+ } else {
+ const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
+ return m_data[index];
+ }
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3) const
+ {
+ if (PlainObjectType::Options&RowMajor) {
+ const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
+ return m_data[index];
+ } else {
+ const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3));
+ return m_data[index];
+ }
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
+ {
+ if (PlainObjectType::Options&RowMajor) {
+ const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
+ return m_data[index];
+ } else {
+ const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4)));
+ return m_data[index];
+ }
+ }
+#endif
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, PlainObjectType::NumIndices>& indices)
+ {
+ // eigen_assert(checkIndexRange(indices));
+ if (PlainObjectType::Options&RowMajor) {
+ const Index index = m_dimensions.IndexOfRowMajor(indices);
+ return m_data[index];
+ } else {
+ const Index index = m_dimensions.IndexOfColMajor(indices);
+ return m_data[index];
+ }
+ }
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
@@ -100,8 +184,60 @@ template<typename PlainObjectType> class TensorMap : public TensorBase<TensorMap
return m_data[index];
}
}
+#else
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& operator()(Index index)
+ {
+ eigen_internal_assert(index >= 0 && index < size());
+ return m_data[index];
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1)
+ {
+ if (PlainObjectType::Options&RowMajor) {
+ const Index index = i1 + i0 * m_dimensions[0];
+ return m_data[index];
+ } else {
+ const Index index = i0 + i1 * m_dimensions[0];
+ return m_data[index];
+ }
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2)
+ {
+ if (PlainObjectType::Options&RowMajor) {
+ const Index index = i2 + m_dimensions[1] * (i1 + m_dimensions[0] * i0);
+ return m_data[index];
+ } else {
+ const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
+ return m_data[index];
+ }
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
+ {
+ if (PlainObjectType::Options&RowMajor) {
+ const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
+ return m_data[index];
+ } else {
+ const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3));
+ return m_data[index];
+ }
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4)
+ {
+ if (PlainObjectType::Options&RowMajor) {
+ const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
+ return m_data[index];
+ } else {
+ const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4)));
+ return m_data[index];
+ }
+ }
#endif
+
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
Self& operator=(const OtherDerived& other)