aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-11-12 22:35:44 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2014-11-12 22:35:44 -0800
commiteeabf7975e59b47f4e3677c340013ebbfcfbc2bd (patch)
treea69c6f1f5905d5952896bca8f34829ed2276641c /unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
parentc2d1074932ae92a001eadb27e9f85eaf2de187b9 (diff)
Optimized broadcasting
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h36
1 files changed, 29 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
index 2bd158dac..a77903dca 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
@@ -24,11 +24,13 @@ template<typename Broadcast, typename XprType>
struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType>
{
typedef typename XprType::Scalar Scalar;
- typedef typename internal::packet_traits<Scalar>::type Packet;
- typedef typename traits<XprType>::StorageKind StorageKind;
- typedef typename traits<XprType>::Index Index;
+ typedef traits<XprType> XprTraits;
+ typedef typename packet_traits<Scalar>::type Packet;
+ typedef typename XprTraits::StorageKind StorageKind;
+ typedef typename XprTraits::Index Index;
typedef typename XprType::Nested Nested;
typedef typename remove_reference<Nested>::type _Nested;
+ static const int NumDimensions = XprTraits::NumDimensions;
};
template<typename Broadcast, typename XprType>
@@ -85,6 +87,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
typedef DSizes<Index, NumDims> Dimensions;
typedef typename XprType::Scalar Scalar;
+ typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
enum {
IsAligned = false,
@@ -129,10 +132,19 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
Index inputIndex = 0;
for (int i = NumDims - 1; i > 0; --i) {
const Index idx = index / m_outputStrides[i];
- inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
+ if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
+ eigen_assert(idx % m_impl.dimensions()[i] == 0);
+ } else {
+ inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
+ }
index -= idx * m_outputStrides[i];
}
- inputIndex += (index % m_impl.dimensions()[0]);
+ if (internal::index_statically_eq<Broadcast>()(0, 1)) {
+ eigen_assert(index < m_impl.dimensions()[0]);
+ inputIndex += index;
+ } else {
+ inputIndex += (index % m_impl.dimensions()[0]);
+ }
return m_impl.coeff(inputIndex);
}
@@ -150,10 +162,20 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
Index inputIndex = 0;
for (int i = NumDims - 1; i > 0; --i) {
const Index idx = index / m_outputStrides[i];
- inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
+ if (internal::index_statically_eq<InputDimensions>()(i, 1)) {
+ eigen_assert(idx % m_impl.dimensions()[i] == 0);
+ } else {
+ inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
+ }
index -= idx * m_outputStrides[i];
}
- const Index innermostLoc = index % m_impl.dimensions()[0];
+ Index innermostLoc;
+ if (internal::index_statically_eq<Broadcast>()(0, 1)) {
+ eigen_assert(index < m_impl.dimensions()[0]);
+ innermostLoc = index;
+ } else {
+ innermostLoc = index % m_impl.dimensions()[0];
+ }
inputIndex += innermostLoc;
// Todo: this could be extended to the second dimension if we're not