aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-01-06 18:47:45 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-01-06 18:47:45 -0800
commit213459d81850f98f3822624ae84c1f420f12092c (patch)
treeb90b42ce4d6d1edc0e7c1063a3b9456488dc81d4 /unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
parentee738321aa6c13f327821f4a4b1aaa4ead635687 (diff)
Optimized the performance of broadcasting of scalars.
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h25
1 files changed, 24 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
index dc64959e1..0c95e5c0b 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h
@@ -46,6 +46,21 @@ struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorB
typedef TensorBroadcastingOp<Broadcast, XprType> type;
};
+template <typename Dims>
+struct is_input_scalar {
+ static const bool value = false;
+};
+template <>
+struct is_input_scalar<Sizes<> > {
+ static const bool value = true;
+};
+#ifndef EIGEN_EMULATE_CXX11_META_H
+template <typename std::size_t... Indices>
+struct is_input_scalar<Sizes<Indices...> > {
+ static const bool value = (Sizes<Indices...>::total_size == 1);
+};
+#endif
+
} // end namespace internal
@@ -103,7 +118,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
// and store the result in a scalar. Instead one should reshape the scalar into a a N-D
// tensor with N >= 1 of 1 element first and then broadcast.
EIGEN_STATIC_ASSERT(NumDims > 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
- const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
+ const InputDimensions& input_dims = m_impl.dimensions();
const Broadcast& broadcast = op.broadcast();
for (int i = 0; i < NumDims; ++i) {
eigen_assert(input_dims[i] > 0);
@@ -143,6 +158,10 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const
{
+ if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
+ return m_impl.coeff(0);
+ }
+
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
return coeffColMajor(index);
} else {
@@ -214,6 +233,10 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
template<int LoadMode>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const
{
+ if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
+ return m_impl.coeff(0);
+ }
+
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
return packetColMajor<LoadMode>(index);
} else {