diff options
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorBase.h | 5 | ||||
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h | 88 | ||||
-rw-r--r-- | unsupported/test/cxx11_tensor_concatenation.cpp | 21 |
3 files changed, 113 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index e08ac6aa1..cfcf18e8e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -524,6 +524,11 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA swap_layout() const { return TensorLayoutSwapOp<Derived>(derived()); } + template <typename Axis, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + TensorConcatenationOp<const Axis, Derived, OtherDerived> + concatenate(const OtherDerived& other, const Axis& axis) const { + return TensorConcatenationOp<const Axis, Derived, OtherDerived>(derived(), other.derived(), axis); + } template <typename NewDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorReshapingOp<const NewDimensions, Derived> reshape(const NewDimensions& newDimensions) const { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h index a1dec76d1..78214df11 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h @@ -81,7 +81,26 @@ class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsX const typename internal::remove_all<typename RhsXprType::Nested>::type& rhsExpression() const { return m_rhs_xpr; } - EIGEN_DEVICE_FUNC Axis axis() const { return m_axis; } + EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const TensorConcatenationOp& other) + { + typedef TensorAssignOp<TensorConcatenationOp, const TensorConcatenationOp> Assign; + Assign assign(*this, other); + internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice()); + return *this; + } + + template<typename OtherDerived> + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const OtherDerived& other) + { + typedef TensorAssignOp<TensorConcatenationOp, const OtherDerived> Assign; + Assign assign(*this, other); + internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice()); + return *this; + } protected: typename LhsXprType::Nested m_lhs_xpr; @@ -252,6 +271,73 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy const Axis m_axis; }; +// Eval as lvalue +template<typename Axis, typename LeftArgType, typename RightArgType, typename Device> + struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> + : public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> +{ + typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base; + typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType; + typedef typename Base::Dimensions Dimensions; + enum { + IsAligned = false, + PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess, + Layout = TensorEvaluator<LeftArgType, Device>::Layout, + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device) + : Base(op, device) + { + EIGEN_STATIC_ASSERT((static_cast<int>(Layout) == static_cast<int>(ColMajor)), YOU_MADE_A_PROGRAMMING_MISTAKE); + } + + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index) + { + // Collect dimension-wise indices (subs). + array<Index, Base::NumDims> subs; + for (int i = Base::NumDims - 1; i > 0; --i) { + subs[i] = index / this->m_outputStrides[i]; + index -= subs[i] * this->m_outputStrides[i]; + } + subs[0] = index; + + const Dimensions& left_dims = this->m_leftImpl.dimensions(); + if (subs[this->m_axis] < left_dims[this->m_axis]) { + Index left_index = subs[0]; + for (int i = 1; i < Base::NumDims; ++i) { + left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i]; + } + return this->m_leftImpl.coeffRef(left_index); + } else { + subs[this->m_axis] -= left_dims[this->m_axis]; + const Dimensions& right_dims = this->m_rightImpl.dimensions(); + Index right_index = subs[0]; + for (int i = 1; i < Base::NumDims; ++i) { + right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i]; + } + return this->m_rightImpl.coeffRef(right_index); + } + } + + template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + void writePacket(Index index, const PacketReturnType& x) + { + static const int packetSize = internal::unpacket_traits<PacketReturnType>::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize()); + + EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize]; + PacketReturnType rslt = internal::pstore<PacketReturnType>(values, x); + for (int i = 0; i < packetSize; ++i) { + coeffRef(index+i) = values[i]; + } + } +}; } // end namespace Eigen diff --git a/unsupported/test/cxx11_tensor_concatenation.cpp b/unsupported/test/cxx11_tensor_concatenation.cpp index 9fdf33c16..cc9dfb769 100644 --- a/unsupported/test/cxx11_tensor_concatenation.cpp +++ b/unsupported/test/cxx11_tensor_concatenation.cpp @@ -103,6 +103,25 @@ static void test_simple_concatenation() // TODO(phli): Add test once we have a real vectorized implementation. // static void test_vectorized_concatenation() {} +static void test_concatenation_as_lvalue() +{ + Tensor<int, 2> t1(2, 3); + Tensor<int, 2> t2(2, 3); + t1.setRandom(); + t2.setRandom(); + + Tensor<int, 2> result(4, 3); + result.setRandom(); + t1.concatenate(t2, 0) = result; + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + VERIFY_IS_EQUAL(t1(i, j), result(i, j)); + VERIFY_IS_EQUAL(t2(i, j), result(i+2, j)); + } + } +} + void test_cxx11_tensor_concatenation() { @@ -113,4 +132,6 @@ void test_cxx11_tensor_concatenation() CALL_SUBTEST(test_simple_concatenation<ColMajor>()); CALL_SUBTEST(test_simple_concatenation<RowMajor>()); // CALL_SUBTEST(test_vectorized_concatenation()); + CALL_SUBTEST(test_concatenation_as_lvalue()); + } |