diff options
author | 2018-03-23 16:00:14 -0700 | |
---|---|---|
committer | 2018-03-25 04:21:07 -0700 | |
commit | 084c10784887d7c4d467416430626cf7eb333cb8 (patch) | |
tree | 69727c1d14ddbb97fc74b24f69ce2381125ee6c9 /tensorflow/core/kernels/scatter_functor_gpu.cu.cc | |
parent | 95a87277174f9fc49b4b5d9c1edbbd149bd0274c (diff) |
Extended scatter operations to work with a scalar update parameter and added scatter-min and scatter-max operations.
PiperOrigin-RevId: 190289664
Diffstat (limited to 'tensorflow/core/kernels/scatter_functor_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/scatter_functor_gpu.cu.cc | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc index 52972997cc..59911bf0d2 100644 --- a/tensorflow/core/kernels/scatter_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.cc @@ -23,15 +23,18 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -#define DEFINE_GPU_SPECS_OP(T, Index, op) \ - template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; +#define DEFINE_GPU_SPECS_OP(T, Index, op) \ + template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; \ + template struct functor::ScatterScalarFunctor<GPUDevice, T, Index, op>; #define DEFINE_GPU_SPECS_INDEX(T, Index) \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \ - DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MIN); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MAX); #define DEFINE_GPU_SPECS(T) \ DEFINE_GPU_SPECS_INDEX(T, int32); \ |