diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 12:47:46 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 12:47:46 -0800 |
commit | 1ac86001266db55b78086617fb68206b29748919 (patch) | |
tree | 77ef4a659d2743390cbed843726cb6a47b1229c0 /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | |
parent | 378bdfb7f0c4b2a8eb2b91c2a65f3bc1c57e689e (diff) |
Fixed the return type of coefficient wise operations. For example, the abs function returns a floating point value when called on a complex input.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | 58 |
1 files changed, 55 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index f7c784942..97f225f0a 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -34,9 +34,15 @@ struct TensorEvaluator typedef typename Derived::Packet PacketReturnType; typedef typename Derived::Dimensions Dimensions; + // NumDimensions is -1 for variable dim tensors + static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ? + internal::traits<Derived>::NumDimensions : 0; + enum { IsAligned = Derived::IsAligned, PacketAccess = Derived::PacketAccess, + Layout = Derived::Layout, + CoordAccess = NumCoords > 0, }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device) @@ -77,6 +83,24 @@ struct TensorEvaluator return internal::pstoret<Scalar, Packet, StoreMode>(m_data + index, x); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const { + eigen_assert(m_data); + if (Layout == ColMajor) { + return m_data[m_dims.IndexOfColMajor(coords)]; + } else { + return m_data[m_dims.IndexOfRowMajor(coords)]; + } + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<DenseIndex, NumCoords>& coords) { + eigen_assert(m_data); + if (Layout == ColMajor) { + return m_data[m_dims.IndexOfColMajor(coords)]; + } else { + return m_data[m_dims.IndexOfRowMajor(coords)]; + } + } + Scalar* data() const { return m_data; } protected: @@ -97,9 +121,15 @@ struct TensorEvaluator<const Derived, Device> typedef typename Derived::Packet PacketReturnType; typedef typename Derived::Dimensions Dimensions; + // NumDimensions is -1 for variable dim tensors + static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ? + internal::traits<Derived>::NumDimensions : 0; + enum { IsAligned = Derived::IsAligned, PacketAccess = Derived::PacketAccess, + Layout = Derived::Layout, + CoordAccess = NumCoords > 0, }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device&) @@ -126,6 +156,17 @@ struct TensorEvaluator<const Derived, Device> return internal::ploadt_ro<Packet, LoadMode>(m_data + index); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const { + eigen_assert(m_data); + const Index index = (Layout == ColMajor) ? m_dims.IndexOfColMajor(coords) + : m_dims.IndexOfRowMajor(coords); +#ifdef __CUDA_ARCH__ + return __ldg(m_data+index); +#else + return m_data[index]; +#endif + } + const Scalar* data() const { return m_data; } protected: @@ -146,6 +187,8 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device> enum { IsAligned = true, PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess, + Layout = TensorEvaluator<ArgType, Device>::Layout, + CoordAccess = false, // to be implemented }; EIGEN_DEVICE_FUNC @@ -194,6 +237,8 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device> enum { IsAligned = TensorEvaluator<ArgType, Device>::IsAligned, PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess, + Layout = TensorEvaluator<ArgType, Device>::Layout, + CoordAccess = false, // to be implemented }; EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) @@ -247,6 +292,8 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned, PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess & internal::functor_traits<BinaryOp>::PacketAccess, + Layout = TensorEvaluator<LeftArgType, Device>::Layout, + CoordAccess = false, // to be implemented }; EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) @@ -254,7 +301,8 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device) { - eigen_assert(internal::dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions())); + EIGEN_STATIC_ASSERT((TensorEvaluator<LeftArgType, Device>::Layout == TensorEvaluator<RightArgType, Device>::Layout || internal::traits<XprType>::NumDimensions == 1), YOU_MADE_A_PROGRAMMING_MISTAKE); + eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions())); } typedef typename XprType::Index Index; @@ -309,6 +357,8 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned, PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess/* & TensorEvaluator<IfArgType>::PacketAccess*/, + Layout = TensorEvaluator<IfArgType, Device>::Layout, + CoordAccess = false, // to be implemented }; EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) @@ -316,8 +366,10 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> m_thenImpl(op.thenExpression(), device), m_elseImpl(op.elseExpression(), device) { - eigen_assert(internal::dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions())); - eigen_assert(internal::dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions())); + EIGEN_STATIC_ASSERT((TensorEvaluator<IfArgType, Device>::Layout == TensorEvaluator<ThenArgType, Device>::Layout), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((TensorEvaluator<IfArgType, Device>::Layout == TensorEvaluator<ElseArgType, Device>::Layout), YOU_MADE_A_PROGRAMMING_MISTAKE); + eigen_assert(dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions())); + eigen_assert(dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions())); } typedef typename XprType::Index Index; |