From 213459d81850f98f3822624ae84c1f420f12092c Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 6 Jan 2016 18:47:45 -0800 Subject: Optimized the performance of broadcasting of scalars. --- .../Eigen/CXX11/src/Tensor/TensorBroadcasting.h | 25 +++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h') diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h index dc64959e1..0c95e5c0b 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h @@ -46,6 +46,21 @@ struct nested, 1, typename eval type; }; +template +struct is_input_scalar { + static const bool value = false; +}; +template <> +struct is_input_scalar > { + static const bool value = true; +}; +#ifndef EIGEN_EMULATE_CXX11_META_H +template +struct is_input_scalar > { + static const bool value = (Sizes::total_size == 1); +}; +#endif + } // end namespace internal @@ -103,7 +118,7 @@ struct TensorEvaluator, Device> // and store the result in a scalar. Instead one should reshape the scalar into a a N-D // tensor with N >= 1 of 1 element first and then broadcast. EIGEN_STATIC_ASSERT(NumDims > 0, YOU_MADE_A_PROGRAMMING_MISTAKE); - const typename TensorEvaluator::Dimensions& input_dims = m_impl.dimensions(); + const InputDimensions& input_dims = m_impl.dimensions(); const Broadcast& broadcast = op.broadcast(); for (int i = 0; i < NumDims; ++i) { eigen_assert(input_dims[i] > 0); @@ -143,6 +158,10 @@ struct TensorEvaluator, Device> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const { + if (internal::is_input_scalar::type>::value) { + return m_impl.coeff(0); + } + if (static_cast(Layout) == static_cast(ColMajor)) { return coeffColMajor(index); } else { @@ -214,6 +233,10 @@ struct TensorEvaluator, Device> template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const { + if (internal::is_input_scalar::type>::value) { + return m_impl.coeff(0); + } + if (static_cast(Layout) == static_cast(ColMajor)) { return packetColMajor(index); } else { -- cgit v1.2.3