diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 12:47:46 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-01-14 12:47:46 -0800 |
commit | 1ac86001266db55b78086617fb68206b29748919 (patch) | |
tree | 77ef4a659d2743390cbed843726cb6a47b1229c0 /unsupported | |
parent | 378bdfb7f0c4b2a8eb2b91c2a65f3bc1c57e689e (diff) |
Fixed the return type of coefficient wise operations. For example, the abs function returns a floating point value when called on a complex input.
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | 58 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h | 87 |
2 files changed, 107 insertions, 38 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index f7c784942..97f225f0a 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -34,9 +34,15 @@ struct TensorEvaluator typedef typename Derived::Packet PacketReturnType; typedef typename Derived::Dimensions Dimensions; + // NumDimensions is -1 for variable dim tensors + static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ? + internal::traits<Derived>::NumDimensions : 0; + enum { IsAligned = Derived::IsAligned, PacketAccess = Derived::PacketAccess, + Layout = Derived::Layout, + CoordAccess = NumCoords > 0, }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device) @@ -77,6 +83,24 @@ struct TensorEvaluator return internal::pstoret<Scalar, Packet, StoreMode>(m_data + index, x); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const { + eigen_assert(m_data); + if (Layout == ColMajor) { + return m_data[m_dims.IndexOfColMajor(coords)]; + } else { + return m_data[m_dims.IndexOfRowMajor(coords)]; + } + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<DenseIndex, NumCoords>& coords) { + eigen_assert(m_data); + if (Layout == ColMajor) { + return m_data[m_dims.IndexOfColMajor(coords)]; + } else { + return m_data[m_dims.IndexOfRowMajor(coords)]; + } + } + Scalar* data() const { return m_data; } protected: @@ -97,9 +121,15 @@ struct TensorEvaluator<const Derived, Device> typedef typename Derived::Packet PacketReturnType; typedef typename Derived::Dimensions Dimensions; + // NumDimensions is -1 for variable dim tensors + static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ? + internal::traits<Derived>::NumDimensions : 0; + enum { IsAligned = Derived::IsAligned, PacketAccess = Derived::PacketAccess, + Layout = Derived::Layout, + CoordAccess = NumCoords > 0, }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device&) @@ -126,6 +156,17 @@ struct TensorEvaluator<const Derived, Device> return internal::ploadt_ro<Packet, LoadMode>(m_data + index); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const { + eigen_assert(m_data); + const Index index = (Layout == ColMajor) ? m_dims.IndexOfColMajor(coords) + : m_dims.IndexOfRowMajor(coords); +#ifdef __CUDA_ARCH__ + return __ldg(m_data+index); +#else + return m_data[index]; +#endif + } + const Scalar* data() const { return m_data; } protected: @@ -146,6 +187,8 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device> enum { IsAligned = true, PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess, + Layout = TensorEvaluator<ArgType, Device>::Layout, + CoordAccess = false, // to be implemented }; EIGEN_DEVICE_FUNC @@ -194,6 +237,8 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device> enum { IsAligned = TensorEvaluator<ArgType, Device>::IsAligned, PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess, + Layout = TensorEvaluator<ArgType, Device>::Layout, + CoordAccess = false, // to be implemented }; EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) @@ -247,6 +292,8 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned, PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess & internal::functor_traits<BinaryOp>::PacketAccess, + Layout = TensorEvaluator<LeftArgType, Device>::Layout, + CoordAccess = false, // to be implemented }; EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) @@ -254,7 +301,8 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device) { - eigen_assert(internal::dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions())); + EIGEN_STATIC_ASSERT((TensorEvaluator<LeftArgType, Device>::Layout == TensorEvaluator<RightArgType, Device>::Layout || internal::traits<XprType>::NumDimensions == 1), YOU_MADE_A_PROGRAMMING_MISTAKE); + eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions())); } typedef typename XprType::Index Index; @@ -309,6 +357,8 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned, PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess/* & TensorEvaluator<IfArgType>::PacketAccess*/, + Layout = TensorEvaluator<IfArgType, Device>::Layout, + CoordAccess = false, // to be implemented }; EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) @@ -316,8 +366,10 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> m_thenImpl(op.thenExpression(), device), m_elseImpl(op.elseExpression(), device) { - eigen_assert(internal::dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions())); - eigen_assert(internal::dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions())); + EIGEN_STATIC_ASSERT((TensorEvaluator<IfArgType, Device>::Layout == TensorEvaluator<ThenArgType, Device>::Layout), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((TensorEvaluator<IfArgType, Device>::Layout == TensorEvaluator<ElseArgType, Device>::Layout), YOU_MADE_A_PROGRAMMING_MISTAKE); + eigen_assert(dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions())); + eigen_assert(dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions())); } typedef typename XprType::Index Index; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h index 6e5503de1..b66b3ec2c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h @@ -17,14 +17,14 @@ namespace Eigen { * * \brief Tensor expression classes. * - * The TensorCwiseNullaryOp class applies a nullary operators to an expression. This - * is typically used to generate constants. + * The TensorCwiseNullaryOp class applies a nullary operators to an expression. + * This is typically used to generate constants. * * The TensorCwiseUnaryOp class represents an expression where a unary operator * (e.g. cwiseSqrt) is applied to an expression. * - * The TensorCwiseBinaryOp class represents an expression where a binary operator - * (e.g. addition) is applied to a lhs and a rhs expression. + * The TensorCwiseBinaryOp class represents an expression where a binary + * operator (e.g. addition) is applied to a lhs and a rhs expression. * */ namespace internal { @@ -33,9 +33,12 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> > : traits<XprType> { typedef typename XprType::Packet Packet; + typedef traits<XprType> XprTraits; typedef typename XprType::Scalar Scalar; typedef typename XprType::Nested XprTypeNested; typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; enum { Flags = 0, @@ -47,7 +50,7 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> > template<typename NullaryOp, typename XprType> -class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType> > +class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors> { public: typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar; @@ -81,12 +84,15 @@ template<typename UnaryOp, typename XprType> struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> > : traits<XprType> { - typedef typename result_of< - UnaryOp(typename XprType::Scalar) - >::type Scalar; + // TODO(phli): Add InputScalar, InputPacket. Check references to + // current Scalar/Packet to see if the intent is Input or Output. + typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar; + typedef traits<XprType> XprTraits; typedef typename internal::packet_traits<Scalar>::type Packet; typedef typename XprType::Nested XprTypeNested; typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; }; template<typename UnaryOp, typename XprType> @@ -106,14 +112,16 @@ struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwise template<typename UnaryOp, typename XprType> -class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType> > +class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors> { public: + // TODO(phli): Add InputScalar, InputPacket. Check references to + // current Scalar/Packet to see if the intent is Input or Output. typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar; typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Packet Packet; typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; - typedef typename XprType::CoeffReturnType CoeffReturnType; - typedef typename XprType::PacketReturnType PacketReturnType; + typedef Scalar CoeffReturnType; + typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType; typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested; typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind; typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index; @@ -139,22 +147,27 @@ namespace internal { template<typename BinaryOp, typename LhsXprType, typename RhsXprType> struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> > { - // Type promotion to handle the case where the types of the lhs and the rhs are different. + // Type promotion to handle the case where the types of the lhs and the rhs + // are different. + // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to + // current Scalar/Packet to see if the intent is Inputs or Output. typedef typename result_of< - BinaryOp( - typename LhsXprType::Scalar, - typename RhsXprType::Scalar - ) - >::type Scalar; + BinaryOp(typename LhsXprType::Scalar, + typename RhsXprType::Scalar)>::type Scalar; + typedef traits<LhsXprType> XprTraits; typedef typename internal::packet_traits<Scalar>::type Packet; - typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, - typename traits<RhsXprType>::StorageKind>::ret StorageKind; - typedef typename promote_index_type<typename traits<LhsXprType>::Index, - typename traits<RhsXprType>::Index>::type Index; + typedef typename promote_storage_type< + typename traits<LhsXprType>::StorageKind, + typename traits<RhsXprType>::StorageKind>::ret StorageKind; + typedef typename promote_index_type< + typename traits<LhsXprType>::Index, + typename traits<RhsXprType>::Index>::type Index; typedef typename LhsXprType::Nested LhsNested; typedef typename RhsXprType::Nested RhsNested; typedef typename remove_reference<LhsNested>::type _LhsNested; typedef typename remove_reference<RhsNested>::type _RhsNested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; enum { Flags = 0, @@ -178,21 +191,22 @@ struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename template<typename BinaryOp, typename LhsXprType, typename RhsXprType> -class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> > +class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors> { public: - typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar; - typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Packet Packet; - typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; - typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType, - typename RhsXprType::CoeffReturnType>::ret CoeffReturnType; - typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType; - typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested; - typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind; - typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index; - - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp()) - : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {} + // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to + // current Scalar/Packet to see if the intent is Inputs or Output. + typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar; + typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Packet Packet; + typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; + typedef Scalar CoeffReturnType; + typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType; + typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested; + typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind; + typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp()) + : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {} EIGEN_DEVICE_FUNC const BinaryOp& functor() const { return m_functor; } @@ -219,7 +233,8 @@ struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > : traits<ThenXprType> { typedef typename traits<ThenXprType>::Scalar Scalar; - typedef typename internal::packet_traits<Scalar>::type Packet; + typedef traits<ThenXprType> XprTraits; + typedef typename packet_traits<Scalar>::type Packet; typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind, typename traits<ElseXprType>::StorageKind>::ret StorageKind; typedef typename promote_index_type<typename traits<ElseXprType>::Index, @@ -227,6 +242,8 @@ struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > typedef typename IfXprType::Nested IfNested; typedef typename ThenXprType::Nested ThenNested; typedef typename ElseXprType::Nested ElseNested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; }; template<typename IfXprType, typename ThenXprType, typename ElseXprType> |