aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-03-04 11:10:21 -0800
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-03-04 11:10:21 -0800
commit694084ecbd12c5183a8ff0604d04971d043abfff (patch)
tree9dbed8e2b2da739bbdf8e99a135778c3765be5ff /unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h
parentb0d406d91c62ff32153df43d5f698ceb02341ac7 (diff)
Use fast divisors in TensorGeneratorOp
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h12
1 files changed, 10 insertions, 2 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h
index ac66f9cf1..0fee18fb6 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h
@@ -98,6 +98,8 @@ struct TensorEvaluator<const TensorGeneratorOp<Generator, ArgType>, Device>
RawAccess = false
};
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
: m_generator(op.generator())
#ifdef EIGEN_USE_SYCL
@@ -118,6 +120,9 @@ struct TensorEvaluator<const TensorGeneratorOp<Generator, ArgType>, Device>
m_strides[i] = m_strides[i + 1] * m_dimensions[i + 1];
}
}
+ for (int i = 0; i < NumDims; ++i) {
+ m_fast_strides[i] = IndexDivisor(m_strides[i]);
+ }
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
@@ -150,6 +155,8 @@ struct TensorEvaluator<const TensorGeneratorOp<Generator, ArgType>, Device>
return rslt;
}
+ // TODO(ezhulenev): Add tiled evaluation support.
+
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
costPerCoeff(bool) const {
// TODO(rmlarsen): This is just a placeholder. Define interface to make
@@ -170,14 +177,14 @@ struct TensorEvaluator<const TensorGeneratorOp<Generator, ArgType>, Device>
void extract_coordinates(Index index, array<Index, NumDims>& coords) const {
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
for (int i = NumDims - 1; i > 0; --i) {
- const Index idx = index / m_strides[i];
+ const Index idx = index / m_fast_strides[i];
index -= idx * m_strides[i];
coords[i] = idx;
}
coords[0] = index;
} else {
for (int i = 0; i < NumDims - 1; ++i) {
- const Index idx = index / m_strides[i];
+ const Index idx = index / m_fast_strides[i];
index -= idx * m_strides[i];
coords[i] = idx;
}
@@ -187,6 +194,7 @@ struct TensorEvaluator<const TensorGeneratorOp<Generator, ArgType>, Device>
Dimensions m_dimensions;
array<Index, NumDims> m_strides;
+ array<IndexDivisor, NumDims> m_fast_strides;
Generator m_generator;
#ifdef EIGEN_USE_SYCL
TensorEvaluator<ArgType, Device> m_argImpl;