diff options
-rw-r--r-- | Eigen/src/Core/AssignEvaluator.h | 53 | ||||
-rw-r--r-- | Eigen/src/Core/CoreEvaluators.h | 67 | ||||
-rw-r--r-- | Eigen/src/Core/SelfCwiseBinaryOp.h | 10 | ||||
-rw-r--r-- | test/evaluators.cpp | 26 |
4 files changed, 155 insertions, 1 deletions
diff --git a/Eigen/src/Core/AssignEvaluator.h b/Eigen/src/Core/AssignEvaluator.h index c5f345a2f..006a87d47 100644 --- a/Eigen/src/Core/AssignEvaluator.h +++ b/Eigen/src/Core/AssignEvaluator.h @@ -624,6 +624,59 @@ void swap_using_evaluator(const DstXprType& dst, const SrcXprType& src) copy_using_evaluator(SwapWrapper<DstXprType>(const_cast<DstXprType&>(dst)), src); } +// Based on MatrixBase::operator+= (in CwiseBinaryOp.h) +template<typename DstXprType, typename SrcXprType> +void add_assign_using_evaluator(const MatrixBase<DstXprType>& dst, const MatrixBase<SrcXprType>& src) +{ + typedef typename DstXprType::Scalar Scalar; + SelfCwiseBinaryOp<internal::scalar_sum_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived()); + copy_using_evaluator(tmp, src.derived()); +} + +// Based on ArrayBase::operator+= +template<typename DstXprType, typename SrcXprType> +void add_assign_using_evaluator(const ArrayBase<DstXprType>& dst, const ArrayBase<SrcXprType>& src) +{ + typedef typename DstXprType::Scalar Scalar; + SelfCwiseBinaryOp<internal::scalar_sum_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived()); + copy_using_evaluator(tmp, src.derived()); +} + +// TODO: Add add_assign_using_evaluator for EigenBase ? + +template<typename DstXprType, typename SrcXprType> +void subtract_assign_using_evaluator(const MatrixBase<DstXprType>& dst, const MatrixBase<SrcXprType>& src) +{ + typedef typename DstXprType::Scalar Scalar; + SelfCwiseBinaryOp<internal::scalar_difference_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived()); + copy_using_evaluator(tmp, src.derived()); +} + +template<typename DstXprType, typename SrcXprType> +void subtract_assign_using_evaluator(const ArrayBase<DstXprType>& dst, const ArrayBase<SrcXprType>& src) +{ + typedef typename DstXprType::Scalar Scalar; + SelfCwiseBinaryOp<internal::scalar_difference_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived()); + copy_using_evaluator(tmp, src.derived()); +} + +template<typename DstXprType, typename SrcXprType> +void multiply_assign_using_evaluator(const ArrayBase<DstXprType>& dst, const ArrayBase<SrcXprType>& src) +{ + typedef typename DstXprType::Scalar Scalar; + SelfCwiseBinaryOp<internal::scalar_product_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived()); + copy_using_evaluator(tmp, src.derived()); +} + +template<typename DstXprType, typename SrcXprType> +void divide_assign_using_evaluator(const ArrayBase<DstXprType>& dst, const ArrayBase<SrcXprType>& src) +{ + typedef typename DstXprType::Scalar Scalar; + SelfCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived()); + copy_using_evaluator(tmp, src.derived()); +} + + } // namespace internal #endif // EIGEN_ASSIGN_EVALUATOR_H diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h index 899aa04ea..2314be719 100644 --- a/Eigen/src/Core/CoreEvaluators.h +++ b/Eigen/src/Core/CoreEvaluators.h @@ -1032,6 +1032,8 @@ struct evaluator_impl<SwapWrapper<ArgType> > typedef typename XprType::Scalar Scalar; typedef typename XprType::Packet Packet; + // This function and the next one are needed by assign to correctly align loads/stores + // TODO make Assign use .data() Scalar& coeffRef(Index row, Index col) { return m_argImpl.coeffRef(row, col); @@ -1085,6 +1087,71 @@ protected: }; +// ---------- SelfCwiseBinaryOp ---------- + +template<typename BinaryOp, typename LhsXpr, typename RhsXpr> +struct evaluator_impl<SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> > + : evaluator_impl_base<SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> > +{ + typedef SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> XprType; + + evaluator_impl(const XprType& selfCwiseBinaryOp) + : m_argImpl(selfCwiseBinaryOp.expression()), + m_functor(selfCwiseBinaryOp.functor()) + { } + + typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::Packet Packet; + + // This function and the next one are needed by assign to correctly align loads/stores + // TODO make Assign use .data() + Scalar& coeffRef(Index row, Index col) + { + return m_argImpl.coeffRef(row, col); + } + + inline Scalar& coeffRef(Index index) + { + return m_argImpl.coeffRef(index); + } + + template<typename OtherEvaluatorType> + void copyCoeff(Index row, Index col, const OtherEvaluatorType& other) + { + Scalar& tmp = m_argImpl.coeffRef(row, col); + tmp = m_functor(tmp, other.coeff(row, col)); + } + + template<typename OtherEvaluatorType> + void copyCoeff(Index index, const OtherEvaluatorType& other) + { + Scalar& tmp = m_argImpl.coeffRef(index); + tmp = m_functor(tmp, other.coeff(index)); + } + + template<int StoreMode, int LoadMode, typename OtherEvaluatorType> + void copyPacket(Index row, Index col, const OtherEvaluatorType& other) + { + const Packet res = m_functor.packetOp(m_argImpl.template packet<StoreMode>(row, col), + other.template packet<LoadMode>(row, col)); + m_argImpl.template writePacket<StoreMode>(row, col, res); + } + + template<int StoreMode, int LoadMode, typename OtherEvaluatorType> + void copyPacket(Index index, const OtherEvaluatorType& other) + { + const Packet res = m_functor.packetOp(m_argImpl.template packet<StoreMode>(index), + other.template packet<LoadMode>(index)); + m_argImpl.template writePacket<StoreMode>(index, res); + } + +protected: + typename evaluator<LhsXpr>::type m_argImpl; + const BinaryOp& m_functor; +}; + + } // namespace internal #endif // EIGEN_COREEVALUATORS_H diff --git a/Eigen/src/Core/SelfCwiseBinaryOp.h b/Eigen/src/Core/SelfCwiseBinaryOp.h index 4e9ca8874..d7cb261c4 100644 --- a/Eigen/src/Core/SelfCwiseBinaryOp.h +++ b/Eigen/src/Core/SelfCwiseBinaryOp.h @@ -163,6 +163,16 @@ template<typename BinaryOp, typename Lhs, typename Rhs> class SelfCwiseBinaryOp return Base::operator=(rhs); } + Lhs& expression() const + { + return m_matrix; + } + + const BinaryOp& functor() const + { + return m_functor; + } + protected: Lhs& m_matrix; const BinaryOp& m_functor; diff --git a/test/evaluators.cpp b/test/evaluators.cpp index 6e81ad5ef..5c8e500bc 100644 --- a/test/evaluators.cpp +++ b/test/evaluators.cpp @@ -236,5 +236,29 @@ void test_evaluators() VERIFY_IS_APPROX(mat1, mat1ref); VERIFY_IS_APPROX(mat2, mat2ref); } - + + { + // test compound assignment + const Matrix4d mat_const = Matrix4d::Random(); + Matrix4d mat, mat_ref; + mat = mat_ref = Matrix4d::Identity(); + add_assign_using_evaluator(mat, mat_const); + mat_ref += mat_const; + VERIFY_IS_APPROX(mat, mat_ref); + + subtract_assign_using_evaluator(mat.row(1), 2*mat.row(2)); + mat_ref.row(1) -= 2*mat_ref.row(2); + VERIFY_IS_APPROX(mat, mat_ref); + + const ArrayXXf arr_const = ArrayXXf::Random(5,3); + ArrayXXf arr, arr_ref; + arr = arr_ref = ArrayXXf::Constant(5, 3, 0.5); + multiply_assign_using_evaluator(arr, arr_const); + arr_ref *= arr_const; + VERIFY_IS_APPROX(arr, arr_ref); + + divide_assign_using_evaluator(arr.row(1), arr.row(2) + 1); + arr_ref.row(1) /= (arr_ref.row(2) + 1); + VERIFY_IS_APPROX(arr, arr_ref); + } } |