diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-06-04 09:21:48 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-06-04 09:21:48 -0700 |
commit | 6fa6cdd2b988da98cbdd2b1a5fd2fd3b9d56a4b1 (patch) | |
tree | 195d19a0318e92323a6148570c7e68831c3c77b2 /unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h | |
parent | 736267cf6b17832a571acf7e34ca07c7f55907ee (diff) |
Added support for tensor contractions
Updated expression evaluation mechanism to also compute the size of the tensor result
Misc fixes and improvements.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h | 36 |
1 files changed, 20 insertions, 16 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h index 94cfae05c..60908ee94 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h @@ -28,13 +28,13 @@ namespace Eigen { * */ namespace internal { -template<typename NullaryOp, typename PlainObjectType> -struct traits<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> > - : traits<PlainObjectType> +template<typename NullaryOp, typename XprType> +struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> > + : traits<XprType> { - typedef typename PlainObjectType::Packet Packet; - typedef typename PlainObjectType::Scalar Scalar; - typedef typename PlainObjectType::Nested XprTypeNested; + typedef typename XprType::Packet Packet; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::Nested XprTypeNested; typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; }; @@ -42,27 +42,31 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> > -template<typename NullaryOp, typename PlainObjectType> -class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> > +template<typename NullaryOp, typename XprType> +class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType> > { public: typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar; typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Packet Packet; typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; - typedef typename PlainObjectType::CoeffReturnType CoeffReturnType; - typedef typename PlainObjectType::PacketReturnType PacketReturnType; - typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> Nested; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested; typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind; typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const NullaryOp& func = NullaryOp()) - : m_functor(func) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp()) + : m_xpr(xpr), m_functor(func) {} + + EIGEN_DEVICE_FUNC + const typename internal::remove_all<typename XprType::Nested>::type& + nestedExpression() const { return m_xpr; } EIGEN_DEVICE_FUNC const NullaryOp& functor() const { return m_functor; } protected: - // todo: add tensor dimension to be able to do some sanity checks + typename XprType::Nested m_xpr; const NullaryOp m_functor; }; @@ -71,7 +75,7 @@ class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, P namespace internal { template<typename UnaryOp, typename XprType> struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> > - : traits<XprType> + : traits<XprType> { typedef typename result_of< UnaryOp(typename XprType::Scalar) @@ -207,7 +211,7 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX namespace internal { template<typename IfXprType, typename ThenXprType, typename ElseXprType> struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> > - : traits<ThenXprType> + : traits<ThenXprType> { typedef typename traits<ThenXprType>::Scalar Scalar; typedef typename internal::packet_traits<Scalar>::type Packet; |