diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h | 37 |
1 files changed, 26 insertions, 11 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h index b35b36475..278689915 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h @@ -105,7 +105,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; - bool nByOne = false, oneByN = false; + bool isCopy= false, nByOne = false, oneByN = false; enum { IsAligned = true, @@ -122,10 +122,13 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> // tensor with N >= 1 of 1 element first and then broadcast. EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE); const InputDimensions& input_dims = m_impl.dimensions(); - const Broadcast& broadcast = op.broadcast(); + isCopy = true; for (int i = 0; i < NumDims; ++i) { eigen_assert(input_dims[i] > 0); - m_dimensions[i] = input_dims[i] * broadcast[i]; + m_dimensions[i] = input_dims[i] * m_broadcast[i]; + if (m_broadcast[i] != 1) { + isCopy = false; + } } if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { @@ -147,7 +150,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> if (input_dims[0] == 1) { oneByN = true; for (int i = 1; i < NumDims; ++i) { - if (broadcast[i] != 1) { + if (m_broadcast[i] != 1) { oneByN = false; break; } @@ -155,7 +158,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> } else if (input_dims[NumDims-1] == 1) { nByOne = true; for (int i = 0; i < NumDims-1; ++i) { - if (broadcast[i] != 1) { + if (m_broadcast[i] != 1) { nByOne = false; break; } @@ -169,7 +172,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> nByOne = true; oneByN = true; for (int i = 1; i < NumDims-1; ++i) { - if (broadcast[i] != 1) { + if (m_broadcast[i] != 1) { nByOne = false; oneByN = false; break; @@ -197,9 +200,17 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> } if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { - return coeffColMajor(index); + if (isCopy) { + return m_impl.coeff(index); + } else { + return coeffColMajor(index); + } } else { - return coeffRowMajor(index); + if (isCopy) { + return m_impl.coeff(index); + } else { + return coeffRowMajor(index); + } } } @@ -272,7 +283,9 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> } if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { - if (oneByN && !nByOne) { + if (isCopy) { + return m_impl.template packet<LoadMode>(index); + } else if (oneByN && !nByOne) { return packetNByOne<LoadMode>(index); } else if (!oneByN && nByOne) { return packetOneByN<LoadMode>(index); @@ -282,7 +295,9 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> return packetColMajor<LoadMode>(index); } } else { - if (oneByN && !nByOne) { + if (isCopy) { + return m_impl.template packet<LoadMode>(index); + } else if (oneByN && !nByOne) { return packetOneByN<LoadMode>(index); } else if (!oneByN && nByOne) { return packetNByOne<LoadMode>(index); @@ -516,7 +531,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const { double compute_cost = TensorOpCost::AddCost<Index>(); - if (NumDims > 0) { + if (!isCopy && NumDims > 0) { for (int i = NumDims - 1; i > 0; --i) { compute_cost += TensorOpCost::DivCost<Index>(); if (internal::index_statically_eq<Broadcast>(i, 1)) { |