From 6fa6cdd2b988da98cbdd2b1a5fd2fd3b9d56a4b1 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 4 Jun 2014 09:21:48 -0700 Subject: Added support for tensor contractions Updated expression evaluation mechanism to also compute the size of the tensor result Misc fixes and improvements. --- unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h | 36 ++++++++++++++----------- 1 file changed, 20 insertions(+), 16 deletions(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h index 94cfae05c..60908ee94 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h @@ -28,13 +28,13 @@ namespace Eigen { * */ namespace internal { -template -struct traits > - : traits +template +struct traits > + : traits { - typedef typename PlainObjectType::Packet Packet; - typedef typename PlainObjectType::Scalar Scalar; - typedef typename PlainObjectType::Nested XprTypeNested; + typedef typename XprType::Packet Packet; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::Nested XprTypeNested; typedef typename remove_reference::type _XprTypeNested; }; @@ -42,27 +42,31 @@ struct traits > -template -class TensorCwiseNullaryOp : public TensorBase > +template +class TensorCwiseNullaryOp : public TensorBase > { public: typedef typename Eigen::internal::traits::Scalar Scalar; typedef typename Eigen::internal::traits::Packet Packet; typedef typename Eigen::NumTraits::Real RealScalar; - typedef typename PlainObjectType::CoeffReturnType CoeffReturnType; - typedef typename PlainObjectType::PacketReturnType PacketReturnType; - typedef TensorCwiseNullaryOp Nested; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + typedef TensorCwiseNullaryOp Nested; typedef typename Eigen::internal::traits::StorageKind StorageKind; typedef typename Eigen::internal::traits::Index Index; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const NullaryOp& func = NullaryOp()) - : m_functor(func) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp()) + : m_xpr(xpr), m_functor(func) {} + + EIGEN_DEVICE_FUNC + const typename internal::remove_all::type& + nestedExpression() const { return m_xpr; } EIGEN_DEVICE_FUNC const NullaryOp& functor() const { return m_functor; } protected: - // todo: add tensor dimension to be able to do some sanity checks + typename XprType::Nested m_xpr; const NullaryOp m_functor; }; @@ -71,7 +75,7 @@ class TensorCwiseNullaryOp : public TensorBase struct traits > - : traits + : traits { typedef typename result_of< UnaryOp(typename XprType::Scalar) @@ -207,7 +211,7 @@ class TensorCwiseBinaryOp : public TensorBase struct traits > - : traits + : traits { typedef typename traits::Scalar Scalar; typedef typename internal::packet_traits::type Packet; -- cgit v1.2.3