aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_op_gpu.cu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/scatter_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/scatter_op_gpu.cu.cc18
1 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/scatter_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_op_gpu.cu.cc
index 213c62402a..e51579f032 100644
--- a/tensorflow/core/kernels/scatter_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/scatter_op_gpu.cu.cc
@@ -54,6 +54,14 @@ __global__ void ScatterOpCustomKernel(
CudaAtomicSub(params + params_i, ldg(updates + updates_i));
break;
}
+ case scatter_op::UpdateOp::MUL: {
+ CudaAtomicMul(params + params_i, ldg(updates + updates_i));
+ break;
+ }
+ case scatter_op::UpdateOp::DIV: {
+ CudaAtomicDiv(params + params_i, ldg(updates + updates_i));
+ break;
+ }
}
}
}
@@ -86,10 +94,12 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
#define DEFINE_GPU_SPECS_OP(T, Index, op) \
template struct functor::ScatterFunctor<GPUDevice, T, Index, op>;
-#define DEFINE_GPU_SPECS_INDEX(T, Index) \
- DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
- DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
- DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB);
+#define DEFINE_GPU_SPECS_INDEX(T, Index) \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
#define DEFINE_GPU_SPECS(T) \
DEFINE_GPU_SPECS_INDEX(T, int32); \