aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-01-19 11:15:10 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2016-01-20 07:48:36 -0800
commit2712ed6de036e16b2599fcab2071acd7bbf8b17a (patch)
tree07d978b8249b431675e2cb93c634152c57b7a27d /tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
parentee274fdd0af515a09a3a45a1d31065b3edd06087 (diff)
Implement TensorArray forward ops.
Allows dynamic writing to- and reading from- an array of Tensors (size of the array determined at run time). This is useful for, e.g., While loops. Each while iteration can write to the Array; and the final handle can be used with Concat to get all the outputs in one Tensor. No gradient support yet, this will be implemented in a future CL. Change: 112493043
Diffstat (limited to 'tensorflow/core/kernels/aggregate_ops_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/aggregate_ops_gpu.cu.cc26
1 files changed, 16 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
index 06f132aed5..47c24ab27b 100644
--- a/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/kernels/aggregate_ops.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/port.h"
@@ -140,16 +141,21 @@ struct Add9Functor<GPUDevice, T> {
} // end namespace functor
-// Instantiate the GPU implementation for float.
-template struct functor::Add2Functor<GPUDevice, float>;
-template struct functor::Add3Functor<GPUDevice, float>;
-template struct functor::Add4Functor<GPUDevice, float>;
-template struct functor::Add5Functor<GPUDevice, float>;
-template struct functor::Add6Functor<GPUDevice, float>;
-template struct functor::Add7Functor<GPUDevice, float>;
-template struct functor::Add8Functor<GPUDevice, float>;
-template struct functor::Add8pFunctor<GPUDevice, float>;
-template struct functor::Add9Functor<GPUDevice, float>;
+// Instantiate the GPU implementation for GPU number types.
+#define REGISTER_FUNCTORS(type) \
+ template struct functor::Add2Functor<GPUDevice, type>; \
+ template struct functor::Add3Functor<GPUDevice, type>; \
+ template struct functor::Add4Functor<GPUDevice, type>; \
+ template struct functor::Add5Functor<GPUDevice, type>; \
+ template struct functor::Add6Functor<GPUDevice, type>; \
+ template struct functor::Add7Functor<GPUDevice, type>; \
+ template struct functor::Add8Functor<GPUDevice, type>; \
+ template struct functor::Add8pFunctor<GPUDevice, type>; \
+ template struct functor::Add9Functor<GPUDevice, type>;
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_FUNCTORS);
+
+#undef REGISTER_FUNCTORS
} // end namespace tensorflow