diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h | 31 |
1 files changed, 26 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h index dc64959e1..b6e6db12a 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h @@ -25,7 +25,6 @@ struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType { typedef typename XprType::Scalar Scalar; typedef traits<XprType> XprTraits; - typedef typename packet_traits<Scalar>::type Packet; typedef typename XprTraits::StorageKind StorageKind; typedef typename XprTraits::Index Index; typedef typename XprType::Nested Nested; @@ -46,6 +45,21 @@ struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorB typedef TensorBroadcastingOp<Broadcast, XprType> type; }; +template <typename Dims> +struct is_input_scalar { + static const bool value = false; +}; +template <> +struct is_input_scalar<Sizes<> > { + static const bool value = true; +}; +#ifndef EIGEN_EMULATE_CXX11_META_H +template <typename std::size_t... Indices> +struct is_input_scalar<Sizes<Indices...> > { + static const bool value = (Sizes<Indices...>::total_size == 1); +}; +#endif + } // end namespace internal @@ -55,10 +69,8 @@ class TensorBroadcastingOp : public TensorBase<TensorBroadcastingOp<Broadcast, X { public: typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar; - typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Packet Packet; typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; typedef typename XprType::CoeffReturnType CoeffReturnType; - typedef typename XprType::PacketReturnType PacketReturnType; typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested; typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind; typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index; @@ -94,6 +106,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> IsAligned = false, PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, Layout = TensorEvaluator<ArgType, Device>::Layout, + RawAccess = false }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) @@ -103,7 +116,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> // and store the result in a scalar. Instead one should reshape the scalar into a a N-D // tensor with N >= 1 of 1 element first and then broadcast. EIGEN_STATIC_ASSERT(NumDims > 0, YOU_MADE_A_PROGRAMMING_MISTAKE); - const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions(); + const InputDimensions& input_dims = m_impl.dimensions(); const Broadcast& broadcast = op.broadcast(); for (int i = 0; i < NumDims; ++i) { eigen_assert(input_dims[i] > 0); @@ -128,7 +141,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> } typedef typename XprType::CoeffReturnType CoeffReturnType; - typedef typename XprType::PacketReturnType PacketReturnType; + typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } @@ -143,6 +156,10 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const { + if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) { + return m_impl.coeff(0); + } + if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { return coeffColMajor(index); } else { @@ -214,6 +231,10 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const { + if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) { + return internal::pset1<PacketReturnType>(m_impl.coeff(0)); + } + if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { return packetColMajor<LoadMode>(index); } else { |