aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_functor_gpu.cu.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-23 16:00:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 04:21:07 -0700
commit084c10784887d7c4d467416430626cf7eb333cb8 (patch)
tree69727c1d14ddbb97fc74b24f69ce2381125ee6c9 /tensorflow/core/kernels/scatter_functor_gpu.cu.h
parent95a87277174f9fc49b4b5d9c1edbbd149bd0274c (diff)
Extended scatter operations to work with a scalar update parameter and added scatter-min and scatter-max operations.
PiperOrigin-RevId: 190289664
Diffstat (limited to 'tensorflow/core/kernels/scatter_functor_gpu.cu.h')
-rw-r--r--tensorflow/core/kernels/scatter_functor_gpu.cu.h108
1 files changed, 86 insertions, 22 deletions
diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/tensorflow/core/kernels/scatter_functor_gpu.cu.h
index be18658543..70809e4dcf 100644
--- a/tensorflow/core/kernels/scatter_functor_gpu.cu.h
+++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.h
@@ -29,12 +29,53 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
+namespace scatter_op_gpu {
+
+template <typename T, scatter_op::UpdateOp op>
+struct ScatterOpKernelBody;
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ASSIGN> {
+ __device__ void operator()(T* dest, T src) const { *dest = src; }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::ADD> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicAdd(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::SUB> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicSub(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MUL> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicMul(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::DIV> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicDiv(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MIN> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicMin(dest, src); }
+};
+
+template <typename T>
+struct ScatterOpKernelBody<T, scatter_op::UpdateOp::MAX> {
+ __device__ void operator()(T* dest, T src) const { CudaAtomicMax(dest, src); }
+};
+
template <typename T, typename Index, scatter_op::UpdateOp op>
__global__ void ScatterOpCustomKernel(T* params, const T* updates,
const Index* indices,
Index first_dim_size, Index updates_size,
Index indices_size) {
Index update_block = updates_size / indices_size;
+ ScatterOpKernelBody<T, op> body;
CUDA_1D_KERNEL_LOOP(i, updates_size) {
int indices_i = i / update_block;
int updates_i = i;
@@ -44,31 +85,33 @@ __global__ void ScatterOpCustomKernel(T* params, const T* updates,
continue;
}
int params_i = param_first_index * update_block + (i % update_block);
- switch (op) {
- case scatter_op::UpdateOp::ASSIGN: {
- params[params_i] = ldg(updates + updates_i);
- break;
- }
- case scatter_op::UpdateOp::ADD: {
- CudaAtomicAdd(params + params_i, ldg(updates + updates_i));
- break;
- }
- case scatter_op::UpdateOp::SUB: {
- CudaAtomicSub(params + params_i, ldg(updates + updates_i));
- break;
- }
- case scatter_op::UpdateOp::MUL: {
- CudaAtomicMul(params + params_i, ldg(updates + updates_i));
- break;
- }
- case scatter_op::UpdateOp::DIV: {
- CudaAtomicDiv(params + params_i, ldg(updates + updates_i));
- break;
- }
+ body(&params[params_i], ldg(updates + updates_i));
+ }
+}
+
+template <typename T, typename Index, scatter_op::UpdateOp op>
+__global__ void ScatterScalarOpCustomKernel(T* params, const T* update,
+ const Index* indices,
+ Index first_dim_size,
+ Index indices_size,
+ Index synthesized_updates_size) {
+ Index update_block = synthesized_updates_size / indices_size;
+ ScatterOpKernelBody<T, op> body;
+ CUDA_1D_KERNEL_LOOP(i, synthesized_updates_size) {
+ int indices_i = i / update_block;
+ int param_first_index = indices[indices_i];
+ const T update_val = *update;
+ if (!(param_first_index >= 0 && param_first_index < first_dim_size)) {
+ // Ignore indices that are out of range.
+ continue;
}
+ int params_i = param_first_index * update_block + (i % update_block);
+ body(&params[params_i], update_val);
}
}
+} // namespace scatter_op_gpu
+
namespace functor {
// Specialization for a GPU device.
template <typename T, typename Index, scatter_op::UpdateOp op>
@@ -84,7 +127,7 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
const Index indices_size = indices.size();
const Index updates_size = updates.size();
CudaLaunchConfig config = GetCudaLaunchConfig(updates_size, d);
- ScatterOpCustomKernel<T, Index, op>
+ scatter_op_gpu::ScatterOpCustomKernel<T, Index, op>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
params.data(), updates.data(), indices.data(), first_dim_size,
updates_size, indices_size);
@@ -92,6 +135,27 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
}
};
+template <typename T, typename Index, scatter_op::UpdateOp op>
+struct ScatterScalarFunctor<GPUDevice, T, Index, op> {
+ Index operator()(OpKernelContext* c, const GPUDevice& d,
+ typename TTypes<T>::Matrix params,
+ const typename TTypes<T>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices) {
+ // TODO(b/31801742): Implement indices range check. The hardest part is
+ // with returning a value after the range check, as we do not want to do
+ // device to host memcpy during a stream.
+ const Index first_dim_size = params.dimension(0);
+ const Index indices_size = indices.size();
+ const Index synthesized_updates_size = indices_size * params.dimension(1);
+ CudaLaunchConfig config = GetCudaLaunchConfig(synthesized_updates_size, d);
+ scatter_op_gpu::ScatterScalarOpCustomKernel<T, Index, op>
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
+ params.data(), update.data(), indices.data(), first_dim_size,
+ indices_size, synthesized_updates_size);
+ return -1;
+ }
+};
+
} // namespace functor
} // namespace tensorflow