diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-04-20 17:34:11 -0700 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-04-20 17:34:11 -0700 |
commit | 10a1f81822d9bc6d998af413848fddc103c801a0 (patch) | |
tree | 3efe4fcee5e56ae9ec946776ffdf91bde60b61f5 /unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | |
parent | 43eb2ca6e1d80c2c3517f7af3c144b50b472cfae (diff) |
Sped up the assignment of a tensor to a tensor slice, as well as the assigment of a constant slice to a tensor
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index 9198c17ef..a38af84d5 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -132,13 +132,20 @@ struct TensorEvaluator<const Derived, Device> CoordAccess = NumCoords > 0, }; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device&) - : m_data(m.data()), m_dims(m.dimensions()) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device) + : m_data(m.data()), m_dims(m.dimensions()), m_device(device) { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { return true; } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { + if (internal::is_arithmetic<typename internal::remove_const<Scalar>::type>::value && data) { + m_device.memcpy((void*)data, m_data, m_dims.TotalSize() * sizeof(Scalar)); + return false; + } + return true; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { @@ -172,6 +179,7 @@ struct TensorEvaluator<const Derived, Device> protected: const Scalar* m_data; Dimensions m_dims; + const Device& m_device; }; |