diff options
Diffstat (limited to 'tensorflow/core/kernels/scatter_op.cc')
-rw-r--r-- | tensorflow/core/kernels/scatter_op.cc | 38 |
1 files changed, 29 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 30e81737f8..1516455cc6 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -55,6 +55,20 @@ struct Assign<scatter_op::UpdateOp::SUB> { p -= u; } }; +template <> +struct Assign<scatter_op::UpdateOp::MUL> { + template <typename Params, typename Update> + static void Run(Params p, Update u) { + p *= u; + } +}; +template <> +struct Assign<scatter_op::UpdateOp::DIV> { + template <typename Params, typename Update> + static void Run(Params p, Update u) { + p /= u; + } +}; } // namespace @@ -195,8 +209,10 @@ struct ScatterFunctor<CPUDevice, T, Index, op> { REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op); -#define REGISTER_SCATTER_ADD_SUB(type, dev) \ +#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \ + REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \ + REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB); #define REGISTER_SCATTER_UPDATE(type, dev) \ @@ -204,28 +220,30 @@ struct ScatterFunctor<CPUDevice, T, Index, op> { scatter_op::UpdateOp::ASSIGN); // Registers CPU kernels. -#define REGISTER_SCATTER_ADD_SUB_CPU(type) REGISTER_SCATTER_ADD_SUB(type, CPU); +#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \ + REGISTER_SCATTER_ARITHEMTIC(type, CPU); #define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_SUB_CPU); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU); // Registers GPU kernels. #if GOOGLE_CUDA -#define REGISTER_SCATTER_ADD_SUB_GPU(type) REGISTER_SCATTER_ADD_SUB(type, GPU); +#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ + REGISTER_SCATTER_ARITHEMTIC(type, GPU); #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ADD_SUB_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU); #endif // GOOGLE_CUDA #undef REGISTER_SCATTER_ADD -#undef REGISTER_SCATTER_ADD_SUB -#undef REGISTER_SCATTER_ADD_SUB_CPU -#undef REGISTER_SCATTER_ADD_SUB_GPU +#undef REGISTER_SCATTER_ARITHEMTIC +#undef REGISTER_SCATTER_ARITHEMTIC_CPU +#undef REGISTER_SCATTER_ARITHEMTIC_GPU #undef REGISTER_SCATTER_UPDATE #undef REGISTER_SCATTER_UPDATE_CPU #undef REGISTER_SCATTER_UPDATE_GPU @@ -248,7 +266,9 @@ namespace functor { #define DECLARE_GPU_SPECS_INDEX(T, Index) \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ - DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); #define DECLARE_GPU_SPECS(T) \ DECLARE_GPU_SPECS_INDEX(T, int32); \ |