diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h | 110 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_broadcasting.cpp | 62 |
2 files changed, 168 insertions, 4 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h index b6c93aff9..9ab6b3565 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h @@ -105,6 +105,7 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size; + bool nByOne = false, oneByN = false; enum { IsAligned = true, @@ -142,6 +143,24 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1]; } } + + if (input_dims[0] == 1) { + oneByN = true; + for (int i = 1; i < NumDims; ++i) { + if (broadcast[i] != 1) { + oneByN = false; + break; + } + } + } else if (input_dims[NumDims-1] == 1) { + nByOne = true; + for (int i = 0; i < NumDims-1; ++i) { + if (broadcast[i] != 1) { + nByOne = false; + break; + } + } + } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } @@ -237,9 +256,84 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> } if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { - return packetColMajor<LoadMode>(index); + if (oneByN) { + return packetNByOne<LoadMode>(index); + } else if (nByOne) { + return packetOneByN<LoadMode>(index); + } else { + return packetColMajor<LoadMode>(index); + } } else { - return packetRowMajor<LoadMode>(index); + if (oneByN) { + return packetOneByN<LoadMode>(index); + } else if (nByOne) { + return packetNByOne<LoadMode>(index); + } else { + return packetRowMajor<LoadMode>(index); + } + } + } + + template<int LoadMode> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index) const + { + EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(index+PacketSize-1 < dimensions().TotalSize()); + + Index dim, inputIndex; + + if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { + dim = NumDims - 1; + } else { + dim = 0; + } + + inputIndex = index % m_inputStrides[dim]; + if (inputIndex + PacketSize <= m_inputStrides[dim]) { + return m_impl.template packet<Unaligned>(inputIndex); + } else { + EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize]; + for (int i = 0; i < PacketSize; ++i) { + if (inputIndex > m_inputStrides[dim]-1) { + inputIndex = 0; + } + values[i] = m_impl.coeff(inputIndex++); + } + return internal::pload<PacketReturnType>(values); + } + } + + template<int LoadMode> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetNByOne(Index index) const + { + EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(index+PacketSize-1 < dimensions().TotalSize()); + + EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize]; + Index dim, inputIndex, outputOffset; + + if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { + dim = 1; + } else { + dim = NumDims - 2; + } + + inputIndex = index / m_outputStrides[dim]; + outputOffset = index % m_outputStrides[dim]; + if (outputOffset + PacketSize <= m_outputStrides[dim]) { + values[0] = m_impl.coeff(inputIndex); + return internal::pload1<PacketReturnType>(values); + } else { + for (int i = 0, cur = 0; i < PacketSize; ++i, ++cur) { + if (outputOffset + cur < m_outputStrides[dim]) { + values[i] = m_impl.coeff(inputIndex); + } else { + values[i] = m_impl.coeff(++inputIndex); + outputOffset = 0; + cur = 0; + } + } + return internal::pload<PacketReturnType>(values); } } @@ -290,7 +384,11 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize]; values[0] = m_impl.coeff(inputIndex); for (int i = 1; i < PacketSize; ++i) { - values[i] = coeffColMajor(originalIndex+i); + if (innermostLoc + i < m_impl.dimensions()[0]) { + values[i] = m_impl.coeff(inputIndex+i); + } else { + values[i] = coeffColMajor(originalIndex+i); + } } PacketReturnType rslt = internal::pload<PacketReturnType>(values); return rslt; @@ -342,7 +440,11 @@ struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device> EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize]; values[0] = m_impl.coeff(inputIndex); for (int i = 1; i < PacketSize; ++i) { - values[i] = coeffRowMajor(originalIndex+i); + if (innermostLoc + i < m_impl.dimensions()[NumDims-1]) { + values[i] = m_impl.coeff(inputIndex+i); + } else { + values[i] = coeffRowMajor(originalIndex+i); + } } PacketReturnType rslt = internal::pload<PacketReturnType>(values); return rslt; diff --git a/unsupported/test/cxx11_tensor_broadcasting.cpp b/unsupported/test/cxx11_tensor_broadcasting.cpp index 5c0ea5889..a9d268ea6 100644 --- a/unsupported/test/cxx11_tensor_broadcasting.cpp +++ b/unsupported/test/cxx11_tensor_broadcasting.cpp @@ -180,6 +180,64 @@ static void test_fixed_size_broadcasting() #endif } +template <int DataLayout> +static void test_simple_broadcasting_one_by_n() +{ + Tensor<float, 4, DataLayout> tensor(1,13,5,7); + tensor.setRandom(); + array<ptrdiff_t, 4> broadcasts; + broadcasts[0] = 9; + broadcasts[1] = 1; + broadcasts[2] = 1; + broadcasts[3] = 1; + Tensor<float, 4, DataLayout> broadcast; + broadcast = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcast.dimension(0), 9); + VERIFY_IS_EQUAL(broadcast.dimension(1), 13); + VERIFY_IS_EQUAL(broadcast.dimension(2), 5); + VERIFY_IS_EQUAL(broadcast.dimension(3), 7); + + for (int i = 0; i < 9; ++i) { + for (int j = 0; j < 13; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(tensor(i%1,j%13,k%5,l%7), broadcast(i,j,k,l)); + } + } + } + } +} + +template <int DataLayout> +static void test_simple_broadcasting_n_by_one() +{ + Tensor<float, 4, DataLayout> tensor(7,3,5,1); + tensor.setRandom(); + array<ptrdiff_t, 4> broadcasts; + broadcasts[0] = 1; + broadcasts[1] = 1; + broadcasts[2] = 1; + broadcasts[3] = 19; + Tensor<float, 4, DataLayout> broadcast; + broadcast = tensor.broadcast(broadcasts); + + VERIFY_IS_EQUAL(broadcast.dimension(0), 7); + VERIFY_IS_EQUAL(broadcast.dimension(1), 3); + VERIFY_IS_EQUAL(broadcast.dimension(2), 5); + VERIFY_IS_EQUAL(broadcast.dimension(3), 19); + + for (int i = 0; i < 7; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 19; ++l) { + VERIFY_IS_EQUAL(tensor(i%7,j%3,k%5,l%1), broadcast(i,j,k,l)); + } + } + } + } +} + void test_cxx11_tensor_broadcasting() { @@ -191,4 +249,8 @@ void test_cxx11_tensor_broadcasting() CALL_SUBTEST(test_static_broadcasting<RowMajor>()); CALL_SUBTEST(test_fixed_size_broadcasting<ColMajor>()); CALL_SUBTEST(test_fixed_size_broadcasting<RowMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n<RowMajor>()); + CALL_SUBTEST(test_simple_broadcasting_n_by_one<RowMajor>()); + CALL_SUBTEST(test_simple_broadcasting_one_by_n<ColMajor>()); + CALL_SUBTEST(test_simple_broadcasting_n_by_one<ColMajor>()); } |