aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Eigen/src/Core/arch/CUDA/TypeCasting.h20
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h15
-rw-r--r--unsupported/test/cxx11_tensor_of_float16_cuda.cu23
3 files changed, 45 insertions, 13 deletions
diff --git a/Eigen/src/Core/arch/CUDA/TypeCasting.h b/Eigen/src/Core/arch/CUDA/TypeCasting.h
index 279fd4fd0..2742a4e7b 100644
--- a/Eigen/src/Core/arch/CUDA/TypeCasting.h
+++ b/Eigen/src/Core/arch/CUDA/TypeCasting.h
@@ -34,6 +34,26 @@ template<>
struct functor_traits<scalar_cast_op<float, half> >
{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+template<>
+struct scalar_cast_op<int, half> {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
+ typedef half result_type;
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half operator() (const int& a) const {
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
+ return __float2half(static_cast<float>(a));
+ #else
+ assert(false && "tbd");
+ return half();
+ #endif
+ }
+};
+
+template<>
+struct functor_traits<scalar_cast_op<int, half> >
+{ enum { Cost = NumTraits<float>::AddCost, PacketAccess = false }; };
+
+
template<>
struct scalar_cast_op<half, float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op)
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
index f94ffa020..e2d876140 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h
@@ -72,11 +72,12 @@ template <typename T> struct SumReducer
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
- return static_cast<T>(0);
+ internal::scalar_cast_op<int, T> conv;
+ return conv(0);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
- return pset1<Packet>(0);
+ return pset1<Packet>(initialize());
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
return accum;
@@ -110,11 +111,12 @@ template <typename T> struct MeanReducer
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
- return static_cast<T>(0);
+ internal::scalar_cast_op<int, T> conv;
+ return conv(0);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
- return pset1<Packet>(0);
+ return pset1<Packet>(initialize());
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
return accum / scalarCount_;
@@ -214,11 +216,12 @@ template <typename T> struct ProdReducer
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
- return static_cast<T>(1);
+ internal::scalar_cast_op<int, T> conv;
+ return conv(1);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
- return pset1<Packet>(1);
+ return pset1<Packet>(initialize());
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
return accum;
diff --git a/unsupported/test/cxx11_tensor_of_float16_cuda.cu b/unsupported/test/cxx11_tensor_of_float16_cuda.cu
index d3cd94cd6..5ce96a1c2 100644
--- a/unsupported/test/cxx11_tensor_of_float16_cuda.cu
+++ b/unsupported/test/cxx11_tensor_of_float16_cuda.cu
@@ -93,7 +93,6 @@ void test_cuda_elementwise() {
gpu_device.deallocate(d_res_half);
gpu_device.deallocate(d_res_float);
}
-
/*
void test_cuda_contractions() {
Eigen::CudaStreamDevice stream;
@@ -139,7 +138,7 @@ void test_cuda_contractions() {
gpu_device.deallocate(d_float2);
gpu_device.deallocate(d_res_half);
gpu_device.deallocate(d_res_float);
-}
+}*/
void test_cuda_reductions() {
@@ -183,7 +182,7 @@ void test_cuda_reductions() {
gpu_device.deallocate(d_res_half);
gpu_device.deallocate(d_res_float);
}
-*/
+
#endif
@@ -191,9 +190,19 @@ void test_cuda_reductions() {
void test_cxx11_tensor_of_float16_cuda()
{
#ifdef EIGEN_HAS_CUDA_FP16
- CALL_SUBTEST_1(test_cuda_conversion());
- CALL_SUBTEST_1(test_cuda_elementwise());
-// CALL_SUBTEST_2(test_cuda_contractions());
-// CALL_SUBTEST_3(test_cuda_reductions());
+ Eigen::CudaStreamDevice stream;
+ Eigen::GpuDevice device(&stream);
+ if (device.majorDeviceVersion() > 5 ||
+ (device.majorDeviceVersion() == 5 && device.minorDeviceVersion() >= 3)) {
+ CALL_SUBTEST_1(test_cuda_conversion());
+ CALL_SUBTEST_1(test_cuda_elementwise());
+// CALL_SUBTEST_2(test_cuda_contractions());
+ CALL_SUBTEST_3(test_cuda_reductions());
+ }
+ else {
+ std::cout << "Half floats require compute capability of at least 5.3. This device only supports " << device.majorDeviceVersion() << "." << device.minorDeviceVersion() << ". Skipping the test" << std::endl;
+ }
+#else
+ std::cout << "Half floats are not supported by this version of cuda: skipping the test" << std::endl;
#endif
}