aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-09 09:45:30 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-06-09 09:45:30 -0700
commita669052f12d6d71ba815764d6419726d64fef675 (patch)
treea087876a5b341c0c3f2380d3530579cfbeb25c1c
parent36a2b2e9dc9368356b3f327a1fb00616397c1e0e (diff)
Improved support for rvalues in tensor expressions.
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h58
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h4
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h4
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h8
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h6
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h5
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h6
7 files changed, 71 insertions, 20 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index 932e5c82d..e447a5d40 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -22,7 +22,7 @@ namespace Eigen {
*/
template<typename Derived>
-class TensorBase
+class TensorBase<Derived, ReadOnlyAccessors>
{
public:
typedef typename internal::traits<Derived>::Scalar Scalar;
@@ -30,19 +30,6 @@ class TensorBase
typedef Scalar CoeffReturnType;
typedef typename internal::packet_traits<Scalar>::type PacketReturnType;
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Derived& setZero() {
- return setConstant(Scalar(0));
- }
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) {
- return derived() = constant(val);
- }
- EIGEN_DEVICE_FUNC
- EIGEN_STRONG_INLINE Derived& setRandom() {
- return derived() = random();
- }
-
// Nullary operators
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived>
@@ -224,14 +211,53 @@ class TensorBase
return TensorReshapingOp<const Derived, const NewDimensions>(derived(), newDimensions);
}
+ protected:
+ template <typename OtherDerived, int AccessLevel> friend class TensorBase;
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast<const Derived*>(this); }
+};
+
+
+template<typename Derived>
+class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyAccessors> {
+ public:
+ typedef typename internal::traits<Derived>::Scalar Scalar;
+ typedef typename internal::traits<Derived>::Index Index;
+ typedef Scalar CoeffReturnType;
+ typedef typename internal::packet_traits<Scalar>::type PacketReturnType;
+
+ template <typename OtherDerived, int AccessLevel> friend class TensorBase;
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Derived& setZero() {
+ return setConstant(Scalar(0));
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) {
+ return derived() = this->constant(val);
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Derived& setRandom() {
+ return derived() = this->random();
+ }
+
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Derived& operator+=(const OtherDerived& other) {
+ return derived() = TensorCwiseBinaryOp<internal::scalar_sum_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ }
+
+ template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ Derived& operator-=(const OtherDerived& other) {
+ return derived() = TensorCwiseBinaryOp<internal::scalar_difference_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
+ }
+
// Select the device on which to evaluate the expression.
template <typename DeviceType>
TensorDevice<Derived, DeviceType> device(const DeviceType& device) {
return TensorDevice<Derived, DeviceType>(device, derived());
}
- protected:
- template <typename OtherDerived> friend class TensorBase;
+ protected:
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& derived() { return *static_cast<Derived*>(this); }
EIGEN_DEVICE_FUNC
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
index d424df36e..d371eb76d 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h
@@ -35,6 +35,10 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
typedef typename RhsXprType::Nested RhsNested;
typedef typename remove_reference<LhsNested>::type _LhsNested;
typedef typename remove_reference<RhsNested>::type _RhsNested;
+
+ enum {
+ Flags = 0,
+ };
};
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h
index ca2e0e562..501e9a522 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h
@@ -35,6 +35,10 @@ struct traits<TensorConvolutionOp<Dimensions, InputXprType, KernelXprType> >
typedef typename KernelXprType::Nested RhsNested;
typedef typename remove_reference<LhsNested>::type _LhsNested;
typedef typename remove_reference<RhsNested>::type _RhsNested;
+
+ enum {
+ Flags = 0,
+ };
};
template<typename Dimensions, typename InputXprType, typename KernelXprType>
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
index 60908ee94..de66da13f 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h
@@ -36,6 +36,10 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
typedef typename XprType::Scalar Scalar;
typedef typename XprType::Nested XprTypeNested;
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
+
+ enum {
+ Flags = 0,
+ };
};
} // end namespace internal
@@ -153,6 +157,10 @@ struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
typedef typename RhsXprType::Nested RhsNested;
typedef typename remove_reference<LhsNested>::type _LhsNested;
typedef typename remove_reference<RhsNested>::type _RhsNested;
+
+ enum {
+ Flags = 0,
+ };
};
template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
index b8833362c..1fb90478f 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h
@@ -15,7 +15,7 @@ namespace Eigen {
template<typename Scalar_, std::size_t NumIndices_, int Options_ = 0> class Tensor;
template<typename Scalar_, typename Dimensions, int Options_ = 0> class TensorFixedSize;
template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap;
-template<typename Derived> class TensorBase;
+template<typename Derived, int AccessLevel = internal::accessors_level<Derived>::value> class TensorBase;
template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
@@ -29,6 +29,10 @@ template<typename ExpressionType, typename DeviceType> class TensorDevice;
// Move to internal?
template<typename Derived> struct TensorEvaluator;
+namespace internal {
+template<typename Derived, typename OtherDerived, bool Vectorizable> struct TensorAssign;
+} // end namespace internal
+
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_FORWARD_DECLARATIONS_H
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
index 3e089fe1e..7d5f9271e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
@@ -21,7 +21,7 @@ namespace Eigen {
*/
namespace internal {
template<typename XprType, typename NewDimensions>
-struct traits<TensorReshapingOp<XprType, NewDimensions> >
+struct traits<TensorReshapingOp<XprType, NewDimensions> > : public traits<XprType>
{
// Type promotion to handle the case where the types of the lhs and the rhs are different.
typedef typename XprType::Scalar Scalar;
@@ -81,6 +81,7 @@ template<typename ArgType, typename NewDimensions>
struct TensorEvaluator<const TensorReshapingOp<ArgType, NewDimensions> >
{
typedef TensorReshapingOp<ArgType, NewDimensions> XprType;
+ typedef NewDimensions Dimensions;
enum {
IsAligned = TensorEvaluator<ArgType>::IsAligned,
@@ -95,7 +96,7 @@ struct TensorEvaluator<const TensorReshapingOp<ArgType, NewDimensions> >
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
- const NewDimensions& dimensions() const { return m_dimensions; }
+ const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
{
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h
index 2de698a57..40f805741 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h
@@ -52,7 +52,7 @@ struct traits<Tensor<Scalar_, NumIndices_, Options_> >
typedef DenseIndex Index;
enum {
Options = Options_,
- Flags = compute_tensor_flags<Scalar_, Options_>::ret,
+ Flags = compute_tensor_flags<Scalar_, Options_>::ret | LvalueBit,
};
};
@@ -63,6 +63,10 @@ struct traits<TensorFixedSize<Scalar_, Dimensions, Options_> >
typedef Scalar_ Scalar;
typedef Dense StorageKind;
typedef DenseIndex Index;
+ enum {
+ Options = Options_,
+ Flags = compute_tensor_flags<Scalar_, Options_>::ret | LvalueBit,
+ };
};