aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-04 09:21:48 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-04 09:21:48 -0700
commit6fa6cdd2b988da98cbdd2b1a5fd2fd3b9d56a4b1 (patch)
tree195d19a0318e92323a6148570c7e68831c3c77b2 /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
parent736267cf6b17832a571acf7e34ca07c7f55907ee (diff)
Added support for tensor contractions
Updated expression evaluation mechanism to also compute the size of the tensor result Misc fixes and improvements.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h44
1 files changed, 34 insertions, 10 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index e0c0863b7..ab2513cea 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -21,7 +21,6 @@ namespace Eigen {
*
* TODO: add support for more types of expressions, in particular expressions
* leading to lvalues (slicing, reshaping, etc...)
- * TODO: add support for vectorization
*/
template<typename Derived>
@@ -32,16 +31,19 @@ struct TensorEvaluator
typedef typename Derived::Packet Packet;
typedef typename Derived::Scalar CoeffReturnType;
typedef typename Derived::Packet PacketReturnType;
+ typedef typename Derived::Dimensions Dimensions;
enum {
IsAligned = Derived::IsAligned,
PacketAccess = Derived::PacketAccess,
};
- TensorEvaluator(Derived& m)
- : m_data(const_cast<Scalar*>(m.data()))
+ EIGEN_DEVICE_FUNC TensorEvaluator(Derived& m)
+ : m_data(const_cast<Scalar*>(m.data())), m_dims(m.dimensions())
{ }
+ EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_dims; }
+
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const {
return m_data[index];
}
@@ -64,29 +66,34 @@ struct TensorEvaluator
protected:
Scalar* m_data;
+ Dimensions m_dims;
};
// -------------------- CwiseNullaryOp --------------------
-template<typename NullaryOp, typename PlainObjectType>
-struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
+template<typename NullaryOp, typename ArgType>
+struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType> >
{
- typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> XprType;
+ typedef TensorCwiseNullaryOp<NullaryOp, ArgType> XprType;
enum {
IsAligned = true,
PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess,
};
+ EIGEN_DEVICE_FUNC
TensorEvaluator(const XprType& op)
- : m_functor(op.functor())
+ : m_functor(op.functor()), m_argImpl(op.nestedExpression())
{ }
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
+ typedef typename TensorEvaluator<ArgType>::Dimensions Dimensions;
+
+ EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
{
@@ -101,6 +108,7 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
private:
const NullaryOp m_functor;
+ TensorEvaluator<ArgType> m_argImpl;
};
@@ -117,7 +125,7 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType> >
PacketAccess = TensorEvaluator<ArgType>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess,
};
- TensorEvaluator(const XprType& op)
+ EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op)
: m_functor(op.functor()),
m_argImpl(op.nestedExpression())
{ }
@@ -125,6 +133,9 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType> >
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
+ typedef typename TensorEvaluator<ArgType>::Dimensions Dimensions;
+
+ EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
{
@@ -156,7 +167,7 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
internal::functor_traits<BinaryOp>::PacketAccess,
};
- TensorEvaluator(const XprType& op)
+ EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op)
: m_functor(op.functor()),
m_leftImpl(op.lhsExpression()),
m_rightImpl(op.rhsExpression())
@@ -165,6 +176,13 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
+ typedef typename TensorEvaluator<LeftArgType>::Dimensions Dimensions;
+
+ EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
+ {
+ // TODO: use right impl instead if right impl dimensions are known at compile time.
+ return m_leftImpl.dimensions();
+ }
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
{
@@ -196,7 +214,7 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
TensorEvaluator<IfArgType>::PacketAccess*/,
};
- TensorEvaluator(const XprType& op)
+ EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op)
: m_condImpl(op.ifExpression()),
m_thenImpl(op.thenExpression()),
m_elseImpl(op.elseExpression())
@@ -205,7 +223,13 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
+ typedef typename TensorEvaluator<IfArgType>::Dimensions Dimensions;
+ EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
+ {
+ // TODO: use then or else impl instead if they happen to be known at compile time.
+ return m_condImpl.dimensions();
+ }
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
{
return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);