aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/aggregate_ops_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/aggregate_ops_gpu.cu.cc141
1 files changed, 141 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
new file mode 100644
index 0000000000..5cf2934ac1
--- /dev/null
+++ b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
@@ -0,0 +1,141 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/aggregate_ops.h"
+
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+// Partial specialization for a GPUDevice, that uses the Eigen implementation.
+namespace functor {
+template <typename T>
+struct Add2Functor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2) {
+ Add2EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2);
+ }
+};
+
+template <typename T>
+struct Add3Functor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3) {
+ Add3EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3);
+ }
+};
+
+template <typename T>
+struct Add4Functor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4) {
+ Add4EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4);
+ }
+};
+
+template <typename T>
+struct Add5Functor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5) {
+ Add5EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5);
+ }
+};
+
+template <typename T>
+struct Add6Functor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6) {
+ Add6EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6);
+ }
+};
+
+template <typename T>
+struct Add7Functor<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1,
+ typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3,
+ typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5,
+ typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7) {
+ Add7EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7);
+ }
+};
+
+template <typename T>
+struct Add8Functor<GPUDevice, T> {
+ void operator()(
+ const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
+ Add8EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8);
+ }
+};
+
+template <typename T>
+struct Add8pFunctor<GPUDevice, T> {
+ void operator()(
+ const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8) {
+ Add8pEigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8);
+ }
+};
+
+template <typename T>
+struct Add9Functor<GPUDevice, T> {
+ void operator()(
+ const GPUDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstFlat in1, typename TTypes<T>::ConstFlat in2,
+ typename TTypes<T>::ConstFlat in3, typename TTypes<T>::ConstFlat in4,
+ typename TTypes<T>::ConstFlat in5, typename TTypes<T>::ConstFlat in6,
+ typename TTypes<T>::ConstFlat in7, typename TTypes<T>::ConstFlat in8,
+ typename TTypes<T>::ConstFlat in9) {
+ Add9EigenImpl<GPUDevice, T>::Compute(d, out, in1, in2, in3, in4, in5, in6,
+ in7, in8, in9);
+ }
+};
+
+} // end namespace functor
+
+// Instantiate the GPU implementation for float.
+template struct functor::Add2Functor<GPUDevice, float>;
+template struct functor::Add3Functor<GPUDevice, float>;
+template struct functor::Add4Functor<GPUDevice, float>;
+template struct functor::Add5Functor<GPUDevice, float>;
+template struct functor::Add6Functor<GPUDevice, float>;
+template struct functor::Add7Functor<GPUDevice, float>;
+template struct functor::Add8Functor<GPUDevice, float>;
+template struct functor::Add8pFunctor<GPUDevice, float>;
+template struct functor::Add9Functor<GPUDevice, float>;
+
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA