diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h | 12 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | 55 |
2 files changed, 50 insertions, 17 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h index a2a925775..3bfe80c9e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h @@ -102,6 +102,7 @@ struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device> { } typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; typedef typename TensorEvaluator<RightArgType, Device>::Dimensions Dimensions; @@ -112,9 +113,14 @@ struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, Device> return m_rightImpl.dimensions(); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { - m_leftImpl.evalSubExprsIfNeeded(); - m_rightImpl.evalSubExprsIfNeeded(); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { + eigen_assert(internal::dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions())); + m_leftImpl.evalSubExprsIfNeeded(NULL); + // If the lhs provides raw access to its storage area (i.e. if m_leftImpl.data() returns a non + // null value), attempt to evaluate the rhs expression in place. Returns true iff in place + // evaluation isn't supported and the caller still needs to manually assign the values generated + // by the rhs to the lhs. + return m_rightImpl.evalSubExprsIfNeeded(m_leftImpl.data()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_leftImpl.cleanup(); diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index ac9829ce9..0f969036c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -39,13 +39,20 @@ struct TensorEvaluator PacketAccess = Derived::PacketAccess, }; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device&) - : m_data(const_cast<Scalar*>(m.data())), m_dims(m.dimensions()) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device) + : m_data(const_cast<Scalar*>(m.data())), m_dims(m.dimensions()), m_device(device) { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* dest) { + if (dest) { + m_device.memcpy((void*)dest, m_data, sizeof(Scalar) * m_dims.TotalSize()); + return false; + } + return true; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { @@ -70,9 +77,12 @@ struct TensorEvaluator return internal::pstoret<Scalar, Packet, StoreMode>(m_data + index, x); } + Scalar* data() const { return m_data; } + protected: Scalar* m_data; Dimensions m_dims; + const Device& m_device; }; @@ -98,7 +108,7 @@ struct TensorEvaluator<const Derived, Device> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { return true; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { @@ -112,6 +122,8 @@ struct TensorEvaluator<const Derived, Device> return internal::ploadt<Packet, LoadMode>(m_data + index); } + const Scalar* data() const { return m_data; } + protected: const Scalar* m_data; Dimensions m_dims; @@ -138,13 +150,14 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device> { } typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions; EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { return true; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const @@ -158,6 +171,8 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device> return m_functor.packetOp(index); } + Scalar* data() const { return NULL; } + private: const NullaryOp m_functor; TensorEvaluator<ArgType, Device> m_argImpl; @@ -183,14 +198,16 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device> { } typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions; EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { - m_argImpl.evalSubExprsIfNeeded(); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { + m_argImpl.evalSubExprsIfNeeded(NULL); + return true; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_argImpl.cleanup(); @@ -207,6 +224,8 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device> return m_functor.packetOp(m_argImpl.template packet<LoadMode>(index)); } + Scalar* data() const { return NULL; } + private: const UnaryOp m_functor; TensorEvaluator<ArgType, Device> m_argImpl; @@ -233,6 +252,7 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg { } typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; typedef typename TensorEvaluator<LeftArgType, Device>::Dimensions Dimensions; @@ -243,9 +263,10 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg return m_leftImpl.dimensions(); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { - m_leftImpl.evalSubExprsIfNeeded(); - m_rightImpl.evalSubExprsIfNeeded(); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { + m_leftImpl.evalSubExprsIfNeeded(NULL); + m_rightImpl.evalSubExprsIfNeeded(NULL); + return true; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_leftImpl.cleanup(); @@ -262,6 +283,8 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg return m_functor.packetOp(m_leftImpl.template packet<LoadMode>(index), m_rightImpl.template packet<LoadMode>(index)); } + Scalar* data() const { return NULL; } + private: const BinaryOp m_functor; TensorEvaluator<LeftArgType, Device> m_leftImpl; @@ -289,6 +312,7 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> { } typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; typedef typename TensorEvaluator<IfArgType, Device>::Dimensions Dimensions; @@ -299,10 +323,11 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> return m_condImpl.dimensions(); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { - m_condImpl.evalSubExprsIfNeeded(); - m_thenImpl.evalSubExprsIfNeeded(); - m_elseImpl.evalSubExprsIfNeeded(); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { + m_condImpl.evalSubExprsIfNeeded(NULL); + m_thenImpl.evalSubExprsIfNeeded(NULL); + m_elseImpl.evalSubExprsIfNeeded(NULL); + return true; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_condImpl.cleanup(); @@ -327,6 +352,8 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType> m_elseImpl.template packet<LoadMode>(index)); } + Scalar* data() const { return NULL; } + private: TensorEvaluator<IfArgType, Device> m_condImpl; TensorEvaluator<ThenArgType, Device> m_thenImpl; |