aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jitse Niesen <jitse@maths.leeds.ac.uk>2012-06-29 13:49:25 +0100
committerGravatar Jitse Niesen <jitse@maths.leeds.ac.uk>2012-06-29 13:49:25 +0100
commitd0b873822f52f5739868ba322dae5b3d0c399a4d (patch)
tree8c6af7d69559a24555fd9531a4d3dda2ae9ccdd1
parent2393ceb38048506b799689e7bc109a4db5d09e99 (diff)
Make product eval-at-once.
* Make product EvalAtOnce in cases OuterProduct, GemmProduct and GemvProduct * Ensure that product evaluators are nested inside EvalToTemp evaluator * As temporary kludge, evaluate expression to temporary in AllAtOnce traversal and pass expression operator to evalTo()
-rw-r--r--Eigen/src/Core/AssignEvaluator.h8
-rw-r--r--Eigen/src/Core/CoreEvaluators.h207
-rw-r--r--Eigen/src/Core/ProductEvaluators.h44
-rw-r--r--test/evaluators.cpp5
4 files changed, 173 insertions, 91 deletions
diff --git a/Eigen/src/Core/AssignEvaluator.h b/Eigen/src/Core/AssignEvaluator.h
index 08a2c696a..9be00067d 100644
--- a/Eigen/src/Core/AssignEvaluator.h
+++ b/Eigen/src/Core/AssignEvaluator.h
@@ -616,7 +616,13 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, AllAtOnceTraversal, NoU
DstEvaluatorType dstEvaluator(dst);
SrcEvaluatorType srcEvaluator(src);
- srcEvaluator.evalTo(dstEvaluator);
+ // Evaluate rhs in temporary to prevent aliasing problems in a = a * a;
+ // TODO: Be smarter about this
+ // TODO: Do not pass the xpr object to evalTo()
+ typename DstXprType::PlainObject tmp;
+ typename evaluator<typename DstXprType::PlainObject>::type tmpEvaluator(tmp);
+ srcEvaluator.evalTo(tmpEvaluator, tmp);
+ copy_using_evaluator(dst, tmp);
}
};
diff --git a/Eigen/src/Core/CoreEvaluators.h b/Eigen/src/Core/CoreEvaluators.h
index 768fa8950..808546ec1 100644
--- a/Eigen/src/Core/CoreEvaluators.h
+++ b/Eigen/src/Core/CoreEvaluators.h
@@ -3,7 +3,7 @@
//
// Copyright (C) 2011 Benoit Jacob <jacob.benoit.1@gmail.com>
// Copyright (C) 2011 Gael Guennebaud <gael.guennebaud@inria.fr>
-// Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk>
+// Copyright (C) 2011-2012 Jitse Niesen <jitse@maths.leeds.ac.uk>
//
// Eigen is free software; you can redistribute it and/or
// modify it under the terms of the GNU Lesser General Public
@@ -42,24 +42,46 @@ struct evaluator_traits
static const int HasEvalTo = 0;
};
+// expression class for evaluating nested expression to a temporary
+
+template<typename ArgType>
+class EvalToTemp;
+
// evaluator<T>::type is type of evaluator for T
+// evaluator<T>::nestedType is type of evaluator if T is nested inside another evaluator
+
+template<typename T>
+struct evaluator_impl
+{ };
+
+template<typename T, int Nested = evaluator_traits<T>::HasEvalTo>
+struct evaluator_nested_type;
template<typename T>
-struct evaluator_impl {};
+struct evaluator_nested_type<T, 0>
+{
+ typedef evaluator_impl<T> type;
+};
+
+template<typename T>
+struct evaluator_nested_type<T, 1>
+{
+ typedef evaluator_impl<EvalToTemp<T> > type;
+};
template<typename T>
struct evaluator
{
typedef evaluator_impl<T> type;
+ typedef typename evaluator_nested_type<T>::type nestedType;
};
// TODO: Think about const-correctness
template<typename T>
struct evaluator<const T>
-{
- typedef evaluator_impl<T> type;
-};
+ : evaluator<T>
+{ };
// ---------- base class for all writable evaluators ----------
@@ -132,70 +154,6 @@ struct evaluator_impl_base
}
};
-// -------------------- Transpose --------------------
-
-template<typename ArgType>
-struct evaluator_impl<Transpose<ArgType> >
- : evaluator_impl_base<Transpose<ArgType> >
-{
- typedef Transpose<ArgType> XprType;
-
- evaluator_impl(const XprType& t) : m_argImpl(t.nestedExpression()) {}
-
- typedef typename XprType::Index Index;
- typedef typename XprType::Scalar Scalar;
- typedef typename XprType::CoeffReturnType CoeffReturnType;
- typedef typename XprType::PacketScalar PacketScalar;
- typedef typename XprType::PacketReturnType PacketReturnType;
-
- CoeffReturnType coeff(Index row, Index col) const
- {
- return m_argImpl.coeff(col, row);
- }
-
- CoeffReturnType coeff(Index index) const
- {
- return m_argImpl.coeff(index);
- }
-
- Scalar& coeffRef(Index row, Index col)
- {
- return m_argImpl.coeffRef(col, row);
- }
-
- typename XprType::Scalar& coeffRef(Index index)
- {
- return m_argImpl.coeffRef(index);
- }
-
- template<int LoadMode>
- PacketReturnType packet(Index row, Index col) const
- {
- return m_argImpl.template packet<LoadMode>(col, row);
- }
-
- template<int LoadMode>
- PacketReturnType packet(Index index) const
- {
- return m_argImpl.template packet<LoadMode>(index);
- }
-
- template<int StoreMode>
- void writePacket(Index row, Index col, const PacketScalar& x)
- {
- m_argImpl.template writePacket<StoreMode>(col, row, x);
- }
-
- template<int StoreMode>
- void writePacket(Index index, const PacketScalar& x)
- {
- m_argImpl.template writePacket<StoreMode>(index, x);
- }
-
-protected:
- typename evaluator<ArgType>::type m_argImpl;
-};
-
// -------------------- Matrix and Array --------------------
//
// evaluator_impl<PlainObjectBase> is a common base class for the
@@ -285,6 +243,89 @@ struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
{ }
};
+// -------------------- EvalToTemp --------------------
+
+template<typename ArgType>
+struct evaluator_impl<EvalToTemp<ArgType> >
+ : evaluator_impl<typename ArgType::PlainObject>
+{
+ typedef typename ArgType::PlainObject PlainObject;
+ typedef evaluator_impl<PlainObject> BaseType;
+
+ evaluator_impl(const ArgType& arg)
+ : BaseType(m_result)
+ {
+ copy_using_evaluator(m_result, arg);
+ };
+
+protected:
+ PlainObject m_result;
+};
+
+// -------------------- Transpose --------------------
+
+template<typename ArgType>
+struct evaluator_impl<Transpose<ArgType> >
+ : evaluator_impl_base<Transpose<ArgType> >
+{
+ typedef Transpose<ArgType> XprType;
+
+ evaluator_impl(const XprType& t) : m_argImpl(t.nestedExpression()) {}
+
+ typedef typename XprType::Index Index;
+ typedef typename XprType::Scalar Scalar;
+ typedef typename XprType::CoeffReturnType CoeffReturnType;
+ typedef typename XprType::PacketScalar PacketScalar;
+ typedef typename XprType::PacketReturnType PacketReturnType;
+
+ CoeffReturnType coeff(Index row, Index col) const
+ {
+ return m_argImpl.coeff(col, row);
+ }
+
+ CoeffReturnType coeff(Index index) const
+ {
+ return m_argImpl.coeff(index);
+ }
+
+ Scalar& coeffRef(Index row, Index col)
+ {
+ return m_argImpl.coeffRef(col, row);
+ }
+
+ typename XprType::Scalar& coeffRef(Index index)
+ {
+ return m_argImpl.coeffRef(index);
+ }
+
+ template<int LoadMode>
+ PacketReturnType packet(Index row, Index col) const
+ {
+ return m_argImpl.template packet<LoadMode>(col, row);
+ }
+
+ template<int LoadMode>
+ PacketReturnType packet(Index index) const
+ {
+ return m_argImpl.template packet<LoadMode>(index);
+ }
+
+ template<int StoreMode>
+ void writePacket(Index row, Index col, const PacketScalar& x)
+ {
+ m_argImpl.template writePacket<StoreMode>(col, row, x);
+ }
+
+ template<int StoreMode>
+ void writePacket(Index index, const PacketScalar& x)
+ {
+ m_argImpl.template writePacket<StoreMode>(index, x);
+ }
+
+protected:
+ typename evaluator<ArgType>::nestedType m_argImpl;
+};
+
// -------------------- CwiseNullaryOp --------------------
template<typename NullaryOp, typename PlainObjectType>
@@ -366,7 +407,7 @@ struct evaluator_impl<CwiseUnaryOp<UnaryOp, ArgType> >
protected:
const UnaryOp m_functor;
- typename evaluator<ArgType>::type m_argImpl;
+ typename evaluator<ArgType>::nestedType m_argImpl;
};
// -------------------- CwiseBinaryOp --------------------
@@ -412,8 +453,8 @@ struct evaluator_impl<CwiseBinaryOp<BinaryOp, Lhs, Rhs> >
protected:
const BinaryOp m_functor;
- typename evaluator<Lhs>::type m_lhsImpl;
- typename evaluator<Rhs>::type m_rhsImpl;
+ typename evaluator<Lhs>::nestedType m_lhsImpl;
+ typename evaluator<Rhs>::nestedType m_rhsImpl;
};
// -------------------- CwiseUnaryView --------------------
@@ -455,7 +496,7 @@ struct evaluator_impl<CwiseUnaryView<UnaryOp, ArgType> >
protected:
const UnaryOp m_unaryOp;
- typename evaluator<ArgType>::type m_argImpl;
+ typename evaluator<ArgType>::nestedType m_argImpl;
};
// -------------------- Map --------------------
@@ -626,7 +667,7 @@ struct evaluator_impl<Block<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDir
}
protected:
- typename evaluator<ArgType>::type m_argImpl;
+ typename evaluator<ArgType>::nestedType m_argImpl;
// TODO: Get rid of m_startRow, m_startCol if known at compile time
Index m_startRow;
@@ -681,9 +722,9 @@ struct evaluator_impl<Select<ConditionMatrixType, ThenMatrixType, ElseMatrixType
}
protected:
- typename evaluator<ConditionMatrixType>::type m_conditionImpl;
- typename evaluator<ThenMatrixType>::type m_thenImpl;
- typename evaluator<ElseMatrixType>::type m_elseImpl;
+ typename evaluator<ConditionMatrixType>::nestedType m_conditionImpl;
+ typename evaluator<ThenMatrixType>::nestedType m_thenImpl;
+ typename evaluator<ElseMatrixType>::nestedType m_elseImpl;
};
@@ -731,7 +772,7 @@ struct evaluator_impl<Replicate<ArgType, RowFactor, ColFactor> >
}
protected:
- typename evaluator<ArgType>::type m_argImpl;
+ typename evaluator<ArgType>::nestedType m_argImpl;
Index m_rows; // TODO: Get rid of this if known at compile time
Index m_cols;
};
@@ -834,7 +875,7 @@ struct evaluator_impl_wrapper_base
}
protected:
- typename evaluator<ArgType>::type m_argImpl;
+ typename evaluator<ArgType>::nestedType m_argImpl;
};
template<typename ArgType>
@@ -949,7 +990,7 @@ struct evaluator_impl<Reverse<ArgType, Direction> >
}
protected:
- typename evaluator<ArgType>::type m_argImpl;
+ typename evaluator<ArgType>::nestedType m_argImpl;
Index m_rows; // TODO: Don't use if known at compile time or not needed
Index m_cols;
};
@@ -993,7 +1034,7 @@ struct evaluator_impl<Diagonal<ArgType, DiagIndex> >
}
protected:
- typename evaluator<ArgType>::type m_argImpl;
+ typename evaluator<ArgType>::nestedType m_argImpl;
Index m_index; // TODO: Don't use if known at compile time
private:
@@ -1069,7 +1110,7 @@ struct evaluator_impl<SwapWrapper<ArgType> >
}
protected:
- typename evaluator<ArgType>::type m_argImpl;
+ typename evaluator<ArgType>::nestedType m_argImpl;
};
@@ -1133,7 +1174,7 @@ struct evaluator_impl<SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> >
}
protected:
- typename evaluator<LhsXpr>::type m_argImpl;
+ typename evaluator<LhsXpr>::nestedType m_argImpl;
const BinaryOp m_functor;
};
diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
index aadaa9303..e814a4710 100644
--- a/Eigen/src/Core/ProductEvaluators.h
+++ b/Eigen/src/Core/ProductEvaluators.h
@@ -50,6 +50,14 @@ struct evaluator_impl<Product<Lhs, Rhs> >
{ }
};
+template<typename XprType, typename ProductType>
+struct product_evaluator_traits_dispatcher;
+
+template<typename Lhs, typename Rhs>
+struct evaluator_traits<Product<Lhs, Rhs> >
+ : product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, typename ProductReturnType<Lhs, Rhs>::Type>
+{ };
+
// Case 1: Evaluate all at once
//
// We can view the GeneralProduct class as a part of the product evaluator.
@@ -57,13 +65,20 @@ struct evaluator_impl<Product<Lhs, Rhs> >
// InnerProduct is special because GeneralProduct does not have an evalTo() method in this case.
template<typename Lhs, typename Rhs>
+struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, InnerProduct> >
+{
+ static const int HasEvalTo = 0;
+};
+
+template<typename Lhs, typename Rhs>
struct product_evaluator_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, InnerProduct> >
: public evaluator<typename Product<Lhs, Rhs>::PlainObject>::type
{
typedef Product<Lhs, Rhs> XprType;
typedef typename XprType::PlainObject PlainObject;
typedef typename evaluator<PlainObject>::type evaluator_base;
-
+
+ // TODO: Computation is too early (?)
product_evaluator_dispatcher(const XprType& xpr) : evaluator_base(m_result)
{
m_result.coeffRef(0,0) = (xpr.lhs().transpose().cwiseProduct(xpr.rhs())).sum();
@@ -77,21 +92,30 @@ protected:
// TODO: GeneralProduct should take evaluators, not expression objects.
template<typename Lhs, typename Rhs, int ProductType>
+struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, ProductType> >
+{
+ static const int HasEvalTo = 1;
+};
+
+template<typename Lhs, typename Rhs, int ProductType>
struct product_evaluator_dispatcher<Product<Lhs, Rhs>, GeneralProduct<Lhs, Rhs, ProductType> >
- : public evaluator<typename Product<Lhs, Rhs>::PlainObject>::type
{
typedef Product<Lhs, Rhs> XprType;
typedef typename XprType::PlainObject PlainObject;
typedef typename evaluator<PlainObject>::type evaluator_base;
- product_evaluator_dispatcher(const XprType& xpr) : evaluator_base(m_result)
+ product_evaluator_dispatcher(const XprType& xpr) : m_xpr(xpr)
+ { }
+
+ template<typename DstEvaluatorType, typename DstXprType>
+ void evalTo(DstEvaluatorType /* not used */, DstXprType& dst)
{
- m_result.resize(xpr.rows(), xpr.cols());
- GeneralProduct<Lhs, Rhs, ProductType>(xpr.lhs(), xpr.rhs()).evalTo(m_result);
+ dst.resize(m_xpr.rows(), m_xpr.cols());
+ GeneralProduct<Lhs, Rhs, ProductType>(m_xpr.lhs(), m_xpr.rhs()).evalTo(dst);
}
-protected:
- PlainObject m_result;
+protected:
+ const XprType& m_xpr;
};
// Case 2: Evaluate coeff by coeff
@@ -107,6 +131,12 @@ template<int StorageOrder, int UnrollingIndex, typename Lhs, typename Rhs, typen
struct etor_product_packet_impl;
template<typename Lhs, typename Rhs, typename LhsNested, typename RhsNested, int Flags>
+struct product_evaluator_traits_dispatcher<Product<Lhs, Rhs>, CoeffBasedProduct<LhsNested, RhsNested, Flags> >
+{
+ static const int HasEvalTo = 0;
+};
+
+template<typename Lhs, typename Rhs, typename LhsNested, typename RhsNested, int Flags>
struct product_evaluator_dispatcher<Product<Lhs, Rhs>, CoeffBasedProduct<LhsNested, RhsNested, Flags> >
: evaluator_impl_base<Product<Lhs, Rhs> >
{
diff --git a/test/evaluators.cpp b/test/evaluators.cpp
index 62ba5b126..3081d7858 100644
--- a/test/evaluators.cpp
+++ b/test/evaluators.cpp
@@ -65,6 +65,11 @@ void test_evaluators()
VERIFY_IS_APPROX_EVALUATOR2(d, s * prod(a,b), s * a*b);
VERIFY_IS_APPROX_EVALUATOR2(d, prod(a,b).transpose(), (a*b).transpose());
VERIFY_IS_APPROX_EVALUATOR2(d, prod(a,b) + prod(b,c), a*b + b*c);
+
+ // check that prod works even with aliasing present
+ c = a*a;
+ copy_using_evaluator(a, prod(a,a));
+ VERIFY_IS_APPROX(a,c);
}
{