aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gael Guennebaud <g.gael@free.fr>2017-02-20 11:46:21 +0100
committerGravatar Gael Guennebaud <g.gael@free.fr>2017-02-20 11:46:21 +0100
commit9081c8f6eaeb61a682950fac53af6b321667e355 (patch)
tree4531a647e5952576c89a4fbac3b3a679774a09d6
parent4b22048cead4b3b34f2a784bb77f215350496103 (diff)
Add support for RowOrder reshaped
-rw-r--r--Eigen/src/Core/Reshaped.h24
-rw-r--r--Eigen/src/plugins/ReshapedMethods.h12
-rw-r--r--test/reshape.cpp108
3 files changed, 88 insertions, 56 deletions
diff --git a/Eigen/src/Core/Reshaped.h b/Eigen/src/Core/Reshaped.h
index 42ce2dbae..56fd3519a 100644
--- a/Eigen/src/Core/Reshaped.h
+++ b/Eigen/src/Core/Reshaped.h
@@ -21,7 +21,7 @@ namespace Eigen {
* \tparam XprType the type of the expression in which we are taking a reshape
* \tparam Rows the number of rows of the reshape we are taking at compile time (optional)
* \tparam Cols the number of columns of the reshape we are taking at compile time (optional)
- * \tparam Order
+ * \tparam Order can be ColMajor or RowMajor, default is ColMajor.
*
* This class represents an expression of either a fixed-size or dynamic-size reshape.
* It is the return type of DenseBase::reshaped(NRowsType,NColsType) and
@@ -68,9 +68,8 @@ struct traits<Reshaped<XprType, Rows, Cols, Order> > : traits<XprType>
: Dynamic,
OuterStrideAtCompileTime = Dynamic,
- InOrder = Order,
HasDirectAccess = internal::has_direct_access<XprType>::ret
- && (Order==int(AutoOrderValue) || Order==int(XpxStorageOrder))
+ && (Order==int(XpxStorageOrder))
&& ((evaluator<XprType>::Flags&LinearAccessBit)==LinearAccessBit),
MaskPacketAccessBit = (InnerSize == Dynamic || (InnerSize % packet_traits<Scalar>::size) == 0)
@@ -324,11 +323,20 @@ struct reshaped_evaluator<ArgType, Rows, Cols, Order, /* HasDirectAccess */ fals
typedef std::pair<Index, Index> RowCol;
- inline RowCol index_remap(Index rowId, Index colId) const {
- const Index nth_elem_idx = colId * m_xpr.rows() + rowId;
- const Index actual_col = nth_elem_idx / m_xpr.nestedExpression().rows();
- const Index actual_row = nth_elem_idx % m_xpr.nestedExpression().rows();
- return RowCol(actual_row, actual_col);
+ inline RowCol index_remap(Index rowId, Index colId) const
+ {
+ if(Order==ColMajor)
+ {
+ const Index nth_elem_idx = colId * m_xpr.rows() + rowId;
+ return RowCol(nth_elem_idx % m_xpr.nestedExpression().rows(),
+ nth_elem_idx / m_xpr.nestedExpression().rows());
+ }
+ else
+ {
+ const Index nth_elem_idx = colId + rowId * m_xpr.cols();
+ return RowCol(nth_elem_idx / m_xpr.nestedExpression().cols(),
+ nth_elem_idx % m_xpr.nestedExpression().cols());
+ }
}
EIGEN_DEVICE_FUNC
diff --git a/Eigen/src/plugins/ReshapedMethods.h b/Eigen/src/plugins/ReshapedMethods.h
index a9b4af7c3..7a11a4bcc 100644
--- a/Eigen/src/plugins/ReshapedMethods.h
+++ b/Eigen/src/plugins/ReshapedMethods.h
@@ -40,10 +40,12 @@ reshaped(NRowsType nRows, NColsType nCols)
template<typename NRowsType, typename NColsType, typename OrderType>
EIGEN_DEVICE_FUNC
-inline Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,OrderType::value>
+inline Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
+ OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>
reshaped(NRowsType nRows, NColsType nCols, OrderType)
{
- return Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,OrderType::value>(
+ return Reshaped<Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
+ OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>(
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols));
}
@@ -59,10 +61,12 @@ reshaped(NRowsType nRows, NColsType nCols) const
template<typename NRowsType, typename NColsType, typename OrderType>
EIGEN_DEVICE_FUNC
-inline const Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,OrderType::value>
+inline const Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
+ OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>
reshaped(NRowsType nRows, NColsType nCols, OrderType) const
{
- return Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,OrderType::value>(
+ return Reshaped<const Derived,internal::get_fixed_value<NRowsType>::value,internal::get_fixed_value<NColsType>::value,
+ OrderType::value==AutoOrderValue?Flags&RowMajorBit:OrderType::value>(
derived(), internal::get_runtime_value(nRows), internal::get_runtime_value(nCols));
}
diff --git a/test/reshape.cpp b/test/reshape.cpp
index 9b2825d86..516dce0ba 100644
--- a/test/reshape.cpp
+++ b/test/reshape.cpp
@@ -1,6 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
+// Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2014 yoco <peter.xiau@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
@@ -9,45 +10,44 @@
#include "main.h"
-using Eigen::Map;
-using Eigen::MatrixXi;
-
-// just test a 4x4 matrix, enumerate all combination manually,
-// so I don't have to do template-meta-programming here.
+// just test a 4x4 matrix, enumerate all combination manually
template <typename MatType>
-void reshape_all_size(MatType m)
+void reshape4x4(MatType m)
{
- typedef Eigen::Map<MatrixXi> MapMat;
- // dynamic
- VERIFY_IS_EQUAL((m.reshaped( 1, 16)), MapMat(m.data(), 1, 16));
- VERIFY_IS_EQUAL((m.reshaped( 2, 8)), MapMat(m.data(), 2, 8));
- VERIFY_IS_EQUAL((m.reshaped( 4, 4)), MapMat(m.data(), 4, 4));
- VERIFY_IS_EQUAL((m.reshaped( 8, 2)), MapMat(m.data(), 8, 2));
- VERIFY_IS_EQUAL((m.reshaped(16, 1)), MapMat(m.data(), 16, 1));
+ if((MatType::Flags&RowMajorBit)==0)
+ {
+ typedef Map<MatrixXi> MapMat;
+ // dynamic
+ VERIFY_IS_EQUAL((m.reshaped( 1, 16)), MapMat(m.data(), 1, 16));
+ VERIFY_IS_EQUAL((m.reshaped( 2, 8)), MapMat(m.data(), 2, 8));
+ VERIFY_IS_EQUAL((m.reshaped( 4, 4)), MapMat(m.data(), 4, 4));
+ VERIFY_IS_EQUAL((m.reshaped( 8, 2)), MapMat(m.data(), 8, 2));
+ VERIFY_IS_EQUAL((m.reshaped(16, 1)), MapMat(m.data(), 16, 1));
- // static
- VERIFY_IS_EQUAL(m.reshaped(fix< 1>, fix<16>), MapMat(m.data(), 1, 16));
- VERIFY_IS_EQUAL(m.reshaped(fix< 2>, fix< 8>), MapMat(m.data(), 2, 8));
- VERIFY_IS_EQUAL(m.reshaped(fix< 4>, fix< 4>), MapMat(m.data(), 4, 4));
- VERIFY_IS_EQUAL(m.reshaped(fix< 8>, fix< 2>), MapMat(m.data(), 8, 2));
- VERIFY_IS_EQUAL(m.reshaped(fix<16>, fix< 1>), MapMat(m.data(), 16, 1));
+ // static
+ VERIFY_IS_EQUAL(m.reshaped(fix< 1>, fix<16>), MapMat(m.data(), 1, 16));
+ VERIFY_IS_EQUAL(m.reshaped(fix< 2>, fix< 8>), MapMat(m.data(), 2, 8));
+ VERIFY_IS_EQUAL(m.reshaped(fix< 4>, fix< 4>), MapMat(m.data(), 4, 4));
+ VERIFY_IS_EQUAL(m.reshaped(fix< 8>, fix< 2>), MapMat(m.data(), 8, 2));
+ VERIFY_IS_EQUAL(m.reshaped(fix<16>, fix< 1>), MapMat(m.data(), 16, 1));
- // reshape chain
- VERIFY_IS_EQUAL(
- (m
- .reshaped( 1, 16)
- .reshaped(fix< 2>,fix< 8>)
- .reshaped(16, 1)
- .reshaped(fix< 8>,fix< 2>)
- .reshaped( 2, 8)
- .reshaped(fix< 1>,fix<16>)
- .reshaped( 4, 4)
- .reshaped(fix<16>,fix< 1>)
- .reshaped( 8, 2)
- .reshaped(fix< 4>,fix< 4>)
- ),
- MapMat(m.data(), 4, 4)
- );
+ // reshape chain
+ VERIFY_IS_EQUAL(
+ (m
+ .reshaped( 1, 16)
+ .reshaped(fix< 2>,fix< 8>)
+ .reshaped(16, 1)
+ .reshaped(fix< 8>,fix< 2>)
+ .reshaped( 2, 8)
+ .reshaped(fix< 1>,fix<16>)
+ .reshaped( 4, 4)
+ .reshaped(fix<16>,fix< 1>)
+ .reshaped( 8, 2)
+ .reshaped(fix< 4>,fix< 4>)
+ ),
+ MapMat(m.data(), 4, 4)
+ );
+ }
VERIFY_IS_EQUAL(m.reshaped( 1, 16).data(), m.data());
VERIFY_IS_EQUAL(m.reshaped( 1, 16).innerStride(), 1);
@@ -56,23 +56,43 @@ void reshape_all_size(MatType m)
VERIFY_IS_EQUAL(m.reshaped( 2, 8).innerStride(), 1);
VERIFY_IS_EQUAL(m.reshaped( 2, 8).outerStride(), 2);
- m.reshaped(2,8,ColOrder);
+ if((MatType::Flags&RowMajorBit)==0)
+ {
+ VERIFY_IS_EQUAL(m.reshaped(2,8,ColOrder),m.reshaped(2,8));
+ VERIFY_IS_EQUAL(m.reshaped(2,8,ColOrder),m.reshaped(2,8,AutoOrder));
+ VERIFY_IS_EQUAL(m.transpose().reshaped(2,8,RowOrder),m.transpose().reshaped(2,8,AutoOrder));
+ }
+ else
+ {
+ VERIFY_IS_EQUAL(m.reshaped(2,8,ColOrder),m.reshaped(2,8));
+ VERIFY_IS_EQUAL(m.reshaped(2,8,RowOrder),m.reshaped(2,8,AutoOrder));
+ VERIFY_IS_EQUAL(m.transpose().reshaped(2,8,ColOrder),m.transpose().reshaped(2,8,AutoOrder));
+ VERIFY_IS_EQUAL(m.transpose().reshaped(2,8),m.transpose().reshaped(2,8,AutoOrder));
+ }
- MatrixXi m28r = m.reshaped(2,8,RowOrder);
- std::cout << m28r << "\n";
+ MatrixXi m28r1 = m.reshaped(2,8,RowOrder);
+ MatrixXi m28r2 = m.transpose().reshaped(8,2,ColOrder).transpose();
+ VERIFY_IS_EQUAL( m28r1, m28r2);
}
void test_reshape()
{
- Eigen::MatrixXi mx = Eigen::MatrixXi::Random(4, 4);
- Eigen::Matrix4i m4 = Eigen::Matrix4i::Random(4, 4);
+ typedef Matrix<int,Dynamic,Dynamic> RowMatrixXi;
+ typedef Matrix<int,4,4> RowMatrix4i;
+ MatrixXi mx = MatrixXi::Random(4, 4);
+ Matrix4i m4 = Matrix4i::Random(4, 4);
+ RowMatrixXi rmx = RowMatrixXi::Random(4, 4);
+ RowMatrix4i rm4 = RowMatrix4i::Random(4, 4);
// test dynamic-size matrix
- CALL_SUBTEST(reshape_all_size(mx));
+ CALL_SUBTEST(reshape4x4(mx));
// test static-size matrix
- CALL_SUBTEST(reshape_all_size(m4));
+ CALL_SUBTEST(reshape4x4(m4));
// test dynamic-size const matrix
- CALL_SUBTEST(reshape_all_size(static_cast<const Eigen::MatrixXi>(mx)));
+ CALL_SUBTEST(reshape4x4(static_cast<const MatrixXi>(mx)));
// test static-size const matrix
- CALL_SUBTEST(reshape_all_size(static_cast<const Eigen::Matrix4i>(m4)));
+ CALL_SUBTEST(reshape4x4(static_cast<const Matrix4i>(m4)));
+
+ CALL_SUBTEST(reshape4x4(rmx));
+ CALL_SUBTEST(reshape4x4(rm4));
}