aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-04 09:21:48 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-04 09:21:48 -0700
commit6fa6cdd2b988da98cbdd2b1a5fd2fd3b9d56a4b1 (patch)
tree195d19a0318e92323a6148570c7e68831c3c77b2 /unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
parent736267cf6b17832a571acf7e34ca07c7f55907ee (diff)
Added support for tensor contractions
Updated expression evaluation mechanism to also compute the size of the tensor result Misc fixes and improvements.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h36
1 files changed, 20 insertions, 16 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
index 94cfae05c..60908ee94 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
@@ -28,13 +28,13 @@ namespace Eigen {
*
*/
namespace internal {
-template<typename NullaryOp, typename PlainObjectType>
-struct traits<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
- : traits<PlainObjectType>
+template<typename NullaryOp, typename XprType>
+struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
+ : traits<XprType>
{
- typedef typename PlainObjectType::Packet Packet;
- typedef typename PlainObjectType::Scalar Scalar;
- typedef typename PlainObjectType::Nested XprTypeNested;
+ typedef typename XprType::Packet Packet;
+ typedef typename XprType::Scalar Scalar;
+ typedef typename XprType::Nested XprTypeNested;
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
};
@@ -42,27 +42,31 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
-template<typename NullaryOp, typename PlainObjectType>
-class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
+template<typename NullaryOp, typename XprType>
+class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType> >
{
public:
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Packet Packet;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
- typedef typename PlainObjectType::CoeffReturnType CoeffReturnType;
- typedef typename PlainObjectType::PacketReturnType PacketReturnType;
- typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> Nested;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+ typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested;
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const NullaryOp& func = NullaryOp())
- : m_functor(func) {}
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
+ : m_xpr(xpr), m_functor(func) {}
+
+ EIGEN_DEVICE_FUNC
+ const typename internal::remove_all<typename XprType::Nested>::type&
+ nestedExpression() const { return m_xpr; }
EIGEN_DEVICE_FUNC
const NullaryOp& functor() const { return m_functor; }
protected:
- // todo: add tensor dimension to be able to do some sanity checks
+ typename XprType::Nested m_xpr;
const NullaryOp m_functor;
};
@@ -71,7 +75,7 @@ class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, P
namespace internal {
template<typename UnaryOp, typename XprType>
struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
- : traits<XprType>
+ : traits<XprType>
{
typedef typename result_of<
UnaryOp(typename XprType::Scalar)
@@ -207,7 +211,7 @@ class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsX
namespace internal {
template<typename IfXprType, typename ThenXprType, typename ElseXprType>
struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
- : traits<ThenXprType>
+ : traits<ThenXprType>
{
typedef typename traits<ThenXprType>::Scalar Scalar;
typedef typename internal::packet_traits<Scalar>::type Packet;