diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h | 36 |
1 files changed, 29 insertions, 7 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h index 2bd158dac..a77903dca 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h @@ -24,11 +24,13 @@ template<typename Broadcast, typename XprType> struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType> { typedef typename XprType::Scalar Scalar; - typedef typename internal::packet_traits<Scalar>::type Packet; - typedef typename traits<XprType>::StorageKind StorageKind; - typedef typename traits<XprType>::Index Index; + typedef traits<XprType> XprTraits; + typedef typename packet_traits<Scalar>::type Packet; + typedef typename XprTraits::StorageKind StorageKind; + typedef typename XprTraits::Index Index; typedef typename XprType::Nested Nested; typedef typename remove_reference<Nested>::type _Nested; + static const int NumDimensions = XprTraits::NumDimensions; }; template<typename Broadcast, typename XprType> @@ -85,6 +87,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value; typedef DSizes<Index, NumDims> Dimensions; typedef typename XprType::Scalar Scalar; + typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions; enum { IsAligned = false, @@ -129,10 +132,19 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> Index inputIndex = 0; for (int i = NumDims - 1; i > 0; --i) { const Index idx = index / m_outputStrides[i]; - inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; + if (internal::index_statically_eq<InputDimensions>()(i, 1)) { + eigen_assert(idx % m_impl.dimensions()[i] == 0); + } else { + inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; + } index -= idx * m_outputStrides[i]; } - inputIndex += (index % m_impl.dimensions()[0]); + if (internal::index_statically_eq<Broadcast>()(0, 1)) { + eigen_assert(index < m_impl.dimensions()[0]); + inputIndex += index; + } else { + inputIndex += (index % m_impl.dimensions()[0]); + } return m_impl.coeff(inputIndex); } @@ -150,10 +162,20 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> Index inputIndex = 0; for (int i = NumDims - 1; i > 0; --i) { const Index idx = index / m_outputStrides[i]; - inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; + if (internal::index_statically_eq<InputDimensions>()(i, 1)) { + eigen_assert(idx % m_impl.dimensions()[i] == 0); + } else { + inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i]; + } index -= idx * m_outputStrides[i]; } - const Index innermostLoc = index % m_impl.dimensions()[0]; + Index innermostLoc; + if (internal::index_statically_eq<Broadcast>()(0, 1)) { + eigen_assert(index < m_impl.dimensions()[0]); + innermostLoc = index; + } else { + innermostLoc = index % m_impl.dimensions()[0]; + } inputIndex += innermostLoc; // Todo: this could be extended to the second dimension if we're not |