diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-01-06 18:47:45 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-01-06 18:47:45 -0800 |
commit | 213459d81850f98f3822624ae84c1f420f12092c (patch) | |
tree | b90b42ce4d6d1edc0e7c1063a3b9456488dc81d4 /unsupported | |
parent | ee738321aa6c13f327821f4a4b1aaa4ead635687 (diff) |
Optimized the performance of broadcasting of scalars.
Diffstat (limited to 'unsupported')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h index dc64959e1..0c95e5c0b 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h @@ -46,6 +46,21 @@ struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorB typedef TensorBroadcastingOp<Broadcast, XprType> type; }; +template <typename Dims> +struct is_input_scalar { + static const bool value = false; +}; +template <> +struct is_input_scalar<Sizes<> > { + static const bool value = true; +}; +#ifndef EIGEN_EMULATE_CXX11_META_H +template <typename std::size_t... Indices> +struct is_input_scalar<Sizes<Indices...> > { + static const bool value = (Sizes<Indices...>::total_size == 1); +}; +#endif + } // end namespace internal @@ -103,7 +118,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> // and store the result in a scalar. Instead one should reshape the scalar into a a N-D // tensor with N >= 1 of 1 element first and then broadcast. EIGEN_STATIC_ASSERT(NumDims > 0, YOU_MADE_A_PROGRAMMING_MISTAKE); - const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions(); + const InputDimensions& input_dims = m_impl.dimensions(); const Broadcast& broadcast = op.broadcast(); for (int i = 0; i < NumDims; ++i) { eigen_assert(input_dims[i] > 0); @@ -143,6 +158,10 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const { + if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) { + return m_impl.coeff(0); + } + if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { return coeffColMajor(index); } else { @@ -214,6 +233,10 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const { + if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) { + return m_impl.coeff(0); + } + if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { return packetColMajor<LoadMode>(index); } else { |