aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h
diff options
context:
space:
mode:
Diffstat (limited to 'unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h51
1 files changed, 41 insertions, 10 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h
index cd04716bd..fdb5ee6b8 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h
@@ -47,22 +47,39 @@ template <> struct max_n_1<0> {
// Default packet types
template <typename Scalar, typename Device>
-struct PacketType {
+struct PacketType : internal::packet_traits<Scalar> {
typedef typename internal::packet_traits<Scalar>::type type;
- enum { size = internal::unpacket_traits<type>::size };
};
// For CUDA packet types when using a GpuDevice
-#if defined(EIGEN_USE_GPU) && defined(__CUDACC__)
+#if defined(EIGEN_USE_GPU) && defined(__CUDACC__) && defined(EIGEN_HAS_CUDA_FP16)
template <>
-struct PacketType<float, GpuDevice> {
- typedef float4 type;
- static const int size = 4;
-};
-template <>
-struct PacketType<double, GpuDevice> {
- typedef double2 type;
+struct PacketType<half, GpuDevice> {
+ typedef half2 type;
static const int size = 2;
+ enum {
+ HasAdd = 1,
+ HasSub = 1,
+ HasMul = 1,
+ HasNegate = 1,
+ HasAbs = 1,
+ HasArg = 0,
+ HasAbs2 = 0,
+ HasMin = 1,
+ HasMax = 1,
+ HasConj = 0,
+ HasSetLinear = 0,
+ HasBlend = 0,
+
+ HasDiv = 1,
+ HasSqrt = 1,
+ HasRsqrt = 1,
+ HasExp = 1,
+ HasLog = 1,
+ HasLog1p = 0,
+ HasLog10 = 0,
+ HasPow = 1,
+ };
};
#endif
@@ -112,6 +129,20 @@ bool operator!=(const Tuple<U, V>& x, const Tuple<U, V>& y) {
}
+// Can't use std::pairs on cuda devices
+template <typename Idx> struct IndexPair {
+ EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair() : first(0), second(0) {}
+ EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair(Idx f, Idx s) : first(f), second(s) {}
+
+ EIGEN_DEVICE_FUNC void set(IndexPair<Idx> val) {
+ first = val.first;
+ second = val.second;
+ }
+
+ Idx first;
+ Idx second;
+};
+
#ifdef EIGEN_HAS_SFINAE
namespace internal {