diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h | 87 |
1 files changed, 52 insertions, 35 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h index 6e5503de1..b66b3ec2c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h @@ -17,14 +17,14 @@ namespace Eigen { * * \brief Tensor expression classes. * - * The TensorCwiseNullaryOp class applies a nullary operators to an expression. This - * is typically used to generate constants. + * The TensorCwiseNullaryOp class applies a nullary operators to an expression. + * This is typically used to generate constants. * * The TensorCwiseUnaryOp class represents an expression where a unary operator * (e.g. cwiseSqrt) is applied to an expression. * - * The TensorCwiseBinaryOp class represents an expression where a binary operator - * (e.g. addition) is applied to a lhs and a rhs expression. + * The TensorCwiseBinaryOp class represents an expression where a binary + * operator (e.g. addition) is applied to a lhs and a rhs expression. * */ namespace internal { @@ -33,9 +33,12 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> > : traits<XprType> { typedef typename XprType::Packet Packet; + typedef traits<XprType> XprTraits; typedef typename XprType::Scalar Scalar; typedef typename XprType::Nested XprTypeNested; typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; enum { Flags = 0, @@ -47,7 +50,7 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> > template<typename NullaryOp, typename XprType> -class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType> > +class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors> { public: typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar; @@ -81,12 +84,15 @@ template<typename UnaryOp, typename XprType> struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> > : traits<XprType> { - typedef typename result_of< - UnaryOp(typename XprType::Scalar) - >::type Scalar; + // TODO(phli): Add InputScalar, InputPacket. Check references to + // current Scalar/Packet to see if the intent is Input or Output. + typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar; + typedef traits<XprType> XprTraits; typedef typename internal::packet_traits<Scalar>::type Packet; typedef typename XprType::Nested XprTypeNested; typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; }; template<typename UnaryOp, typename XprType> @@ -106,14 +112,16 @@ struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwise template<typename UnaryOp, typename XprType> -class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType> > +class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors> { public: + // TODO(phli): Add InputScalar, InputPacket. Check references to + // current Scalar/Packet to see if the intent is Input or Output. typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar; typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Packet Packet; typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; - typedef typename XprType::CoeffReturnType CoeffReturnType; - typedef typename XprType::PacketReturnType PacketReturnType; + typedef Scalar CoeffReturnType; + typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType; typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested; typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind; typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index; @@ -139,22 +147,27 @@ namespace internal { template<typename BinaryOp, typename LhsXprType, typename RhsXprType> struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> > { - // Type promotion to handle the case where the types of the lhs and the rhs are different. + // Type promotion to handle the case where the types of the lhs and the rhs + // are different. + // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to + // current Scalar/Packet to see if the intent is Inputs or Output. typedef typename result_of< - BinaryOp( - typename LhsXprType::Scalar, - typename RhsXprType::Scalar - ) - >::type Scalar; + BinaryOp(typename LhsXprType::Scalar, + typename RhsXprType::Scalar)>::type Scalar; + typedef traits<LhsXprType> XprTraits; typedef typename internal::packet_traits<Scalar>::type Packet; - typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, - typename traits<RhsXprType>::StorageKind>::ret StorageKind; - typedef typename promote_index_type<typename traits<LhsXprType>::Index, - typename traits<RhsXprType>::Index>::type Index; + typedef typename promote_storage_type< + typename traits<LhsXprType>::StorageKind, + typename traits<RhsXprType>::StorageKind>::ret StorageKind; + typedef typename promote_index_type< + typename traits<LhsXprType>::Index, + typename traits<RhsXprType>::Index>::type Index; typedef typename LhsXprType::Nested LhsNested; typedef typename RhsXprType::Nested RhsNested; typedef typename remove_reference<LhsNested>::type _LhsNested; typedef typename remove_reference<RhsNested>::type _RhsNested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; enum { Flags = 0, @@ -178,21 +191,22 @@ struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename template<typename BinaryOp, typename LhsXprType, typename RhsXprType> -class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> > +class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors> { public: - typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar; - typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Packet Packet; - typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; - typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType, - typename RhsXprType::CoeffReturnType>::ret CoeffReturnType; - typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType; - typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested; - typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind; - typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index; - - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp()) - : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {} + // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to + // current Scalar/Packet to see if the intent is Inputs or Output. + typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar; + typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Packet Packet; + typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; + typedef Scalar CoeffReturnType; + typedef typename internal::packet_traits<CoeffReturnType>::type PacketReturnType; + typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested; + typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind; + typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp()) + : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {} EIGEN_DEVICE_FUNC const BinaryOp& functor() const { return m_functor; } @@ -219,7 +233,8 @@ struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > : traits<ThenXprType> { typedef typename traits<ThenXprType>::Scalar Scalar; - typedef typename internal::packet_traits<Scalar>::type Packet; + typedef traits<ThenXprType> XprTraits; + typedef typename packet_traits<Scalar>::type Packet; typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind, typename traits<ElseXprType>::StorageKind>::ret StorageKind; typedef typename promote_index_type<typename traits<ElseXprType>::Index, @@ -227,6 +242,8 @@ struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > typedef typename IfXprType::Nested IfNested; typedef typename ThenXprType::Nested ThenNested; typedef typename ElseXprType::Nested ElseNested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; }; template<typename IfXprType, typename ThenXprType, typename ElseXprType> |