diff options
author | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-10-13 23:53:11 +0000 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2020-10-13 23:53:11 +0000 |
commit | 61fc78bbda22f425a97761be57c219928d929ddc (patch) | |
tree | a577d1021784676fa5188faa1c987fe4061fa122 /unsupported/Eigen/CXX11/src | |
parent | c6953f799b01d36f4236b64f351cc1446e0abe17 (diff) |
Get rid of nested template specialization in TensorReductionGpu.h, which was broken by c6953f799b01d36f4236b64f351cc1446e0abe17.
Diffstat (limited to 'unsupported/Eigen/CXX11/src')
-rw-r--r-- | unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h index 02a514c0f..db4e8d866 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h @@ -80,8 +80,8 @@ __device__ inline double atomicExchCustom(double* address, double val) { } #ifdef EIGEN_HAS_GPU_FP16 -template <template <typename T> class R> -__device__ inline void atomicReduce(half2* output, half2 accum, R<half>& reducer) { +template <typename R> +__device__ inline void atomicReduce(half2* output, half2 accum, R& reducer) { unsigned int oldval = *reinterpret_cast<unsigned int*>(output); unsigned int newval = oldval; reducer.reducePacket(accum, reinterpret_cast<half2*>(&newval)); @@ -99,9 +99,8 @@ __device__ inline void atomicReduce(half2* output, half2 accum, R<half>& reducer } } // reduction should be associative since reduction is not atomic in wide vector but atomic in half2 operations -template <template <typename T> class R> -__device__ inline void atomicReduce(Packet4h2* output, Packet4h2 accum, - R<half>& reducer) { +template <typename R> +__device__ inline void atomicReduce(Packet4h2* output, Packet4h2 accum, R& reducer) { half2* houtput=reinterpret_cast<half2*>(output); half2* haccum=reinterpret_cast<half2*>(&accum); for(int i=0;i<4;++i){ |