aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorBase.h5
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h88
-rw-r--r--unsupported/test/cxx11_tensor_concatenation.cpp21
3 files changed, 113 insertions, 1 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
index e08ac6aa1..cfcf18e8e 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h
@@ -524,6 +524,11 @@ class TensorBase<Derived, WriteAccessors> : public TensorBase<Derived, ReadOnlyA
swap_layout() const {
return TensorLayoutSwapOp<Derived>(derived());
}
+ template <typename Axis, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ TensorConcatenationOp<const Axis, Derived, OtherDerived>
+ concatenate(const OtherDerived& other, const Axis& axis) const {
+ return TensorConcatenationOp<const Axis, Derived, OtherDerived>(derived(), other.derived(), axis);
+ }
template <typename NewDimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorReshapingOp<const NewDimensions, Derived>
reshape(const NewDimensions& newDimensions) const {
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
index a1dec76d1..78214df11 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h
@@ -81,7 +81,26 @@ class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsX
const typename internal::remove_all<typename RhsXprType::Nested>::type&
rhsExpression() const { return m_rhs_xpr; }
- EIGEN_DEVICE_FUNC Axis axis() const { return m_axis; }
+ EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const TensorConcatenationOp& other)
+ {
+ typedef TensorAssignOp<TensorConcatenationOp, const TensorConcatenationOp> Assign;
+ Assign assign(*this, other);
+ internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
+ return *this;
+ }
+
+ template<typename OtherDerived>
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE TensorConcatenationOp& operator = (const OtherDerived& other)
+ {
+ typedef TensorAssignOp<TensorConcatenationOp, const OtherDerived> Assign;
+ Assign assign(*this, other);
+ internal::TensorExecutor<const Assign, DefaultDevice, false>::run(assign, DefaultDevice());
+ return *this;
+ }
protected:
typename LhsXprType::Nested m_lhs_xpr;
@@ -252,6 +271,73 @@ struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgTy
const Axis m_axis;
};
+// Eval as lvalue
+template<typename Axis, typename LeftArgType, typename RightArgType, typename Device>
+ struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
+ : public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
+{
+ typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base;
+ typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType;
+ typedef typename Base::Dimensions Dimensions;
+ enum {
+ IsAligned = false,
+ PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess,
+ Layout = TensorEvaluator<LeftArgType, Device>::Layout,
+ };
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device)
+ : Base(op, device)
+ {
+ EIGEN_STATIC_ASSERT((static_cast<int>(Layout) == static_cast<int>(ColMajor)), YOU_MADE_A_PROGRAMMING_MISTAKE);
+ }
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::Scalar Scalar;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
+ {
+ // Collect dimension-wise indices (subs).
+ array<Index, Base::NumDims> subs;
+ for (int i = Base::NumDims - 1; i > 0; --i) {
+ subs[i] = index / this->m_outputStrides[i];
+ index -= subs[i] * this->m_outputStrides[i];
+ }
+ subs[0] = index;
+
+ const Dimensions& left_dims = this->m_leftImpl.dimensions();
+ if (subs[this->m_axis] < left_dims[this->m_axis]) {
+ Index left_index = subs[0];
+ for (int i = 1; i < Base::NumDims; ++i) {
+ left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i];
+ }
+ return this->m_leftImpl.coeffRef(left_index);
+ } else {
+ subs[this->m_axis] -= left_dims[this->m_axis];
+ const Dimensions& right_dims = this->m_rightImpl.dimensions();
+ Index right_index = subs[0];
+ for (int i = 1; i < Base::NumDims; ++i) {
+ right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i];
+ }
+ return this->m_rightImpl.coeffRef(right_index);
+ }
+ }
+
+ template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
+ void writePacket(Index index, const PacketReturnType& x)
+ {
+ static const int packetSize = internal::unpacket_traits<PacketReturnType>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize());
+
+ EIGEN_ALIGN_DEFAULT CoeffReturnType values[packetSize];
+ PacketReturnType rslt = internal::pstore<PacketReturnType>(values, x);
+ for (int i = 0; i < packetSize; ++i) {
+ coeffRef(index+i) = values[i];
+ }
+ }
+};
} // end namespace Eigen
diff --git a/unsupported/test/cxx11_tensor_concatenation.cpp b/unsupported/test/cxx11_tensor_concatenation.cpp
index 9fdf33c16..cc9dfb769 100644
--- a/unsupported/test/cxx11_tensor_concatenation.cpp
+++ b/unsupported/test/cxx11_tensor_concatenation.cpp
@@ -103,6 +103,25 @@ static void test_simple_concatenation()
// TODO(phli): Add test once we have a real vectorized implementation.
// static void test_vectorized_concatenation() {}
+static void test_concatenation_as_lvalue()
+{
+ Tensor<int, 2> t1(2, 3);
+ Tensor<int, 2> t2(2, 3);
+ t1.setRandom();
+ t2.setRandom();
+
+ Tensor<int, 2> result(4, 3);
+ result.setRandom();
+ t1.concatenate(t2, 0) = result;
+
+ for (int i = 0; i < 2; ++i) {
+ for (int j = 0; j < 3; ++j) {
+ VERIFY_IS_EQUAL(t1(i, j), result(i, j));
+ VERIFY_IS_EQUAL(t2(i, j), result(i+2, j));
+ }
+ }
+}
+
void test_cxx11_tensor_concatenation()
{
@@ -113,4 +132,6 @@ void test_cxx11_tensor_concatenation()
CALL_SUBTEST(test_simple_concatenation<ColMajor>());
CALL_SUBTEST(test_simple_concatenation<RowMajor>());
// CALL_SUBTEST(test_vectorized_concatenation());
+ CALL_SUBTEST(test_concatenation_as_lvalue());
+
}