diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-23 16:00:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-25 04:21:07 -0700 |
commit | 084c10784887d7c4d467416430626cf7eb333cb8 (patch) | |
tree | 69727c1d14ddbb97fc74b24f69ce2381125ee6c9 /tensorflow/core/kernels/scatter_functor_gpu.cu.h | |
parent | 95a87277174f9fc49b4b5d9c1edbbd149bd0274c (diff) |
Extended scatter operations to work with a scalar update parameter and added scatter-min and scatter-max operations.
PiperOrigin-RevId: 190289664
Diffstat (limited to 'tensorflow/core/kernels/scatter_functor_gpu.cu.h')
-rw-r--r-- | tensorflow/core/kernels/scatter_functor_gpu.cu.h | 108 |
1 files changed, 86 insertions, 22 deletions
diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/tensorflow/core/kernels/scatter_functor_gpu.cu.h index be18658543..70809e4dcf 100644 --- a/tensorflow/core/kernels/scatter_functor_gpu.cu.h +++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.h @@ -29,12 +29,53 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; +namespace scatter_op_gpu { + +template <typename T, scatter_op::UpdateOp op> +struct ScatterOpKernelBody; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ASSIGN> { + __device__ void operator()(T* dest, T src) const { *dest = src; } +}; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ADD> { + __device__ void operator()(T* dest, T src) const { CudaAtomicAdd(dest, src); } +}; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::SUB> { + __device__ void operator()(T* dest, T src) const { CudaAtomicSub(dest, src); } +}; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MUL> { + __device__ void operator()(T* dest, T src) const { CudaAtomicMul(dest, src); } +}; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::DIV> { + __device__ void operator()(T* dest, T src) const { CudaAtomicDiv(dest, src); } +}; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MIN> { + __device__ void operator()(T* dest, T src) const { CudaAtomicMin(dest, src); } +}; + +template <typename T> +struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MAX> { + __device__ void operator()(T* dest, T src) const { CudaAtomicMax(dest, src); } +}; + template <typename T, typename Index, scatter_op::UpdateOp op> __global__ void ScatterOpCustomKernel(T* params, const T* updates, const Index* indices, Index first_dim_size, Index updates_size, Index indices_size) { Index update_block = updates_size / indices_size; + ScatterOpKernelBody<T, op> body; CUDA_1D_KERNEL_LOOP(i, updates_size) { int indices_i = i / update_block; int updates_i = i; @@ -44,31 +85,33 @@ __global__ void ScatterOpCustomKernel(T* params, const T* updates, continue; } int params_i = param_first_index * update_block + (i % update_block); - switch (op) { - case scatter_op::UpdateOp::ASSIGN: { - params[params_i] = ldg(updates + updates_i); - break; - } - case scatter_op::UpdateOp::ADD: { - CudaAtomicAdd(params + params_i, ldg(updates + updates_i)); - break; - } - case scatter_op::UpdateOp::SUB: { - CudaAtomicSub(params + params_i, ldg(updates + updates_i)); - break; - } - case scatter_op::UpdateOp::MUL: { - CudaAtomicMul(params + params_i, ldg(updates + updates_i)); - break; - } - case scatter_op::UpdateOp::DIV: { - CudaAtomicDiv(params + params_i, ldg(updates + updates_i)); - break; - } + body(¶ms[params_i], ldg(updates + updates_i)); + } +} + +template <typename T, typename Index, scatter_op::UpdateOp op> +__global__ void ScatterScalarOpCustomKernel(T* params, const T* update, + const Index* indices, + Index first_dim_size, + Index indices_size, + Index synthesized_updates_size) { + Index update_block = synthesized_updates_size / indices_size; + ScatterOpKernelBody<T, op> body; + CUDA_1D_KERNEL_LOOP(i, synthesized_updates_size) { + int indices_i = i / update_block; + int param_first_index = indices[indices_i]; + const T update_val = *update; + if (!(param_first_index >= 0 && param_first_index < first_dim_size)) { + // Ignore indices that are out of range. + continue; } + int params_i = param_first_index * update_block + (i % update_block); + body(¶ms[params_i], update_val); } } +} // namespace scatter_op_gpu + namespace functor { // Specialization for a GPU device. template <typename T, typename Index, scatter_op::UpdateOp op> @@ -84,7 +127,7 @@ struct ScatterFunctor<GPUDevice, T, Index, op> { const Index indices_size = indices.size(); const Index updates_size = updates.size(); CudaLaunchConfig config = GetCudaLaunchConfig(updates_size, d); - ScatterOpCustomKernel<T, Index, op> + scatter_op_gpu::ScatterOpCustomKernel<T, Index, op> <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( params.data(), updates.data(), indices.data(), first_dim_size, updates_size, indices_size); @@ -92,6 +135,27 @@ struct ScatterFunctor<GPUDevice, T, Index, op> { } }; +template <typename T, typename Index, scatter_op::UpdateOp op> +struct ScatterScalarFunctor<GPUDevice, T, Index, op> { + Index operator()(OpKernelContext* c, const GPUDevice& d, + typename TTypes<T>::Matrix params, + const typename TTypes<T>::ConstScalar update, + typename TTypes<Index>::ConstFlat indices) { + // TODO(b/31801742): Implement indices range check. The hardest part is + // with returning a value after the range check, as we do not want to do + // device to host memcpy during a stream. + const Index first_dim_size = params.dimension(0); + const Index indices_size = indices.size(); + const Index synthesized_updates_size = indices_size * params.dimension(1); + CudaLaunchConfig config = GetCudaLaunchConfig(synthesized_updates_size, d); + scatter_op_gpu::ScatterScalarOpCustomKernel<T, Index, op> + <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + params.data(), update.data(), indices.data(), first_dim_size, + indices_size, synthesized_updates_size); + return -1; + } +}; + } // namespace functor } // namespace tensorflow |