aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar ShengYang1 <yang.sheng@intel.com>2020-04-07 13:18:00 +0800
committerGravatar ShengYang1 <yang.sheng@intel.com>2020-06-09 08:12:07 +0800
commitb5d66b5e7395be326cbc66434ac0e35da33732e2 (patch)
tree5e1f6baaa3cab1e9e00ffcfa265405213ff23ae0
parentc4059ffcb6763be9108f050df9ef179f4bbbfa73 (diff)
Implement scalar_cmp_with_cast_op
-rw-r--r--Eigen/src/Core/functors/BinaryFunctors.h104
1 files changed, 104 insertions, 0 deletions
diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h
index 697816663..54bcddd34 100644
--- a/Eigen/src/Core/functors/BinaryFunctors.h
+++ b/Eigen/src/Core/functors/BinaryFunctors.h
@@ -251,6 +251,110 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_NEQ> : binary_op_base<LhsScalar,Rh
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a!=b;}
};
+/** \internal
+ * \brief Template functors for comparison of two scalars and cast the output from boolean to input datatype
+ */
+template<typename LhsScalar, typename RhsScalar, ComparisonName cmp> struct scalar_cmp_with_cast_op;
+
+template<typename LhsScalar, typename RhsScalar, ComparisonName cmp>
+struct functor_traits<scalar_cmp_with_cast_op<LhsScalar,RhsScalar, cmp> > {
+ enum {
+ Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
+ PacketAccess = internal::is_same<LhsScalar, RhsScalar>::value && internal::is_same<LhsScalar, float>::value
+ };
+};
+
+template<typename LhsScalar, typename RhsScalar>
+struct scalar_cmp_with_cast_op<LhsScalar, RhsScalar, cmp_EQ> : binary_op_base<LhsScalar,RhsScalar>
+{
+ typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_cmp_with_cast_op>::ReturnType result_type;
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_with_cast_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
+ if(a==b) return static_cast<result_type>(1);
+ else return static_cast<result_type>(0);
+ }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ { return internal::pselect(internal::pcmp_eq(a,b), internal::pset1<Packet>(static_cast<result_type>(1)), internal::pzero(a)); }
+};
+template<typename LhsScalar, typename RhsScalar>
+struct scalar_cmp_with_cast_op<LhsScalar, RhsScalar, cmp_LT> : binary_op_base<LhsScalar,RhsScalar>
+{
+ typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_cmp_with_cast_op>::ReturnType result_type;
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_with_cast_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
+ if(a<b) return static_cast<result_type>(1);
+ else return static_cast<result_type>(0);
+ }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ { return internal::pselect(internal::pcmp_lt(a,b), internal::pset1<Packet>(static_cast<result_type>(1)), internal::pzero(a)); }
+};
+template<typename LhsScalar, typename RhsScalar>
+struct scalar_cmp_with_cast_op<LhsScalar, RhsScalar, cmp_LE> : binary_op_base<LhsScalar,RhsScalar>
+{
+ typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_cmp_with_cast_op>::ReturnType result_type;
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_with_cast_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
+ if(a<=b) return static_cast<result_type>(1);
+ else return static_cast<result_type>(0);
+ }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ { return internal::pselect(internal::pcmp_le(a,b), internal::pset1<Packet>(static_cast<result_type>(1)), internal::pzero(a)); }
+};
+template<typename LhsScalar, typename RhsScalar>
+struct scalar_cmp_with_cast_op<LhsScalar, RhsScalar, cmp_GT> : binary_op_base<LhsScalar,RhsScalar>
+{
+ typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_cmp_with_cast_op>::ReturnType result_type;
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_with_cast_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
+ if(a>b) return static_cast<result_type>(1);
+ else return static_cast<result_type>(0);
+ }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ { return internal::pselect(internal::pcmp_le(a,b), internal::pzero(a), internal::pset1<Packet>(static_cast<result_type>(1))); }
+};
+template<typename LhsScalar, typename RhsScalar>
+struct scalar_cmp_with_cast_op<LhsScalar, RhsScalar, cmp_GE> : binary_op_base<LhsScalar,RhsScalar>
+{
+ typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_cmp_with_cast_op>::ReturnType result_type;
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_with_cast_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
+ if(a>=b) return static_cast<result_type>(1);
+ else return static_cast<result_type>(0);
+ }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ { return internal::pselect(internal::pcmp_lt(a,b), internal::pzero(a), internal::pset1<Packet>(static_cast<result_type>(1))); }
+};
+template<typename LhsScalar, typename RhsScalar>
+struct scalar_cmp_with_cast_op<LhsScalar, RhsScalar, cmp_UNORD> : binary_op_base<LhsScalar,RhsScalar>
+{
+ typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_cmp_with_cast_op>::ReturnType result_type;
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_with_cast_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
+ if(a<=b || b<=a) return static_cast<result_type>(0);
+ else return static_cast<result_type>(1);
+ }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ { return internal::pselect(por(internal::pcmp_le(a,b), internal::pcmp_le(b,a)), internal::pzero(a), internal::pset1<Packet>(static_cast<result_type>(1))); }
+};
+template<typename LhsScalar, typename RhsScalar>
+struct scalar_cmp_with_cast_op<LhsScalar, RhsScalar, cmp_NEQ> : binary_op_base<LhsScalar,RhsScalar>
+{
+ typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_cmp_with_cast_op>::ReturnType result_type;
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_with_cast_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
+ if(a!=b) return static_cast<result_type>(1);
+ else return static_cast<result_type>(0);
+ }
+ template<typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a, const Packet& b) const
+ { return internal::pselect(internal::pcmp_eq(a,b), internal::pzero(a), internal::pset1<Packet>(static_cast<result_type>(1))); }
+};
/** \internal
* \brief Template functor to compute the hypot of two \b positive \b and \b real scalars