diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-02-17 09:57:41 -0800 |
---|---|---|
committer | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2015-02-17 09:57:41 -0800 |
commit | 1d3b64d32b57a756a2c8409a6d60fb308f17e595 (patch) | |
tree | 648e0e2965144fe015035d2eb3f4f8f3d845253a /unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h | |
parent | 00f048d44f7dd40a0a4e80e40787a930db0f18f0 (diff) |
Added support for tensor concatenation as lvalue
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h | 88 |
1 files changed, 87 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h index a1dec76d1..91d8d54dc 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h @@ -81,7 +81,26 @@ class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsX const typename internal::remove_all<typename RhsXprType::Nested>::type& rhsExpression() const { return m_rhs_xpr; } - EIGEN_DEVICE_FUNC Axis axis() const { return m_axis; } + EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const TensorConcatenationOp& other) + { + typedef TensorAssignOp<TensorConcatenationOp, const TensorConcatenationOp> Assign; + Assign assign(*this, other); + internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice()); + return *this; + } + + template<typename OtherDerived> + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const OtherDerived& other) + { + typedef TensorAssignOp<TensorConcatenationOp, const OtherDerived> Assign; + Assign assign(*this, other); + internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice()); + return *this; + } protected: typename LhsXprType::Nested m_lhs_xpr; @@ -252,6 +271,73 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy const Axis m_axis; }; +// Eval as lvalue +template<typename Axis, typename LeftArgType, typename RightArgType, typename Device> + struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> + : public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> +{ + typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base; + typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType; + typedef typename Base::Dimensions Dimensions; + enum { + IsAligned = false, + PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess, + Layout = TensorEvaluator<LeftArgType, Device>::Layout, + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device) + : Base(op, device) + { + EIGEN_STATIC_ASSERT((Layout == ColMajor), YOU_MADE_A_PROGRAMMING_MISTAKE); + } + + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index) + { + // Collect dimension-wise indices (subs). + array<Index, Base::NumDims> subs; + for (int i = Base::NumDims - 1; i > 0; --i) { + subs[i] = index / this->m_outputStrides[i]; + index -= subs[i] * this->m_outputStrides[i]; + } + subs[0] = index; + + const Dimensions& left_dims = this->m_leftImpl.dimensions(); + if (subs[this->m_axis] < left_dims[this->m_axis]) { + Index left_index = subs[0]; + for (int i = 1; i < Base::NumDims; ++i) { + left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i]; + } + return this->m_leftImpl.coeffRef(left_index); + } else { + subs[this->m_axis] -= left_dims[this->m_axis]; + const Dimensions& right_dims = this->m_rightImpl.dimensions(); + Index right_index = subs[0]; + for (int i = 1; i < Base::NumDims; ++i) { + right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i]; + } + return this->m_rightImpl.coeffRef(right_index); + } + } + + template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + void writePacket(Index index, const PacketReturnType& x) + { + static const int packetSize = internal::unpacket_traits<PacketReturnType>::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize()); + + EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize]; + PacketReturnType rslt = internal::pstore<PacketReturnType>(values, x); + for (int i = 0; i < packetSize; ++i) { + coeffRef(index+i) = values[i]; + } + } +}; } // end namespace Eigen |