aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen/src/Core/arch/NEON/PacketMath.h
diff options
context:
space:
mode:
authorGravatar Joel Holdsworth <joel@airwebreathe.org.uk>2020-03-26 20:18:19 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-03-26 20:18:19 +0000
commit52d54278beefee8b2f19dcca4fd900916154e174 (patch)
tree8a584a2bb27450b0e2af2ed102473ba265b7aa00 /Eigen/src/Core/arch/NEON/PacketMath.h
parentdeb93ed1bf359ac99923e3a2b90a2920b1101290 (diff)
Additional NEON packet-math operations
Diffstat (limited to 'Eigen/src/Core/arch/NEON/PacketMath.h')
-rw-r--r--Eigen/src/Core/arch/NEON/PacketMath.h149
1 files changed, 148 insertions, 1 deletions
diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h
index 76c61b42f..326873f8a 100644
--- a/Eigen/src/Core/arch/NEON/PacketMath.h
+++ b/Eigen/src/Core/arch/NEON/PacketMath.h
@@ -137,6 +137,7 @@ struct packet_traits<float> : default_packet_traits
size = 4,
HasHalfPacket = 1,
+ HasCast = 1,
HasAdd = 1,
HasSub = 1,
HasShift = 1,
@@ -151,6 +152,7 @@ struct packet_traits<float> : default_packet_traits
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
+ HasInsert = 1,
HasReduxp = 1,
HasDiv = 1,
@@ -178,6 +180,7 @@ struct packet_traits<int8_t> : default_packet_traits
size = 16,
HasHalfPacket = 1,
+ HasCast = 1,
HasAdd = 1,
HasSub = 1,
HasShift = 1,
@@ -192,6 +195,7 @@ struct packet_traits<int8_t> : default_packet_traits
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
+ HasInsert = 1,
HasReduxp = 1
};
};
@@ -208,6 +212,7 @@ struct packet_traits<uint8_t> : default_packet_traits
size = 16,
HasHalfPacket = 1,
+ HasCast = 1,
HasAdd = 1,
HasSub = 1,
HasShift = 1,
@@ -222,6 +227,7 @@ struct packet_traits<uint8_t> : default_packet_traits
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
+ HasInsert = 1,
HasReduxp = 1,
HasSqrt = 1
@@ -240,6 +246,7 @@ struct packet_traits<int16_t> : default_packet_traits
size = 8,
HasHalfPacket = 1,
+ HasCast = 1,
HasAdd = 1,
HasSub = 1,
HasShift = 1,
@@ -254,6 +261,7 @@ struct packet_traits<int16_t> : default_packet_traits
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
+ HasInsert = 1,
HasReduxp = 1
};
};
@@ -270,6 +278,7 @@ struct packet_traits<uint16_t> : default_packet_traits
size = 8,
HasHalfPacket = 1,
+ HasCast = 1,
HasAdd = 1,
HasSub = 1,
HasShift = 1,
@@ -284,6 +293,7 @@ struct packet_traits<uint16_t> : default_packet_traits
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
+ HasInsert = 1,
HasReduxp = 1,
HasSqrt = 1
@@ -302,6 +312,7 @@ struct packet_traits<int32_t> : default_packet_traits
size = 4,
HasHalfPacket = 1,
+ HasCast = 1,
HasAdd = 1,
HasSub = 1,
HasShift = 1,
@@ -316,6 +327,7 @@ struct packet_traits<int32_t> : default_packet_traits
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
+ HasInsert = 1,
HasReduxp = 1
};
};
@@ -332,6 +344,7 @@ struct packet_traits<uint32_t> : default_packet_traits
size = 4,
HasHalfPacket = 1,
+ HasCast = 1,
HasAdd = 1,
HasSub = 1,
HasShift = 1,
@@ -346,6 +359,7 @@ struct packet_traits<uint32_t> : default_packet_traits
HasConj = 1,
HasSetLinear = 0,
HasBlend = 0,
+ HasInsert = 1,
HasReduxp = 1,
HasSqrt = 1
@@ -1509,6 +1523,43 @@ template<> EIGEN_STRONG_INLINE Packet2l pandnot<Packet2l>(const Packet2l& a, con
template<> EIGEN_STRONG_INLINE Packet2ul pandnot<Packet2ul>(const Packet2ul& a, const Packet2ul& b)
{ return vbicq_u64(a,b); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2f pnot<Packet2f>(const Packet2f& a)
+{ return vreinterpret_f32_u32(vmvn_u32(vreinterpret_u32_f32(a))); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4f pnot<Packet4f>(const Packet4f& a)
+{ return vreinterpretq_f32_u32(vmvnq_u32(vreinterpretq_u32_f32(a))); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4c pnot<Packet4c>(const Packet4c& a)
+{ return ~a; }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8c pnot<Packet8c>(const Packet8c& a)
+{ return vmvn_s8(a); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16c pnot<Packet16c>(const Packet16c& a)
+{ return vmvnq_s8(a); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4uc pnot<Packet4uc>(const Packet4uc& a)
+{ return ~a; }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8uc pnot<Packet8uc>(const Packet8uc& a)
+{ return vmvn_u8(a); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16uc pnot<Packet16uc>(const Packet16uc& a)
+{ return vmvnq_u8(a); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4s pnot<Packet4s>(const Packet4s& a)
+{ return vmvn_s16(a); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8s pnot<Packet8s>(const Packet8s& a)
+{ return vmvnq_s16(a); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4us pnot<Packet4us>(const Packet4us& a)
+{ return vmvn_u16(a); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8us pnot<Packet8us>(const Packet8us& a)
+{ return vmvnq_u16(a); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2i pnot<Packet2i>(const Packet2i& a)
+{ return vmvn_s32(a); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4i pnot<Packet4i>(const Packet4i& a)
+{ return vmvnq_s32(a); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ui pnot<Packet2ui>(const Packet2ui& a)
+{ return vmvn_u32(a); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4ui pnot<Packet4ui>(const Packet4ui& a)
+{ return vmvnq_u32(a); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2l pnot<Packet2l>(const Packet2l& a)
+{ return vreinterpretq_s64_s32(vmvnq_s32(vreinterpretq_s32_s64(a))); }
+template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ul pnot<Packet2ul>(const Packet2ul& a)
+{ return vreinterpretq_u64_u32(vmvnq_u32(vreinterpretq_u32_u64(a))); }
+
template<int N> EIGEN_STRONG_INLINE Packet4c parithmetic_shift_right(Packet4c& a)
{ return vget_lane_s32(vreinterpret_s32_s8(vshr_n_s8(vreinterpret_s8_s32(vdup_n_s32(a)), N)), 0); }
template<int N> EIGEN_STRONG_INLINE Packet8c parithmetic_shift_right(Packet8c a) { return vshr_n_s8(a,N); }
@@ -3431,6 +3482,82 @@ ptranspose(PacketBlock<Packet2ul, 2>& kernel)
#endif
}
+template<> EIGEN_DEVICE_FUNC inline Packet2f pselect( const Packet2f& mask, const Packet2f& a, const Packet2f& b)
+{ return vbsl_f32(vreinterpret_u32_f32(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b)
+{ return vbslq_f32(vreinterpretq_u32_f32(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet8c pselect(const Packet8c& mask, const Packet8c& a, const Packet8c& b)
+{ return vbsl_s8(vreinterpret_u8_s8(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet16c pselect(const Packet16c& mask, const Packet16c& a, const Packet16c& b)
+{ return vbslq_s8(vreinterpretq_u8_s8(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet8uc pselect(const Packet8uc& mask, const Packet8uc& a, const Packet8uc& b)
+{ return vbsl_u8(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet16uc pselect(const Packet16uc& mask, const Packet16uc& a, const Packet16uc& b)
+{ return vbslq_u8(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet4s pselect(const Packet4s& mask, const Packet4s& a, const Packet4s& b)
+{ return vbsl_s16(vreinterpret_u16_s16(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet8s pselect(const Packet8s& mask, const Packet8s& a, const Packet8s& b)
+{ return vbslq_s16(vreinterpretq_u16_s16(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet4us pselect(const Packet4us& mask, const Packet4us& a, const Packet4us& b)
+{ return vbsl_u16(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet8us pselect(const Packet8us& mask, const Packet8us& a, const Packet8us& b)
+{ return vbslq_u16(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet2i pselect(const Packet2i& mask, const Packet2i& a, const Packet2i& b)
+{ return vbsl_s32(vreinterpret_u32_s32(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet4i pselect(const Packet4i& mask, const Packet4i& a, const Packet4i& b)
+{ return vbslq_s32(vreinterpretq_u32_s32(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet2ui pselect(const Packet2ui& mask, const Packet2ui& a, const Packet2ui& b)
+{ return vbsl_u32(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet4ui pselect(const Packet4ui& mask, const Packet4ui& a, const Packet4ui& b)
+{ return vbslq_u32(mask, a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet2l pselect(const Packet2l& mask, const Packet2l& a, const Packet2l& b)
+{ return vbslq_s64(vreinterpretq_u64_s64(mask), a, b); }
+template<> EIGEN_DEVICE_FUNC inline Packet2ul pselect(const Packet2ul& mask, const Packet2ul& a, const Packet2ul& b)
+{ return vbslq_u64(mask, a, b); }
+
+EIGEN_DEVICE_FUNC inline Packet2f pinsertfirst(const Packet2f& a, float b) { return vset_lane_f32(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet4f pinsertfirst(const Packet4f& a, float b) { return vsetq_lane_f32(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet4c pinsertfirst(const Packet4c& a, int8_t b)
+{
+ return static_cast<int32_t>((static_cast<uint32_t>(a) & 0xffffff00u) |
+ (static_cast<uint32_t>(b) & 0xffu));
+}
+EIGEN_DEVICE_FUNC inline Packet8c pinsertfirst(const Packet8c& a, int8_t b) { return vset_lane_s8(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet16c pinsertfirst(const Packet16c& a, int8_t b) { return vsetq_lane_s8(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet4uc pinsertfirst(const Packet4uc& a, uint8_t b) { return (a & ~0xffu) | b; }
+EIGEN_DEVICE_FUNC inline Packet8uc pinsertfirst(const Packet8uc& a, uint8_t b) { return vset_lane_u8(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet16uc pinsertfirst(const Packet16uc& a, uint8_t b) { return vsetq_lane_u8(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet4s pinsertfirst(const Packet4s& a, int16_t b) { return vset_lane_s16(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet8s pinsertfirst(const Packet8s& a, int16_t b) { return vsetq_lane_s16(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet4us pinsertfirst(const Packet4us& a, uint16_t b) { return vset_lane_u16(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet8us pinsertfirst(const Packet8us& a, uint16_t b) { return vsetq_lane_u16(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet2i pinsertfirst(const Packet2i& a, int32_t b) { return vset_lane_s32(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet4i pinsertfirst(const Packet4i& a, int32_t b) { return vsetq_lane_s32(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet2ui pinsertfirst(const Packet2ui& a, uint32_t b) { return vset_lane_u32(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet4ui pinsertfirst(const Packet4ui& a, uint32_t b) { return vsetq_lane_u32(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet2l pinsertfirst(const Packet2l& a, int64_t b) { return vsetq_lane_s64(b, a, 0); }
+EIGEN_DEVICE_FUNC inline Packet2ul pinsertfirst(const Packet2ul& a, uint64_t b) { return vsetq_lane_u64(b, a, 0); }
+
+EIGEN_DEVICE_FUNC inline Packet2f pinsertlast(const Packet2f& a, float b) { return vset_lane_f32(b, a, 1); }
+EIGEN_DEVICE_FUNC inline Packet4f pinsertlast(const Packet4f& a, float b) { return vsetq_lane_f32(b, a, 3); }
+EIGEN_DEVICE_FUNC inline Packet4c pinsertlast(const Packet4c& a, int8_t b)
+{ return (static_cast<uint32_t>(a) & 0x00ffffffu) | (static_cast<uint32_t>(b) << 24); }
+EIGEN_DEVICE_FUNC inline Packet8c pinsertlast(const Packet8c& a, int8_t b) { return vset_lane_s8(b, a, 7); }
+EIGEN_DEVICE_FUNC inline Packet16c pinsertlast(const Packet16c& a, int8_t b) { return vsetq_lane_s8(b, a, 15); }
+EIGEN_DEVICE_FUNC inline Packet4uc pinsertlast(const Packet4uc& a, uint8_t b) { return (a & ~0xff000000u) | (b << 24); }
+EIGEN_DEVICE_FUNC inline Packet8uc pinsertlast(const Packet8uc& a, uint8_t b) { return vset_lane_u8(b, a, 7); }
+EIGEN_DEVICE_FUNC inline Packet16uc pinsertlast(const Packet16uc& a, uint8_t b) { return vsetq_lane_u8(b, a, 15); }
+EIGEN_DEVICE_FUNC inline Packet4s pinsertlast(const Packet4s& a, int16_t b) { return vset_lane_s16(b, a, 3); }
+EIGEN_DEVICE_FUNC inline Packet8s pinsertlast(const Packet8s& a, int16_t b) { return vsetq_lane_s16(b, a, 7); }
+EIGEN_DEVICE_FUNC inline Packet4us pinsertlast(const Packet4us& a, uint16_t b) { return vset_lane_u16(b, a, 3); }
+EIGEN_DEVICE_FUNC inline Packet8us pinsertlast(const Packet8us& a, uint16_t b) { return vsetq_lane_u16(b, a, 7); }
+EIGEN_DEVICE_FUNC inline Packet2i pinsertlast(const Packet2i& a, int32_t b) { return vset_lane_s32(b, a, 1); }
+EIGEN_DEVICE_FUNC inline Packet4i pinsertlast(const Packet4i& a, int32_t b) { return vsetq_lane_s32(b, a, 3); }
+EIGEN_DEVICE_FUNC inline Packet2ui pinsertlast(const Packet2ui& a, uint32_t b) { return vset_lane_u32(b, a, 1); }
+EIGEN_DEVICE_FUNC inline Packet4ui pinsertlast(const Packet4ui& a, uint32_t b) { return vsetq_lane_u32(b, a, 3); }
+EIGEN_DEVICE_FUNC inline Packet2l pinsertlast(const Packet2l& a, int64_t b) { return vsetq_lane_s64(b, a, 1); }
+EIGEN_DEVICE_FUNC inline Packet2ul pinsertlast(const Packet2ul& a, uint64_t b) { return vsetq_lane_u64(b, a, 1); }
+
/**
* Computes the integer square root
* @remarks The calculation is performed using an algorithm which iterates through each binary digit of the result
@@ -3579,7 +3706,7 @@ template<> struct packet_traits<double> : default_packet_traits
HasReduxp = 1,
HasDiv = 1,
- HasFloor = 0,
+ HasFloor = 1,
HasSin = 0,
HasCos = 0,
@@ -3639,6 +3766,18 @@ template<> EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const
template<> EIGEN_STRONG_INLINE Packet2d pmax<Packet2d>(const Packet2d& a, const Packet2d& b) { return vmaxq_f64(a,b); }
+template<> EIGEN_STRONG_INLINE Packet2d pfloor<Packet2d>(const Packet2d& a)
+{
+ const Packet2d cst_1 = pset1<Packet2d>(1.0);
+ /* perform a floorf */
+ const Packet2d tmp = vcvtq_f64_s64(vcvtq_s64_f64(a));
+
+ /* if greater, substract 1 */
+ uint64x2_t mask = vcgtq_f64(tmp, a);
+ mask = vandq_u64(mask, vreinterpretq_u64_f64(cst_1));
+ return vsubq_f64(tmp, vreinterpretq_f64_u64(mask));
+}
+
// Logical Operations are not supported for float, so we have to reinterpret casts using NEON intrinsics
template<> EIGEN_STRONG_INLINE Packet2d pand<Packet2d>(const Packet2d& a, const Packet2d& b)
{ return vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); }
@@ -3755,6 +3894,14 @@ ptranspose(PacketBlock<Packet2d, 2>& kernel)
kernel.packet[0] = tmp1;
kernel.packet[1] = tmp2;
}
+
+template<> EIGEN_DEVICE_FUNC inline Packet2d pselect( const Packet2d& mask, const Packet2d& a, const Packet2d& b)
+{ return vbslq_f64(vreinterpretq_u64_f64(mask), a, b); }
+
+EIGEN_DEVICE_FUNC inline Packet2d pinsertfirst(const Packet2d& a, double b) { return vsetq_lane_f64(b, a, 0); }
+
+EIGEN_DEVICE_FUNC inline Packet2d pinsertlast(const Packet2d& a, double b) { return vsetq_lane_f64(b, a, 1); }
+
#endif // EIGEN_ARCH_ARM64
} // end namespace internal