diff options
Diffstat (limited to 'tensorflow/core/kernels/aggregate_ops_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/aggregate_ops_gpu.cu.cc | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc index 06f132aed5..47c24ab27b 100644 --- a/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/kernels/aggregate_ops.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/port.h" @@ -140,16 +141,21 @@ struct Add9Functor<GPUDevice, T> { } // end namespace functor -// Instantiate the GPU implementation for float. -template struct functor::Add2Functor<GPUDevice, float>; -template struct functor::Add3Functor<GPUDevice, float>; -template struct functor::Add4Functor<GPUDevice, float>; -template struct functor::Add5Functor<GPUDevice, float>; -template struct functor::Add6Functor<GPUDevice, float>; -template struct functor::Add7Functor<GPUDevice, float>; -template struct functor::Add8Functor<GPUDevice, float>; -template struct functor::Add8pFunctor<GPUDevice, float>; -template struct functor::Add9Functor<GPUDevice, float>; +// Instantiate the GPU implementation for GPU number types. +#define REGISTER_FUNCTORS(type) \ + template struct functor::Add2Functor<GPUDevice, type>; \ + template struct functor::Add3Functor<GPUDevice, type>; \ + template struct functor::Add4Functor<GPUDevice, type>; \ + template struct functor::Add5Functor<GPUDevice, type>; \ + template struct functor::Add6Functor<GPUDevice, type>; \ + template struct functor::Add7Functor<GPUDevice, type>; \ + template struct functor::Add8Functor<GPUDevice, type>; \ + template struct functor::Add8pFunctor<GPUDevice, type>; \ + template struct functor::Add9Functor<GPUDevice, type>; + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_FUNCTORS); + +#undef REGISTER_FUNCTORS } // end namespace tensorflow |