aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Eigen/src/Core/VectorwiseOp.h39
-rw-r--r--test/vectorwiseop.cpp24
2 files changed, 63 insertions, 0 deletions
diff --git a/Eigen/src/Core/VectorwiseOp.h b/Eigen/src/Core/VectorwiseOp.h
index f48ec884c..4549641dd 100644
--- a/Eigen/src/Core/VectorwiseOp.h
+++ b/Eigen/src/Core/VectorwiseOp.h
@@ -233,6 +233,28 @@ template<typename ExpressionType, int Direction> class VectorwiseOp
Direction==Vertical ? 1 : m_matrix.rows(),
Direction==Horizontal ? 1 : m_matrix.cols());
}
+
+ template<typename OtherDerived> struct OppositeExtendedType {
+ typedef Replicate<OtherDerived,
+ Direction==Horizontal ? 1 : ExpressionType::RowsAtCompileTime,
+ Direction==Vertical ? 1 : ExpressionType::ColsAtCompileTime> Type;
+ };
+
+ /** \internal
+ * Replicates a vector in the opposite direction to match the size of \c *this */
+ template<typename OtherDerived>
+ typename OppositeExtendedType<OtherDerived>::Type
+ extendedToOpposite(const DenseBase<OtherDerived>& other) const
+ {
+ EIGEN_STATIC_ASSERT(EIGEN_IMPLIES(Direction==Horizontal, OtherDerived::MaxColsAtCompileTime==1),
+ YOU_PASSED_A_ROW_VECTOR_BUT_A_COLUMN_VECTOR_WAS_EXPECTED)
+ EIGEN_STATIC_ASSERT(EIGEN_IMPLIES(Direction==Vertical, OtherDerived::MaxRowsAtCompileTime==1),
+ YOU_PASSED_A_COLUMN_VECTOR_BUT_A_ROW_VECTOR_WAS_EXPECTED)
+ return typename OppositeExtendedType<OtherDerived>::Type
+ (other.derived(),
+ Direction==Horizontal ? 1 : m_matrix.rows(),
+ Direction==Vertical ? 1 : m_matrix.cols());
+ }
public:
@@ -504,6 +526,23 @@ template<typename ExpressionType, int Direction> class VectorwiseOp
EIGEN_STATIC_ASSERT_SAME_XPR_KIND(ExpressionType, OtherDerived)
return m_matrix / extendedTo(other.derived());
}
+
+ /** \returns an expression where each column of row of the referenced matrix are normalized.
+ * The referenced matrix is \b not modified.
+ * \sa MatrixBase::normalized(), normalize()
+ */
+ CwiseBinaryOp<internal::scalar_quotient_op<Scalar>,
+ const ExpressionTypeNestedCleaned,
+ const typename OppositeExtendedType<typename ReturnType<internal::member_norm,RealScalar>::Type>::Type>
+ normalized() const { return m_matrix.cwiseQuotient(extendedToOpposite(this->norm())); }
+
+
+ /** Normalize in-place each row or columns of the referenced matrix.
+ * \sa MatrixBase::normalize(), normalized()
+ */
+ void normalize() {
+ m_matrix = this->normalized();
+ }
/////////// Geometry module ///////////
diff --git a/test/vectorwiseop.cpp b/test/vectorwiseop.cpp
index b938e3957..904b51664 100644
--- a/test/vectorwiseop.cpp
+++ b/test/vectorwiseop.cpp
@@ -111,6 +111,8 @@ template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> ColVectorType;
typedef Matrix<Scalar, 1, MatrixType::ColsAtCompileTime> RowVectorType;
+ typedef Matrix<RealScalar, MatrixType::RowsAtCompileTime, 1> RealColVectorType;
+ typedef Matrix<RealScalar, 1, MatrixType::ColsAtCompileTime> RealRowVectorType;
Index rows = m.rows();
Index cols = m.cols();
@@ -123,6 +125,8 @@ template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
ColVectorType colvec = ColVectorType::Random(rows);
RowVectorType rowvec = RowVectorType::Random(cols);
+ RealColVectorType rcres;
+ RealRowVectorType rrres;
// test addition
@@ -159,6 +163,26 @@ template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
VERIFY_RAISES_ASSERT(m2.rowwise() -= rowvec.transpose());
VERIFY_RAISES_ASSERT(m1.rowwise() - rowvec.transpose());
+
+ // test norm
+ rrres = m1.colwise().norm();
+ VERIFY_IS_APPROX(rrres(c), m1.col(c).norm());
+ rcres = m1.rowwise().norm();
+ VERIFY_IS_APPROX(rcres(r), m1.row(r).norm());
+
+ // test normalized
+ m2 = m1.colwise().normalized();
+ VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized());
+ m2 = m1.rowwise().normalized();
+ VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
+
+ // test normalize
+ m2 = m1;
+ m2.colwise().normalize();
+ VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized());
+ m2 = m1;
+ m2.rowwise().normalize();
+ VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
}
void test_vectorwiseop()