aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-10-13 23:53:11 +0000
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2020-10-13 23:53:11 +0000
commit61fc78bbda22f425a97761be57c219928d929ddc (patch)
treea577d1021784676fa5188faa1c987fe4061fa122 /unsupported
parentc6953f799b01d36f4236b64f351cc1446e0abe17 (diff)
Get rid of nested template specialization in TensorReductionGpu.h, which was broken by c6953f799b01d36f4236b64f351cc1446e0abe17.
Diffstat (limited to 'unsupported')
-rw-r--r--unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h9
1 files changed, 4 insertions, 5 deletions
diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h b/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h
index 02a514c0f..db4e8d866 100644
--- a/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h
+++ b/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h
@@ -80,8 +80,8 @@ __device__ inline double atomicExchCustom(double* address, double val) {
}
#ifdef EIGEN_HAS_GPU_FP16
-template <template <typename T> class R>
-__device__ inline void atomicReduce(half2* output, half2 accum, R<half>& reducer) {
+template <typename R>
+__device__ inline void atomicReduce(half2* output, half2 accum, R& reducer) {
unsigned int oldval = *reinterpret_cast<unsigned int*>(output);
unsigned int newval = oldval;
reducer.reducePacket(accum, reinterpret_cast<half2*>(&newval));
@@ -99,9 +99,8 @@ __device__ inline void atomicReduce(half2* output, half2 accum, R<half>& reducer
}
}
// reduction should be associative since reduction is not atomic in wide vector but atomic in half2 operations
-template <template <typename T> class R>
-__device__ inline void atomicReduce(Packet4h2* output, Packet4h2 accum,
- R<half>& reducer) {
+template <typename R>
+__device__ inline void atomicReduce(Packet4h2* output, Packet4h2 accum, R& reducer) {
half2* houtput=reinterpret_cast<half2*>(output);
half2* haccum=reinterpret_cast<half2*>(&accum);
for(int i=0;i<4;++i){