aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-10 09:14:44 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-10 09:14:44 -0700
commit925fb6b93710b95082ba44d30405289dff3707eb (patch)
tree004ce9af64e2ffdb9148a286938a2710ee8c2607 /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
parenta77458a8ff2a83e716add62253eb50ef64980b21 (diff)
TensorEval are now typed on the device: this will make it possible to use partial template specialization to optimize the strategy of each evaluator for each device type.
Started work on partial evaluations.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h76
1 files changed, 38 insertions, 38 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index ab2513cea..80fe06957 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -23,7 +23,7 @@ namespace Eigen {
* leading to lvalues (slicing, reshaping, etc...)
*/
-template<typename Derived>
+template<typename Derived, typename Device>
struct TensorEvaluator
{
typedef typename Derived::Index Index;
@@ -38,7 +38,7 @@ struct TensorEvaluator
PacketAccess = Derived::PacketAccess,
};
- EIGEN_DEVICE_FUNC TensorEvaluator(Derived& m)
+ EIGEN_DEVICE_FUNC TensorEvaluator(Derived& m, const Device&)
: m_data(const_cast<Scalar*>(m.data())), m_dims(m.dimensions())
{ }
@@ -73,8 +73,8 @@ struct TensorEvaluator
// -------------------- CwiseNullaryOp --------------------
-template<typename NullaryOp, typename ArgType>
-struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType> >
+template<typename NullaryOp, typename ArgType, typename Device>
+struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device>
{
typedef TensorCwiseNullaryOp<NullaryOp, ArgType> XprType;
@@ -84,14 +84,14 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType> >
};
EIGEN_DEVICE_FUNC
- TensorEvaluator(const XprType& op)
- : m_functor(op.functor()), m_argImpl(op.nestedExpression())
+ TensorEvaluator(const XprType& op, const Device& device)
+ : m_functor(op.functor()), m_argImpl(op.nestedExpression(), device)
{ }
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
- typedef typename TensorEvaluator<ArgType>::Dimensions Dimensions;
+ typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
@@ -108,32 +108,32 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType> >
private:
const NullaryOp m_functor;
- TensorEvaluator<ArgType> m_argImpl;
+ TensorEvaluator<ArgType, Device> m_argImpl;
};
// -------------------- CwiseUnaryOp --------------------
-template<typename UnaryOp, typename ArgType>
-struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType> >
+template<typename UnaryOp, typename ArgType, typename Device>
+struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device>
{
typedef TensorCwiseUnaryOp<UnaryOp, ArgType> XprType;
enum {
- IsAligned = TensorEvaluator<ArgType>::IsAligned,
- PacketAccess = TensorEvaluator<ArgType>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess,
+ IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
+ PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess,
};
- EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op)
+ EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
: m_functor(op.functor()),
- m_argImpl(op.nestedExpression())
+ m_argImpl(op.nestedExpression(), device)
{ }
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
- typedef typename TensorEvaluator<ArgType>::Dimensions Dimensions;
+ typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
@@ -150,33 +150,33 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType> >
private:
const UnaryOp m_functor;
- TensorEvaluator<ArgType> m_argImpl;
+ TensorEvaluator<ArgType, Device> m_argImpl;
};
// -------------------- CwiseBinaryOp --------------------
-template<typename BinaryOp, typename LeftArgType, typename RightArgType>
-struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType> >
+template<typename BinaryOp, typename LeftArgType, typename RightArgType, typename Device>
+struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType>, Device>
{
typedef TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType> XprType;
enum {
- IsAligned = TensorEvaluator<LeftArgType>::IsAligned & TensorEvaluator<RightArgType>::IsAligned,
- PacketAccess = TensorEvaluator<LeftArgType>::PacketAccess & TensorEvaluator<RightArgType>::PacketAccess &
+ IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned,
+ PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess &
internal::functor_traits<BinaryOp>::PacketAccess,
};
- EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op)
+ EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
: m_functor(op.functor()),
- m_leftImpl(op.lhsExpression()),
- m_rightImpl(op.rhsExpression())
+ m_leftImpl(op.lhsExpression(), device),
+ m_rightImpl(op.rhsExpression(), device)
{ }
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
- typedef typename TensorEvaluator<LeftArgType>::Dimensions Dimensions;
+ typedef typename TensorEvaluator<LeftArgType, Device>::Dimensions Dimensions;
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
{
@@ -196,34 +196,34 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
private:
const BinaryOp m_functor;
- TensorEvaluator<LeftArgType> m_leftImpl;
- TensorEvaluator<RightArgType> m_rightImpl;
+ TensorEvaluator<LeftArgType, Device> m_leftImpl;
+ TensorEvaluator<RightArgType, Device> m_rightImpl;
};
// -------------------- SelectOp --------------------
-template<typename IfArgType, typename ThenArgType, typename ElseArgType>
-struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> >
+template<typename IfArgType, typename ThenArgType, typename ElseArgType, typename Device>
+struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>, Device>
{
typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
enum {
- IsAligned = TensorEvaluator<ThenArgType>::IsAligned & TensorEvaluator<ElseArgType>::IsAligned,
- PacketAccess = TensorEvaluator<ThenArgType>::PacketAccess & TensorEvaluator<ElseArgType>::PacketAccess/* &
+ IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned,
+ PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess/* &
TensorEvaluator<IfArgType>::PacketAccess*/,
};
- EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op)
- : m_condImpl(op.ifExpression()),
- m_thenImpl(op.thenExpression()),
- m_elseImpl(op.elseExpression())
+ EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
+ : m_condImpl(op.ifExpression(), device),
+ m_thenImpl(op.thenExpression(), device),
+ m_elseImpl(op.elseExpression(), device)
{ }
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
- typedef typename TensorEvaluator<IfArgType>::Dimensions Dimensions;
+ typedef typename TensorEvaluator<IfArgType, Device>::Dimensions Dimensions;
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
{
@@ -248,9 +248,9 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
}
private:
- TensorEvaluator<IfArgType> m_condImpl;
- TensorEvaluator<ThenArgType> m_thenImpl;
- TensorEvaluator<ElseArgType> m_elseImpl;
+ TensorEvaluator<IfArgType, Device> m_condImpl;
+ TensorEvaluator<ThenArgType, Device> m_thenImpl;
+ TensorEvaluator<ElseArgType, Device> m_elseImpl;
};