aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Antonio Sanchez <cantonios@google.com>2020-12-10 14:05:38 -0800
committerGravatar Antonio Sanchez <cantonios@google.com>2021-01-05 10:41:25 -0800
commitbb1de9dbdede6669c2c86c028a9deff637e3d1f6 (patch)
tree3d7936aeea85b1541c2ff8c5eece202aefe82c42
parent12dda34b15fe2ea4d994fafdba8c4666963e1a55 (diff)
Fix Ref Stride checks.
The existing `Ref` class failed to consider cases where the Ref's `Stride` setting *could* match the underlying referred object's stride, but **didn't** at runtime. This led to trying to set invalid stride values, causing runtime failures in some cases, and garbage due to mismatched strides in others. Here we add the missing runtime checks. This involves computing the strides necessary to align with the referred object's storage, and verifying we can actually set those strides at runtime. In the `const` case, if it *may* be possible to refer to the original storage at compile-time but fails at runtime, then we defer to the `construct(...)` method that makes a copy. Added more tests to check these cases. Fixes #2093.
-rw-r--r--Eigen/src/Core/Ref.h133
-rw-r--r--test/ref.cpp66
2 files changed, 184 insertions, 15 deletions
diff --git a/Eigen/src/Core/Ref.h b/Eigen/src/Core/Ref.h
index 172c8ffb6..00aa45d34 100644
--- a/Eigen/src/Core/Ref.h
+++ b/Eigen/src/Core/Ref.h
@@ -93,29 +93,127 @@ protected:
typedef Stride<StrideType::OuterStrideAtCompileTime,StrideType::InnerStrideAtCompileTime> StrideBase;
+ // Resolves inner stride if default 0.
+ static Index resolveInnerStride(Index inner) {
+ if (inner == 0) {
+ return 1;
+ }
+ return inner;
+ }
+
+ // Resolves outer stride if default 0.
+ static Index resolveOuterStride(Index inner, Index outer, Index rows, Index cols, bool isVectorAtCompileTime, bool isRowMajor) {
+ if (outer == 0) {
+ if (isVectorAtCompileTime) {
+ outer = inner * rows * cols;
+ } else if (isRowMajor) {
+ outer = inner * cols;
+ } else {
+ outer = inner * rows;
+ }
+ }
+ return outer;
+ }
+
+ // Returns true if construction is valid, false if there is a stride mismatch,
+ // and fails if there is a size mismatch.
template<typename Expression>
- EIGEN_DEVICE_FUNC void construct(Expression& expr)
+ EIGEN_DEVICE_FUNC bool construct(Expression& expr)
{
- EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(PlainObjectType,Expression);
-
+ // Check matrix sizes. If this is a compile-time vector, we do allow
+ // implicitly transposing.
+ EIGEN_STATIC_ASSERT(
+ EIGEN_PREDICATE_SAME_MATRIX_SIZE(PlainObjectType, Expression)
+ // If it is a vector, the transpose sizes might match.
+ || ( PlainObjectType::IsVectorAtCompileTime
+ && ((int(PlainObjectType::RowsAtCompileTime)==Eigen::Dynamic
+ || int(Expression::ColsAtCompileTime)==Eigen::Dynamic
+ || int(PlainObjectType::RowsAtCompileTime)==int(Expression::ColsAtCompileTime))
+ && (int(PlainObjectType::ColsAtCompileTime)==Eigen::Dynamic
+ || int(Expression::RowsAtCompileTime)==Eigen::Dynamic
+ || int(PlainObjectType::ColsAtCompileTime)==int(Expression::RowsAtCompileTime)))),
+ YOU_MIXED_MATRICES_OF_DIFFERENT_SIZES
+ )
+
+ // Determine runtime rows and columns.
+ Index rows = expr.rows();
+ Index cols = expr.cols();
if(PlainObjectType::RowsAtCompileTime==1)
{
eigen_assert(expr.rows()==1 || expr.cols()==1);
- ::new (static_cast<Base*>(this)) Base(expr.data(), 1, expr.size());
+ rows = 1;
+ cols = expr.size();
}
else if(PlainObjectType::ColsAtCompileTime==1)
{
eigen_assert(expr.rows()==1 || expr.cols()==1);
- ::new (static_cast<Base*>(this)) Base(expr.data(), expr.size(), 1);
+ rows = expr.size();
+ cols = 1;
+ }
+ // Verify that the sizes are valid.
+ eigen_assert(
+ (PlainObjectType::RowsAtCompileTime == Dynamic) || (PlainObjectType::RowsAtCompileTime == rows));
+ eigen_assert(
+ (PlainObjectType::ColsAtCompileTime == Dynamic) || (PlainObjectType::ColsAtCompileTime == cols));
+
+
+ // If this is a vector, we might be transposing, which means that stride should swap.
+ const bool transpose = PlainObjectType::IsVectorAtCompileTime && (rows != expr.rows());
+ // If the storage format differs, we also need to swap the stride.
+ const bool row_major = ((PlainObjectType::Flags)&RowMajorBit) != 0;
+ const bool expr_row_major = (Expression::Flags&RowMajorBit) != 0;
+ const bool storage_differs = (row_major != expr_row_major);
+
+ const bool swap_stride = (transpose != storage_differs);
+
+ // Determine expr's actual strides, resolving any defaults if zero.
+ const Index expr_inner_actual = resolveInnerStride(expr.innerStride());
+ const Index expr_outer_actual = resolveOuterStride(expr_inner_actual,
+ expr.outerStride(),
+ expr.rows(),
+ expr.cols(),
+ Expression::IsVectorAtCompileTime != 0,
+ expr_row_major);
+
+ // If this is a column-major row vector or row-major column vector, the inner-stride
+ // is arbitrary, so set it to either the compile-time inner stride or 1.
+ const bool row_vector = (rows == 1);
+ const bool col_vector = (cols == 1);
+ const Index inner_stride =
+ ( (!row_major && row_vector) || (row_major && col_vector) ) ?
+ ( StrideType::InnerStrideAtCompileTime > 0 ? Index(StrideType::InnerStrideAtCompileTime) : 1)
+ : swap_stride ? expr_outer_actual : expr_inner_actual;
+
+ // If this is a column-major column vector or row-major row vector, the outer-stride
+ // is arbitrary, so set it to either the compile-time outer stride or vector size.
+ const Index outer_stride =
+ ( (!row_major && col_vector) || (row_major && row_vector) ) ?
+ ( StrideType::OuterStrideAtCompileTime > 0 ? Index(StrideType::OuterStrideAtCompileTime) : rows * cols * inner_stride)
+ : swap_stride ? expr_inner_actual : expr_outer_actual;
+
+ // Check if given inner/outer strides are compatible with compile-time strides.
+ const bool inner_valid = (StrideType::InnerStrideAtCompileTime == Dynamic)
+ || (resolveInnerStride(Index(StrideType::InnerStrideAtCompileTime)) == inner_stride);
+ if (!inner_valid) {
+ return false;
+ }
+
+ const bool outer_valid = (StrideType::OuterStrideAtCompileTime == Dynamic)
+ || (resolveOuterStride(
+ inner_stride,
+ Index(StrideType::OuterStrideAtCompileTime),
+ rows, cols, PlainObjectType::IsVectorAtCompileTime != 0,
+ row_major)
+ == outer_stride);
+ if (!outer_valid) {
+ return false;
}
- else
- ::new (static_cast<Base*>(this)) Base(expr.data(), expr.rows(), expr.cols());
- if(Expression::IsVectorAtCompileTime && (!PlainObjectType::IsVectorAtCompileTime) && ((Expression::Flags&RowMajorBit)!=(PlainObjectType::Flags&RowMajorBit)))
- ::new (&m_stride) StrideBase(expr.innerStride(), StrideType::InnerStrideAtCompileTime==0?0:1);
- else
- ::new (&m_stride) StrideBase(StrideType::OuterStrideAtCompileTime==0?0:expr.outerStride(),
- StrideType::InnerStrideAtCompileTime==0?0:expr.innerStride());
+ ::new (static_cast<Base*>(this)) Base(expr.data(), rows, cols);
+ ::new (&m_stride) StrideBase(
+ (StrideType::OuterStrideAtCompileTime == 0) ? 0 : outer_stride,
+ (StrideType::InnerStrideAtCompileTime == 0) ? 0 : inner_stride );
+ return true;
}
StrideBase m_stride;
@@ -212,7 +310,8 @@ template<typename PlainObjectType, int Options, typename StrideType> class Ref
typename internal::enable_if<bool(Traits::template match<Derived>::MatchAtCompileTime),Derived>::type* = 0)
{
EIGEN_STATIC_ASSERT(bool(Traits::template match<Derived>::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH);
- Base::construct(expr.derived());
+ // Construction must pass since we will not create temprary storage in the non-const case.
+ eigen_assert(Base::construct(expr.derived()));
}
template<typename Derived>
EIGEN_DEVICE_FUNC inline Ref(const DenseBase<Derived>& expr,
@@ -226,7 +325,8 @@ template<typename PlainObjectType, int Options, typename StrideType> class Ref
EIGEN_STATIC_ASSERT(bool(internal::is_lvalue<Derived>::value), THIS_EXPRESSION_IS_NOT_A_LVALUE__IT_IS_READ_ONLY);
EIGEN_STATIC_ASSERT(bool(Traits::template match<Derived>::MatchAtCompileTime), STORAGE_LAYOUT_DOES_NOT_MATCH);
EIGEN_STATIC_ASSERT(!Derived::IsPlainObjectBase,THIS_EXPRESSION_IS_NOT_A_LVALUE__IT_IS_READ_ONLY);
- Base::construct(expr.const_cast_derived());
+ // Construction must pass since we will not create temporary storage in the non-const case.
+ eigen_assert(Base::construct(expr.const_cast_derived()));
}
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Ref)
@@ -267,7 +367,10 @@ template<typename TPlainObjectType, int Options, typename StrideType> class Ref<
template<typename Expression>
EIGEN_DEVICE_FUNC void construct(const Expression& expr,internal::true_type)
{
- Base::construct(expr);
+ // Check if we can use the underlying expr's storage directly, otherwise call the copy version.
+ if (!Base::construct(expr)) {
+ construct(expr, internal::false_type());
+ }
}
template<typename Expression>
diff --git a/test/ref.cpp b/test/ref.cpp
index c0b6ffdcf..ebfc70d3d 100644
--- a/test/ref.cpp
+++ b/test/ref.cpp
@@ -141,6 +141,69 @@ template<typename VectorType> void ref_vector(const VectorType& m)
VERIFY_IS_APPROX(mat1, mat2);
}
+template<typename Scalar, int Rows, int Cols>
+void ref_vector_fixed_sizes()
+{
+ typedef Matrix<Scalar,Rows,Cols,RowMajor> RowMajorMatrixType;
+ typedef Matrix<Scalar,Rows,Cols,ColMajor> ColMajorMatrixType;
+ typedef Matrix<Scalar,1,Cols> RowVectorType;
+ typedef Matrix<Scalar,Rows,1> ColVectorType;
+ typedef Matrix<Scalar,Cols,1> RowVectorTransposeType;
+ typedef Matrix<Scalar,1,Rows> ColVectorTransposeType;
+ typedef Stride<Dynamic, Dynamic> DynamicStride;
+
+ RowMajorMatrixType mr = RowMajorMatrixType::Random();
+ ColMajorMatrixType mc = ColMajorMatrixType::Random();
+
+ Index i = internal::random<Index>(0,Rows-1);
+ Index j = internal::random<Index>(0,Cols-1);
+
+ // Reference ith row.
+ Ref<RowVectorType, 0, DynamicStride> mr_ri = mr.row(i);
+ VERIFY_IS_EQUAL(mr_ri, mr.row(i));
+ Ref<RowVectorType, 0, DynamicStride> mc_ri = mc.row(i);
+ VERIFY_IS_EQUAL(mc_ri, mc.row(i));
+
+ // Reference jth col.
+ Ref<ColVectorType, 0, DynamicStride> mr_cj = mr.col(j);
+ VERIFY_IS_EQUAL(mr_cj, mr.col(j));
+ Ref<ColVectorType, 0, DynamicStride> mc_cj = mc.col(j);
+ VERIFY_IS_EQUAL(mc_cj, mc.col(j));
+
+ // Reference the transpose of row i.
+ Ref<RowVectorTransposeType, 0, DynamicStride> mr_rit = mr.row(i);
+ VERIFY_IS_EQUAL(mr_rit, mr.row(i).transpose());
+ Ref<RowVectorTransposeType, 0, DynamicStride> mc_rit = mc.row(i);
+ VERIFY_IS_EQUAL(mc_rit, mc.row(i).transpose());
+
+ // Reference the transpose of col j.
+ Ref<ColVectorTransposeType, 0, DynamicStride> mr_cjt = mr.col(j);
+ VERIFY_IS_EQUAL(mr_cjt, mr.col(j).transpose());
+ Ref<ColVectorTransposeType, 0, DynamicStride> mc_cjt = mc.col(j);
+ VERIFY_IS_EQUAL(mc_cjt, mc.col(j).transpose());
+
+ // Const references without strides.
+ Ref<const RowVectorType> cmr_ri = mr.row(i);
+ VERIFY_IS_EQUAL(cmr_ri, mr.row(i));
+ Ref<const RowVectorType> cmc_ri = mc.row(i);
+ VERIFY_IS_EQUAL(cmc_ri, mc.row(i));
+
+ Ref<const ColVectorType> cmr_cj = mr.col(j);
+ VERIFY_IS_EQUAL(cmr_cj, mr.col(j));
+ Ref<const ColVectorType> cmc_cj = mc.col(j);
+ VERIFY_IS_EQUAL(cmc_cj, mc.col(j));
+
+ Ref<const RowVectorTransposeType> cmr_rit = mr.row(i);
+ VERIFY_IS_EQUAL(cmr_rit, mr.row(i).transpose());
+ Ref<const RowVectorTransposeType> cmc_rit = mc.row(i);
+ VERIFY_IS_EQUAL(cmc_rit, mc.row(i).transpose());
+
+ Ref<const ColVectorTransposeType> cmr_cjt = mr.col(j);
+ VERIFY_IS_EQUAL(cmr_cjt, mr.col(j).transpose());
+ Ref<const ColVectorTransposeType> cmc_cjt = mc.col(j);
+ VERIFY_IS_EQUAL(cmc_cjt, mc.col(j).transpose());
+}
+
template<typename PlainObjectType> void check_const_correctness(const PlainObjectType&)
{
// verify that ref-to-const don't have LvalueBit
@@ -287,6 +350,9 @@ EIGEN_DECLARE_TEST(ref)
CALL_SUBTEST_4( ref_matrix(Matrix<std::complex<double>,10,15>()) );
CALL_SUBTEST_5( ref_matrix(MatrixXi(internal::random<int>(1,10),internal::random<int>(1,10))) );
CALL_SUBTEST_6( call_ref() );
+
+ CALL_SUBTEST_8( (ref_vector_fixed_sizes<float,3,5>()) );
+ CALL_SUBTEST_8( (ref_vector_fixed_sizes<float,15,10>()) );
}
CALL_SUBTEST_7( test_ref_overloads() );