aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/scatter_op.cc')
-rw-r--r--tensorflow/core/kernels/scatter_op.cc38
1 files changed, 29 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index 30e81737f8..1516455cc6 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -55,6 +55,20 @@ struct Assign<scatter_op::UpdateOp::SUB> {
p -= u;
}
};
+template <>
+struct Assign<scatter_op::UpdateOp::MUL> {
+ template <typename Params, typename Update>
+ static void Run(Params p, Update u) {
+ p *= u;
+ }
+};
+template <>
+struct Assign<scatter_op::UpdateOp::DIV> {
+ template <typename Params, typename Update>
+ static void Run(Params p, Update u) {
+ p /= u;
+ }
+};
} // namespace
@@ -195,8 +209,10 @@ struct ScatterFunctor<CPUDevice, T, Index, op> {
REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
-#define REGISTER_SCATTER_ADD_SUB(type, dev) \
+#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
#define REGISTER_SCATTER_UPDATE(type, dev) \
@@ -204,28 +220,30 @@ struct ScatterFunctor<CPUDevice, T, Index, op> {
scatter_op::UpdateOp::ASSIGN);
// Registers CPU kernels.
-#define REGISTER_SCATTER_ADD_SUB_CPU(type) REGISTER_SCATTER_ADD_SUB(type, CPU);
+#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
+ REGISTER_SCATTER_ARITHEMTIC(type, CPU);
#define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_SUB_CPU);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
// Registers GPU kernels.
#if GOOGLE_CUDA
-#define REGISTER_SCATTER_ADD_SUB_GPU(type) REGISTER_SCATTER_ADD_SUB(type, GPU);
+#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
+ REGISTER_SCATTER_ARITHEMTIC(type, GPU);
#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ADD_SUB_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
#endif // GOOGLE_CUDA
#undef REGISTER_SCATTER_ADD
-#undef REGISTER_SCATTER_ADD_SUB
-#undef REGISTER_SCATTER_ADD_SUB_CPU
-#undef REGISTER_SCATTER_ADD_SUB_GPU
+#undef REGISTER_SCATTER_ARITHEMTIC
+#undef REGISTER_SCATTER_ARITHEMTIC_CPU
+#undef REGISTER_SCATTER_ARITHEMTIC_GPU
#undef REGISTER_SCATTER_UPDATE
#undef REGISTER_SCATTER_UPDATE_CPU
#undef REGISTER_SCATTER_UPDATE_GPU
@@ -248,7 +266,9 @@ namespace functor {
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
- DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB);
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
#define DECLARE_GPU_SPECS(T) \
DECLARE_GPU_SPECS_INDEX(T, int32); \