diff options
Diffstat (limited to 'tensorflow/core/kernels/scatter_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/scatter_op_gpu.cu.cc | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/scatter_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_op_gpu.cu.cc index 213c62402a..e51579f032 100644 --- a/tensorflow/core/kernels/scatter_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_op_gpu.cu.cc @@ -54,6 +54,14 @@ __global__ void ScatterOpCustomKernel( CudaAtomicSub(params + params_i, ldg(updates + updates_i)); break; } + case scatter_op::UpdateOp::MUL: { + CudaAtomicMul(params + params_i, ldg(updates + updates_i)); + break; + } + case scatter_op::UpdateOp::DIV: { + CudaAtomicDiv(params + params_i, ldg(updates + updates_i)); + break; + } } } } @@ -86,10 +94,12 @@ struct ScatterFunctor<GPUDevice, T, Index, op> { #define DEFINE_GPU_SPECS_OP(T, Index, op) \ template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; -#define DEFINE_GPU_SPECS_INDEX(T, Index) \ - DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ - DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ - DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); +#define DEFINE_GPU_SPECS_INDEX(T, Index) \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); #define DEFINE_GPU_SPECS(T) \ DEFINE_GPU_SPECS_INDEX(T, int32); \ |