diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | 14 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h | 2 |
2 files changed, 12 insertions, 4 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index 9198c17ef..a38af84d5 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -132,13 +132,20 @@ struct TensorEvaluator<const Derived, Device> CoordAccess = NumCoords > 0, }; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device&) - : m_data(m.data()), m_dims(m.dimensions()) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device) + : m_data(m.data()), m_dims(m.dimensions()), m_device(device) { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { return true; } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { + if (internal::is_arithmetic<typename internal::remove_const<Scalar>::type>::value && data) { + m_device.memcpy((void*)data, m_data, m_dims.TotalSize() * sizeof(Scalar)); + return false; + } + return true; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { @@ -172,6 +179,7 @@ struct TensorEvaluator<const Derived, Device> protected: const Scalar* m_data; Dimensions m_dims; + const Device& m_device; }; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h index 90ac7b6a8..d15055727 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h @@ -346,7 +346,7 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { m_impl.evalSubExprsIfNeeded(NULL); - if (internal::is_arithmetic<Scalar>::value && data && m_impl.data()) { + if (internal::is_arithmetic<typename internal::remove_const<Scalar>::type>::value && data && m_impl.data()) { Index contiguous_values = 1; if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { for (int i = 0; i < NumDims; ++i) { |