diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-11-12 22:35:44 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2014-11-12 22:35:44 -0800 |
commit | eeabf7975e59b47f4e3677c340013ebbfcfbc2bd (patch) | |
tree | a69c6f1f5905d5952896bca8f34829ed2276641c /unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h | |
parent | c2d1074932ae92a001eadb27e9f85eaf2de187b9 (diff) |
Optimized broadcasting
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h | 36 |
1 files changed, 29 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h index 2bd158dac..a77903dca 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h @@ -24,11 +24,13 @@ template<typename Broadcast, typename XprType> struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType> { typedef typename XprType::Scalar Scalar; - typedef typename internal::packet_traits<Scalar>::type Packet; - typedef typename traits<XprType>::StorageKind StorageKind; - typedef typename traits<XprType>::Index Index; + typedef traits<XprType> XprTraits; + typedef typename packet_traits<Scalar>::type Packet; + typedef typename XprTraits::StorageKind StorageKind; + typedef typename XprTraits::Index Index; typedef typename XprType::Nested Nested; typedef typename remove_reference<Nested>::type _Nested; + static const int NumDimensions = XprTraits::NumDimensions; }; template<typename Broadcast, typename XprType> @@ -85,6 +87,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value; typedef DSizes<Index, NumDims> Dimensions; typedef typename XprType::Scalar Scalar; + typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions; enum { IsAligned = false, @@ -129,10 +132,19 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> Index inputIndex = 0; for (int i = NumDims - 1; i > 0; --i) { const Index idx = index / m_outputStrides[i]; - inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; + if (internal::index_statically_eq<InputDimensions>()(i, 1)) { + eigen_assert(idx % m_impl.dimensions()[i] == 0); + } else { + inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; + } index -= idx * m_outputStrides[i]; } - inputIndex += (index % m_impl.dimensions()[0]); + if (internal::index_statically_eq<Broadcast>()(0, 1)) { + eigen_assert(index < m_impl.dimensions()[0]); + inputIndex += index; + } else { + inputIndex += (index % m_impl.dimensions()[0]); + } return m_impl.coeff(inputIndex); } @@ -150,10 +162,20 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> Index inputIndex = 0; for (int i = NumDims - 1; i > 0; --i) { const Index idx = index / m_outputStrides[i]; - inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; + if (internal::index_statically_eq<InputDimensions>()(i, 1)) { + eigen_assert(idx % m_impl.dimensions()[i] == 0); + } else { + inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; + } index -= idx * m_outputStrides[i]; } - const Index innermostLoc = index % m_impl.dimensions()[0]; + Index innermostLoc; + if (internal::index_statically_eq<Broadcast>()(0, 1)) { + eigen_assert(index < m_impl.dimensions()[0]); + innermostLoc = index; + } else { + innermostLoc = index % m_impl.dimensions()[0]; + } inputIndex += innermostLoc; // Todo: this could be extended to the second dimension if we're not |