aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-02-17 09:57:41 -0800
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2015-02-17 09:57:41 -0800
commit1d3b64d32b57a756a2c8409a6d60fb308f17e595 (patch)
tree648e0e2965144fe015035d2eb3f4f8f3d845253a /unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
parent00f048d44f7dd40a0a4e80e40787a930db0f18f0 (diff)
Added support for tensor concatenation as lvalue
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h88
1 files changed, 87 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
index a1dec76d1..91d8d54dc 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
@@ -81,7 +81,26 @@ class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsX
const typename internal::remove_all<typename RhsXprType::Nested>::type&
rhsExpression() const { return m_rhs_xpr; }
- EIGEN_DEVICE_FUNC Axis axis() const { return m_axis; }
+ EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const TensorConcatenationOp& other)
+ {
+ typedef TensorAssignOp<TensorConcatenationOp, const TensorConcatenationOp> Assign;
+ Assign assign(*this, other);
+ internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
+ return *this;
+ }
+
+ template<typename OtherDerived>
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const OtherDerived& other)
+ {
+ typedef TensorAssignOp<TensorConcatenationOp, const OtherDerived> Assign;
+ Assign assign(*this, other);
+ internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
+ return *this;
+ }
protected:
typename LhsXprType::Nested m_lhs_xpr;
@@ -252,6 +271,73 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy
const Axis m_axis;
};
+// Eval as lvalue
+template<typename Axis, typename LeftArgType, typename RightArgType, typename Device>
+ struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
+ : public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
+{
+ typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base;
+ typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType;
+ typedef typename Base::Dimensions Dimensions;
+ enum {
+ IsAligned = false,
+ PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess,
+ Layout = TensorEvaluator<LeftArgType, Device>::Layout,
+ };
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device)
+ : Base(op, device)
+ {
+ EIGEN_STATIC_ASSERT((Layout == ColMajor), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ }
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::Scalar Scalar;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
+ {
+ // Collect dimension-wise indices (subs).
+ array<Index, Base::NumDims> subs;
+ for (int i = Base::NumDims - 1; i > 0; --i) {
+ subs[i] = index / this->m_outputStrides[i];
+ index -= subs[i] * this->m_outputStrides[i];
+ }
+ subs[0] = index;
+
+ const Dimensions& left_dims = this->m_leftImpl.dimensions();
+ if (subs[this->m_axis] < left_dims[this->m_axis]) {
+ Index left_index = subs[0];
+ for (int i = 1; i < Base::NumDims; ++i) {
+ left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i];
+ }
+ return this->m_leftImpl.coeffRef(left_index);
+ } else {
+ subs[this->m_axis] -= left_dims[this->m_axis];
+ const Dimensions& right_dims = this->m_rightImpl.dimensions();
+ Index right_index = subs[0];
+ for (int i = 1; i < Base::NumDims; ++i) {
+ right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i];
+ }
+ return this->m_rightImpl.coeffRef(right_index);
+ }
+ }
+
+ template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ void writePacket(Index index, const PacketReturnType& x)
+ {
+ static const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize());
+
+ EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize];
+ PacketReturnType rslt = internal::pstore<PacketReturnType>(values, x);
+ for (int i = 0; i < packetSize; ++i) {
+ coeffRef(index+i) = values[i];
+ }
+ }
+};
} // end namespace Eigen