aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/aggregate_ops_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/aggregate_ops_gpu.cu.cc26
1 files changed, 16 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
index 06f132aed5..47c24ab27b 100644
--- a/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/kernels/aggregate_ops.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/port.h"
@@ -140,16 +141,21 @@ struct Add9Functor<GPUDevice, T> {
} // end namespace functor
-// Instantiate the GPU implementation for float.
-template struct functor::Add2Functor<GPUDevice, float>;
-template struct functor::Add3Functor<GPUDevice, float>;
-template struct functor::Add4Functor<GPUDevice, float>;
-template struct functor::Add5Functor<GPUDevice, float>;
-template struct functor::Add6Functor<GPUDevice, float>;
-template struct functor::Add7Functor<GPUDevice, float>;
-template struct functor::Add8Functor<GPUDevice, float>;
-template struct functor::Add8pFunctor<GPUDevice, float>;
-template struct functor::Add9Functor<GPUDevice, float>;
+// Instantiate the GPU implementation for GPU number types.
+#define REGISTER_FUNCTORS(type) \
+ template struct functor::Add2Functor<GPUDevice, type>; \
+ template struct functor::Add3Functor<GPUDevice, type>; \
+ template struct functor::Add4Functor<GPUDevice, type>; \
+ template struct functor::Add5Functor<GPUDevice, type>; \
+ template struct functor::Add6Functor<GPUDevice, type>; \
+ template struct functor::Add7Functor<GPUDevice, type>; \
+ template struct functor::Add8Functor<GPUDevice, type>; \
+ template struct functor::Add8pFunctor<GPUDevice, type>; \
+ template struct functor::Add9Functor<GPUDevice, type>;
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_FUNCTORS);
+
+#undef REGISTER_FUNCTORS
} // end namespace tensorflow