aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 12:47:46 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-01-14 12:47:46 -0800
commit1ac86001266db55b78086617fb68206b29748919 (patch)
tree77ef4a659d2743390cbed843726cb6a47b1229c0 /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
parent378bdfb7f0c4b2a8eb2b91c2a65f3bc1c57e689e (diff)
Fixed the return type of coefficient wise operations. For example, the abs function returns a floating point value when called on a complex input.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h58
1 files changed, 55 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index f7c784942..97f225f0a 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -34,9 +34,15 @@ struct TensorEvaluator
typedef typename Derived::Packet PacketReturnType;
typedef typename Derived::Dimensions Dimensions;
+ // NumDimensions is -1 for variable dim tensors
+ static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
+ internal::traits<Derived>::NumDimensions : 0;
+
enum {
IsAligned = Derived::IsAligned,
PacketAccess = Derived::PacketAccess,
+ Layout = Derived::Layout,
+ CoordAccess = NumCoords > 0,
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
@@ -77,6 +83,24 @@ struct TensorEvaluator
return internal::pstoret<Scalar, Packet, StoreMode>(m_data + index, x);
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
+ eigen_assert(m_data);
+ if (Layout == ColMajor) {
+ return m_data[m_dims.IndexOfColMajor(coords)];
+ } else {
+ return m_data[m_dims.IndexOfRowMajor(coords)];
+ }
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<DenseIndex, NumCoords>& coords) {
+ eigen_assert(m_data);
+ if (Layout == ColMajor) {
+ return m_data[m_dims.IndexOfColMajor(coords)];
+ } else {
+ return m_data[m_dims.IndexOfRowMajor(coords)];
+ }
+ }
+
Scalar* data() const { return m_data; }
protected:
@@ -97,9 +121,15 @@ struct TensorEvaluator<const Derived, Device>
typedef typename Derived::Packet PacketReturnType;
typedef typename Derived::Dimensions Dimensions;
+ // NumDimensions is -1 for variable dim tensors
+ static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
+ internal::traits<Derived>::NumDimensions : 0;
+
enum {
IsAligned = Derived::IsAligned,
PacketAccess = Derived::PacketAccess,
+ Layout = Derived::Layout,
+ CoordAccess = NumCoords > 0,
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device&)
@@ -126,6 +156,17 @@ struct TensorEvaluator<const Derived, Device>
return internal::ploadt_ro<Packet, LoadMode>(m_data + index);
}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
+ eigen_assert(m_data);
+ const Index index = (Layout == ColMajor) ? m_dims.IndexOfColMajor(coords)
+ : m_dims.IndexOfRowMajor(coords);
+#ifdef __CUDA_ARCH__
+ return __ldg(m_data+index);
+#else
+ return m_data[index];
+#endif
+ }
+
const Scalar* data() const { return m_data; }
protected:
@@ -146,6 +187,8 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device>
enum {
IsAligned = true,
PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess,
+ Layout = TensorEvaluator<ArgType, Device>::Layout,
+ CoordAccess = false, // to be implemented
};
EIGEN_DEVICE_FUNC
@@ -194,6 +237,8 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device>
enum {
IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess,
+ Layout = TensorEvaluator<ArgType, Device>::Layout,
+ CoordAccess = false, // to be implemented
};
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
@@ -247,6 +292,8 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned,
PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess &
internal::functor_traits<BinaryOp>::PacketAccess,
+ Layout = TensorEvaluator<LeftArgType, Device>::Layout,
+ CoordAccess = false, // to be implemented
};
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
@@ -254,7 +301,8 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
m_leftImpl(op.lhsExpression(), device),
m_rightImpl(op.rhsExpression(), device)
{
- eigen_assert(internal::dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions()));
+ EIGEN_STATIC_ASSERT((TensorEvaluator<LeftArgType, Device>::Layout == TensorEvaluator<RightArgType, Device>::Layout || internal::traits<XprType>::NumDimensions == 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions()));
}
typedef typename XprType::Index Index;
@@ -309,6 +357,8 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned,
PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess/* &
TensorEvaluator<IfArgType>::PacketAccess*/,
+ Layout = TensorEvaluator<IfArgType, Device>::Layout,
+ CoordAccess = false, // to be implemented
};
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
@@ -316,8 +366,10 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
m_thenImpl(op.thenExpression(), device),
m_elseImpl(op.elseExpression(), device)
{
- eigen_assert(internal::dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions()));
- eigen_assert(internal::dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions()));
+ EIGEN_STATIC_ASSERT((TensorEvaluator<IfArgType, Device>::Layout == TensorEvaluator<ThenArgType, Device>::Layout), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ EIGEN_STATIC_ASSERT((TensorEvaluator<IfArgType, Device>::Layout == TensorEvaluator<ElseArgType, Device>::Layout), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ eigen_assert(dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions()));
+ eigen_assert(dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions()));
}
typedef typename XprType::Index Index;