diff options
author | 2016-01-19 11:15:10 -0800 | |
---|---|---|
committer | 2016-01-20 07:48:36 -0800 | |
commit | 2712ed6de036e16b2599fcab2071acd7bbf8b17a (patch) | |
tree | 07d978b8249b431675e2cb93c634152c57b7a27d /tensorflow/core/kernels/aggregate_ops_gpu.cu.cc | |
parent | ee274fdd0af515a09a3a45a1d31065b3edd06087 (diff) |
Implement TensorArray forward ops.
Allows dynamic writing to- and reading from- an array of Tensors (size of the array determined at run time).
This is useful for, e.g., While loops. Each while iteration can write to the Array; and the final handle can be used with Concat to get all the outputs in one Tensor.
No gradient support yet, this will be implemented in a future CL.
Change: 112493043
Diffstat (limited to 'tensorflow/core/kernels/aggregate_ops_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/aggregate_ops_gpu.cu.cc | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc index 06f132aed5..47c24ab27b 100644 --- a/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/aggregate_ops_gpu.cu.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/kernels/aggregate_ops.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/port.h" @@ -140,16 +141,21 @@ struct Add9Functor<GPUDevice, T> { } // end namespace functor -// Instantiate the GPU implementation for float. -template struct functor::Add2Functor<GPUDevice, float>; -template struct functor::Add3Functor<GPUDevice, float>; -template struct functor::Add4Functor<GPUDevice, float>; -template struct functor::Add5Functor<GPUDevice, float>; -template struct functor::Add6Functor<GPUDevice, float>; -template struct functor::Add7Functor<GPUDevice, float>; -template struct functor::Add8Functor<GPUDevice, float>; -template struct functor::Add8pFunctor<GPUDevice, float>; -template struct functor::Add9Functor<GPUDevice, float>; +// Instantiate the GPU implementation for GPU number types. +#define REGISTER_FUNCTORS(type) \ + template struct functor::Add2Functor<GPUDevice, type>; \ + template struct functor::Add3Functor<GPUDevice, type>; \ + template struct functor::Add4Functor<GPUDevice, type>; \ + template struct functor::Add5Functor<GPUDevice, type>; \ + template struct functor::Add6Functor<GPUDevice, type>; \ + template struct functor::Add7Functor<GPUDevice, type>; \ + template struct functor::Add8Functor<GPUDevice, type>; \ + template struct functor::Add8pFunctor<GPUDevice, type>; \ + template struct functor::Add9Functor<GPUDevice, type>; + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_FUNCTORS); + +#undef REGISTER_FUNCTORS } // end namespace tensorflow |