aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-04-20 17:34:11 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-04-20 17:34:11 -0700
commit10a1f81822d9bc6d998af413848fddc103c801a0 (patch)
tree3efe4fcee5e56ae9ec946776ffdf91bde60b61f5 /unsupported
parent43eb2ca6e1d80c2c3517f7af3c144b50b472cfae (diff)
Sped up the assignment of a tensor to a tensor slice, as well as the assigment of a constant slice to a tensor
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h14
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h2
2 files changed, 12 insertions, 4 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
index 9198c17ef..a38af84d5 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h
@@ -132,13 +132,20 @@ struct TensorEvaluator<const Derived, Device>
CoordAccess = NumCoords > 0,
};
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device&)
- : m_data(m.data()), m_dims(m.dimensions())
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
+ : m_data(m.data()), m_dims(m.dimensions()), m_device(device)
{ }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; }
- EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { return true; }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
+ if (internal::is_arithmetic<typename internal::remove_const<Scalar>::type>::value && data) {
+ m_device.memcpy((void*)data, m_data, m_dims.TotalSize() * sizeof(Scalar));
+ return false;
+ }
+ return true;
+ }
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
@@ -172,6 +179,7 @@ struct TensorEvaluator<const Derived, Device>
protected:
const Scalar* m_data;
Dimensions m_dims;
+ const Device& m_device;
};
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
index 90ac7b6a8..d15055727 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h
@@ -346,7 +346,7 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
m_impl.evalSubExprsIfNeeded(NULL);
- if (internal::is_arithmetic<Scalar>::value && data && m_impl.data()) {
+ if (internal::is_arithmetic<typename internal::remove_const<Scalar>::type>::value && data && m_impl.data()) {
Index contiguous_values = 1;
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
for (int i = 0; i < NumDims; ++i) {