diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-05-22 16:22:35 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-05-22 16:22:35 -0700 |
commit | 736267cf6b17832a571acf7e34ca07c7f55907ee (patch) | |
tree | 894d0bfd7455b670117a252afad0157ba01a766b /unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h | |
parent | 7402fea0a8e63e3ea248257047c584afee8f8bde (diff) |
Added support for additional tensor operations:
* comparison (<, <=, ==, !=, ...)
* selection
* nullary ops such as random or constant generation
* misc unary ops such as log(), exp(), or a user defined unaryExpr()
Cleaned up the code a little.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h index e32077f6e..94cfae05c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h @@ -17,6 +17,9 @@ namespace Eigen { * * \brief Tensor expression classes. * + * The TensorCwiseNullaryOp class applies a nullary operators to an expression. This + * is typically used to generate constants. + * * The TensorCwiseUnaryOp class represents an expression where a unary operator * (e.g. cwiseSqrt) is applied to an expression. * @@ -24,6 +27,46 @@ namespace Eigen { * (e.g. addition) is applied to a lhs and a rhs expression. * */ +namespace internal { +template<typename NullaryOp, typename PlainObjectType> +struct traits<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> > + : traits<PlainObjectType> +{ + typedef typename PlainObjectType::Packet Packet; + typedef typename PlainObjectType::Scalar Scalar; + typedef typename PlainObjectType::Nested XprTypeNested; + typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; +}; + +} // end namespace internal + + + +template<typename NullaryOp, typename PlainObjectType> +class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> > +{ + public: + typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar; + typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Packet Packet; + typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; + typedef typename PlainObjectType::CoeffReturnType CoeffReturnType; + typedef typename PlainObjectType::PacketReturnType PacketReturnType; + typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> Nested; + typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind; + typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const NullaryOp& func = NullaryOp()) + : m_functor(func) {} + + EIGEN_DEVICE_FUNC + const NullaryOp& functor() const { return m_functor; } + + protected: + // todo: add tensor dimension to be able to do some sanity checks + const NullaryOp m_functor; +}; + + namespace internal { template<typename UnaryOp, typename XprType> @@ -160,6 +203,72 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX const BinaryOp m_functor; }; + +namespace internal { +template<typename IfXprType, typename ThenXprType, typename ElseXprType> +struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > + : traits<ThenXprType> +{ + typedef typename traits<ThenXprType>::Scalar Scalar; + typedef typename internal::packet_traits<Scalar>::type Packet; + typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind, + typename traits<ElseXprType>::StorageKind>::ret StorageKind; + typedef typename promote_index_type<typename traits<ElseXprType>::Index, + typename traits<ThenXprType>::Index>::type Index; + typedef typename IfXprType::Nested IfNested; + typedef typename ThenXprType::Nested ThenNested; + typedef typename ElseXprType::Nested ElseNested; +}; + +template<typename IfXprType, typename ThenXprType, typename ElseXprType> +struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense> +{ + typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type; +}; + +template<typename IfXprType, typename ThenXprType, typename ElseXprType> +struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type> +{ + typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type; +}; + +} // end namespace internal + + +template<typename IfXprType, typename ThenXprType, typename ElseXprType> +class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > +{ + public: + typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar; + typedef typename Eigen::internal::traits<TensorSelectOp>::Packet Packet; + typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; + typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType, + typename ElseXprType::CoeffReturnType>::ret CoeffReturnType; + typedef typename internal::promote_storage_type<typename ThenXprType::PacketReturnType, + typename ElseXprType::PacketReturnType>::ret PacketReturnType; + typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested; + typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind; + typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index; + + TensorSelectOp(const IfXprType& a_condition, + const ThenXprType& a_then, + const ElseXprType& a_else) + : m_condition(a_condition), m_then(a_then), m_else(a_else) + { } + + const IfXprType& ifExpression() const { return m_condition; } + + const ThenXprType& thenExpression() const { return m_then; } + + const ElseXprType& elseExpression() const { return m_else; } + + protected: + typename IfXprType::Nested m_condition; + typename ThenXprType::Nested m_then; + typename ElseXprType::Nested m_else; +}; + + } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H |