diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-06-10 09:14:44 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-06-10 09:14:44 -0700 |
commit | 925fb6b93710b95082ba44d30405289dff3707eb (patch) | |
tree | 004ce9af64e2ffdb9148a286938a2710ee8c2607 /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | |
parent | a77458a8ff2a83e716add62253eb50ef64980b21 (diff) |
TensorEval are now typed on the device: this will make it possible to use partial template specialization to optimize the strategy of each evaluator for each device type.
Started work on partial evaluations.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | 76 |
1 files changed, 38 insertions, 38 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index ab2513cea..80fe06957 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -23,7 +23,7 @@ namespace Eigen { * leading to lvalues (slicing, reshaping, etc...) */ -template<typename Derived> +template<typename Derived, typename Device> struct TensorEvaluator { typedef typename Derived::Index Index; @@ -38,7 +38,7 @@ struct TensorEvaluator PacketAccess = Derived::PacketAccess, }; - EIGEN_DEVICE_FUNC TensorEvaluator(Derived& m) + EIGEN_DEVICE_FUNC TensorEvaluator(Derived& m, const Device&) : m_data(const_cast<Scalar*>(m.data())), m_dims(m.dimensions()) { } @@ -73,8 +73,8 @@ struct TensorEvaluator // -------------------- CwiseNullaryOp -------------------- -template<typename NullaryOp, typename ArgType> -struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType> > +template<typename NullaryOp, typename ArgType, typename Device> +struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device> { typedef TensorCwiseNullaryOp<NullaryOp, ArgType> XprType; @@ -84,14 +84,14 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType> > }; EIGEN_DEVICE_FUNC - TensorEvaluator(const XprType& op) - : m_functor(op.functor()), m_argImpl(op.nestedExpression()) + TensorEvaluator(const XprType& op, const Device& device) + : m_functor(op.functor()), m_argImpl(op.nestedExpression(), device) { } typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; - typedef typename TensorEvaluator<ArgType>::Dimensions Dimensions; + typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions; EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); } @@ -108,32 +108,32 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType> > private: const NullaryOp m_functor; - TensorEvaluator<ArgType> m_argImpl; + TensorEvaluator<ArgType, Device> m_argImpl; }; // -------------------- CwiseUnaryOp -------------------- -template<typename UnaryOp, typename ArgType> -struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType> > +template<typename UnaryOp, typename ArgType, typename Device> +struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device> { typedef TensorCwiseUnaryOp<UnaryOp, ArgType> XprType; enum { - IsAligned = TensorEvaluator<ArgType>::IsAligned, - PacketAccess = TensorEvaluator<ArgType>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess, + IsAligned = TensorEvaluator<ArgType, Device>::IsAligned, + PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess, }; - EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op) + EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : m_functor(op.functor()), - m_argImpl(op.nestedExpression()) + m_argImpl(op.nestedExpression(), device) { } typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; - typedef typename TensorEvaluator<ArgType>::Dimensions Dimensions; + typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions; EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); } @@ -150,33 +150,33 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType> > private: const UnaryOp m_functor; - TensorEvaluator<ArgType> m_argImpl; + TensorEvaluator<ArgType, Device> m_argImpl; }; // -------------------- CwiseBinaryOp -------------------- -template<typename BinaryOp, typename LeftArgType, typename RightArgType> -struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType> > +template<typename BinaryOp, typename LeftArgType, typename RightArgType, typename Device> +struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType>, Device> { typedef TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType> XprType; enum { - IsAligned = TensorEvaluator<LeftArgType>::IsAligned & TensorEvaluator<RightArgType>::IsAligned, - PacketAccess = TensorEvaluator<LeftArgType>::PacketAccess & TensorEvaluator<RightArgType>::PacketAccess & + IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned, + PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess & internal::functor_traits<BinaryOp>::PacketAccess, }; - EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op) + EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : m_functor(op.functor()), - m_leftImpl(op.lhsExpression()), - m_rightImpl(op.rhsExpression()) + m_leftImpl(op.lhsExpression(), device), + m_rightImpl(op.rhsExpression(), device) { } typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; - typedef typename TensorEvaluator<LeftArgType>::Dimensions Dimensions; + typedef typename TensorEvaluator<LeftArgType, Device>::Dimensions Dimensions; EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { @@ -196,34 +196,34 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg private: const BinaryOp m_functor; - TensorEvaluator<LeftArgType> m_leftImpl; - TensorEvaluator<RightArgType> m_rightImpl; + TensorEvaluator<LeftArgType, Device> m_leftImpl; + TensorEvaluator<RightArgType, Device> m_rightImpl; }; // -------------------- SelectOp -------------------- -template<typename IfArgType, typename ThenArgType, typename ElseArgType> -struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> > +template<typename IfArgType, typename ThenArgType, typename ElseArgType, typename Device> +struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>, Device> { typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType; enum { - IsAligned = TensorEvaluator<ThenArgType>::IsAligned & TensorEvaluator<ElseArgType>::IsAligned, - PacketAccess = TensorEvaluator<ThenArgType>::PacketAccess & TensorEvaluator<ElseArgType>::PacketAccess/* & + IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned, + PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess/* & TensorEvaluator<IfArgType>::PacketAccess*/, }; - EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op) - : m_condImpl(op.ifExpression()), - m_thenImpl(op.thenExpression()), - m_elseImpl(op.elseExpression()) + EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) + : m_condImpl(op.ifExpression(), device), + m_thenImpl(op.thenExpression(), device), + m_elseImpl(op.elseExpression(), device) { } typedef typename XprType::Index Index; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; - typedef typename TensorEvaluator<IfArgType>::Dimensions Dimensions; + typedef typename TensorEvaluator<IfArgType, Device>::Dimensions Dimensions; EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { @@ -248,9 +248,9 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> } private: - TensorEvaluator<IfArgType> m_condImpl; - TensorEvaluator<ThenArgType> m_thenImpl; - TensorEvaluator<ElseArgType> m_elseImpl; + TensorEvaluator<IfArgType, Device> m_condImpl; + TensorEvaluator<ThenArgType, Device> m_thenImpl; + TensorEvaluator<ElseArgType, Device> m_elseImpl; }; |