aboutsummaryrefslogtreecommitdiffhomepage
path: root/Eigen
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-05-10 11:05:33 -0700
committerGravatar Benoit Steiner <benoit.steiner.goog@gmail.com>2016-05-10 11:05:33 -0700
commit0b9e3dcd06585d28ac4b59dfd518b0a49af3a359 (patch)
tree9604bc5e7798a0ab4869b60e4cd2bc9c30c064c4 /Eigen
parent6bf8273bc0b2ccc1558c35bf358ccd731970f04a (diff)
Added packet primitives to compute exp, log, sqrt and rsqrt on fp16. This improves the performance by 10 to 30%.
Diffstat (limited to 'Eigen')
-rw-r--r--Eigen/src/Core/arch/CUDA/PacketMathHalf.h38
1 files changed, 37 insertions, 1 deletions
diff --git a/Eigen/src/Core/arch/CUDA/PacketMathHalf.h b/Eigen/src/Core/arch/CUDA/PacketMathHalf.h
index 0cebc1017..8873d5357 100644
--- a/Eigen/src/Core/arch/CUDA/PacketMathHalf.h
+++ b/Eigen/src/Core/arch/CUDA/PacketMathHalf.h
@@ -34,7 +34,11 @@ template<> struct packet_traits<half> : default_packet_traits
AlignedOnScalar = 1,
size=2,
HasHalfPacket = 0,
- HasDiv = 1
+ HasDiv = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasExp = 1,
+ HasLog = 1
};
};
@@ -267,6 +271,38 @@ template<> EIGEN_DEVICE_FUNC inline half predux_mul<half2>(const half2& a) {
#endif
}
+template<> EIGEN_DEVICE_FUNC inline half2 plog<half2>(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = logf(a1);
+ float r2 = logf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+template<> EIGEN_DEVICE_FUNC inline half2 pexp<half2>(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = expf(a1);
+ float r2 = expf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+template<> EIGEN_DEVICE_FUNC inline half2 psqrt<half2>(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = sqrtf(a1);
+ float r2 = sqrtf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
+template<> EIGEN_DEVICE_FUNC inline half2 prsqrt<half2>(const half2& a) {
+ float a1 = __low2float(a);
+ float a2 = __high2float(a);
+ float r1 = rsqrtf(a1);
+ float r2 = rsqrtf(a2);
+ return __floats2half2_rn(r1, r2);
+}
+
} // end namespace internal
} // end namespace Eigen