diff options
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h | 70 |
1 files changed, 66 insertions, 4 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h index 9ab6b3565..b35b36475 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h @@ -161,6 +161,22 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> } } } + + // Handle special format like NCHW, its input shape is '[1, N..., 1]' and + // broadcast shape is '[N, 1..., N]' + if (!oneByN && !nByOne) { + if (input_dims[0] == 1 && input_dims[NumDims-1] == 1 && NumDims > 2) { + nByOne = true; + oneByN = true; + for (int i = 1; i < NumDims-1; ++i) { + if (broadcast[i] != 1) { + nByOne = false; + oneByN = false; + break; + } + } + } + } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } @@ -256,18 +272,22 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> } if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { - if (oneByN) { + if (oneByN && !nByOne) { return packetNByOne<LoadMode>(index); - } else if (nByOne) { + } else if (!oneByN && nByOne) { return packetOneByN<LoadMode>(index); + } else if (oneByN && nByOne) { + return packetOneByNByOne<LoadMode>(index); } else { return packetColMajor<LoadMode>(index); } } else { - if (oneByN) { + if (oneByN && !nByOne) { return packetOneByN<LoadMode>(index); - } else if (nByOne) { + } else if (!oneByN && nByOne) { return packetNByOne<LoadMode>(index); + } else if (oneByN && nByOne) { + return packetOneByNByOne<LoadMode>(index); } else { return packetRowMajor<LoadMode>(index); } @@ -275,6 +295,48 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> } template<int LoadMode> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByNByOne + (Index index) const + { + EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(index+PacketSize-1 < dimensions().TotalSize()); + + EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize]; + Index startDim, endDim; + Index inputIndex, outputOffset, batchedIndex; + + if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { + startDim = NumDims - 1; + endDim = 1; + } else { + startDim = 0; + endDim = NumDims - 2; + } + + batchedIndex = index % m_outputStrides[startDim]; + inputIndex = batchedIndex / m_outputStrides[endDim]; + outputOffset = batchedIndex % m_outputStrides[endDim]; + + if (outputOffset + PacketSize <= m_outputStrides[endDim]) { + values[0] = m_impl.coeff(inputIndex); + return internal::pload1<PacketReturnType>(values); + } else { + for (int i = 0, cur = 0; i < PacketSize; ++i, ++cur) { + if (outputOffset + cur < m_outputStrides[endDim]) { + values[i] = m_impl.coeff(inputIndex); + } else { + ++inputIndex; + inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex); + values[i] = m_impl.coeff(inputIndex); + outputOffset = 0; + cur = 0; + } + } + return internal::pload<PacketReturnType>(values); + } + } + + template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index) const { EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) |