aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-23 16:00:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 04:21:07 -0700
commit084c10784887d7c4d467416430626cf7eb333cb8 (patch)
tree69727c1d14ddbb97fc74b24f69ce2381125ee6c9 /tensorflow/core/kernels/scatter_functor_gpu.cu.cc
parent95a87277174f9fc49b4b5d9c1edbbd149bd0274c (diff)
Extended scatter operations to work with a scalar update parameter and added scatter-min and scatter-max operations.
PiperOrigin-RevId: 190289664
Diffstat (limited to 'tensorflow/core/kernels/scatter_functor_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/scatter_functor_gpu.cu.cc9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
index 52972997cc..59911bf0d2 100644
--- a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc
@@ -23,15 +23,18 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
-#define DEFINE_GPU_SPECS_OP(T, Index, op) \
- template struct functor::ScatterFunctor<GPUDevice, T, Index, op>;
+#define DEFINE_GPU_SPECS_OP(T, Index, op) \
+ template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; \
+ template struct functor::ScatterScalarFunctor<GPUDevice, T, Index, op>;
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
- DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX);
#define DEFINE_GPU_SPECS(T) \
DEFINE_GPU_SPECS_INDEX(T, int32); \